Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2680 - (view) (download)

1 : cchiw 2608 (* Currently under construction
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :     (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 :     structure ProbeEin = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 :     structure mk= mkOperators
24 :     structure SrcIL = HighIL
25 :     structure SrcTy = HighILTypes
26 :     structure SrcOp = HighOps
27 :     structure SrcSV = SrcIL.StateVar
28 :     structure VTbl = SrcIL.Var.Tbl
29 :     structure DstIL = MidIL
30 :     structure DstTy = MidILTypes
31 :     structure DstOp = MidOps
32 :     structure DstV = DstIL.Var
33 :     structure SrcV = SrcIL.Var
34 :     structure P=Printer
35 :     structure shift=ShiftEin
36 :     structure split=SplitEin
37 :     structure F=Filter
38 : cchiw 2611 structure T=TransformEin
39 : cchiw 2606
40 : cchiw 2613 val testing=0
41 : cchiw 2606
42 :     datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
43 :     datatype peanut2= O2 of SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
44 :     in
45 :    
46 :    
47 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
48 :     fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
49 :    
50 :     fun getRHS x = (case SrcIL.Var.binding x
51 :     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
52 :     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'
53 :     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
54 :     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
55 :     | SrcIL.VB_NONE=>(S2 2,[])
56 :     | vb => raise Fail(concat[
57 :     "expected rhs operator for ", SrcIL.Var.toString x,
58 :     "but found ", SrcIL.vbToString vb])
59 :     (* end case *))
60 :    
61 : cchiw 2680 fun getRHSDst x = (case DstIL.Var.binding x
62 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (O rator, args)
63 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
64 :     | DstIL.VB_RHS(DstIL.EINAPP (e,args))=>(E e,args)
65 :     | DstIL.VB_RHS(DstIL.CONS (ty,args))=>(C ty,args)
66 :     | DstIL.VB_NONE=>(S 2,[])
67 :     | vb => raise Fail(concat[
68 :     "expected rhs operator for ", DstIL.Var.toString x,
69 :     "but found ", DstIL.vbToString vb])
70 :     (* end case *))
71 : cchiw 2608
72 :    
73 : cchiw 2680
74 :    
75 :    
76 :    
77 : cchiw 2606 (*Create fractional, and integer position vectors*)
78 : cchiw 2680 fun transformToImgSpace (dim,v,posx,varII)=let
79 : cchiw 2606
80 :     val translate=DstOp.Translate v
81 :     val transform=DstOp.Transform v
82 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
83 : cchiw 2608
84 : cchiw 2680 val T = DstV.new ("T", DstTy.tensorTy [dim]) (*translate*)
85 : cchiw 2606 val x = DstV.new ("x", DstTy.vecTy dim) (*Image-Space position*)
86 : cchiw 2680
87 :     val xA = DstV.new ("xA", DstTy.vecTy dim) (*Image-Space position*)
88 :     (* val xB = DstV.new ("xB", DstTy.vecTy dim) (*Image-Space position*)*)
89 : cchiw 2606 val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
90 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
91 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*integer position*)
92 :     val PosToImgSpace=mk.transform(dim,dim)
93 : cchiw 2680
94 :     val PosToImgSpaceA=mk.transformA(dim,dim)
95 :     val PosToImgSpaceB=mk.transformB(dim,dim)
96 : cchiw 2608 val P = DstV.new ("P", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
97 : cchiw 2606
98 :     val code=[
99 : cchiw 2680 assign(M, transform, [varII]),
100 :     assign(T, translate, [varII]),
101 : cchiw 2606 assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
102 : cchiw 2680
103 :     (* assignEin(xA, PosToImgSpaceA,[M,posx,T]) , (* MX*)
104 :     assignEin(x, PosToImgSpaceB,[xA,T]) , (* ^+T*)*)
105 : cchiw 2606 assign(nd, DstOp.Floor dim, [x]), (*nd *)
106 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
107 : cchiw 2608 assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)
108 :     assignEin(P, mk.transpose([dim,dim]), [M])
109 :     ]
110 :     in ([n,f],P,code)
111 : cchiw 2606 end
112 :    
113 :    
114 :     fun replaceH(kvar, place,args)=let
115 :     val l1=List.take(args, place)
116 :     val l2=List.drop(args,place+1)
117 :     in l1@[kvar]@l2 end
118 :    
119 : cchiw 2680 fun getImageSrc x = (case SrcIL.Var.binding x
120 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage(img),[ivar])) =>
121 :     (print(String.concat["\n MOUSE-TOP::","---",SrcIL.Var.toString(ivar)]))
122 :     | vb => print "not imgae op"
123 :     (* end case *))
124 : cchiw 2606
125 : cchiw 2680 fun getImageDst x = (case DstIL.Var.binding x
126 :     of DstIL.VB_RHS(DstIL.OP(DstOp.LoadImage(img),[ivar])) =>
127 :     (print(String.concat["\n MOUSE-Orig:","---",DstIL.Var.toString(x),"\n MOUSE-BOT:","---",DstIL.Var.toString(ivar)]);ivar)
128 :     | vb => raise Fail "not load op"
129 :     (* end case *))
130 :    
131 :    
132 :    
133 : cchiw 2606 (*Get Img, and Kern Args*)
134 : cchiw 2680 fun getArgs(hid,hArg,V,imgArg,args,lift,varI)=case (getRHS hArg,getRHS imgArg)
135 :     of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),[yy]))=> let
136 : cchiw 2606 val hvar=DstV.new ("KNL", DstTy.KernelTy)
137 :     val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
138 :     val argsVK= (case lift
139 :     of 0=> let
140 : cchiw 2680 val _=print "non lift"
141 : cchiw 2606 val argsN=replaceH(hvar, hid,args)
142 : cchiw 2680 in (*replaceH(imgvar, V,argsN)*) replaceH(varI, V,argsN) end
143 :     | _ =>(* [imgvar, hvar]*) [varI, hvar]
144 : cchiw 2606 (* end case *))
145 : cchiw 2680
146 :     (*val Vimg= assign(imgvar,DstOp.LoadImage img, varI)*)
147 :    
148 :     val assigments=[assign (hvar, DstOp.Kernel(h, i), [])]
149 : cchiw 2606 in
150 :     (Kernel.support h ,img, assigments,argsVK)
151 :     end
152 :     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
153 :     | _ => raise Fail "Not a kernel argument"
154 :    
155 :    
156 : cchiw 2680 fun handleArgs(V,h,t,(params,args),origargs,lift,dstargs)=let
157 : cchiw 2606 val E.IMG(dim)=List.nth(params,V)
158 :     val kArg=List.nth(origargs,h)
159 :     val imgArg=List.nth(origargs,V)
160 :     val newposArg=List.nth(args, t)
161 : cchiw 2680 val imgArgDst=List.nth(dstargs,V)
162 :     val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift,imgArgDst)
163 :     val varII =getImageDst imgArgDst
164 :    
165 :     val (argsT,P,code')=transformToImgSpace(dim,img,newposArg,varII)
166 : cchiw 2608 in (dim,argsVH@argsT,argcode@code', s,P)
167 : cchiw 2606 end
168 :    
169 : cchiw 2608
170 : cchiw 2606 (*createDels=> creates the kronecker deltas for each Kernel*)
171 :     fun createDels([],_)= []
172 :     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
173 :    
174 :     (*Created new body for probe*)
175 :     fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let
176 :    
177 :     (*sumIndex creating summaiton Index for body*)
178 :     fun sumIndex(0)=[]
179 :     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
180 :    
181 :     (*createKRN Image field and kernels *)
182 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
183 :     | createKRN(dim,imgpos,rest)=let
184 :     val dim'=dim-1
185 :     val sum=sx+dim'
186 :     val dels=createDels(deltas,dim')
187 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
188 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
189 :     in
190 :     createKRN(dim',pos@imgpos,[rest']@rest)
191 :     end
192 :    
193 :     val exp=createKRN(dim, [],[])
194 :     val esum=sumIndex (dim)
195 :     in E.Sum(esum, exp)
196 :     end
197 :    
198 :    
199 :     fun ShapeConv([],n)=[]
200 :     | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
201 :     | ShapeConv(E.V v::es, n)=
202 :     if(n>v) then [E.V v] @ ShapeConv(es, n)
203 :     else ShapeConv(es,n)
204 :    
205 :    
206 :     fun mapIndex([],_)=[]
207 :     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)
208 :     | mapIndex(E.C c::es,index) = mapIndex(es,index)
209 :    
210 :    
211 : cchiw 2680 (*
212 : cchiw 2611 (*Lift probe and Multiply by P*)
213 :     fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let
214 :     val _ =print "Lift Probe"
215 : cchiw 2606
216 :     val n=length(index)
217 : cchiw 2611 val ns=length sumIndex
218 :     val nshift=length(dx)
219 :     val np=length(params)
220 :     val nsumshift =(case ns
221 :     of 0=> n
222 :     |_=>let val (E.V v,_,_)=List.nth(sumIndex, ns-1)
223 :     in
224 :     v+1
225 :     end
226 :     (* end case *))
227 : cchiw 2606
228 :    
229 : cchiw 2611 (*Outer Index-id Of Probe*)
230 :     val VShape=ShapeConv(alpha, n)
231 :     val HShape=ShapeConv(dx, n)
232 :     val shape=VShape@HShape
233 : cchiw 2606
234 : cchiw 2611 (* Bindings for Shape*)
235 :     val shapebind= mapIndex(shape,index)
236 :     val Vshapebind= mapIndex(VShape,index)
237 : cchiw 2606
238 : cchiw 2611 (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)
239 :     val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)
240 : cchiw 2608
241 : cchiw 2611 (*New transformations:params, sx, rest, will be empty if no transformation is made*)
242 :     val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)
243 : cchiw 2606
244 : cchiw 2611 (*rewriteBody*)
245 :     val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)
246 : cchiw 2608
247 : cchiw 2611 val sx=sumIndex@sxT
248 :     val body'=(case sx
249 :     of [] =>E.Prod(restT@[bodyExpanded])
250 :     | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))
251 :     (*end case*))
252 : cchiw 2606
253 : cchiw 2611 (*create new EIN OPerator*)
254 :     val _ =print("Found this many args ")
255 :     val _ =print(Int.toString(length(args')))
256 :    
257 :     val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT
258 :     val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])
259 : cchiw 2606 val newbie'=Ein.EIN{params=p', index=i', body=b'}
260 : cchiw 2608 val data=assignEin (oldArg, newbie', a')
261 :    
262 : cchiw 2606 val _ = (case testing
263 :     of 0 => 1
264 : cchiw 2611 | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)
265 :     (*end case *))
266 :     in
267 :     (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)
268 : cchiw 2606 end
269 : cchiw 2615 |liftProbe _ =raise Fail"Incorrect body for Probe"
270 : cchiw 2608
271 : cchiw 2680 *)
272 : cchiw 2608
273 :    
274 : cchiw 2680
275 : cchiw 2608 (*Does not yet do transformation*)
276 : cchiw 2606 (* Expand probe in place *)
277 : cchiw 2680 fun replaceProbe(b,(params,args),index, sumIndex,origargs,dstargs)=let
278 : cchiw 2606
279 : cchiw 2608 val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
280 : cchiw 2606 val fid=length(params)
281 : cchiw 2611 val nid=fid+1
282 : cchiw 2606 val n=length(index)
283 :     val ns=length sumIndex
284 : cchiw 2611 val nshift=length(dx)
285 :     val nsumshift =(case ns
286 :     of 0=> n
287 :     | _=>let
288 : cchiw 2606 val (E.V v,_,_)=List.nth(sumIndex, ns-1)
289 : cchiw 2611 in v+1
290 : cchiw 2606 end
291 : cchiw 2611 (* end case *))
292 :    
293 :     (*Outer Index-id Of Probe*)
294 :     val VShape=ShapeConv(alpha, n)
295 :     val HShape=ShapeConv(dx, n)
296 :     val shape=VShape@HShape
297 :     (* Bindings for Shape*)
298 :     val shapebind= mapIndex(shape,index)
299 :     val Vshapebind= mapIndex(VShape,index)
300 :    
301 :    
302 : cchiw 2680 val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0,dstargs)
303 : cchiw 2611 val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)
304 :    
305 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
306 :     val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
307 :     val body' =(case nshift
308 :     of 0=> body''
309 :     | _ => E.Sum(sxT, E.Prod(restT@[body'']))
310 :     (*end case*))
311 :     val args'=argsA@[PArg]
312 : cchiw 2606 val _ =(case testing
313 :     of 0=> 1
314 :     | _ => let
315 :     val subexp=Ein.EIN{params=params', index=index, body=body'}
316 :     val _= print(String.concat["\n Don't replace probe \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])
317 :     in 1 end
318 :     (* end case *))
319 :    
320 : cchiw 2611
321 : cchiw 2606 in (body',(params',args') ,code)
322 :     end
323 :    
324 : cchiw 2680 (*
325 : cchiw 2612 (*Checks if (1) Summation variable occurs just once (2) it matches n.
326 :     Then we lift otherwise expand in place *)
327 :     fun checkSum(sx,b,info,index,origargs)=(case sx
328 :     of [(E.V i,lb,ub)]=> let
329 :     val E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta))=b
330 :     val n=length(index)
331 :     val _=(case testing
332 :     of 1=> (print(String.concat["in check Sum\n " ,P.printbody(E.Sum([(E.V i,lb,ub)],b))]);1)
333 :     |_ => 1)
334 :     in
335 :     if (i=n) then (case F.countSx(sx,b)
336 :     of (1,ixx) => liftProbe(b,info,index@[ub], [],origargs)
337 :     | _ => replaceProbe(b, info,index,sx,origargs)
338 :     (*end case*))
339 :     else replaceProbe(b, info,index, sx,origargs)
340 :     end
341 :     | _ =>replaceProbe(b, info,index, sx,origargs)
342 :     (*end case*))
343 : cchiw 2680 *)
344 : cchiw 2611
345 : cchiw 2606 fun flatten []=[]
346 :     | flatten(e1::es)=e1@(flatten es)
347 :    
348 :    
349 :     (* sx-[] then move out, otherwise keep in *)
350 :     fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
351 :    
352 :     val dummy=E.Const 0
353 :     val sumIndex=ref []
354 :    
355 :     (*b-current body, info-original ein op, data-new assigments*)
356 :     fun rewriteBody(b,info)= let
357 :    
358 :     fun callfn(c1,body)=let
359 :     val ref x=sumIndex
360 :     val c'=[c1]@x
361 :     val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))
362 :     val ref s=sumIndex
363 :     val z=hd(s)
364 :     val e'=( case bodyK
365 :     of E.Const _ =>bodyK
366 :     | _ => E.Sum(z,bodyK)
367 :     (*end case*))
368 :     in
369 :     (sumIndex:=tl(s);(e',infoK,dataK))
370 :     end
371 : cchiw 2611
372 : cchiw 2671
373 :     (*Nothing liftProbe and checkSum are commented out.
374 :     Some mistake underestimating size of dimension*)
375 :    
376 : cchiw 2611 fun filter es=let
377 :     fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)
378 :     | filterApply(B::es, doneA, infoA,dataA)= let
379 :     val (bodyB, infoB,dataB)= rewriteBody(B,infoA)
380 :     in
381 :     filterApply(es, doneA@[bodyB], infoB,dataA@dataB)
382 :     end
383 :     in filterApply(es, [],info,[])
384 :     end
385 : cchiw 2606 in (case b
386 :     of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let
387 :     val ref sx=sumIndex
388 : cchiw 2611 in (case sx
389 : cchiw 2671 of (* [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
390 : cchiw 2611 | [i]=> checkSum(i,b, info,index,origargs)
391 : cchiw 2671 |*) _ => let
392 : cchiw 2680 val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs,args)
393 : cchiw 2611 in (E.Sum(c,b),m,code)
394 :     end
395 : cchiw 2606 (* end case*))
396 :     end
397 :     | E.Probe(E.Conv _, E.Tensor _) =>let
398 :     val ref sx=sumIndex
399 :     in (case sx
400 : cchiw 2671 of (* []=> liftProbe(b, info,index, [],origargs)
401 : cchiw 2611 | [i]=> checkSum(i,b, info,index,origargs)
402 : cchiw 2680 |*) _ => replaceProbe(b, info,index, flatten sx,origargs,args)
403 : cchiw 2606 (* end case*))
404 :     end
405 :     | E.Probe _=> (dummy,info,[])
406 :     | E.Conv _=> (dummy,info,[])
407 :     | E.Lift _=> (dummy,info,[])
408 :     | E.Field _ => (dummy,info,[])
409 :     | E.Apply _ => (dummy,info,[])
410 :     | E.Neg e=> let
411 :     val (body',info',data')=rewriteBody(e,info)
412 :     in
413 :     (E.Neg(body'),info',data')
414 :     end
415 :     | E.Sum (c,e)=> callfn(c,e)
416 :     | E.Sub(a,b)=>let
417 :     val (bodyA,infoA,dataA)= rewriteBody(a,info)
418 :     val (bodyB, infoB, dataB)= rewriteBody(b,infoA)
419 :     in (E.Sub(bodyA, bodyB),infoB,dataA@dataB)
420 :     end
421 :     | E.Div(a,b)=>let
422 :     val (bodyA,infoA,dataA)= rewriteBody(a,info)
423 :     val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
424 :     in (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
425 :     | E.Add es=> let
426 : cchiw 2611 val (done, info',data')= filter es
427 :     val (_, e)=F.mkAdd done
428 :     in (e, info',data')
429 :     end
430 : cchiw 2606 | E.Prod es=> let
431 : cchiw 2611 val (done, info',data')= filter es
432 :     val (_, e)=F.mkProd done
433 :     in (e, info',data')
434 :     end
435 : cchiw 2606 | _=> (b,info,[])
436 :     (* end case *))
437 :     end
438 :    
439 :     val empty =fn key =>NONE
440 :     val _ =(case testing
441 :     of 0 => 1
442 :     | _ => (print "\n ************************** \n Starting Expand";1)
443 :     (*end case*))
444 :    
445 :     val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
446 :     val e'=Ein.EIN{params=params', index=index, body=body'}
447 : cchiw 2608 (*val _ =(case testing
448 : cchiw 2606 of 0 => 1
449 :     | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)
450 : cchiw 2608 (*end case*))*)
451 : cchiw 2606 in
452 :     ((e',args'),newbies)
453 :     end
454 :    
455 :     end; (* local *)
456 :    
457 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0