Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2870, Wed Feb 25 21:47:43 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3540, Mon Jan 4 18:03:22 2016 UTC
# Line 1  Line 1 
1  (* Expands probe ein  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4     *
5     * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
# Line 14  Line 16 
16      structure P=Printer      structure P=Printer
17      structure T=TransformEin      structure T=TransformEin
18      structure MidToS=MidToString      structure MidToS=MidToString
19        structure DstV = DstIL.Var
20        structure DstTy = MidILTypes
21    
22      in      in
23    
24      (* This file expands probed fields      (* This file expands probed fields
# Line 35  Line 40 
40      *)      *)
41    
42      val testing=0      val testing=0
43      val cnt = ref 0      val valnumflag=true
44        val tsplitvar=true
45        val fieldliftflag=true
46        val constflag=true
47        val detflag =true
48        val detsumflag=true
49        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50        fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51    
52        val liftimgflag =false
53        val pullKrn= false
54    
55        val cnt = ref 0
56      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
57      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
58        fun toStringBind e=(MidToString.toStringBind e)
59        fun toStringBindp e=(MidToString.toStringBind e)
60        fun mkEin e=Ein.mkEin e
61        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
62        fun setConst e = E.setConst e
63        fun setNeg e  =  E.setNeg e
64        fun setExp e  =  E.setExp e
65        fun setDiv e= E.setDiv e
66        fun setSub e= E.setSub e
67        fun setProd e= E.setProd e
68        fun setAdd e= E.setAdd e
69        fun mkCx es =List.map (fn c => E.C (c,true)) es
70        fun mkCxSingle c = E.C (c,true)
71    
72      fun testp n=(case testing      fun testp n=(case testing
73          of 0=> 1          of 0=> 1
74          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
75          (*end case*))          (*end case*))
76    
77    
78      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
79          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
80          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 56  Line 88 
88      returns the support of ther kernel, and image      returns the support of ther kernel, and image
89      *)      *)
90      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
91          of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> ((Kernel.support h) ,img,ImageInfo.dim img)
92              in          |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
                 ((Kernel.support h) ,img,ImageInfo.dim img)  
             end  
         |  _ => raise Fail "Expected Image and kernel arguments"  
93          (*end case*))          (*end case*))
94    
95    
# Line 81  Line 110 
110              (dim,args@argsT,code, s,P)              (dim,args@argsT,code, s,P)
111          end          end
112    
113    
114    
115    
116      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
117      * expands the body for the probed field      * expands the body for the probed field
118      *)      *)
# Line 88  Line 120 
120          (*1-d fields*)          (*1-d fields*)
121          fun createKRND1 ()=let          fun createKRND1 ()=let
122              val sum=sx              val sum=sx
123              val dels=List.map (fn e=>(E.C 0,e)) deltas              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
124              val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
125              val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
126              in              in
127                  E.Prod [E.Img(Vid,alpha,pos),rest]                 setProd[E.Img(Vid,alpha,pos),rest]
128              end              end
129    
130            fun mkImg(imgpos)=E.Img(Vid,alpha,imgpos)
131    
132          (*createKRN Image field and kernels *)          (*createKRN Image field and kernels *)
133          fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=setProd ([mkImg(imgpos)] @rest)
134          | createKRN(dim,imgpos,rest)=let          | createKRN(dim,imgpos,rest)=let
135              val dim'=dim-1              val dim'=dim-1
136              val sum=sx+dim'              val sum=sx+dim'
137              val dels=List.map (fn e=>(E.C dim',e)) deltas              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
138              val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
139              val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
140              in              in
141                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
142              end              end
# Line 111  Line 146 
146              (*end case*))              (*end case*))
147          (*sumIndex creating summaiton Index for body*)          (*sumIndex creating summaiton Index for body*)
148          val slb=1-s          val slb=1-s
149            val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
150          val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))          val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
151      in      in
152          E.Sum(esum, exp)          E.Sum(esum, exp)
153      end      end
154    
155    
156    
157        (* build position *)
158        fun buildPos (dir,dim,argsA,hid,nid,s) =let
159            val vA  = DstV.new ("kernel_pos", DstTy.TensorTy([]))
160            val p=[E.KRN,E.TEN(1,[dim])]
161            val pos=setSub(E.Tensor(1,[mkCxSingle dir]),E.Value(0))
162            val exp= E.BuildPos(s,pos)
163            (*val exp = E.Sum([(E.V 0, slb, s)],E.Krn(0,[],pos))*)
164            val a=[List.nth(argsA,hid),List.nth(argsA,nid)]
165            val A=(vA,mkEinApp(mkEin(p,[],exp),a))
166            in (vA,A) end
167    
168        (* apply differentiation *)
169        fun getKrn1Del(dx,dim,args,slb,s)= let
170            val n=Int.toString(dx)
171            val vA  = DstV.new ("kernel_del"^n, DstTy.TensorTy([]))
172            val p=[E.KRN,E.TEN(1,[dim])]
173            val exp = E.EvalKrn dx
174            val A = (vA,mkEinApp(mkEin(p,[],exp),args))
175            in (vA,A) end
176    
177        (*create holder expression*)
178        fun mkHolder(dim,args) =let
179            val n=List.length(args)
180            val vA  = DstV.new ("kernel_cons", DstTy.TensorTy([n]))
181            val p=[E.KRN,E.TEN(1,[dim])]
182            val A= (vA,mkEinApp(mkEin(p,[],E.Holder n),args))
183            in (vA,A) end
184    
185        (*lifted Kernel expressions*)
186        fun liftKrn(dx,dir,dim,argsA,hid,nid,slb,s)=let
187            val  (vA,A)=buildPos(dir,dim,argsA,hid,nid,s)
188            val args=[List.nth(argsA,hid),vA]
189            fun iter(0,vBs,Bs)=let
190                val (vA,A)=getKrn1Del(0,dim,args,slb,s)
191                in (vA::vBs,A::Bs) end
192              | iter (n,vBs,Bs)= let
193                val (vA,A)=getKrn1Del(n,dim,args,slb,s)
194                in iter(n-1,vA::vBs,A::Bs) end
195            val (vBs,Bs)=iter(length(dx),[],[])
196            val (vC,C)  =mkHolder(dim,vBs)
197            in (vC,(A::Bs)@[C]) end
198    
199    
200    
201        fun createBody2(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=let
202            (*1-d fields*)
203            val slb=1-s
204    
205    
206            (*making image*)
207            val tid=(case liftimgflag
208                of true => length(params)-1
209                | _     => length(params)-1
210            (*end case*))
211            fun mkImg imgpos =(case liftimgflag
212                of true=>(E.Tensor(Vid,alpha),SOME(E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim),slb,s))),E.Img(Vid,alpha,imgpos))))
213                |  _ =>let
214                    val imgpos= List.tabulate(dim,fn e=> setAdd[E.Tensor(fid,[mkCxSingle e]),E.Value(e+sx)])
215                    in (E.Img(Vid,alpha,imgpos),NONE) end
216            (*end case*))
217    
218            fun createKRND1 ()=let
219                val sum=sx
220                val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
221                val imgpos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
222                val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
223                val (talpha,iexp)= mkImg imgpos
224                in (setProd[talpha,rest],iexp,NONE,NONE)end
225    
226            (*createKRN Image field and kernels *)
227            fun createKRN(0,orig,imgpos,vAs,krnpos)= let
228                val (talpha,iexp)= mkImg imgpos
229                in  (setProd ([talpha]@orig),iexp,SOME vAs,SOME krnpos) end
230            | createKRN(d,orig,imgpos,vAs,krnpos)=let
231                val dim'=d-1
232                val sum=sx+dim'
233                val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
234                val ipos=setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(dim')]
235                val opos= E.Krn(hid,dels,E.Tensor(tid+d,[]))
236                val (vA,A)= liftKrn(dels,dim',dim,argsA,hid,nid,slb,s)
237                in
238                    createKRN(dim',[opos]@orig,[ipos]@imgpos,[vA]@vAs,A@krnpos)
239                end
240    
241            val (oexp,iexp,vAs,keinapp)=(case dim
242                of 1 => createKRND1()
243                | _=> createKRN(dim, [],[],[],[])
244            (*end case*))
245    
246            val oexp=E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s))), oexp)
247            in (oexp,iexp,vAs,keinapp) end
248    
249        fun createBody3(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)=
250                createBody2(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)
251        |   createBody3(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=
252                (createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid),NONE,NONE,NONE)
253    
254      (*getsumshift:sum_indexid list* int list-> int      (*getsumshift:sum_indexid list* int list-> int
255      *get fresh/unused index_id, returns int      *get fresh/unused index_id, returns int
256      *)      *)
257      fun getsumshift(sx,index) =let      fun getsumshift(sx,n) =let
258          val nsumshift= (case sx          val nsumshift= (case sx
259              of []=> length(index)              of []=> n
260              | _=>let              | _=>let
261                  val (E.V v,_,_)=List.hd(List.rev sx)                  val (E.V v,_,_)=List.hd(List.rev sx)
262                  in v+1                  in v+1
263                  end                  end
264              (* end case *))              (* end case *))
265    
266          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
267          val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),          val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
268              "\nThink nshift is ", Int.toString nsumshift]          "\n\t Index length:",Int.toString n,
269            "\n\t Freshindex: ", Int.toString nsumshift])
270          in          in
271              nsumshift              nsumshift
272          end          end
# Line 139  Line 276 
276      *)      *)
277      fun formBody(E.Sum([],e))=formBody e      fun formBody(E.Sum([],e))=formBody e
278      | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)      | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
279      | formBody(E.Prod [e])=e      | formBody(E.Opn(E.Prod, [e]))=e
280      | formBody e=e      | formBody e=e
281    
282      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
283              -> ein_exp* *code      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
284        (*
285          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
286          *)
287          | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
288          | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
289    
290    
291        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
292          | multiMergePs e=multiPs e
293    
294        (* *******************************************  setImage *******************************************  *)
295        fun replaceImgA(es,vid,newbie)=List.take(es,vid)@[newbie]@List.drop(es,vid+1)
296        fun setImage(params',argsA,code,vexp2,index,alpha,paraminstant,Vid,s)=
297            (case vexp2
298                of NONE =>(params',argsA,code)
299                | SOME vexp  =>    let
300                    val iArg  = DstV.new ("Img", DstTy.TensorTy([]))
301                    val alphax=List.map (fn (E.V i)=>List.nth(index,i)) alpha
302                    val ieinapp=(iArg,mkEinApp(mkEin(paraminstant,alphax,vexp),argsA))
303                    (*
304                     val _ =print(String.concat["\n****\n Image (",Int.toString(length(argsA)),")"])
305                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
306                    val _ =print(String.concat["\n replace at ",Int.toString Vid ," with " , DstIL.Var.toString iArg ,"\n"])*)
307                    val argsA=replaceImgA(argsA,Vid,iArg)
308                    val params'=replaceImgA(params',Vid,E.TEN(2,[(s-(1-s)+1)*(s-(1-s)+1),(s-(1-s)+1)]))
309                    val code=code@[ieinapp]
310                    (*
311                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
312                    val _ =print(String.concat["\n****\n Image(",Int.toString(length(argsA)),")"])*)
313                    in (params',argsA,code) end
314            (*end case*))
315    
316          (*kernels*)
317          fun setKernel(params',args',code,vAs2,keinapp2,dim)=
318            (case (vAs2,keinapp2)
319                of (NONE,NONE)=> (params',args',code)
320                | (SOME vAs,SOME keinapp) => let
321                (*
322                    val _ =print"\n****\n Kernels\n"
323                    val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])
324                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))*)
325                    val args'=   args'@vAs
326                    val params'= params'@(List.tabulate(dim,fn _=> E.TEN(2,[])))
327                    val code=code@keinapp
328                    (*
329                    val _ =print"\n"
330                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))
331                    val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])*)
332                    in (params',args',code) end
333            (*end case*))
334    
335          fun setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)=let
336            val (params',args',code)=setImage(params',args',code,vexp2,index,alpha,paraminstant,Vid,s)
337            in setKernel(params',args',code,vAs2,keinapp2,dim) end
338    
339    
340        (* *******************************************  Replace probe *******************************************  *)
341        (* replaceProbe
342      * Transforms position to world space      * Transforms position to world space
343      * transforms result back to index_space      * transforms result back to index_space
344      * rewrites body      * rewrites body
345      * replace probe with expanded version      * replace probe with expanded version
346      *)      *)
347      fun replaceProbe(b,params,args,index, sx)=let       fun replaceProbe0((y, DstIL.EINAPP(e,args)),p ,sx)
348          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b          =let
349            val originalb=Ein.body e
350            val params=Ein.params e
351            val index=Ein.index e
352            val _ = testp["\n***************** \n Replace ************ \n"]
353            val _=  toStringBind (y, DstIL.EINAPP(e,args))
354    
355            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
356          val fid=length(params)          val fid=length(params)
357          val nid=fid+1          val nid=fid+1
358          val Pid=nid+1          val Pid=nid+1
359          val nshift=length(dx)          val nshift=length(dx)
360          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
361          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,length(index))
362          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
363          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
364          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
365            val body' = multiPs(Ps,newsx1,body')
366    
367          (*silly change in order of product to match vis branch WorldtoSpace functions*)          val body'=(case originalb
368          val body' =(case Ps              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
369              of [_,_,_]=>        formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
370              | _ =>  formBody(E.Sum(newsx1, E.Prod([body']@Ps)))              | _                                  => body'
371              (*end case*))              (*end case*))
372    
373          val args'=argsA@[PArg]          val args'=argsA@[PArg]
374            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
375            val _= List.map toStringBind(List.rev (einapp::code))
376          in          in
377              (body',params',args' ,code)              code@[einapp]
378          end          end
379    
     (* expandEinOp: code->  code list  
     *Looks to see if the expression has a probe. If so, replaces it.  
     * Note how we keeps eps expressions so only generate pieces that are used  
     *)  
     fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let  
         fun rewriteBody b=(case b  
             of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"  
             | E.Probe e =>let  
380    
                 val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])  
                 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))  
                 val code=newbies@[einapp]  
381    
382        fun replaceProbe3((y, DstIL.EINAPP(e,args)),p ,sx) = let
383            val originalb=Ein.body e
384            val params=Ein.params e
385            val index=Ein.index e
386            val _ = testp["\n***************** \n Replace ************ \n"]
387            val _=  toStringBind (y, DstIL.EINAPP(e,args))
388    
389            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
390            val fid=length(params)
391            val nid=fid+1
392            val Pid=nid+1
393            val nshift=length(dx)
394            val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
395            val freshIndex=getsumshift(sx,length(index))
396            val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
397    
398            val paraminstant=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
399            val params'=paraminstant@[E.TEN(1,[dim,dim])]
400    
401    
402            val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid,paraminstant,argsA)
403            val body' = multiPs(Ps,newsx1,body')
404            val body'=(case originalb
405                of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
406                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
407                | _                                  => body'
408                (*end case*))
409    
410            (*images and kernels*)
411            val (params',argsA,code)=setImageKernel(params',argsA,code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)
412    
413    
414            (*replace term*)
415            val args'=argsA@[PArg]
416            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
417            val _= List.map toStringBindp(code@[einapp])
418                  in                  in
419                      code              code@[einapp]
420                  end                  end
             | E.Sum(sx,E.Probe e)  =>let  
421    
                 val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)  
                 val  body'=E.Sum(sx,body')  
                 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))  
                 val code=newbies@[einapp]  
422    
423        (* ******************************************* Lift probe *******************************************  *)
424        fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
425            val Pid=0
426            val tid=1
427    
428            (*Assumes body is already clean*)
429            val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
430    
431            (*need to rewrite dx*)
432            val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
433                of []=> ([],index,E.Conv(9,alpha,7,newdx))
434                | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
435                (*end case*))
436    
437            val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
438            fun filterAlpha []=[]
439              | filterAlpha(E.C _::es)= filterAlpha es
440              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
441    
442            val tshape=filterAlpha(alpha')@newdx
443            val t=E.Tensor(tid,tshape)
444    
445            val (splitvar,body)=(case originalb
446                of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
447                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
448                | _                                  => (case tsplitvar
449                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
450                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
451                    (*end case*))
452            (*end case*))
453    
454            val _ =(case splitvar
455                of true=> (String.concat["splitvar is true", P.printbody body])
456                | _ => (String.concat["splitvar is false",P.printbody body])
457            (*end case*))
458    
459    
460            val ein0=mkEin(params,index,body)
461                  in                  in
462                      code              (splitvar,ein0,sizes,dx,alpha')
463                  end                  end
             | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
464    
465                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe0((y, DstIL.EINAPP(e,args)),p ,sx)=let
466                  val  body'=E.Sum(sx,E.Prod[eps,body'])          val _=testp["\n******* Lift Geneirc Probe ***\n"]
467                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val originalb=Ein.body e
468                  val code=newbies@[einapp]          val params=Ein.params e
469            val index=Ein.index e
470            val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
471    
472            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
473            val fid=length(params)
474            val nid=fid+1
475            val nshift=length(dx)
476            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
477            val freshIndex=getsumshift(sx,length(index))
478    
479            (*transform T*P*P..Ps*)
480            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
481    
482            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
483            val einApp0=mkEinApp(ein0,[PArg,FArg])
484            val rtn0=(case splitvar
485                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
486                | _      => let
487                     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
488                     in Split.splitEinApp bind3
489                     end
490                (*end case*))
491    
492            (*lifted probe*)
493            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
494            val freshIndex'= length(sizes)
495    
496            val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
497            val ein1=mkEin(params',sizes,body')
498            val einApp1=mkEinApp(ein1,args')
499            val rtn1=(FArg,einApp1)
500            val rtn=code@[rtn1]@rtn0
501            val _= List.map toStringBind ([rtn1]@rtn0)
502             val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
503              val _= List.map toStringBind rtn
504            in
505                rtn
506            end
507    
508        fun liftProbe3((y, DstIL.EINAPP(e,args)),p ,sx)=let
509            val _=testp["\n******* Lift Geneirc Probe ***\n"]
510            val originalb=Ein.body e
511            val params=Ein.params e
512            val index=Ein.index e
513            val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
514    
515            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
516            val fid=length(params)
517            val nid=fid+1
518            val nshift=length(dx)
519            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
520            val freshIndex=getsumshift(sx,length(index))
521    
522            (*transform T*P*P..Ps*)
523            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
524    
525            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
526            val einApp0=mkEinApp(ein0,[PArg,FArg])
527            val rtn0=(case splitvar
528                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
529                | _      => let
530                    val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
531                    in Split.splitEinApp bind3
532                    end
533                    (*end case*))
534    
535            (*lifted probe*)
536            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
537            val freshIndex'= length(sizes)
538    
539            (*val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)*)
540    
541            val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid,params',args')
542    
543            (*set image and kernel*)
544            val (params',args',code)=setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,params',Vid,s)
545    
546            val ein1=mkEin(params',sizes,body')
547            val einApp1=mkEinApp(ein1,args')
548            val rtn1=(FArg,einApp1)
549            val rtn=code@[rtn1]@rtn0
550            val _= List.map toStringBind ([rtn1]@rtn0)
551            val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
552              val _= List.map toStringBindp(rtn)
553                  in                  in
554                      code              rtn
555                  end                  end
556              | _=> [e]  
557        fun replaceProbe e= (case pullKrn
558            of true=>replaceProbe3 e
559            | false => replaceProbe0 e
560            (*end case*))
561        fun liftProbe e=(case pullKrn
562            of true=>liftProbe3 e
563            | false => liftProbe0 e
564            (*end case*))
565    
566    
567        (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
568        (* scans dx for contant
569         * arg:(1,code1, body1,[])
570         *)
571        fun reconstruction([],arg)= replaceProbe arg
572         | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
573            of (true,true) => liftProbe arg
574            | (_,false)    => replaceProbe arg
575            | _ => let
576                fun fConst [] = liftProbe arg
577                | fConst (E.C _::_) = replaceProbe arg
578                | fConst (_ ::es)= fConst es
579                in fConst dx end
580            (* end case*))
581    
582        (* **************************************************** Index Tensor **************************************************** *)
583        (*Push constant indices to tensor replacement*)
584        fun getF (e,fieldset,dim,newvx)= let
585            val (y, DstIL.EINAPP(ein,args))=e
586            val index0=Ein.index ein
587            val index1 = index0@dim
588            val b=Ein.body ein
589    
590            val (c1,dx,body1)=(case b
591                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
592                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
593                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
594                    in (c1,dx,b) end
595                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
596                    val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
597                    (* clean to get body indices in order *)
598                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
599                    in (c1,dx,body1) end
600                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
601                   val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
602                   val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
603                   in (c1,dx,body1) end
604                (*end case*))
605    
606            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
607            val ein1 = mkEin(Ein.params ein,index1,body1)
608            val code1= (lhs1,mkEinApp(ein1,args))
609    
610            val (_,(lhs0,codeAll))= (case valnumflag
611                of false    => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
612                | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
613                    of (fieldset,NONE)     => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
614                    | (fieldset,SOME m)   =>  (fieldset,(m,[]))
615                    (*end case*))
616              (* end case *))              (* end case *))
617    
618            (*Probe that tensor at a constant position  c1*)
619            val param0 = [E.TEN(1,index1)]
620            val nx=List.tabulate(newvx,fn n=>E.V n)
621            val body0 =  (case b
622                of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
623                | _ => E.Tensor(0,[c1]@nx)
624                (*end case*))
625            val ein0 = mkEin(param0,index0,body0)
626            val einApp0 = mkEinApp(ein0,[lhs0])
627            val code0 = (y,einApp0)
628            val _= toStringBind code0
629          in          in
630              rewriteBody body              codeAll@[code0]
631          end          end
632        (* **************************************************** General Fn **************************************************** *)
633        (* expandEinOp: code->  code list
634        * A this point we only have simple ein ops
635        * Looks to see if the expression has a probe. If so, replaces it.
636        * Note how we keeps eps expressions so only generate pieces that are used
637        *)
638        fun expandEinOp(e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
639            fun rewriteBody(e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
640                of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
641                | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
642                | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
643                | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
644                | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
645                | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
646                | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
647                | _                                          => reconstruction(dx,(e,p,[]))
648                (*end case*))
649            | rewriteBody(e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
650                of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
651                | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
652                | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
653                | (_,_,_,[])                                => replaceProbe(e,p, sx)  (*no dx*)
654                | (_,_,[],_)                                => reconstruction(dx,(e,p,sx))
655                | _                                         => replaceProbe(e,p, sx)
656                (* end case *))
657            | rewriteBody(e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(e,E.Probe p,sx)
658            | rewriteBody (e,_)  = [e]
659    
660            val b=Ein.body ein
661            fun pf()=(P.printbody b)
662            fun matchField()=(case b
663    
664            of E.Probe _ => (pf();1)
665                | E.Sum (_, E.Probe _)=> (pf();1)
666                | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=> (pf();1)
667                | _ =>0
668                (*end case*))
669            val (fieldset,varset,code,flag) = (case valnumflag
670                of true => (case (einVarSet.rtnVarN(fieldset,e0))
671                   of  (fldset,NONE)     => (fldset,varset,rewriteBody(e0,b),0)
672                  | (fldset,SOME v)    => (fldset,varset,[(y,DstIL.VAR v)],1)
673                     (*of (fldset, NONE)      => (fldset,varset,rewriteBody((y,DstIL.EINAPP(ein,List.map (fn a=>einVarSet.replaceArg(varset,a)) args)), b),0)
674                     | (fldset,SOME v)    => (fldset,einVarSet.VarSet.add(varset,einVarSet.VAR(v,y)),[],1)*)
675                    (*end case*))
676                | _     => (fieldset,varset,rewriteBody(e0, b),0)
677                (*end case*))
678            val m=matchField()
679            in  (code,fieldset,varset,m,flag) end
680    
681    end; (* local *)    end; (* local *)
682    

Legend:
Removed from v.2870  
changed lines
  Added in v.3540

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0