Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2830 - (view) (download)

1 : cchiw 2608 (* Currently under construction
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :     (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 :     structure ProbeEin = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 :     structure mk= mkOperators
24 :     structure SrcIL = HighIL
25 :     structure SrcTy = HighILTypes
26 :     structure SrcOp = HighOps
27 :     structure SrcSV = SrcIL.StateVar
28 :     structure VTbl = SrcIL.Var.Tbl
29 :     structure DstIL = MidIL
30 :     structure DstTy = MidILTypes
31 :     structure DstOp = MidOps
32 :     structure DstV = DstIL.Var
33 :     structure SrcV = SrcIL.Var
34 :     structure P=Printer
35 :     structure shift=ShiftEin
36 :     structure split=SplitEin
37 :     structure F=Filter
38 : cchiw 2611 structure T=TransformEin
39 : cchiw 2606
40 : cchiw 2827 val testing=0
41 : cchiw 2606
42 : cchiw 2827
43 : cchiw 2606 in
44 :    
45 :    
46 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
47 :     fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
48 : cchiw 2829 fun testp n=(case testing
49 :     of 0=> 1
50 :     | _ =>(print(String.concat n);1)
51 :     (*end case*))
52 : cchiw 2606
53 : cchiw 2830 (*transform image-space position x to world space position*)
54 :     fun WorldToImagespace(dim,v,posx,imgArgDst)=let
55 : cchiw 2606 val translate=DstOp.Translate v
56 :     val transform=DstOp.Transform v
57 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
58 : cchiw 2680 val T = DstV.new ("T", DstTy.tensorTy [dim]) (*translate*)
59 : cchiw 2606 val x = DstV.new ("x", DstTy.vecTy dim) (*Image-Space position*)
60 : cchiw 2829 val x0 = DstV.new ("x0", DstTy.vecTy dim)
61 :     val PosToImgSpaceA=mk.transformA(dim,dim)
62 :     val PosToImgSpaceB=mk.transformB dim
63 : cchiw 2606 val code=[
64 : cchiw 2827 assign(M, transform, [imgArgDst]),
65 :     assign(T, translate, [imgArgDst]),
66 : cchiw 2830 assignEin(x0, PosToImgSpaceA,[M,posx]) , (*xo=MX*)
67 :     assignEin(x, PosToImgSpaceB,[x0,T]) (*x=x0+T*)
68 :     ]
69 :     in (M,x,code)
70 :     end
71 : cchiw 2829
72 : cchiw 2830
73 :     (*Create fractional, and integer position vectors*)
74 :     fun transformToImgSpace (dim,v,posx,imgArgDst)=let
75 :    
76 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
77 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
78 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*integer position*)
79 :     val P = DstV.new ("P", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
80 :    
81 :     val (M,x,code1)=WorldToImagespace(dim,v,posx,imgArgDst)
82 :     val code=[
83 : cchiw 2606 assign(nd, DstOp.Floor dim, [x]), (*nd *)
84 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
85 : cchiw 2608 assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)
86 :     assignEin(P, mk.transpose([dim,dim]), [M])
87 :     ]
88 : cchiw 2830 in ([n,f],P,code1@code)
89 : cchiw 2606 end
90 :    
91 : cchiw 2827 fun getRHS x = (case SrcIL.Var.binding x
92 :     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (rator, args)
93 :     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'
94 :     | vb => raise Fail(concat[ "expected rhs operator for ", SrcIL.Var.toString x, "but found ", SrcIL.vbToString vb])
95 :     (* end case *))
96 : cchiw 2606
97 :     (*Get Img, and Kern Args*)
98 : cchiw 2827 fun getArgs(hid,hArg,V,imgArg,args,lift,varI)=(case (getRHS hArg,getRHS imgArg)
99 :     of ((SrcOp.Kernel(h, i), _ ),(SrcOp.LoadImage img, _ ))=> let
100 : cchiw 2606 val hvar=DstV.new ("KNL", DstTy.KernelTy)
101 :     val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
102 :     val argsVK= (case lift
103 :     of 0=> let
104 : cchiw 2680 val _=print "non lift"
105 : cchiw 2827 val l1=List.take(args, hid)
106 :     val l2=List.drop(args,hid+1)
107 :     in
108 :     l1@[hvar]@l2
109 :     end
110 :     | _ => [varI, hvar]
111 : cchiw 2606 (* end case *))
112 : cchiw 2827
113 : cchiw 2680 val assigments=[assign (hvar, DstOp.Kernel(h, i), [])]
114 : cchiw 2606 in
115 : cchiw 2827 ((Kernel.support h) ,img, assigments,argsVK)
116 : cchiw 2606 end
117 : cchiw 2827 | _ => raise Fail "Expected Image and kernel argument"
118 :     (*end case*))
119 : cchiw 2606
120 :    
121 : cchiw 2830 fun handleArgs(V,h,t,args,origargs,lift,dstargs)=let
122 :    
123 : cchiw 2606 val kArg=List.nth(origargs,h)
124 :     val imgArg=List.nth(origargs,V)
125 :     val newposArg=List.nth(args, t)
126 : cchiw 2680 val imgArgDst=List.nth(dstargs,V)
127 :     val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift,imgArgDst)
128 : cchiw 2830 val dim=ImageInfo.dim img
129 : cchiw 2827 val (argsT,P,code')=transformToImgSpace(dim,img,newposArg,imgArgDst)
130 : cchiw 2608 in (dim,argsVH@argsT,argcode@code', s,P)
131 : cchiw 2606 end
132 :    
133 : cchiw 2608
134 : cchiw 2606 (*Created new body for probe*)
135 :     fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let
136 :    
137 :     (*sumIndex creating summaiton Index for body*)
138 :     fun sumIndex(0)=[]
139 :     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
140 :    
141 :     (*createKRN Image field and kernels *)
142 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
143 :     | createKRN(dim,imgpos,rest)=let
144 :     val dim'=dim-1
145 :     val sum=sx+dim'
146 : cchiw 2830 val dels=List.map (fn e=>(E.C dim',e)) deltas
147 : cchiw 2606 val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
148 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
149 :     in
150 :     createKRN(dim',pos@imgpos,[rest']@rest)
151 :     end
152 :    
153 :     val exp=createKRN(dim, [],[])
154 :     val esum=sumIndex (dim)
155 :     in E.Sum(esum, exp)
156 :     end
157 :    
158 :    
159 :     fun ShapeConv([],n)=[]
160 :     | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
161 :     | ShapeConv(E.V v::es, n)=
162 :     if(n>v) then [E.V v] @ ShapeConv(es, n)
163 :     else ShapeConv(es,n)
164 :    
165 :    
166 :     fun mapIndex([],_)=[]
167 :     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)
168 :     | mapIndex(E.C c::es,index) = mapIndex(es,index)
169 :    
170 :    
171 :    
172 :     (* Expand probe in place *)
173 : cchiw 2680 fun replaceProbe(b,(params,args),index, sumIndex,origargs,dstargs)=let
174 : cchiw 2606
175 : cchiw 2608 val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
176 : cchiw 2606 val fid=length(params)
177 : cchiw 2611 val nid=fid+1
178 : cchiw 2606 val n=length(index)
179 :     val ns=length sumIndex
180 : cchiw 2611 val nshift=length(dx)
181 :     val nsumshift =(case ns
182 :     of 0=> n
183 :     | _=>let
184 : cchiw 2606 val (E.V v,_,_)=List.nth(sumIndex, ns-1)
185 : cchiw 2611 in v+1
186 : cchiw 2606 end
187 : cchiw 2611 (* end case *))
188 :    
189 :     (*Outer Index-id Of Probe*)
190 :     val VShape=ShapeConv(alpha, n)
191 :     val HShape=ShapeConv(dx, n)
192 :     val shape=VShape@HShape
193 :     (* Bindings for Shape*)
194 :     val shapebind= mapIndex(shape,index)
195 :     val Vshapebind= mapIndex(VShape,index)
196 :    
197 :    
198 : cchiw 2830 val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args, origargs,0,dstargs)
199 : cchiw 2829 val _ =testp["\nSupport is ",Int.toString s]
200 : cchiw 2611 val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)
201 :    
202 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
203 :     val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
204 :     val body' =(case nshift
205 :     of 0=> body''
206 :     | _ => E.Sum(sxT, E.Prod(restT@[body'']))
207 :     (*end case*))
208 :     val args'=argsA@[PArg]
209 : cchiw 2829 val subexp=Ein.EIN{params=params', index=index, body=body'}
210 :     val _ = testp["\n Don't replace probe \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"]
211 : cchiw 2830
212 : cchiw 2606 in (body',(params',args') ,code)
213 :     end
214 :    
215 : cchiw 2680 (*
216 : cchiw 2612 (*Checks if (1) Summation variable occurs just once (2) it matches n.
217 :     Then we lift otherwise expand in place *)
218 :     fun checkSum(sx,b,info,index,origargs)=(case sx
219 :     of [(E.V i,lb,ub)]=> let
220 :     val E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta))=b
221 :     val n=length(index)
222 :     val _=(case testing
223 :     of 1=> (print(String.concat["in check Sum\n " ,P.printbody(E.Sum([(E.V i,lb,ub)],b))]);1)
224 :     |_ => 1)
225 :     in
226 :     if (i=n) then (case F.countSx(sx,b)
227 :     of (1,ixx) => liftProbe(b,info,index@[ub], [],origargs)
228 :     | _ => replaceProbe(b, info,index,sx,origargs)
229 :     (*end case*))
230 :     else replaceProbe(b, info,index, sx,origargs)
231 :     end
232 :     | _ =>replaceProbe(b, info,index, sx,origargs)
233 :     (*end case*))
234 : cchiw 2680 *)
235 : cchiw 2611
236 : cchiw 2606 fun flatten []=[]
237 :     | flatten(e1::es)=e1@(flatten es)
238 :    
239 :    
240 :     (* sx-[] then move out, otherwise keep in *)
241 :     fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
242 :    
243 :     val dummy=E.Const 0
244 :     val sumIndex=ref []
245 :    
246 :     (*b-current body, info-original ein op, data-new assigments*)
247 :     fun rewriteBody(b,info)= let
248 :    
249 :     fun callfn(c1,body)=let
250 :     val ref x=sumIndex
251 :     val c'=[c1]@x
252 :     val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))
253 :     val ref s=sumIndex
254 :     val z=hd(s)
255 :     val e'=( case bodyK
256 :     of E.Const _ =>bodyK
257 :     | _ => E.Sum(z,bodyK)
258 :     (*end case*))
259 :     in
260 :     (sumIndex:=tl(s);(e',infoK,dataK))
261 :     end
262 : cchiw 2611
263 : cchiw 2671
264 :     (*Nothing liftProbe and checkSum are commented out.
265 :     Some mistake underestimating size of dimension*)
266 :    
267 : cchiw 2611 fun filter es=let
268 :     fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)
269 :     | filterApply(B::es, doneA, infoA,dataA)= let
270 :     val (bodyB, infoB,dataB)= rewriteBody(B,infoA)
271 :     in
272 :     filterApply(es, doneA@[bodyB], infoB,dataA@dataB)
273 :     end
274 :     in filterApply(es, [],info,[])
275 :     end
276 : cchiw 2606 in (case b
277 :     of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let
278 :     val ref sx=sumIndex
279 : cchiw 2611 in (case sx
280 : cchiw 2671 of (* [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
281 : cchiw 2611 | [i]=> checkSum(i,b, info,index,origargs)
282 : cchiw 2671 |*) _ => let
283 : cchiw 2680 val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs,args)
284 : cchiw 2611 in (E.Sum(c,b),m,code)
285 :     end
286 : cchiw 2606 (* end case*))
287 :     end
288 :     | E.Probe(E.Conv _, E.Tensor _) =>let
289 :     val ref sx=sumIndex
290 :     in (case sx
291 : cchiw 2671 of (* []=> liftProbe(b, info,index, [],origargs)
292 : cchiw 2611 | [i]=> checkSum(i,b, info,index,origargs)
293 : cchiw 2680 |*) _ => replaceProbe(b, info,index, flatten sx,origargs,args)
294 : cchiw 2606 (* end case*))
295 :     end
296 :     | E.Probe _=> (dummy,info,[])
297 :     | E.Conv _=> (dummy,info,[])
298 :     | E.Lift _=> (dummy,info,[])
299 :     | E.Field _ => (dummy,info,[])
300 :     | E.Apply _ => (dummy,info,[])
301 :     | E.Neg e=> let
302 :     val (body',info',data')=rewriteBody(e,info)
303 :     in
304 :     (E.Neg(body'),info',data')
305 :     end
306 :     | E.Sum (c,e)=> callfn(c,e)
307 :     | E.Sub(a,b)=>let
308 :     val (bodyA,infoA,dataA)= rewriteBody(a,info)
309 :     val (bodyB, infoB, dataB)= rewriteBody(b,infoA)
310 :     in (E.Sub(bodyA, bodyB),infoB,dataA@dataB)
311 :     end
312 :     | E.Div(a,b)=>let
313 :     val (bodyA,infoA,dataA)= rewriteBody(a,info)
314 :     val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
315 :     in (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
316 :     | E.Add es=> let
317 : cchiw 2611 val (done, info',data')= filter es
318 :     val (_, e)=F.mkAdd done
319 :     in (e, info',data')
320 :     end
321 : cchiw 2606 | E.Prod es=> let
322 : cchiw 2611 val (done, info',data')= filter es
323 :     val (_, e)=F.mkProd done
324 :     in (e, info',data')
325 :     end
326 : cchiw 2606 | _=> (b,info,[])
327 :     (* end case *))
328 :     end
329 :    
330 :     val empty =fn key =>NONE
331 :     val _ =(case testing
332 :     of 0 => 1
333 :     | _ => (print "\n ************************** \n Starting Expand";1)
334 :     (*end case*))
335 :    
336 :     val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
337 :     val e'=Ein.EIN{params=params', index=index, body=body'}
338 : cchiw 2608 (*val _ =(case testing
339 : cchiw 2606 of 0 => 1
340 :     | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)
341 : cchiw 2608 (*end case*))*)
342 : cchiw 2606 in
343 :     ((e',args'),newbies)
344 :     end
345 :    
346 :     end; (* local *)
347 :    
348 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0