Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2498 - (view) (download)

1 : cchiw 2498 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :     (*Expand ProebConv to Probe of individual field *)
9 :     structure Expand = struct
10 :    
11 :     local
12 :    
13 :     structure E = Ein
14 :    
15 :     in
16 :    
17 :    
18 :     (*
19 :     Dictionary created for Tensor positions
20 :     Tensor X=> fractional and integer tensors in mid-il operators
21 :     Save in variable names.
22 :    
23 :     *)
24 :    
25 :    
26 :     fun insert (key, value) d =fn s =>
27 :     if s = key then SOME value
28 :     else d s
29 :    
30 :     fun lookup k d = d k
31 :    
32 :     (*
33 :     createDels=> creates the kronecker deltas for each Kernel
34 :     For each dimesnion a, and each index in derivative b create element (a,b)
35 :     *)
36 :     fun createDels([],_)= []
37 :     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
38 :    
39 :    
40 :    
41 :    
42 :     fun Position(idt,dict,params,args)=let
43 :     val l=lookup idt dict
44 :     in (case l
45 :     of NONE =>let
46 :     val pos1=length params
47 :     val pos2=pos1+1
48 :     val dict'=insert(idt,(pos1,pos2)) dict
49 :     val params'=params@[E.TEN,E.TEN]
50 :     in (pos1,pos2, dict',params',args)
51 :     end
52 :     (*Create new fractional, and n variables,and returns fresh ids*)
53 :     | SOME (fid,nid)=>(fid,nid, dict,params,args)
54 :     (*end case*))
55 :     end
56 :    
57 :    
58 :    
59 :     fun expandEinProbe(params,body,index,d,args)=(case body
60 :     of E.Probe(E.Conv(id,alpha,kid,deltas),E.Tensor(idt,alphat)) =>
61 :     if(id+1>length params) then (print "not enough params" ;(params,body,index,d,args))
62 :     else (case List.nth(params,id)
63 :     of E.FLD(dim)=>
64 :     let
65 :     val s=2 (*support*)
66 :     val (fid,nid,d',params',args')=Position(idt,d,params,args)
67 :     val shift=length index
68 :    
69 :    
70 :    
71 :     (*createIndex creates summation Index for Index*)
72 :     fun createIndex(0)= []
73 :     | createIndex(dim)=[E.SX(1-s,s)]@ createIndex(dim-1)
74 :    
75 :     (*sumIndex creating summaiton Index for body*)
76 :     fun sumIndex(0)=[]
77 :     |sumIndex(dim)= sumIndex(dim-1)@[E.V (dim+shift-1)]
78 :    
79 :     (*createKRN Image field and kernels *)
80 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(id,alpha,imgpos)] @rest)
81 :     | createKRN(dim,imgpos,rest)=
82 :     let
83 :     val dim'=dim-1
84 :     val sum=dim'+shift
85 :     val dels=createDels(deltas,dim')
86 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
87 :     val rest'= E.Krn(kid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
88 :     in
89 :     createKRN(dim',pos@imgpos,[rest']@rest)
90 :     end
91 :    
92 :    
93 :     val i=createKRN(dim, [],[])
94 :     val esum=sumIndex dim
95 :     val index'=index@ createIndex dim
96 :    
97 :     in (params', E.Sum(esum, i),index', d',args') end
98 :     | _=>(print "err: non field in param spot";(params,E.Const(0.0) ,index,d,args))
99 :     (*end case*))
100 :     |_=>(print "unexpected body" ;(params,body,index,d,args))
101 :     (*end case*))
102 :    
103 :    
104 :    
105 :     (********** trasform to world space***********)
106 :    
107 :     (* for gradients, etc. we have to transform back to world space *)
108 :     (*
109 :     val probeCode = if (*(k > 0)*)
110 :     length(deltas)>0
111 :     then let
112 :     (* for gradients, etc. we have to transform back to world space *)
113 :     val ty = DstV.ty result
114 :     val tensor = DstV.new("tensor", ty)
115 :    
116 :    
117 :    
118 :     val alpha= List.take(dim, length(dim)-1)
119 :     val ilist= [alpha, tl(tensor), hd(tensor)]
120 :     val TensorToWorldSpace=S.tranform(EinOp.innerProduct, ilist,[])
121 :     val xform = assignEin(result, TensorToWorldSpace, [img, tensor])
122 :    
123 :    
124 :     *)
125 :    
126 :    
127 :    
128 :     (*create set of positions
129 :     val dim = ImageInfo.dim v
130 :     val s = Kernel.support h
131 :     val vecsTy =createVec(2*s)
132 :     val vecDimTy = createVec(dim)
133 :     val translate=DstOp.Translate v
134 :     val transform=DstOp.Transform v
135 :    
136 :     (* generate the transform code *)
137 :     val x = DstV.new ("x", vecDimTy) (* image-space position *)
138 :     val f = DstV.new ("f", vecDimTy)
139 :     val nd = DstV.new ("nd", vecDimTy)
140 :     val n = DstV.new ("n", DstTy.iVecTy dim)
141 :     val M = DstV.new ("M", transform)
142 :     val T = DstV.new ("T", translate)
143 :    
144 :     val sub= S.transform(EinOp.subTensor,[dim],[])
145 :    
146 :    
147 :     (* M_ij x_i*)
148 :     val MXop=S.transform(EinOp.innerProduct,[[dim],[],[dim]],[])
149 :     val MX = DstV.new ("MX", MXop)
150 :    
151 :    
152 :    
153 :     val PosToImgSpace=S.transform(EinOp.addTensor,[[dim]],[])
154 :    
155 :    
156 :     val toImgSpaceCode = [
157 :     assignEin(MX, Mxop, [M,pos]), (*Just added*)
158 :     assignEin(x, PosToImgSpace,[MX,T])
159 :     assign(nd, DstOp.Floor dim, [x]),
160 :     assignEin(f, sub,[x,nd]),
161 :     assign(n, DstOp.RealToInt dim, [nd])
162 :     ]
163 :    
164 :     *)
165 :    
166 :     (*copied from high-to-mid.sml*)
167 :     fun expandEinOp ( Ein.EIN{params, index, body}, args) = let
168 :    
169 :     val dummy=E.Const 0.0 (*tmp variables*)
170 :    
171 :     fun rewriteBody exp= let
172 :     val (p,body,ix,d,args')= exp
173 :     in (case body
174 :     of E.Const _=>exp
175 :     | E.Tensor _=>exp
176 :     | E.Krn _=>exp
177 :     | E.Delta _=>exp
178 :     | E.Value _ =>exp
179 :     | E.Epsilon _=>exp
180 :     | E.Partial _=>exp
181 :     | E.Img _=> exp
182 :     | E.Conv _=>(p,dummy,ix,d,args')
183 :     | E.Field _ =>(p,dummy,ix,d,args')
184 :     | E.Apply _ =>(p,dummy,ix,d,args')
185 :     | E.Neg e=> let
186 :     val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')
187 :     in
188 :     (p',E.Neg body',ix',d',args'')
189 :     end
190 :     | E.Sum (c,e)=> let
191 :     val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')
192 :     in
193 :     (p',E.Sum(c,body'),ix',d',args'')
194 :     end
195 :     | E.Probe(E.Conv _, _) =>expandEinProbe exp
196 :     | E.Sub(a,b)=>let
197 :     val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')
198 :     val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')
199 :     in (pb,E.Sub( a', b'),bx,db,args''')
200 :     end
201 :     | E.Div(a,b)=>let
202 :     val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')
203 :     val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')
204 :     in (pb,E.Div( a', b'),bx,db,args''')
205 :     end
206 :     | E.Add es=> let
207 :     fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Add done, ix1,d1,args')
208 :     | addFilter(p1,ix1,d1, e::es,done,args')=let
209 :     val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')
210 :     in addFilter(p2, ix2, d2, es, done@[e2],args'')
211 :     end
212 :     in
213 :     addFilter(p,ix,d,es,[],args')
214 :     end
215 :     | E.Prod es=> let
216 :     fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Prod done, ix1,d1,args')
217 :     | addFilter(p1,ix1,d1, e::es,done,args')=let
218 :     val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')
219 :     in addFilter(p2, ix2, d2, es, done@[e2],args'')
220 :     end
221 :     in
222 :     addFilter(p,ix,d,es,[],args')
223 :     end
224 :     | E.Probe _=> exp
225 :    
226 :     (* end case *))
227 :     end
228 :    
229 :    
230 :     val empty =fn key =>NONE
231 :     val (params',body',ix',_,args')=rewriteBody(params,body,index,empty,args)
232 :     val newbie=Ein.EIN{params=params', index=ix', body=body'}
233 :     in (newbie,args') end
234 :    
235 :     end; (* local *)
236 :    
237 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0