Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin2.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin2.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2843 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure ProbeEin = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure mk= mkOperators
13 :     structure SrcIL = HighIL
14 :     structure SrcTy = HighILTypes
15 :     structure SrcOp = HighOps
16 :     structure SrcSV = SrcIL.StateVar
17 :     structure VTbl = SrcIL.Var.Tbl
18 :     structure DstIL = MidIL
19 :     structure DstTy = MidILTypes
20 :     structure DstOp = MidOps
21 :     structure DstV = DstIL.Var
22 :     structure SrcV = SrcIL.Var
23 :     structure P=Printer
24 :     structure F=Filter
25 :     structure T=TransformEin
26 :     structure split=Split
27 :     structure cleanI=cleanIndex
28 :    
29 :    
30 :     val testing=1
31 :    
32 :    
33 :     in
34 :    
35 :    
36 :     (* This file expands probed fields
37 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
38 :     * Param_ids are used to note the placement of the argument in the midIL.var list
39 :     * Index_ids bind the shape of an Image or differentiation.
40 :     * Generally, we will refer to the following
41 :     *dim:dimension of field V
42 :     * s: support of kernel H
43 :     * alpha: The alpha in <V_alpha * H^(deltas)>
44 :     * deltas: The deltas in <V_alpha * H^(deltas)>
45 :     * Vid:param_id for V
46 :     * hid:param_id for H
47 :     * nid: integer position param_id
48 :     * fid :fractional position param_id
49 :     *img-imginfo about V
50 :     *)
51 :    
52 :    
53 :     val cnt = ref 0
54 :     fun genName prefix = let
55 :     val n = !cnt
56 :     in
57 :     cnt := n+1;
58 :     String.concat[prefix, "_", Int.toString n]
59 :     end
60 :    
61 :    
62 :     fun iterSx e=F.iterSx e
63 :     fun transformToIndexSpace e=T.transformToIndexSpace e
64 :     fun transformToImgSpace e=T.transformToImgSpace e
65 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
66 :     fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
67 :     fun testp n=(case testing
68 :     of 0=> 1
69 :     | _ =>(print(String.concat n);1)
70 :     (*end case*))
71 :     fun getRHSDst x = (case DstIL.Var.binding x
72 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
73 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
74 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
75 :     (* end case *))
76 :    
77 :    
78 :     (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
79 :     uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments
80 :     returns the support of ther kernel, and image
81 :     *)
82 :     fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
83 :     of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
84 :     in
85 :     ((Kernel.support h) ,img,ImageInfo.dim img)
86 :     end
87 :     | _ => raise Fail "Expected Image and kernel arguments"
88 :     (*end case*))
89 :    
90 :    
91 :     (*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var
92 :     * uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each
93 :     *Transforms the position to index space
94 :     *P-mid-il var for the (transformation matrix)transpose
95 :     *)
96 :     fun handleArgs(Vid,hid,tid,args)=let
97 :     val imgArg=List.nth(args,Vid)
98 :     val hArg=List.nth(args,hid)
99 :     val newposArg=List.nth(args,tid)
100 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
101 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
102 :     in (dim,args@argsT,code, s,P)
103 :     end
104 :    
105 :    
106 :     (*createBody:int*int*int, index_id list, param_id, param_id, param_id, param_id
107 :     * expands the body for the probed field
108 :     *)
109 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
110 :    
111 :     (*1-d fields*)
112 :     fun createKRND1 ()=let
113 :     val sum=sx
114 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
115 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
116 :     val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
117 :     in
118 :     E.Prod [E.Img(Vid,alpha,pos),rest]
119 :    
120 :     end
121 :     (*createKRN Image field and kernels *)
122 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
123 :     | createKRN(dim,imgpos,rest)=let
124 :     val dim'=dim-1
125 :     val sum=sx+dim'
126 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
127 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
128 :     val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
129 :     in
130 :     createKRN(dim',pos@imgpos,[rest']@rest)
131 :     end
132 :     val exp=(case dim
133 :     of 1 => createKRND1()
134 :     | _=> createKRN(dim, [],[])
135 :     (*end case*))
136 :    
137 :     (*sumIndex creating summaiton Index for body*)
138 :     val slb=1-s
139 :     val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
140 :     in
141 :     E.Sum(esum, exp)
142 :     end
143 :    
144 :     (*getsumshift:sum_index_id list* index_id list-> int
145 :     *get fresh/unused index_id, returns int
146 :     *)
147 :     fun getsumshift(sx,index) =let
148 :     val nsumshift= (case sx
149 :     of []=> length(index)
150 :     | _=>let
151 :     val (E.V v,_,_)=List.hd(List.rev sx)
152 :     in v+1
153 :     end
154 :     (* end case *))
155 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
156 :     val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
157 :     in
158 :     nsumshift
159 :     end
160 :    
161 :     (*formBody:ein_exp->ein_exp
162 :     *just does a quick rewrite
163 :     *)
164 :     fun formBody(E.Sum([],e))=formBody e
165 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
166 :     | formBody(E.Prod [e])=e
167 :     | formBody e=e
168 :    
169 :    
170 :     (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
171 :     * Transforms position to world space
172 :     * transforms result back to index_space
173 :     * rewrites body
174 :     * replace probe with expanded version
175 :     *)
176 :     fun replaceProbe(b,params,args,index, sx)=let
177 :    
178 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
179 :     val fid=length(params)
180 :     val nid=fid+1
181 :     val Pid=nid+1
182 :     val nshift=length(dx)
183 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
184 :     val freshIndex=getsumshift(sx,index)
185 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
186 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
187 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
188 :     val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
189 :     val args'=argsA@[PArg]
190 :     in
191 :     (body',params',args' ,code)
192 :     end
193 :    
194 :    
195 :    
196 :     (* liftedProbe:e:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
197 :     * Same as above except it does not transforms result back to index_space
198 :     * Also returns P arg.
199 :     *)
200 :     fun liftedProbe(b,params,args,index, sumIndex)=let
201 :     val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
202 :     val fid=length(params)
203 :     val nid=fid+1
204 :     val nshift=length(dx)
205 :     val freshIndex = getsumshift(sumIndex,index)
206 :     val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args)
207 :     val params=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
208 :     val body = createBody(dim, s, freshIndex+nshift,alpha,dx,V, h, nid, fid)
209 :     in
210 :     (body,params,argsA ,PArg,code)
211 :     end
212 :    
213 :     (* expandEinOp: code-> code list
214 :     *Looks to see if the expression has a probe. If so, replaces it.
215 :     * Note how we keeps eps type expressions so we have less time in mid-to-low-il stage
216 :     *)
217 :     fun expandEinOp99( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
218 :     fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
219 :     (String.concatWith",\t"(List.map split.printEINAPP code))]
220 :    
221 :     fun rewriteBody b=(case b
222 :     of E.Probe e =>let
223 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
224 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
225 :     val code=newbies@[einapp]
226 :     in
227 :     code
228 :     end
229 :     | E.Sum(sx,E.Probe e) =>let
230 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
231 :     val body'=E.Sum(sx,body')
232 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
233 :     val code=newbies@[einapp]
234 :     in
235 :     code
236 :     end
237 :     | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
238 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
239 :     val body'=E.Sum(sx,E.Prod[eps,body'])
240 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
241 :     val code=newbies@[einapp]
242 :     in
243 :     code
244 :     end
245 :     | _=> [e]
246 :     (* end case *))
247 :     in
248 :     rewriteBody body
249 :     end
250 :    
251 :    
252 :     (* expandEinOp: code-> code list* Arg List
253 :     *same as above but uses lifted probe()
254 :     *)
255 :    
256 :     fun expandEinOp2( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
257 :     fun rewriteBody b=(case b
258 :     of E.Probe e =>let
259 :     val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, [])
260 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
261 :     val code=newbies@[einapp]
262 :    
263 :    
264 :     in
265 :     (code,[PArg])
266 :     end
267 :     | E.Sum(sx,E.Probe e) =>let
268 :     val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, sx)
269 :     val body'=E.Sum(sx,body')
270 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
271 :     val code=newbies@[einapp]
272 :     in
273 :     (code,[PArg])
274 :     end
275 :     | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
276 :     val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, sx)
277 :     val body'=E.Sum(sx,E.Prod[eps,body'])
278 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
279 :     val code=newbies@[einapp]
280 :     in
281 :     (code,[PArg])
282 :     end
283 :     | _=> ([e],[])
284 :     (* end case *))
285 :     in
286 :     rewriteBody body
287 :     end
288 :    
289 :     (* mkBoody:index_id list, ein_exp: sum_id list, ein_exp
290 :     *rewrite ein_exp. replaces dx with new delta
291 :     * Was sx bound to the probed field's alpha or delta?
292 :     * If it was delta then it needs to be moved to the original ein_app, not in the subexpression
293 :     *)
294 :     fun mkbody(beta ,body)=(case body
295 :     of E.Probe(E.Conv (v,alpha,h,_), E.Tensor t)=>([], E.Probe(E.Conv (v,alpha,h,beta), E.Tensor t))
296 :     | E.Sum(sx,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t))=>let
297 :     val(pre,post)=F.iterSx(sx,dx)
298 :     val body=E.Sum(post,E.Probe(E.Conv (v,alpha,h,beta), E.Tensor t))
299 :     in (pre,body)
300 :     end
301 :     | E.Sum(sx,E.Prod[eps,E.Probe(E.Conv(v,alpha,h,dx), E.Tensor t)])=>let
302 :     val(pre,post)=F.iterSx(sx,dx)
303 :     val body=E.Sum(post,E.Prod[eps,E.Probe(E.Conv(v,alpha,h,beta), E.Tensor t)])
304 :     in (pre,body)
305 :     end
306 :    
307 :     (*end case*))
308 :    
309 :     (*getT:sum_index_id list* int list* int*index_id list*ein_exp*Param_ids*mid_il var list*mid_id var
310 :     *This goal of this function is to create simple EINAPPs
311 :     *When differentiation is involved the deltas in a probed field are rewritten
312 :     *and there is a multiplication by P for each index.
313 :     *This function lifts the probed field out and multiplies it's replacment tensor with the Ps
314 :     *The probed field is rewritten with the new indices then it is cleaned with "split.lift".
315 :     *The result is two einapps that hopefully produce less loops for mid-to-low.sml
316 :     *)
317 :     fun geT(sx,index,dim,dx,body,params,args,y)=let
318 :     val Pid=length(args)+1
319 :     val freshIndex=getsumshift(sx,index)
320 :     val (newdx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
321 :     val (newsx2,e')= mkbody(newdx,body)
322 :     val newsx=newsx1@newsx2
323 :     val (Re,Rparams,Rargs,[einappN])=split.lift(e',params,index,newsx,args)
324 :     val (einappN,Parg)=expandEinOp2 einappN
325 :    
326 :     val body'=E.Sum(newsx,E.Prod(Ps@[Re]))
327 :     val einappO=(y,DstIL.EINAPP(Ein.EIN{params=Rparams@[E.TEN(1,[dim,dim])], index=index, body=body'},Rargs@Parg))
328 :     val code=einappN@[einappO]
329 :     in code
330 :     end
331 :    
332 :    
333 :     (*liftTransform: code->code
334 :     this function is called when we are testing lifting transformations
335 :     analyzes body of ein_exp and sets up the arguments to next function
336 :     Note, I hardcoded the dimension to be 2.
337 :     *)
338 :     fun liftTransform( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
339 :     fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
340 :     (String.concatWith",\t"(List.map split.printEINAPP code))]
341 :     fun rewriteBody b=(case b
342 :     of E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t) =>let
343 :     val code=geT([],index,2,dx,b,params,args,y)
344 :     val _ = printResult code
345 :     in
346 :     code
347 :     end
348 :     | E.Sum(sx,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t)) =>let
349 :     val code=geT(sx,index,2,dx,b,params,args,y)
350 :     in
351 :     code
352 :     end
353 :     | E.Sum(sx,E.Prod[eps,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t)]) =>let
354 :     val code=geT(sx,index,2,dx,b,params,args,y)
355 :     in
356 :     code
357 :     end
358 :     | _=> [e]
359 :     (* end case *))
360 :     in
361 :     rewriteBody body
362 :     end
363 :    
364 :     val testlift=0
365 :     fun expandEinOp e=(case testlift
366 :     of 1=>liftTransform e
367 :     | _ =>let
368 :     val code= expandEinOp99 e
369 :     in code
370 :     end
371 :     (*end case*))
372 :    
373 :    
374 :    
375 :     end; (* local *)
376 :    
377 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0