Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3312 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure ProbeEin = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure DstIL = MidIL
13 :     structure DstOp = MidOps
14 : jhr 3060 structure P = Printer
15 :     structure T = TransformEin
16 :     structure MidToS = MidToString
17 : cchiw 2976 structure DstV = DstIL.Var
18 :     structure DstTy = MidILTypes
19 :    
20 : cchiw 2606 in
21 :    
22 : cchiw 2870 (* This file expands probed fields
23 :     * Take a look at ProbeEin tex file for examples
24 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
25 :     * Param_ids are used to note the placement of the argument in the midIL.var list
26 :     * Index_ids keep track of the shape of an Image or differentiation.
27 :     * Mu bind Index_id
28 :     * Generally, we will refer to the following
29 :     *dim:dimension of field V
30 :     * s: support of kernel H
31 :     * alpha: The alpha in <V_alpha * H^(deltas)>
32 :     * deltas: The deltas in <V_alpha * H^(deltas)>
33 :     * Vid:param_id for V
34 :     * hid:param_id for H
35 :     * nid: integer position param_id
36 :     * fid :fractional position param_id
37 : cchiw 3166 * img-imginfo about V
38 : cchiw 2870 *)
39 : cchiw 3033
40 : cchiw 2923 val testing=0
41 : cchiw 3094 val testlift=0
42 : cchiw 3312 val detflag =true
43 : cchiw 3307
44 :    
45 : cchiw 2845 val cnt = ref 0
46 :     fun transformToIndexSpace e=T.transformToIndexSpace e
47 :     fun transformToImgSpace e=T.transformToImgSpace e
48 : cchiw 3268 fun toStringBind e=(MidToString.toStringBind e)
49 : cchiw 3260 fun mkEin e=Ein.mkEin e
50 : cchiw 3033 fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
51 : cchiw 3048
52 : cchiw 2845 fun testp n=(case testing
53 :     of 0=> 1
54 :     | _ =>(print(String.concat n);1)
55 :     (*end case*))
56 : cchiw 3260
57 :    
58 : cchiw 2845 fun getRHSDst x = (case DstIL.Var.binding x
59 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
60 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
61 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
62 :     (* end case *))
63 : cchiw 2838
64 : cchiw 2606
65 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
66 :     uses the Param_ids for the image, kernel,
67 :     and position tensor to get the Mid-IL arguments
68 :     returns the support of ther kernel, and image
69 :     *)
70 : jhr 3060 fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
71 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
72 : cchiw 2845 in
73 : jhr 3060 ((Kernel.support h) ,img,ImageInfo.dim img)
74 : cchiw 2845 end
75 :     | _ => raise Fail "Expected Image and kernel arguments"
76 :     (*end case*))
77 : cchiw 2606
78 :    
79 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
80 :     ->int*Mid.ILVars list* code*int* low-il-var
81 :     * uses the Param_ids for the image, kernel, and tensor
82 :     * and gets the mid-IL vars for each.
83 :     *Transforms the position to index space
84 :     *P is the mid-il var for the (transformation matrix)transpose
85 :     *)
86 :     fun handleArgs(Vid,hid,tid,args)=let
87 :     val imgArg=List.nth(args,Vid)
88 :     val hArg=List.nth(args,hid)
89 :     val newposArg=List.nth(args,tid)
90 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
91 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
92 : cchiw 2606 in
93 : cchiw 2845 (dim,args@argsT,code, s,P)
94 : cchiw 2606 end
95 : cchiw 2838
96 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
97 :     * expands the body for the probed field
98 :     *)
99 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
100 :     (*1-d fields*)
101 :     fun createKRND1 ()=let
102 :     val sum=sx
103 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
104 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
105 :     val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
106 :     in
107 :     E.Prod [E.Img(Vid,alpha,pos),rest]
108 : cchiw 2843 end
109 : cchiw 2845 (*createKRN Image field and kernels *)
110 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
111 :     | createKRN(dim,imgpos,rest)=let
112 :     val dim'=dim-1
113 :     val sum=sx+dim'
114 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
115 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
116 :     val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
117 :     in
118 :     createKRN(dim',pos@imgpos,[rest']@rest)
119 :     end
120 :     val exp=(case dim
121 :     of 1 => createKRND1()
122 :     | _=> createKRN(dim, [],[])
123 :     (*end case*))
124 :     (*sumIndex creating summaiton Index for body*)
125 :     val slb=1-s
126 :     val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
127 : cchiw 2843 in
128 : cchiw 2845 E.Sum(esum, exp)
129 : cchiw 2606 end
130 :    
131 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
132 :     *get fresh/unused index_id, returns int
133 :     *)
134 :     fun getsumshift(sx,index) =let
135 :     val nsumshift= (case sx
136 :     of []=> length(index)
137 :     | _=>let
138 :     val (E.V v,_,_)=List.hd(List.rev sx)
139 :     in v+1
140 :     end
141 :     (* end case *))
142 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
143 : cchiw 3267 val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
144 :     "\nThink nshift is ", Int.toString nsumshift]
145 : cchiw 2845 in
146 :     nsumshift
147 :     end
148 : cchiw 2611
149 : cchiw 2845 (*formBody:ein_exp->ein_exp
150 :     *just does a quick rewrite
151 :     *)
152 :     fun formBody(E.Sum([],e))=formBody e
153 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
154 :     | formBody(E.Prod [e])=e
155 :     | formBody e=e
156 : cchiw 2606
157 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
158 :     fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
159 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
160 : cchiw 3195
161 : cchiw 2976
162 : cchiw 3195 fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
163 :     | multiMergePs e=multiPs e
164 :    
165 :    
166 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
167 :     -> ein_exp* *code
168 :     * Transforms position to world space
169 :     * transforms result back to index_space
170 :     * rewrites body
171 :     * replace probe with expanded version
172 :     *)
173 : cchiw 3048 (* fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
174 :    
175 :     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
176 :     =let
177 :     val originalb=Ein.body e
178 :     val params=Ein.params e
179 :     val index=Ein.index e
180 : cchiw 3267 val _ = testp["\n***************** \n Replace ************ \n"]
181 : cchiw 3260 val _= toStringBind (y, DstIL.EINAPP(e,args))
182 : cchiw 3048
183 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
184 : cchiw 2845 val fid=length(params)
185 :     val nid=fid+1
186 :     val Pid=nid+1
187 :     val nshift=length(dx)
188 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
189 :     val freshIndex=getsumshift(sx,index)
190 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
191 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
192 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
193 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
194 : cchiw 3033
195 :     val body'=(case originalb
196 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
197 :     | E.Sum(sx,E.Prod[eps0,E.Probe _ ]) => E.Sum(sx,E.Prod[eps0,body'])
198 :     | _ => body'
199 :     (*end case*))
200 : cchiw 3260
201 : cchiw 3033
202 : cchiw 2845 val args'=argsA@[PArg]
203 : cchiw 3033 val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
204 :     in
205 :     code@[einapp]
206 : cchiw 2845 end
207 : cchiw 2976
208 : cchiw 3195 val tsplitvar=true
209 : cchiw 3048 fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
210 : cchiw 2976 val Pid=0
211 :     val tid=1
212 : cchiw 3260
213 :     (*Assumes body is already clean*)
214 : cchiw 3048 val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
215 :    
216 :     (*need to rewrite dx*)
217 : cchiw 3260 val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
218 : cchiw 3048 of []=> ([],index,E.Conv(9,alpha,7,newdx))
219 : cchiw 3260 | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
220 : cchiw 3048 (*end case*))
221 :    
222 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
223 : cchiw 3260 fun filterAlpha []=[]
224 :     | filterAlpha(E.C _::es)= filterAlpha es
225 :     | filterAlpha(e1::es)=[e1]@(filterAlpha es)
226 :    
227 :     val tshape=filterAlpha(alpha')@newdx
228 : cchiw 3033 val t=E.Tensor(tid,tshape)
229 : cchiw 3195 val (splitvar,body)=(case originalb
230 :     of E.Sum(sx, E.Probe _) => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
231 :     | E.Sum(sx,E.Prod[eps0,E.Probe _ ]) => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
232 :     | _ => (case tsplitvar
233 : cchiw 3196 of(* true => (true,multiMergePs(Ps,newsx,t)) (*pushes summations in place*)
234 : cchiw 3259 | false*) _ => (true,multiPs(Ps,newsx,t))
235 : cchiw 3195 (*end case*))
236 : cchiw 3048 (*end case*))
237 :    
238 :     val ein0=mkEin(params,index,body)
239 : cchiw 2976 in
240 : cchiw 3260 (splitvar,ein0,sizes,dx,alpha')
241 : cchiw 2976 end
242 : cchiw 3048
243 : cchiw 3260 fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
244 : cchiw 3267 val _=testp["\n******* Lift ******** \n"]
245 : cchiw 3048 val originalb=Ein.body e
246 :     val params=Ein.params e
247 : cchiw 3189 val index=Ein.index e
248 : cchiw 3260 val _= toStringBind (y, DstIL.EINAPP(e,args))
249 : cchiw 2976
250 : cchiw 3048 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
251 : cchiw 2976 val fid=length(params)
252 :     val nid=fid+1
253 :     val nshift=length(dx)
254 : cchiw 3048 val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
255 : cchiw 2976 val freshIndex=getsumshift(sx,index)
256 :    
257 :     (*transform T*P*P..Ps*)
258 : cchiw 3260 val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
259 : cchiw 3048 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
260 :     val einApp0=mkEinApp(ein0,[PArg,FArg])
261 : cchiw 3195 val rtn0=(case splitvar
262 :     of false => [(y,einApp0)]
263 :     | _ => Split.splitEinApp (y,einApp0)
264 :     (*end case*))
265 : cchiw 2976
266 :     (*lifted probe*)
267 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
268 : cchiw 3260 val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
269 : cchiw 3033 val ein1=mkEin(params',sizes,body')
270 : cchiw 2976 val einApp1=mkEinApp(ein1,args')
271 : cchiw 3048 val rtn1=(FArg,einApp1)
272 : cchiw 3195 val rtn=code@[rtn1]@rtn0
273 : cchiw 3260 val _= List.map toStringBind ([rtn1]@rtn0)
274 : cchiw 3259
275 : cchiw 2976 in
276 :     rtn
277 :     end
278 :    
279 : cchiw 2845 (* expandEinOp: code-> code list
280 : cchiw 3259 * A this point we only have simple ein ops
281 :     * Looks to see if the expression has a probe. If so, replaces it.
282 : cchiw 2845 * Note how we keeps eps expressions so only generate pieces that are used
283 :     *)
284 : cchiw 3229 fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
285 : cchiw 3261
286 :     fun checkConst ([],a) = liftProbe a
287 : cchiw 3260 | checkConst ((E.C _::_),a) = replaceProbe a
288 : cchiw 3166 | checkConst ((_ ::es),a)= checkConst(es,a)
289 : cchiw 3260
290 :     fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
291 :     let
292 :    
293 :     val _= toStringBind e
294 :     val index0=Ein.index ein
295 :     val index1 = index0@[3]
296 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
297 :     (* clean to get body indices in order *)
298 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
299 : cchiw 3267 val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
300 : cchiw 3260
301 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
302 :     val ein1 = mkEin(Ein.params ein,index1,body1)
303 :     val code1= (lhs1,mkEinApp(ein1,args))
304 : cchiw 3262 val codeAll= (case dx
305 :     of []=> replaceProbe(1,code1,body1,[])
306 :     | _ =>liftProbe(1,code1,body1,[])
307 :     (*end case*))
308 : cchiw 3260
309 :     (*Probe that tensor at a constant position E.C c1*)
310 :     val param0 = [E.TEN(1,index1)]
311 : cchiw 3261 val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
312 :     val body0 = E.Tensor(0,[E.C c1]@nx)
313 : cchiw 3260 val ein0 = mkEin(param0,index0,body0)
314 :     val einApp0 = mkEinApp(ein0,[lhs1])
315 :     val code0 = (y,einApp0)
316 :     val _= toStringBind code0
317 :     in
318 :     codeAll@[code0]
319 :     end
320 : cchiw 3259
321 : cchiw 3307 fun rewriteBody b=(case (detflag,b)
322 :     of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
323 : cchiw 3262 => liftFieldMat (1,b)
324 : cchiw 3307 | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
325 : cchiw 3260 => liftFieldMat (2,b)
326 : cchiw 3307 | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
327 : cchiw 3260 => liftFieldMat (3,b)
328 : cchiw 3307 | (_,E.Probe(E.Conv(_,_,_,[]),_))
329 : cchiw 3262 => replaceProbe(0,e,b,[])
330 : cchiw 3307 | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
331 : cchiw 3259 => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
332 : cchiw 3307 | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
333 : cchiw 3260 => replaceProbe(0,e,p, sx) (*no dx*)
334 : cchiw 3307 | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
335 : cchiw 3260 => checkConst(dx,(0,e,p,sx)) (*scalar field*)
336 : cchiw 3307 | (_,E.Sum(sx,E.Probe p))
337 : cchiw 3094 => replaceProbe(0,e,E.Probe p, sx)
338 : cchiw 3307 | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
339 : cchiw 3094 => replaceProbe(0,e,E.Probe p,sx)
340 : cchiw 3307 | (_,_) => [e]
341 : cchiw 2845 (* end case *))
342 : cchiw 3174
343 :     val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
344 : cchiw 3271
345 :     fun matchField b=(case b
346 :     of E.Probe _ => 1
347 :     | E.Sum (_, E.Probe _)=>1
348 :     | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
349 :     | _ =>0
350 :     (*end case*))
351 :    
352 : cchiw 3174 in (case var
353 : cchiw 3271 of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
354 :     | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
355 : cchiw 3174 (*end case*))
356 : cchiw 2845 end
357 : cchiw 2843
358 : cchiw 2606 end; (* local *)
359 :    
360 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0