Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2606 - (view) (download)
Original Path: branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

1 : cchiw 2606 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :     (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 :     structure ProbeEin = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 :     structure mk= mkOperators
24 :     structure SrcIL = HighIL
25 :     structure SrcTy = HighILTypes
26 :     structure SrcOp = HighOps
27 :     structure SrcSV = SrcIL.StateVar
28 :     structure VTbl = SrcIL.Var.Tbl
29 :     structure DstIL = MidIL
30 :     structure DstTy = MidILTypes
31 :     structure DstOp = MidOps
32 :     structure DstV = DstIL.Var
33 :     structure SrcV = SrcIL.Var
34 :     structure P=Printer
35 :     structure shift=ShiftEin
36 :     structure split=SplitEin
37 :     structure F=Filter
38 :    
39 :     val testing=0
40 :    
41 :     datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
42 :     datatype peanut2= O2 of SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
43 :     in
44 :    
45 :    
46 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
47 :     fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
48 :    
49 :     fun getRHS x = (case SrcIL.Var.binding x
50 :     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
51 :     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'
52 :     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
53 :     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
54 :     | SrcIL.VB_NONE=>(S2 2,[])
55 :     | vb => raise Fail(concat[
56 :     "expected rhs operator for ", SrcIL.Var.toString x,
57 :     "but found ", SrcIL.vbToString vb])
58 :     (* end case *))
59 :    
60 :     (*Create fractional, and integer position vectors*)
61 :     fun transformToImgSpace (dim,v,posx)=let
62 :    
63 :     val translate=DstOp.Translate v
64 :     val transform=DstOp.Transform v
65 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
66 :     val T = DstV.new ("T", DstTy.tensorTy [dim,dim]) (*translate*)
67 :     val x = DstV.new ("x", DstTy.vecTy dim) (*Image-Space position*)
68 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
69 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
70 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*integer position*)
71 :     val PosToImgSpace=mk.transform(dim,dim)
72 :    
73 :     val code=[
74 :     assign(M, transform, []),
75 :     assign(T, translate, []),
76 :     assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
77 :     assign(nd, DstOp.Floor dim, [x]), (*nd *)
78 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
79 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
80 :     ]
81 :     in ([n,f],code)
82 :     end
83 :    
84 :    
85 :     fun replaceH(kvar, place,args)=let
86 :     val l1=List.take(args, place)
87 :     val l2=List.drop(args,place+1)
88 :     in l1@[kvar]@l2 end
89 :    
90 :    
91 :     (*Get Img, and Kern Args*)
92 :     fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)
93 :     of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let
94 :     val hvar=DstV.new ("KNL", DstTy.KernelTy)
95 :     val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
96 :     val argsVK= (case lift
97 :     of 0=> let
98 :     val argsN=replaceH(hvar, hid,args)
99 :     in replaceH(imgvar, V,argsN) end
100 :     | _ => [imgvar, hvar]
101 :     (* end case *))
102 :     val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]
103 :    
104 :     in
105 :     (Kernel.support h ,img, assigments,argsVK)
106 :     end
107 :     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
108 :     | _ => raise Fail "Not a kernel argument"
109 :    
110 :    
111 :     fun handleArgs(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),origargs,lift)=let
112 :     val E.IMG(dim)=List.nth(params,V)
113 :     val kArg=List.nth(origargs,h)
114 :     val imgArg=List.nth(origargs,V)
115 :     val newposArg=List.nth(args, t)
116 :     val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)
117 :     val (argsT,code')=transformToImgSpace(dim,img,newposArg)
118 :     in (dim,argsVH@argsT,argcode@code', s)
119 :     end
120 :     | handleArgs _ =raise Fail"Expression is wrong for handleArgs"
121 :    
122 :     (*createDels=> creates the kronecker deltas for each Kernel*)
123 :     fun createDels([],_)= []
124 :     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
125 :    
126 :     (*Created new body for probe*)
127 :     fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let
128 :    
129 :     (*sumIndex creating summaiton Index for body*)
130 :     fun sumIndex(0)=[]
131 :     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
132 :    
133 :     (*createKRN Image field and kernels *)
134 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
135 :     | createKRN(dim,imgpos,rest)=let
136 :     val dim'=dim-1
137 :     val sum=sx+dim'
138 :     val dels=createDels(deltas,dim')
139 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
140 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
141 :     in
142 :     createKRN(dim',pos@imgpos,[rest']@rest)
143 :     end
144 :    
145 :     val exp=createKRN(dim, [],[])
146 :     val esum=sumIndex (dim)
147 :     in E.Sum(esum, exp)
148 :     end
149 :    
150 :    
151 :     fun ShapeConv([],n)=[]
152 :     | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
153 :     | ShapeConv(E.V v::es, n)=
154 :     if(n>v) then [E.V v] @ ShapeConv(es, n)
155 :     else ShapeConv(es,n)
156 :    
157 :    
158 :     fun mapIndex([],_)=[]
159 :     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)
160 :     | mapIndex(E.C c::es,index) = mapIndex(es,index)
161 :    
162 :    
163 :    
164 :     (*Lift probe*)
165 :     fun liftProbe(b,(params,args),index, sumIndex,origargs)=let
166 :    
167 :     val E.Probe(E.Conv(_,alpha,_,dx),pos)=b
168 :     val newId=length(params)
169 :     val n=length(index)
170 :    
171 :     (*Create new tensor replacement*)
172 :     val shape=ShapeConv(alpha@dx, n)
173 :     val newB=E.Tensor(newId,shape)
174 :    
175 :     (* Create new Param*)
176 :     (* val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)
177 :     val shape'= mapIndex(shape,index)
178 :     val newP= E.TEN(1,shape')
179 :    
180 :     (*Create new Arg*)
181 :     val newArg = DstV.new ("PC", DstTy.tensorTy shape')
182 :    
183 :     (*Expand Probe*)
184 :     val ns=length sumIndex
185 :    
186 :     val (dim,args',code,s) = handleArgs(b,(params,args), origargs,1)
187 :    
188 :     val body' =(case ns
189 :     of 0=> createBody(dim, s,n,alpha,dx,0, 1, 3, 2)
190 :     |_=>let
191 :     val (E.V v,_,_)=List.nth(sumIndex, ns-1)
192 :     val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)
193 :     in E.Sum(sumIndex ,body')
194 :     end
195 :     (* end case *))
196 :    
197 :     val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
198 :     val (p',i',b',a')=shift.clean(params', index,body', args')
199 :     val newbie'=Ein.EIN{params=p', index=i', body=b'}
200 :     val data=assignEin (newArg, newbie', a')
201 :    
202 :     val _ = (case testing
203 :     of 0 => 1
204 :     | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)
205 :     (*end case *))
206 :     in (newB, (params@[newP],args@[newArg]) ,code@[data])
207 :     end
208 :    
209 :    
210 :     (* Expand probe in place *)
211 :     fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let
212 :    
213 :     val E.Probe(E.Conv(V,alpha,h,dx),pos)=b
214 :     val fid=length(params)
215 :     val n=length(index)
216 :    
217 :     (*Expand Probe*)
218 :     val ns=length sumIndex
219 :     val (dim,args',code,s) = handleArgs(b,(params,args), origargs,0)
220 :     val nid=fid+1
221 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
222 :     val body' =(case ns
223 :     of 0=> createBody(dim, s,n,alpha,dx,V, h, nid, fid)
224 :     |_=>let
225 :     val (E.V v,_,_)=List.nth(sumIndex, ns-1)
226 :     in createBody(dim, s,v+1,alpha,dx,V, h, nid, fid)
227 :     end
228 :     (* end case *))
229 :     val _ =(case testing
230 :     of 0=> 1
231 :     | _ => let
232 :     val subexp=Ein.EIN{params=params', index=index, body=body'}
233 :     val _= print(String.concat["\n Don't replace probe \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])
234 :     in 1 end
235 :     (* end case *))
236 :    
237 :     in (body',(params',args') ,code)
238 :     end
239 :    
240 :    
241 :     fun flatten []=[]
242 :     | flatten(e1::es)=e1@(flatten es)
243 :    
244 :    
245 :     (* sx-[] then move out, otherwise keep in *)
246 :     fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
247 :    
248 :     val dummy=E.Const 0
249 :     val sumIndex=ref []
250 :    
251 :     (*b-current body, info-original ein op, data-new assigments*)
252 :     fun rewriteBody(b,info)= let
253 :    
254 :     fun callfn(c1,body)=let
255 :     val ref x=sumIndex
256 :     val c'=[c1]@x
257 :     val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))
258 :     val ref s=sumIndex
259 :     val z=hd(s)
260 :     val e'=( case bodyK
261 :     of E.Const _ =>bodyK
262 :     | _ => E.Sum(z,bodyK)
263 :     (*end case*))
264 :     in
265 :     (sumIndex:=tl(s);(e',infoK,dataK))
266 :     end
267 :    
268 :     in (case b
269 :     of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let
270 :     val ref sx=sumIndex
271 :     in (case sx
272 :     of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
273 :     | _ => replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
274 :     (* end case*))
275 :     end
276 :     | E.Probe(E.Conv _, E.Tensor _) =>let
277 :     val ref sx=sumIndex
278 :     in (case sx
279 :     of []=> liftProbe(b, info,index, [],origargs)
280 :     | _=> replaceProbe(b, info,index, flatten sx,origargs)
281 :     (* end case*))
282 :     end
283 :     | E.Probe _=> (dummy,info,[])
284 :     | E.Conv _=> (dummy,info,[])
285 :     | E.Lift _=> (dummy,info,[])
286 :     | E.Field _ => (dummy,info,[])
287 :     | E.Apply _ => (dummy,info,[])
288 :     | E.Neg e=> let
289 :     val (body',info',data')=rewriteBody(e,info)
290 :     in
291 :     (E.Neg(body'),info',data')
292 :     end
293 :     | E.Sum (c,e)=> callfn(c,e)
294 :     | E.Sub(a,b)=>let
295 :     val (bodyA,infoA,dataA)= rewriteBody(a,info)
296 :     val (bodyB, infoB, dataB)= rewriteBody(b,infoA)
297 :     in (E.Sub(bodyA, bodyB),infoB,dataA@dataB)
298 :     end
299 :     | E.Div(a,b)=>let
300 :     val (bodyA,infoA,dataA)= rewriteBody(a,info)
301 :     val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
302 :     in (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
303 :     | E.Add es=> let
304 :     fun filter([], done, info', data)= let
305 :     val (_, e)=F.mkAdd done
306 :     in (e, info',data)
307 :     end
308 :     | filter(e::es, done, info',data)= let
309 :     val (body', info'',data')= rewriteBody(e,info')
310 :     in filter(es, done@[body'], info'',data@data') end
311 :     in filter(es, [],info,[]) end
312 :    
313 :     | E.Prod es=> let
314 :     fun filter([], done, info',data)= let
315 :     val (_, e)=F.mkProd done
316 :     in (e,info', data)
317 :     end
318 :     | filter(e::es, done, info',data)= let
319 :     val (body', info'',data')= rewriteBody(e, info')
320 :     in filter(es, done@[body'], info'',data@data') end
321 :     in filter(es, [],info,[]) end
322 :     | _=> (b,info,[])
323 :     (* end case *))
324 :     end
325 :    
326 :     val empty =fn key =>NONE
327 :     val _ =(case testing
328 :     of 0 => 1
329 :     | _ => (print "\n ************************** \n Starting Expand";1)
330 :     (*end case*))
331 :    
332 :     val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
333 :     val e'=Ein.EIN{params=params', index=index, body=body'}
334 :     val _ =(case testing
335 :     of 0 => 1
336 :     | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)
337 :     (*end case*))
338 :     in
339 :     ((e',args'),newbies)
340 :     end
341 :    
342 :     end; (* local *)
343 :    
344 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0