Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3730 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : cchiw 2606 * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin = struct
10 :    
11 :     local
12 :    
13 :     structure E = Ein
14 :     structure DstIL = MidIL
15 :     structure DstOp = MidOps
16 : jhr 3060 structure P = Printer
17 :     structure T = TransformEin
18 :     structure MidToS = MidToString
19 : cchiw 2976 structure DstV = DstIL.Var
20 :     structure DstTy = MidILTypes
21 :    
22 : cchiw 2606 in
23 :    
24 : cchiw 2870 (* This file expands probed fields
25 :     * Take a look at ProbeEin tex file for examples
26 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27 :     * Param_ids are used to note the placement of the argument in the midIL.var list
28 :     * Index_ids keep track of the shape of an Image or differentiation.
29 :     * Mu bind Index_id
30 :     * Generally, we will refer to the following
31 :     *dim:dimension of field V
32 :     * s: support of kernel H
33 :     * alpha: The alpha in <V_alpha * H^(deltas)>
34 :     * deltas: The deltas in <V_alpha * H^(deltas)>
35 :     * Vid:param_id for V
36 :     * hid:param_id for H
37 :     * nid: integer position param_id
38 :     * fid :fractional position param_id
39 : cchiw 3166 * img-imginfo about V
40 : cchiw 2870 *)
41 : cchiw 3033
42 : cchiw 2923 val testing=0
43 : cchiw 3672 val valnumflag= true
44 :     val tsplitvar = true
45 : cchiw 3679 val fieldliftflag= true
46 :     val constflag = true
47 : cchiw 3672 val detflag = true
48 :     val detsumflag= true
49 : cchiw 3503 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50 :     fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51 : cchiw 3307
52 : cchiw 3677 val liftimgflag = false
53 : cchiw 3730 val pullKrn= true
54 : cchiw 3540
55 : cchiw 2845 val cnt = ref 0
56 :     fun transformToIndexSpace e=T.transformToIndexSpace e
57 :     fun transformToImgSpace e=T.transformToImgSpace e
58 : cchiw 3557 fun transformToImgSpaceF e=T.transformToImgSpaceF e
59 : cchiw 3268 fun toStringBind e=(MidToString.toStringBind e)
60 : cchiw 3686 fun toStringBindp e=print(MidToString.toStringBind e)
61 : cchiw 3260 fun mkEin e=Ein.mkEin e
62 : cchiw 3033 fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
63 : cchiw 3441 fun setConst e = E.setConst e
64 :     fun setNeg e = E.setNeg e
65 :     fun setExp e = E.setExp e
66 :     fun setDiv e= E.setDiv e
67 :     fun setSub e= E.setSub e
68 :     fun setProd e= E.setProd e
69 :     fun setAdd e= E.setAdd e
70 : cchiw 3472 fun mkCx es =List.map (fn c => E.C (c,true)) es
71 :     fun mkCxSingle c = E.C (c,true)
72 : cchiw 3441
73 : cchiw 2845 fun testp n=(case testing
74 :     of 0=> 1
75 : cchiw 3472 | _ =>(print(String.concat n);1)
76 : cchiw 2845 (*end case*))
77 : cchiw 3260
78 :    
79 : cchiw 2845 fun getRHSDst x = (case DstIL.Var.binding x
80 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
81 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
82 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
83 :     (* end case *))
84 : cchiw 2838
85 : cchiw 2606
86 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
87 :     uses the Param_ids for the image, kernel,
88 :     and position tensor to get the Mid-IL arguments
89 :     returns the support of ther kernel, and image
90 :     *)
91 : jhr 3060 fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
92 : cchiw 3531 of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> ((Kernel.support h) ,img,ImageInfo.dim img)
93 :     | ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
94 : cchiw 2845 (*end case*))
95 : cchiw 2606
96 :    
97 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
98 :     ->int*Mid.ILVars list* code*int* low-il-var
99 :     * uses the Param_ids for the image, kernel, and tensor
100 :     * and gets the mid-IL vars for each.
101 :     *Transforms the position to index space
102 :     *P is the mid-il var for the (transformation matrix)transpose
103 :     *)
104 :     fun handleArgs(Vid,hid,tid,args)=let
105 :     val imgArg=List.nth(args,Vid)
106 :     val hArg=List.nth(args,hid)
107 :     val newposArg=List.nth(args,tid)
108 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
109 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
110 : cchiw 2606 in
111 : cchiw 2845 (dim,args@argsT,code, s,P)
112 : cchiw 2606 end
113 : cchiw 3531
114 : cchiw 3557 fun handleArgsF(fieldset,Vid,hid,tid,args)=let
115 :     val imgArg=List.nth(args,Vid)
116 :     val hArg=List.nth(args,hid)
117 :     val newposArg=List.nth(args,tid)
118 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
119 :     val (fieldset,argsT,P,code)=transformToImgSpaceF(fieldset,dim,img,newposArg,imgArg)
120 :     in
121 :     (fieldset,dim,args@argsT,code, s,P)
122 :     end
123 :    
124 : cchiw 2838
125 : cchiw 3540
126 :    
127 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
128 :     * expands the body for the probed field
129 :     *)
130 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
131 :     (*1-d fields*)
132 :     fun createKRND1 ()=let
133 :     val sum=sx
134 : cchiw 3472 val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
135 : cchiw 3441 val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
136 :     val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
137 : cchiw 2845 in
138 : cchiw 3441 setProd[E.Img(Vid,alpha,pos),rest]
139 : cchiw 2843 end
140 : cchiw 3531
141 :     fun mkImg(imgpos)=E.Img(Vid,alpha,imgpos)
142 :    
143 : cchiw 2845 (*createKRN Image field and kernels *)
144 : cchiw 3531 fun createKRN(0,imgpos,rest)=setProd ([mkImg(imgpos)] @rest)
145 : cchiw 2845 | createKRN(dim,imgpos,rest)=let
146 :     val dim'=dim-1
147 :     val sum=sx+dim'
148 : cchiw 3472 val dels=List.map (fn e=>(mkCxSingle dim',e)) deltas
149 :     val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
150 :     val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
151 : cchiw 2845 in
152 :     createKRN(dim',pos@imgpos,[rest']@rest)
153 :     end
154 :     val exp=(case dim
155 :     of 1 => createKRND1()
156 :     | _=> createKRN(dim, [],[])
157 :     (*end case*))
158 :     (*sumIndex creating summaiton Index for body*)
159 :     val slb=1-s
160 : cchiw 3383 val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
161 : cchiw 2845 val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
162 : cchiw 2843 in
163 : cchiw 2845 E.Sum(esum, exp)
164 : cchiw 2606 end
165 : cchiw 3540
166 : cchiw 2606
167 : cchiw 3540
168 :     (* build position *)
169 :     fun buildPos (dir,dim,argsA,hid,nid,s) =let
170 :     val vA = DstV.new ("kernel_pos", DstTy.TensorTy([]))
171 : cchiw 3730 val a=[List.nth(argsA,hid),List.nth(argsA,nid)]
172 :     (*
173 : cchiw 3540 val p=[E.KRN,E.TEN(1,[dim])]
174 :     val pos=setSub(E.Tensor(1,[mkCxSingle dir]),E.Value(0))
175 :     val exp= E.BuildPos(s,pos)
176 : cchiw 3730 val A=(vA,mkEinApp(mkEin(p,[],exp),a))*)
177 :    
178 :     val A= (vA,DstOp.OP(Op.BuildPos(s, dir),a))
179 : cchiw 3540 in (vA,A) end
180 :    
181 :     (* apply differentiation *)
182 :     fun getKrn1Del(dx,dim,args,slb,s)= let
183 :     val n=Int.toString(dx)
184 :     val vA = DstV.new ("kernel_del"^n, DstTy.TensorTy([]))
185 :     val p=[E.KRN,E.TEN(1,[dim])]
186 :     val exp = E.EvalKrn dx
187 :     val A = (vA,mkEinApp(mkEin(p,[],exp),args))
188 :     in (vA,A) end
189 :    
190 :     (*create holder expression*)
191 :     fun mkHolder(dim,args) =let
192 :     val n=List.length(args)
193 :     val vA = DstV.new ("kernel_cons", DstTy.TensorTy([n]))
194 :     val p=[E.KRN,E.TEN(1,[dim])]
195 :     val A= (vA,mkEinApp(mkEin(p,[],E.Holder n),args))
196 :     in (vA,A) end
197 :    
198 :     (*lifted Kernel expressions*)
199 :     fun liftKrn(dx,dir,dim,argsA,hid,nid,slb,s)=let
200 :     val (vA,A)=buildPos(dir,dim,argsA,hid,nid,s)
201 :     val args=[List.nth(argsA,hid),vA]
202 :     fun iter(0,vBs,Bs)=let
203 :     val (vA,A)=getKrn1Del(0,dim,args,slb,s)
204 :     in (vA::vBs,A::Bs) end
205 :     | iter (n,vBs,Bs)= let
206 :     val (vA,A)=getKrn1Del(n,dim,args,slb,s)
207 :     in iter(n-1,vA::vBs,A::Bs) end
208 :     val (vBs,Bs)=iter(length(dx),[],[])
209 :     val (vC,C) =mkHolder(dim,vBs)
210 :     in (vC,(A::Bs)@[C]) end
211 :    
212 :    
213 :    
214 :     fun createBody2(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=let
215 :     (*1-d fields*)
216 :     val slb=1-s
217 :    
218 :    
219 :     (*making image*)
220 :     val tid=(case liftimgflag
221 :     of true => length(params)-1
222 :     | _ => length(params)-1
223 :     (*end case*))
224 :     fun mkImg imgpos =(case liftimgflag
225 :     of true=>(E.Tensor(Vid,alpha),SOME(E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim),slb,s))),E.Img(Vid,alpha,imgpos))))
226 :     | _ =>let
227 :     val imgpos= List.tabulate(dim,fn e=> setAdd[E.Tensor(fid,[mkCxSingle e]),E.Value(e+sx)])
228 :     in (E.Img(Vid,alpha,imgpos),NONE) end
229 :     (*end case*))
230 :    
231 :     fun createKRND1 ()=let
232 :     val sum=sx
233 :     val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
234 :     val imgpos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
235 :     val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
236 :     val (talpha,iexp)= mkImg imgpos
237 :     in (setProd[talpha,rest],iexp,NONE,NONE)end
238 :    
239 :     (*createKRN Image field and kernels *)
240 :     fun createKRN(0,orig,imgpos,vAs,krnpos)= let
241 :     val (talpha,iexp)= mkImg imgpos
242 :     in (setProd ([talpha]@orig),iexp,SOME vAs,SOME krnpos) end
243 :     | createKRN(d,orig,imgpos,vAs,krnpos)=let
244 :     val dim'=d-1
245 :     val sum=sx+dim'
246 :     val dels=List.map (fn e=>(mkCxSingle dim',e)) deltas
247 :     val ipos=setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(dim')]
248 :     val opos= E.Krn(hid,dels,E.Tensor(tid+d,[]))
249 :     val (vA,A)= liftKrn(dels,dim',dim,argsA,hid,nid,slb,s)
250 :     in
251 :     createKRN(dim',[opos]@orig,[ipos]@imgpos,[vA]@vAs,A@krnpos)
252 :     end
253 :    
254 :     val (oexp,iexp,vAs,keinapp)=(case dim
255 :     of 1 => createKRND1()
256 :     | _=> createKRN(dim, [],[],[],[])
257 :     (*end case*))
258 :    
259 :     val oexp=E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s))), oexp)
260 :     in (oexp,iexp,vAs,keinapp) end
261 :    
262 :     fun createBody3(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)=
263 :     createBody2(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)
264 :     | createBody3(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=
265 :     (createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid),NONE,NONE,NONE)
266 :    
267 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
268 : cchiw 3540 *get fresh/unused index_id, returns int
269 : cchiw 2845 *)
270 : cchiw 3383 fun getsumshift(sx,n) =let
271 : cchiw 2845 val nsumshift= (case sx
272 : cchiw 3383 of []=> n
273 : cchiw 2845 | _=>let
274 :     val (E.V v,_,_)=List.hd(List.rev sx)
275 :     in v+1
276 :     end
277 :     (* end case *))
278 : cchiw 3383
279 : cchiw 2845 val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
280 : cchiw 3383 val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
281 :     "\n\t Index length:",Int.toString n,
282 :     "\n\t Freshindex: ", Int.toString nsumshift])
283 : cchiw 2845 in
284 :     nsumshift
285 : cchiw 3540 end
286 : cchiw 2611
287 : cchiw 2845 (*formBody:ein_exp->ein_exp
288 :     *just does a quick rewrite
289 :     *)
290 :     fun formBody(E.Sum([],e))=formBody e
291 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
292 : cchiw 3441 | formBody(E.Opn(E.Prod, [e]))=e
293 : cchiw 2845 | formBody e=e
294 : cchiw 2606
295 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
296 : cchiw 3441 fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
297 : cchiw 3353 (*
298 : cchiw 3441 | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
299 : cchiw 3353 *)
300 : cchiw 3441 | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
301 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
302 : cchiw 3195
303 : cchiw 2976
304 : cchiw 3441 fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
305 : cchiw 3195 | multiMergePs e=multiPs e
306 : cchiw 3540
307 :     (* ******************************************* setImage ******************************************* *)
308 :     fun replaceImgA(es,vid,newbie)=List.take(es,vid)@[newbie]@List.drop(es,vid+1)
309 :     fun setImage(params',argsA,code,vexp2,index,alpha,paraminstant,Vid,s)=
310 :     (case vexp2
311 :     of NONE =>(params',argsA,code)
312 :     | SOME vexp => let
313 :     val iArg = DstV.new ("Img", DstTy.TensorTy([]))
314 :     val alphax=List.map (fn (E.V i)=>List.nth(index,i)) alpha
315 :     val ieinapp=(iArg,mkEinApp(mkEin(paraminstant,alphax,vexp),argsA))
316 :     (*
317 :     val _ =print(String.concat["\n****\n Image (",Int.toString(length(argsA)),")"])
318 :     val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
319 :     val _ =print(String.concat["\n replace at ",Int.toString Vid ," with " , DstIL.Var.toString iArg ,"\n"])*)
320 :     val argsA=replaceImgA(argsA,Vid,iArg)
321 :     val params'=replaceImgA(params',Vid,E.TEN(2,[(s-(1-s)+1)*(s-(1-s)+1),(s-(1-s)+1)]))
322 :     val code=code@[ieinapp]
323 :     (*
324 :     val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
325 :     val _ =print(String.concat["\n****\n Image(",Int.toString(length(argsA)),")"])*)
326 :     in (params',argsA,code) end
327 :     (*end case*))
328 :    
329 :     (*kernels*)
330 :     fun setKernel(params',args',code,vAs2,keinapp2,dim)=
331 :     (case (vAs2,keinapp2)
332 :     of (NONE,NONE)=> (params',args',code)
333 :     | (SOME vAs,SOME keinapp) => let
334 :     (*
335 :     val _ =print"\n****\n Kernels\n"
336 :     val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])
337 :     val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))*)
338 :     val args'= args'@vAs
339 :     val params'= params'@(List.tabulate(dim,fn _=> E.TEN(2,[])))
340 :     val code=code@keinapp
341 :     (*
342 :     val _ =print"\n"
343 :     val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))
344 :     val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])*)
345 :     in (params',args',code) end
346 :     (*end case*))
347 : cchiw 3195
348 : cchiw 3540 fun setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)=let
349 :     val (params',args',code)=setImage(params',args',code,vexp2,index,alpha,paraminstant,Vid,s)
350 :     in setKernel(params',args',code,vAs2,keinapp2,dim) end
351 :    
352 : cchiw 3557
353 : cchiw 3472 (* ******************************************* Replace probe ******************************************* *)
354 :     (* replaceProbe
355 : cchiw 2845 * Transforms position to world space
356 :     * transforms result back to index_space
357 :     * rewrites body
358 :     * replace probe with expanded version
359 :     *)
360 : cchiw 3557 fun replaceProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)
361 : cchiw 3048 =let
362 :     val originalb=Ein.body e
363 :     val params=Ein.params e
364 :     val index=Ein.index e
365 : cchiw 3686 val _ =print (String.concat["\n***************** \n Replace ************ \n"])
366 : cchiw 3557 val _= toStringBindp (y, DstIL.EINAPP(e,args))
367 : cchiw 3048
368 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
369 : cchiw 2845 val fid=length(params)
370 :     val nid=fid+1
371 :     val Pid=nid+1
372 :     val nshift=length(dx)
373 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
374 : cchiw 3383 val freshIndex=getsumshift(sx,length(index))
375 : cchiw 2845 val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
376 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
377 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
378 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
379 : cchiw 3033
380 :     val body'=(case originalb
381 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
382 : cchiw 3441 | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,body'])
383 : cchiw 3033 | _ => body'
384 :     (*end case*))
385 :    
386 : cchiw 2845 val args'=argsA@[PArg]
387 : cchiw 3686
388 : cchiw 3033 val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
389 : cchiw 3686
390 :     (*
391 :    
392 :     val e2=SummationEin.main (mkEin(params',index,body'))
393 :     val einapp=(y,mkEinApp(e2,args'))
394 :     val _ = print("\n shifted:=>"^P.printerE(e2))
395 :     *)
396 : cchiw 3557 val _= List.map toStringBindp(code@[einapp])
397 : cchiw 3033 in
398 : cchiw 3557 (fieldset,code@[einapp])
399 : cchiw 2845 end
400 : cchiw 2976
401 : cchiw 3540
402 :    
403 : cchiw 3557 fun replaceProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx) = let
404 : cchiw 3531 val originalb=Ein.body e
405 :     val params=Ein.params e
406 :     val index=Ein.index e
407 :     val _ = testp["\n***************** \n Replace ************ \n"]
408 :     val _= toStringBind (y, DstIL.EINAPP(e,args))
409 :    
410 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
411 :     val fid=length(params)
412 :     val nid=fid+1
413 :     val Pid=nid+1
414 :     val nshift=length(dx)
415 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
416 :     val freshIndex=getsumshift(sx,length(index))
417 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
418 :    
419 :     val paraminstant=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
420 :     val params'=paraminstant@[E.TEN(1,[dim,dim])]
421 : cchiw 3540
422 : cchiw 3531
423 : cchiw 3557
424 : cchiw 3540 val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid,paraminstant,argsA)
425 : cchiw 3531 val body' = multiPs(Ps,newsx1,body')
426 : cchiw 3557
427 : cchiw 3531 val body'=(case originalb
428 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
429 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,body'])
430 :     | _ => body'
431 :     (*end case*))
432 : cchiw 3540
433 :     (*images and kernels*)
434 :     val (params',argsA,code)=setImageKernel(params',argsA,code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)
435 : cchiw 3531
436 :    
437 :     (*replace term*)
438 :     val args'=argsA@[PArg]
439 :     val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
440 :     val _= List.map toStringBindp(code@[einapp])
441 :     in
442 : cchiw 3557 (fieldset,code@[einapp])
443 : cchiw 3531 end
444 :    
445 :    
446 : cchiw 3672 (*to call from avail-rhs. replacing field in place*)
447 :     fun replaceProbeF(params,index,sx,p,args)= let
448 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
449 :     val fid=length(params)
450 :     val nid=fid+1
451 :     val Pid=nid+1
452 :     val nshift=length(dx)
453 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
454 :     val freshIndex=getsumshift(sx,length(index))
455 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
456 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
457 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
458 :     val body' = multiPs(Ps,newsx1,body')
459 :     val args'=argsA@[PArg]
460 :     in
461 :     (params',body',args',code)
462 :     end
463 :    
464 : cchiw 3472 (* ******************************************* Lift probe ******************************************* *)
465 : cchiw 3048 fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
466 : cchiw 2976 val Pid=0
467 :     val tid=1
468 : cchiw 3260
469 :     (*Assumes body is already clean*)
470 : cchiw 3048 val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
471 :    
472 :     (*need to rewrite dx*)
473 : cchiw 3260 val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
474 : cchiw 3048 of []=> ([],index,E.Conv(9,alpha,7,newdx))
475 : cchiw 3260 | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
476 : cchiw 3048 (*end case*))
477 :    
478 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
479 : cchiw 3260 fun filterAlpha []=[]
480 :     | filterAlpha(E.C _::es)= filterAlpha es
481 :     | filterAlpha(e1::es)=[e1]@(filterAlpha es)
482 :    
483 :     val tshape=filterAlpha(alpha')@newdx
484 : cchiw 3033 val t=E.Tensor(tid,tshape)
485 : cchiw 3441
486 : cchiw 3195 val (splitvar,body)=(case originalb
487 : cchiw 3441 of E.Sum(sx, E.Probe _) => (true,multiPs(Ps,sx@newsx,t))
488 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
489 : cchiw 3195 | _ => (case tsplitvar
490 : cchiw 3441 of(* true => (true,multiMergePs(Ps,newsx,t)) (*pushes summations in place*)
491 : cchiw 3259 | false*) _ => (true,multiPs(Ps,newsx,t))
492 : cchiw 3195 (*end case*))
493 : cchiw 3441 (*end case*))
494 : cchiw 3048
495 : cchiw 3324 val _ =(case splitvar
496 : cchiw 3540 of true=> (String.concat["splitvar is true", P.printbody body])
497 :     | _ => (String.concat["splitvar is false",P.printbody body])
498 : cchiw 3324 (*end case*))
499 :    
500 :    
501 : cchiw 3048 val ein0=mkEin(params,index,body)
502 : cchiw 2976 in
503 : cchiw 3260 (splitvar,ein0,sizes,dx,alpha')
504 : cchiw 2976 end
505 : cchiw 3048
506 : cchiw 3557 fun liftProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
507 : cchiw 3686 val _= print(String.concat["\n******* Lift Geneirc Probe ***\n"])
508 : cchiw 3048 val originalb=Ein.body e
509 :     val params=Ein.params e
510 : cchiw 3189 val index=Ein.index e
511 : cchiw 3557 val _ = (toStringBindp (y, DstIL.EINAPP(e,args)))
512 : cchiw 2976
513 : cchiw 3048 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
514 : cchiw 2976 val fid=length(params)
515 :     val nid=fid+1
516 :     val nshift=length(dx)
517 : cchiw 3557 val (fieldset,dim,args',code,s,PArg) = handleArgsF(fieldset,Vid,hid,tid,args)
518 : cchiw 3383 val freshIndex=getsumshift(sx,length(index))
519 : cchiw 2976
520 :     (*transform T*P*P..Ps*)
521 : cchiw 3260 val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
522 : cchiw 3557 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
523 :    
524 :    
525 :    
526 :     (*addedhere*)
527 :     val ein9=mkEin(params,sizes,E.Conv(Vid,alpha',hid,dx))
528 :     val einApp9=mkEinApp(ein9,args)
529 :     val rtn9=(FArg,einApp9)
530 :     val (fieldset,FArg,rtn1)= (case (einVarSet.rtnVarN(fieldset,rtn9))
531 :     of (fieldset,SOME v) => let
532 :     val _ = (" \n did find"^toStringBind(rtn9))
533 :     in (fieldset, v,[]) end
534 :     | (fieldset,NONE) => let
535 :     (*lifted probe*)
536 :     val _ =(" \n did not find"^toStringBind(rtn9))
537 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
538 :     val freshIndex'= length(sizes)
539 :     val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
540 :     val ein1=mkEin(params',sizes,body')
541 :     val einApp1=mkEinApp(ein1,args')
542 :     val rtn1=(FArg,einApp1)
543 :     in (fieldset,FArg ,[rtn1]) end
544 :     (*end case*))
545 : cchiw 3441
546 : cchiw 3048 val einApp0=mkEinApp(ein0,[PArg,FArg])
547 : cchiw 3195 val rtn0=(case splitvar
548 : cchiw 3324 of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
549 :     | _ => let
550 : cchiw 3686 val e2=SummationEin.main ein0
551 :     val _ = print("\n shifted:"^P.printerE(ein0)^"=>"^P.printerE(e2))
552 :     val bind3 = (y,DstIL.EINAPP(e2,[PArg,FArg]))
553 : cchiw 3557 in Split.splitEinApp bind3
554 :     end
555 :     (*end case*))
556 :    
557 :     val rtn=code@rtn1@rtn0
558 :     val _= List.map toStringBindp (code@rtn1)
559 :     val _ ="\n**** split code **\n"
560 :     val _= List.map toStringBindp rtn0
561 : cchiw 2976 in
562 : cchiw 3557 (fieldset,rtn)
563 : cchiw 2976 end
564 : cchiw 3540
565 : cchiw 3557 fun liftProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
566 : cchiw 3531 val _=testp["\n******* Lift Geneirc Probe ***\n"]
567 :     val originalb=Ein.body e
568 :     val params=Ein.params e
569 :     val index=Ein.index e
570 :     val _ = (toStringBind (y, DstIL.EINAPP(e,args)))
571 :    
572 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
573 :     val fid=length(params)
574 :     val nid=fid+1
575 :     val nshift=length(dx)
576 :     val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
577 :     val freshIndex=getsumshift(sx,length(index))
578 :    
579 :     (*transform T*P*P..Ps*)
580 :     val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
581 :    
582 :     val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
583 :     val einApp0=mkEinApp(ein0,[PArg,FArg])
584 :     val rtn0=(case splitvar
585 : cchiw 3540 of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
586 :     | _ => let
587 :     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
588 :     in Split.splitEinApp bind3
589 :     end
590 :     (*end case*))
591 : cchiw 3531
592 :     (*lifted probe*)
593 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
594 :     val freshIndex'= length(sizes)
595 :    
596 :     (*val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)*)
597 : cchiw 3540
598 :     val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid,params',args')
599 : cchiw 3531
600 : cchiw 3540 (*set image and kernel*)
601 :     val (params',args',code)=setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,params',Vid,s)
602 : cchiw 3531
603 :     val ein1=mkEin(params',sizes,body')
604 :     val einApp1=mkEinApp(ein1,args')
605 :     val rtn1=(FArg,einApp1)
606 :     val rtn=code@[rtn1]@rtn0
607 :     val _= List.map toStringBind ([rtn1]@rtn0)
608 :     val _=(String.concat["\n* end Lift Geneirc Probe ******** \n"])
609 : cchiw 3540 val _= List.map toStringBindp(rtn)
610 : cchiw 3531 in
611 : cchiw 3557 (fieldset,rtn)
612 : cchiw 3531 end
613 : cchiw 3540
614 :     fun replaceProbe e= (case pullKrn
615 :     of true=>replaceProbe3 e
616 :     | false => replaceProbe0 e
617 :     (*end case*))
618 :     fun liftProbe e=(case pullKrn
619 :     of true=>liftProbe3 e
620 :     | false => liftProbe0 e
621 :     (*end case*))
622 :    
623 : cchiw 3531
624 : cchiw 3472 (* ******************************************* Reconstruction -> Lift|Replace probe ******************************************* *)
625 :     (* scans dx for contant
626 :     * arg:(1,code1, body1,[])
627 :     *)
628 :     fun reconstruction([],arg)= replaceProbe arg
629 :     | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
630 :     of (true,true) => liftProbe arg
631 :     | (_,false) => replaceProbe arg
632 :     | _ => let
633 :     fun fConst [] = liftProbe arg
634 :     | fConst (E.C _::_) = replaceProbe arg
635 :     | fConst (_ ::es)= fConst es
636 :     in fConst dx end
637 :     (* end case*))
638 : cchiw 3383
639 : cchiw 3472 (* **************************************************** Index Tensor **************************************************** *)
640 :     (*Push constant indices to tensor replacement*)
641 :     fun getF (e,fieldset,dim,newvx)= let
642 : cchiw 3324 val (y, DstIL.EINAPP(ein,args))=e
643 :     val index0=Ein.index ein
644 : cchiw 3472 val index1 = index0@dim
645 :     val b=Ein.body ein
646 : cchiw 3324
647 : cchiw 3472 val (c1,dx,body1)=(case b
648 :     of E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
649 :     val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
650 :     val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
651 :     in (c1,dx,b) end
652 :     | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
653 :     val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
654 :     (* clean to get body indices in order *)
655 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
656 :     in (c1,dx,body1) end
657 :     | E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
658 :     val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
659 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
660 :     in (c1,dx,body1) end
661 :     (*end case*))
662 :    
663 : cchiw 3324 val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
664 :     val ein1 = mkEin(Ein.params ein,index1,body1)
665 :     val code1= (lhs1,mkEinApp(ein1,args))
666 : cchiw 3472
667 : cchiw 3557 val (lhs0,(fieldset,codeAll))= (case valnumflag
668 :     of false => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
669 : cchiw 3503 | true => (case (einVarSet.rtnVarN(fieldset,code1))
670 : cchiw 3557 of (fieldset,NONE) => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
671 :     | (fieldset,SOME m) => (m,(fieldset,[]))
672 : cchiw 3472 (*end case*))
673 : cchiw 3383 (*end case*))
674 : cchiw 3472
675 : cchiw 3324 (*Probe that tensor at a constant position c1*)
676 :     val param0 = [E.TEN(1,index1)]
677 : cchiw 3472 val nx=List.tabulate(newvx,fn n=>E.V n)
678 :     val body0 = (case b
679 :     of E.Sum([(vsum,0,n)],_)=> E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
680 :     | _ => E.Tensor(0,[c1]@nx)
681 :     (*end case*))
682 : cchiw 3324 val ein0 = mkEin(param0,index0,body0)
683 : cchiw 3472 val einApp0 = mkEinApp(ein0,[lhs0])
684 : cchiw 3324 val code0 = (y,einApp0)
685 : cchiw 3472 val _= toStringBind code0
686 : cchiw 3324 in
687 : cchiw 3557 (fieldset,codeAll@[code0])
688 : cchiw 3472 end
689 :     (* **************************************************** General Fn **************************************************** *)
690 : cchiw 2845 (* expandEinOp: code-> code list
691 : cchiw 3259 * A this point we only have simple ein ops
692 :     * Looks to see if the expression has a probe. If so, replaces it.
693 : cchiw 2845 * Note how we keeps eps expressions so only generate pieces that are used
694 :     *)
695 : cchiw 3540 fun expandEinOp(e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
696 : cchiw 3557 fun rewriteBody(fieldset,e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
697 : cchiw 3472 of (true,[E.C(_,true), E.V 0],[]) => getF(e,fieldset,[3],1)
698 :     | (true,[E.C(_,true), E.V 0],[E.V 1]) => getF(e,fieldset,[3],2)
699 :     | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2]) => getF(e,fieldset,[3],3)
700 :     | (true,[E.C(_,true)],[]) => getF(e,fieldset,[3],0)
701 :     | (true,[E.C(_,true)],[E.V 0]) => getF(e,fieldset,[3],1)
702 :     | (true,[E.C(_,true)],[E.V 0,E.V 1]) => getF(e,fieldset,[3],2)
703 :     | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2]) => getF(e,fieldset,[3],3)
704 : cchiw 3557 | _ => reconstruction(dx,(fieldset,e,p,[]))
705 : cchiw 3472 (*end case*))
706 : cchiw 3557 | rewriteBody(fieldset,e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
707 : cchiw 3472 of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[]) => getF(e,fieldset,[3,3],0)
708 :     | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0]) => getF(e,fieldset,[3,3],1)
709 :     | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1]) => getF(e,fieldset,[3,3],2)
710 : cchiw 3557 | (_,_,_,[]) => replaceProbe(fieldset,e,p, sx) (*no dx*)
711 :     | (_,_,[],_) => reconstruction(dx,(fieldset,e,p,sx))
712 :     | _ => replaceProbe(fieldset,e,p, sx)
713 : cchiw 2845 (* end case *))
714 : cchiw 3557 | rewriteBody(fieldset,e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p]))) = replaceProbe(fieldset,e,E.Probe p,sx)
715 :     | rewriteBody (fieldset,e,_) = (fieldset,[e])
716 : cchiw 3174
717 : cchiw 3472 val b=Ein.body ein
718 : cchiw 3557 fun pf()=("\n **************************** starting **************************** \n"^(P.printerE(ein)))
719 : cchiw 3472 fun matchField()=(case b
720 : cchiw 3557 of E.Probe _ => (pf();1)
721 : cchiw 3540 | E.Sum (_, E.Probe _)=> (pf();1)
722 :     | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=> (pf();1)
723 : cchiw 3271 | _ =>0
724 :     (*end case*))
725 : cchiw 3557 val m=matchField()
726 : cchiw 3503 val (fieldset,varset,code,flag) = (case valnumflag
727 :     of true => (case (einVarSet.rtnVarN(fieldset,e0))
728 : cchiw 3557 of (fieldset,NONE) => let
729 :     val(fieldset,code)=rewriteBody(fieldset,e0,b)
730 :     in (fieldset,varset,code,0) end
731 :     | (fieldset,SOME v) => (fieldset,varset,[(y,DstIL.VAR v)],1)
732 : cchiw 3472 (*end case*))
733 : cchiw 3557 | _ => let
734 :     val(fieldset,code)=rewriteBody(fieldset,e0, b)
735 :     in (fieldset,varset,code,0) end
736 : cchiw 3327 (*end case*))
737 : cchiw 3557
738 : cchiw 3503 in (code,fieldset,varset,m,flag) end
739 : cchiw 2843
740 : cchiw 2606 end; (* local *)
741 :    
742 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0