Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2603 - (view) (download)

1 : cchiw 2498 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 : cchiw 2510 (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 : cchiw 2498 structure Expand = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 : cchiw 2510 structure mk= mkOperators
24 : cchiw 2498
25 : cchiw 2510
26 :     structure SrcIL = HighIL
27 :     structure SrcTy = HighILTypes
28 :     structure SrcOp = HighOps
29 :     structure SrcSV = SrcIL.StateVar
30 :     structure VTbl = SrcIL.Var.Tbl
31 :     structure DstIL = MidIL
32 :     structure DstTy = MidILTypes
33 :     structure DstOp = MidOps
34 : cchiw 2515 structure DstV = DstIL.Var
35 :     structure SrcV = SrcIL.Var
36 : cchiw 2510 structure P=Printer
37 : cchiw 2555 structure shiftHtM=shiftHtM
38 : cchiw 2584 structure split=splitHtM
39 : cchiw 2603 structure F=Filter
40 : cchiw 2502
41 : cchiw 2515 datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
42 : cchiw 2522 datatype peanut2= O2 of SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
43 : cchiw 2498 in
44 :    
45 :    
46 : cchiw 2502 fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
47 : cchiw 2584 fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
48 : cchiw 2510
49 :    
50 : cchiw 2498
51 : cchiw 2499
52 : cchiw 2510
53 : cchiw 2603
54 : cchiw 2522 fun getRHS2 x = (case SrcIL.Var.binding x
55 :     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
56 :     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS2 x'
57 :     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
58 :     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
59 :     | SrcIL.VB_NONE=>(S2 2,[])
60 :     | vb => raise Fail(concat[
61 :     "expected rhs operator for ", SrcIL.Var.toString x,
62 :     "but found ", SrcIL.vbToString vb])
63 :     (* end case *))
64 : cchiw 2510
65 :    
66 : cchiw 2515
67 : cchiw 2522
68 :    
69 : cchiw 2502 (*Create fractional, and integer position vectors*)
70 : cchiw 2585 fun transformToImgSpace (dim,v,posx)=let
71 : cchiw 2584
72 : cchiw 2522 val translate=DstOp.Translate v
73 :     val transform=DstOp.Transform v
74 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
75 : cchiw 2523 val T = DstV.new ("T", DstTy.tensorTy [dim,dim]) (*translate*)
76 : cchiw 2522 val x = DstV.new ("x", DstTy.vecTy dim)
77 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
78 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
79 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*interger position*)
80 : cchiw 2525
81 : cchiw 2522
82 :     val PosToImgSpace=mk.transform(dim,dim)
83 :     val code=[
84 :     assign(M, transform, []),
85 :     assign(T, translate, []),
86 :    
87 :     assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
88 :     assign(nd, DstOp.Floor dim, [x]), (*nd *)
89 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
90 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
91 :     ]
92 :    
93 : cchiw 2584 in ([n,f],code)
94 : cchiw 2522 end
95 :    
96 : cchiw 2585 fun replaceH(kvar, place,args)=let
97 :     val l1=List.take(args, place)
98 :     val l2=List.drop(args,place+1)
99 :     in l1@[kvar]@l2 end
100 : cchiw 2522
101 : cchiw 2525 (*Get Img, and Kern Args*)
102 : cchiw 2585 fun getArgs(hid,hArg,V,imgArg,args,lift)=
103 : cchiw 2525 case (getRHS2 hArg,getRHS2 imgArg)
104 :     of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
105 :     val hvar=DstV.new ("KNL", DstTy.KernelTy)
106 :     val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
107 : cchiw 2585 val args2= (case lift
108 :     of 0=> let
109 :     val args1=replaceH(hvar, hid,args)
110 :     in replaceH(imgvar, V,args1) end
111 :     | _ => [imgvar, hvar]
112 :     (* end case *))
113 :     in
114 : cchiw 2525 (Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],args2)
115 :     end
116 :     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
117 :     | _ => raise Fail "Not a kernel argument"
118 : cchiw 2522
119 :    
120 :    
121 : cchiw 2585 fun handleArgs(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),origargs,lift)=let
122 :     val E.IMG(dim)=List.nth(params,V)
123 :     val kArg=List.nth(origargs,h)
124 :     val imgArg=List.nth(origargs,V)
125 :     val newposArg=List.nth(args, t)
126 :     val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)
127 :     val (argsT,code')=transformToImgSpace(dim,img,newposArg)
128 :     in (dim,argsVH@argsT,argcode@code', s)
129 :     end
130 : cchiw 2510
131 : cchiw 2522
132 : cchiw 2585
133 :     (*createDels=> creates the kronecker deltas for each Kernel*)
134 :     fun createDels([],_)= []
135 :     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
136 :    
137 :    
138 :     (*Created new body for probe*)
139 : cchiw 2584 fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let
140 : cchiw 2515
141 : cchiw 2555 (*sumIndex creating summaiton Index for body*)
142 :     fun sumIndex(0)=[]
143 : cchiw 2584 |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
144 : cchiw 2498
145 : cchiw 2555 (*createKRN Image field and kernels *)
146 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
147 : cchiw 2584 | createKRN(dim,imgpos,rest)=let
148 : cchiw 2525
149 : cchiw 2555 val dim'=dim-1
150 :     val sum=sx+dim'
151 :     val dels=createDels(deltas,dim')
152 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
153 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
154 :     in
155 :     createKRN(dim',pos@imgpos,[rest']@rest)
156 :     end
157 : cchiw 2498
158 : cchiw 2555 val exp=createKRN(dim, [],[])
159 :     val esum=sumIndex (dim)
160 : cchiw 2584 in E.Sum(esum, exp)
161 :     end
162 : cchiw 2522
163 : cchiw 2555 fun ShapeConv([],n)=[]
164 :     | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
165 :     | ShapeConv(E.V v::es, n)=
166 :     if(n>v) then [E.V v] @ ShapeConv(es, n)
167 :     else ShapeConv(es,n)
168 :    
169 :    
170 : cchiw 2584 (*Lift probe*)
171 : cchiw 2585 fun liftProbe(b,(params,args),index, sumIndex,origargs)=let
172 : cchiw 2584 val _=print "\n Lift Probe\n"
173 : cchiw 2555 val E.Probe(E.Conv(_,alpha,_,dx),pos)=b
174 :    
175 :     val newId=length(params)
176 :     val n=length(index)
177 :    
178 :     (*Create new tensor replacement*)
179 :     val shape=ShapeConv(alpha@dx, n)
180 :     val newB=E.Tensor(newId,shape)
181 :    
182 :     (* Create new Param*)
183 :     val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape
184 :     val newP= E.TEN(1,shape')
185 :    
186 :     (*Create new Arg*)
187 :     val newArg = DstV.new ("PC", DstTy.tensorTy shape')
188 :    
189 :     (*Expand Probe*)
190 :     val ns=length sumIndex
191 : cchiw 2585
192 :     val (dim,args',code,s) = handleArgs(b,(params,args), origargs,1)
193 :    
194 :     val body' =(case ns
195 :     of 0=> createBody(dim, s,n,alpha,dx,0, 1, 3, 2)
196 : cchiw 2584 |_=>let
197 : cchiw 2555 val (E.V v,_,_)=List.nth(sumIndex, ns-1)
198 : cchiw 2585 val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)
199 :     in E.Sum(sumIndex ,body')
200 : cchiw 2555 end
201 :     (* end case *))
202 :    
203 : cchiw 2585 val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
204 : cchiw 2555 val (p',i',b',a')=shiftHtM.clean(params', index,body', args')
205 : cchiw 2585
206 : cchiw 2555 val newbie'=Ein.EIN{params=p', index=i', body=b'}
207 : cchiw 2585 val _ = print(String.concat["\n ", split.printA(newArg, newbie', a'),"\n"])
208 : cchiw 2555 val data=assignEin (newArg, newbie', a')
209 :     in (newB, (params@[newP],args@[newArg]) ,code@[data])
210 :     end
211 : cchiw 2584
212 :    
213 :     (* Expand probe in place *)
214 : cchiw 2585 fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let
215 : cchiw 2584 val _=print "\n Don't replace probe \n"
216 : cchiw 2585 val E.Probe(E.Conv(V,alpha,h,dx),pos)=b
217 : cchiw 2584
218 : cchiw 2585 val fid=length(params)
219 : cchiw 2584 val n=length(index)
220 : cchiw 2585
221 : cchiw 2584 (*Expand Probe*)
222 :     val ns=length sumIndex
223 : cchiw 2585
224 :     val (dim,args',code,s) = handleArgs(b,(params,args), origargs,0)
225 :     val nid=fid+1
226 :    
227 :    
228 :     val body' =(case ns
229 :     of 0=> createBody(dim, s,n,alpha,dx,V, h, nid, fid)
230 : cchiw 2584 |_=>let
231 :     val (E.V v,_,_)=List.nth(sumIndex, ns-1)
232 : cchiw 2585 in createBody(dim, s,v+1,alpha,dx,V, h, nid, fid)
233 : cchiw 2584 end
234 :     (* end case *))
235 : cchiw 2555
236 : cchiw 2585 val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
237 :     val UUU=Ein.EIN{params=params', index=index, body=body'}
238 :     val _ =print(String.concat["\n $$$ new sub-expression $$$ \n",P.printerE(UUU),"\n"])
239 : cchiw 2555
240 : cchiw 2584 in (body',(params',args') ,code)
241 : cchiw 2585 end
242 : cchiw 2584
243 : cchiw 2555
244 : cchiw 2584 fun flatten []=[]
245 :     | flatten(e1::es)=e1@(flatten es)
246 :    
247 :    
248 :     (* sx-[] then move out, otherwise keep in *)
249 : cchiw 2522 fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
250 : cchiw 2498
251 : cchiw 2555 val dummy=E.Const 0
252 :     val sumIndex=ref []
253 :    
254 : cchiw 2603
255 :    
256 : cchiw 2555 (*b-current body, info-original ein op, data-new assigments*)
257 :     fun rewriteBody(b,info)= let
258 : cchiw 2603
259 :     fun callfn(c1,body)=let
260 :     val ref x=sumIndex
261 :     val c'=[c1]@x
262 :     val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))
263 :     val ref s=sumIndex
264 :     val z=hd(s)
265 :     in
266 :     (sumIndex:=tl(s);(E.Sum(z,bodyK),infoK,dataK))
267 :     end
268 :    
269 :     (*val t=print(String.concat["\n\n ***:",Printer.printbody(b),"$ \n"])*)
270 : cchiw 2555 in (case b
271 :     of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let
272 :     val ref sx=sumIndex
273 : cchiw 2584 in (case sx
274 : cchiw 2585 of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
275 :     | _ => replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
276 : cchiw 2584 (* end case*))
277 : cchiw 2555 end
278 :     | E.Probe(E.Conv _, E.Tensor _) =>let
279 :     val ref sx=sumIndex
280 : cchiw 2584 in (case sx
281 : cchiw 2585 of []=> liftProbe(b, info,index, [],origargs)
282 :     | _=> replaceProbe(b, info,index, flatten sx,origargs)
283 : cchiw 2584 (* end case*))
284 : cchiw 2555 end
285 :     | E.Probe _=> (dummy,info,[])
286 :     | E.Conv _=> (dummy,info,[])
287 : cchiw 2603 | E.Lift _=> (dummy,info,[])
288 : cchiw 2555 | E.Field _ => (dummy,info,[])
289 :     | E.Apply _ => (dummy,info,[])
290 : cchiw 2498 | E.Neg e=> let
291 : cchiw 2555 val (body',info',data')=rewriteBody(e,info)
292 : cchiw 2498 in
293 : cchiw 2555 (E.Neg(body'),info',data')
294 : cchiw 2498 end
295 : cchiw 2603 | E.Sum([c1],E.Prod p1) => (case F.pushSum2(c1,p1)
296 :     of ([],[])=>raise Fail"Empty summation product"
297 :     | ([],[keep])=> callfn([c1],keep)
298 :     | ([],keep)=> callfn([c1],E.Prod(keep))
299 :     | ([s], [])=> rewriteBody(s ,info)
300 :     | (s, [])=> rewriteBody(E.Prod s ,info)
301 :     | (s,keep)=>let
302 :     val (bodyS,infoS,dataS)= rewriteBody(E.Prod s ,info)
303 :     val ref x=sumIndex
304 :     val c'=[[c1]]@x
305 :     val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(E.Prod keep ,infoS))
306 :     val ref s=sumIndex
307 :     val z=hd(s)
308 :     in (sumIndex:=tl(s);(E.Prod[bodyS,E.Sum(z,bodyK)],infoK, dataS@dataK))
309 :     end
310 :     (*end case*))
311 :     | E.Sum([(v, lb, ub)],e1)=>(case F.foundSx(v,e1)
312 :     of NONE => rewriteBody(e1, info)
313 :     | SOME _ => callfn([(v, lb, ub)],e1)
314 :     (* end case *))
315 :     | E.Sum (c,e)=> callfn(c,e)
316 : cchiw 2498 | E.Sub(a,b)=>let
317 : cchiw 2555 val (bodyA,infoA,dataA)= rewriteBody(a,info)
318 :     val (bodyB, infoB, dataB)= rewriteBody(b,infoA)
319 :     in (E.Sub(bodyA, bodyB),infoB,dataA@dataB)
320 : cchiw 2498 end
321 :     | E.Div(a,b)=>let
322 : cchiw 2555 val (bodyA,infoA,dataA)= rewriteBody(a,info)
323 :     val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
324 :     in (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
325 : cchiw 2498 | E.Add es=> let
326 : cchiw 2555 fun filter([], done, info', data)= (E.Add done, info',data)
327 :     | filter(e::es, done, info',data)= let
328 :     val (body', info'',data')= rewriteBody(e,info')
329 :     in filter(es, done@[body'], info'',data@data') end
330 :     in filter(es, [],info,[]) end
331 : cchiw 2510
332 : cchiw 2498 | E.Prod es=> let
333 : cchiw 2555 fun filter([], done, info',data)= (E.Prod done,info', data)
334 :     | filter(e::es, done, info',data)= let
335 :     val (body', info'',data')= rewriteBody(e, info')
336 :     in filter(es, done@[body'], info'',data@data') end
337 :     in filter(es, [],info,[]) end
338 :     | _=> (b,info,[])
339 : cchiw 2498 (* end case *))
340 : cchiw 2510 end
341 : cchiw 2498
342 : cchiw 2510 val empty =fn key =>NONE
343 : cchiw 2584 val mm=print "\n ************************** \n Starting Exapnd"
344 : cchiw 2555 val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
345 :     val e'=Ein.EIN{params=params', index=index, body=body'}
346 : cchiw 2584
347 :     val rr=print (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "])
348 : cchiw 2555 in ((e',args'),newbies) end
349 : cchiw 2498
350 :     end; (* local *)
351 :    
352 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0