Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2613, Wed May 7 04:35:38 2014 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3503, Thu Dec 17 23:13:57 2015 UTC
# Line 1  Line 1 
1  (* Currently under construction  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4     *
5     * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
   
 (*  
 A couple of different approaches.  
 One approach is to find all the Probe(Conv). Gerenerate exp for it  
 Then use Subst function to sub in. That takes care for index matching and  
   
 *)  
   
 (*This approach creates probe expanded terms, and adds params to the end. *)  
   
   
9  structure ProbeEin = struct  structure ProbeEin = struct
10    
11      local      local
12    
13      structure E = Ein      structure E = Ein
     structure mk= mkOperators  
     structure SrcIL = HighIL  
     structure SrcTy = HighILTypes  
     structure SrcOp = HighOps  
     structure SrcSV = SrcIL.StateVar  
     structure VTbl = SrcIL.Var.Tbl  
14      structure DstIL = MidIL      structure DstIL = MidIL
     structure DstTy = MidILTypes  
15      structure DstOp = MidOps      structure DstOp = MidOps
     structure DstV = DstIL.Var  
     structure SrcV = SrcIL.Var  
16      structure P=Printer      structure P=Printer
     structure shift=ShiftEin  
     structure split=SplitEin  
     structure F=Filter  
17      structure T=TransformEin      structure T=TransformEin
18        structure MidToS = MidToString
19        structure DstV = DstIL.Var
20        structure DstTy = MidILTypes
21    
 val testing=0  
   
 datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  
 datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int  
22      in      in
23    
24        (* This file expands probed fields
25        * Take a look at ProbeEin tex file for examples
26        *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27        * Param_ids are used to note the placement of the argument in the midIL.var list
28        * Index_ids  keep track of the shape of an Image or differentiation.
29        * Mu  bind Index_id
30        * Generally, we will refer to the following
31        *dim:dimension of field V
32        * s: support of kernel H
33        * alpha: The alpha in <V_alpha * H^(deltas)>
34        * deltas: The deltas in <V_alpha * H^(deltas)>
35        * Vid:param_id for V
36        * hid:param_id for H
37        * nid: integer position param_id
38        * fid :fractional position param_id
39        * img-imginfo about V
40        *)
41    
42  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))      val testing=0
43  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))      val valnumflag=true
44        val tsplitvar=true
45        val fieldliftflag=true
46        val constflag=true
47        val detflag =true
48        val detsumflag=true
49        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50        fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51    
52        val cnt = ref 0
53        fun transformToIndexSpace e=T.transformToIndexSpace e
54        fun transformToImgSpace  e=T.transformToImgSpace  e
55        fun toStringBind e=(MidToString.toStringBind e)
56        fun mkEin e=Ein.mkEin e
57        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
58        fun setConst e = E.setConst e
59        fun setNeg e  =  E.setNeg e
60        fun setExp e  =  E.setExp e
61        fun setDiv e= E.setDiv e
62        fun setSub e= E.setSub e
63        fun setProd e= E.setProd e
64        fun setAdd e= E.setAdd e
65        fun mkCx es =List.map (fn c => E.C (c,true)) es
66        fun mkCxSingle c = E.C (c,true)
67    
68  fun getRHS x  = (case SrcIL.Var.binding x      fun testp n=(case testing
69      of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)          of 0=> 1
70      | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'          | _ =>(print(String.concat n);1)
     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)  
     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)  
     | SrcIL.VB_NONE=>(S2 2,[])  
     | vb => raise Fail(concat[  
     "expected rhs operator for ", SrcIL.Var.toString x,  
     "but found ", SrcIL.vbToString vb])  
71      (* end case *))      (* end case *))
72    
73    
74        fun getRHSDst x  = (case DstIL.Var.binding x
75            of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
76            | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
77            | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
78        (* end case *))
79    
 (*Create fractional, and integer position vectors*)  
 fun transformToImgSpace  (dim,v,posx)=let  
   
     val translate=DstOp.Translate v  
     val transform=DstOp.Transform v  
     val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)  
   
     val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)  
     val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)  
     val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)  
     val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)  
     val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)  
     val PosToImgSpace=mk.transform(dim,dim)  
     val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)  
80    
81      val code=[      (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
82          assign(M, transform, []),          uses the Param_ids for the image, kernel,
83          assign(T, translate, []),          and position tensor to get the Mid-IL arguments
84          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)      returns the support of ther kernel, and image
85          assign(nd, DstOp.Floor dim, [x]),   (*nd *)      *)
86          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)      fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
87          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
88          assignEin(P, mk.transpose([dim,dim]), [M])              in
89          ]                ((Kernel.support h) ,img,ImageInfo.dim img)
     in ([n,f],P,code)  
90      end      end
91                |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
   
 fun replaceH(kvar, place,args)=let  
     val l1=List.take(args, place)  
     val l2=List.drop(args,place+1)  
     in l1@[kvar]@l2 end  
   
   
 (*Get Img, and Kern Args*)  
 fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)  
     of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let  
         val hvar=DstV.new ("KNL", DstTy.KernelTy)  
         val imgvar=DstV.new ("IMG", DstTy.ImageTy img)  
         val argsVK= (case lift  
             of 0=> let  
                 val argsN=replaceH(hvar, hid,args)  
                 in replaceH(imgvar, V,argsN) end  
             | _ => [imgvar, hvar]  
92          (* end case *))          (* end case *))
         val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]  
93    
94    
95        (*handleArgs():int*int*int*Mid IL.Var list
96            ->int*Mid.ILVars list* code*int* low-il-var
97            * uses the Param_ids for the image, kernel, and tensor
98            * and gets the mid-IL vars for each.
99            *Transforms the position to index space
100            *P is the mid-il var for the (transformation matrix)transpose
101        *)
102        fun handleArgs(Vid,hid,tid,args)=let
103            val imgArg=List.nth(args,Vid)
104            val hArg=List.nth(args,hid)
105            val newposArg=List.nth(args,tid)
106            val (s,img,dim) =getArgsDst(hArg,imgArg,args)
107            val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
108          in          in
109              (Kernel.support h ,img, assigments,argsVK)              (dim,args@argsT,code, s,P)
110          end          end
     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"  
     |  _ => raise Fail "Not a kernel argument"  
111    
112        (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
113  fun handleArgs(V,h,t,(params,args),origargs,lift)=let      * expands the body for the probed field
114      val E.IMG(dim)=List.nth(params,V)      *)
115      val kArg=List.nth(origargs,h)      fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
116      val imgArg=List.nth(origargs,V)          (*1-d fields*)
117      val newposArg=List.nth(args, t)          fun createKRND1 ()=let
118      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)              val sum=sx
119      val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
120      in (dim,argsVH@argsT,argcode@code', s,P)              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
121                val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
122                in
123                   setProd[E.Img(Vid,alpha,pos),rest]
124      end      end
   
   
 (*createDels=> creates the kronecker deltas for each Kernel*)  
 fun createDels([],_)= []  
     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)  
   
 (*Created new body for probe*)  
 fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let  
   
     (*sumIndex creating summaiton Index for body*)  
     fun sumIndex(0)=[]  
     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]  
   
125      (*createKRN Image field and kernels *)      (*createKRN Image field and kernels *)
126      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)
127      | createKRN(dim,imgpos,rest)=let      | createKRN(dim,imgpos,rest)=let
128          val dim'=dim-1          val dim'=dim-1
129          val sum=sx+dim'          val sum=sx+dim'
130          val dels=createDels(deltas,dim')              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
131          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
132          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
133          in          in
134              createKRN(dim',pos@imgpos,[rest']@rest)              createKRN(dim',pos@imgpos,[rest']@rest)
135          end          end
136            val exp=(case dim
137      val exp=createKRN(dim, [],[])              of 1 => createKRND1()
138      val esum=sumIndex (dim)              | _=> createKRN(dim, [],[])
139      in E.Sum(esum, exp)              (*end case*))
140            (*sumIndex creating summaiton Index for body*)
141            val slb=1-s
142            val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
143            val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
144        in
145            E.Sum(esum, exp)
146      end      end
147    
148        (*getsumshift:sum_indexid list* int list-> int
149  fun ShapeConv([],n)=[]      *get fresh/unused index_id, returns int
150      | ShapeConv(E.C c::es, n)=ShapeConv(es, n)      *)
151      | ShapeConv(E.V v::es, n)=      fun getsumshift(sx,n) =let
152          if(n>v) then [E.V v] @ ShapeConv(es, n)          val nsumshift= (case sx
153          else ShapeConv(es,n)              of []=> n
154                | _=>let
155                    val (E.V v,_,_)=List.hd(List.rev sx)
156  fun mapIndex([],_)=[]                  in v+1
     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)  
     | mapIndex(E.C c::es,index) = mapIndex(es,index)  
   
   
 (*Lift probe and Multiply by P*)  
 fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let  
     val _ =print "Lift Probe"  
   
     val n=length(index)  
     val ns=length sumIndex  
     val nshift=length(dx)  
     val np=length(params)  
     val nsumshift =(case ns  
         of 0=>   n  
         |_=>let  val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
             in  
                 v+1  
157              end              end
158          (* end case *))          (* end case *))
159    
160            val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
161            val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
162            "\n\t Index length:",Int.toString n,
163            "\n\t Freshindex: ", Int.toString nsumshift])
164            in
165                nsumshift
166            end
167    
168      (*Outer Index-id Of Probe*)      (*formBody:ein_exp->ein_exp
169      val VShape=ShapeConv(alpha, n)      *just does a quick rewrite
170      val HShape=ShapeConv(dx, n)      *)
171      val shape=VShape@HShape      fun formBody(E.Sum([],e))=formBody e
172        | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
173      (* Bindings for Shape*)      | formBody(E.Opn(E.Prod, [e]))=e
174      val shapebind= mapIndex(shape,index)      | formBody e=e
     val Vshapebind= mapIndex(VShape,index)  
175    
176      (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
177      val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
178        (*
179          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
180          *)
181          | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
182          | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
183    
     (*New transformations:params, sx, rest, will be empty if no transformation is made*)  
     val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)  
184    
185      (*rewriteBody*)      fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
186      val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)        | multiMergePs e=multiPs e
187    
     val sx=sumIndex@sxT  
     val body'=(case sx  
         of [] =>E.Prod(restT@[bodyExpanded])  
         | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))  
         (*end case*))  
188    
189      (*create new EIN OPerator*)      (* *******************************************  Replace probe *******************************************  *)
190      val _ =print("Found this many args ")      (* replaceProbe
191      val _ =print(Int.toString(length(args')))      * Transforms position to world space
192        * transforms result back to index_space
193        * rewrites body
194        * replace probe with expanded version
195        *)
196         fun replaceProbe((y, DstIL.EINAPP(e,args)),p ,sx)
197            =let
198            val originalb=Ein.body e
199            val params=Ein.params e
200            val index=Ein.index e
201            val _ = testp["\n***************** \n Replace ************ \n"]
202            val _=  toStringBind (y, DstIL.EINAPP(e,args))
203    
204      val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
205      val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])          val fid=length(params)
206      val newbie'=Ein.EIN{params=p', index=i', body=b'}          val nid=fid+1
207      val data=assignEin (oldArg, newbie', a')          val Pid=nid+1
208            val nshift=length(dx)
209            val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
210            val freshIndex=getsumshift(sx,length(index))
211            val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
212            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
213            val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
214            val body' = multiPs(Ps,newsx1,body')
215    
216      val _ = (case testing          val body'=(case originalb
217          of 0 => 1              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
218          | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
219                | _                                  => body'
220          (*end case *))          (*end case *))
221    
222            val args'=argsA@[PArg]
223            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
224      in      in
225          (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)              code@[einapp]
226      end      end
227    
228        (* ******************************************* Lift probe *******************************************  *)
229        fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
230            val Pid=0
231            val tid=1
232    
233            (*Assumes body is already clean*)
234            val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
235    
236  (*Does not yet do transformation*)          (*need to rewrite dx*)
237   (* Expand probe in place *)          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
238   fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let              of []=> ([],index,E.Conv(9,alpha,7,newdx))
239                | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
     val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b  
     val fid=length(params)  
     val nid=fid+1  
     val n=length(index)  
     val ns=length sumIndex  
     val nshift=length(dx)  
     val nsumshift =(case ns  
         of 0=> n  
         | _=>let  
             val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
             in v+1  
             end  
240      (* end case *))      (* end case *))
241    
242      (*Outer Index-id Of Probe*)          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
243      val VShape=ShapeConv(alpha, n)          fun filterAlpha []=[]
244      val HShape=ShapeConv(dx, n)            | filterAlpha(E.C _::es)= filterAlpha es
245      val shape=VShape@HShape            | filterAlpha(e1::es)=[e1]@(filterAlpha es)
     (* Bindings for Shape*)  
     val shapebind= mapIndex(shape,index)  
     val Vshapebind= mapIndex(VShape,index)  
246    
247            val tshape=filterAlpha(alpha')@newdx
248            val t=E.Tensor(tid,tshape)
249    
250      val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)          val (splitvar,body)=(case originalb
251      val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)              of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
252                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
253      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]              | _                                  => (case tsplitvar
254      val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)                of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
255      val body' =(case nshift                  | false*) _ =>   (true,multiPs(Ps,newsx,t))
         of 0=> body''  
         | _ => E.Sum(sxT, E.Prod(restT@[body'']))  
256          (*end case*))          (*end case*))
     val args'=argsA@[PArg]  
     val _ =(case testing  
         of 0=> 1  
         | _ =>  let  
             val subexp=Ein.EIN{params=params', index=index, body=body'}  
             val _= print(String.concat["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])  
             in 1 end  
257          (* end case *))          (* end case *))
258    
259            val _ =(case splitvar
260            of true=> (String.concat["splitvar is true", P.printbody body])
261            | _ => (String.concat["splitvar is false",P.printbody body])
262            (*end case*))
263    
     in (body',(params',args') ,code)  
     end  
264    
265  (*Checks if (1) Summation variable occurs just once (2) it matches n.          val ein0=mkEin(params,index,body)
 Then we lift otherwise expand in place *)  
 fun checkSum(sx,b,info,index,origargs)=(case sx  
     of [(E.V i,lb,ub)]=>  let  
         val E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta))=b  
         val n=length(index)  
         val _=(case testing  
             of 1=> (print(String.concat["in check Sum\n " ,P.printbody(E.Sum([(E.V i,lb,ub)],b))]);1)  
             |_ => 1)  
266          in          in
267              if (i=n) then (case F.countSx(sx,b)              (splitvar,ein0,sizes,dx,alpha')
                 of (1,ixx) => liftProbe(b,info,index@[ub], [],origargs)  
                 | _ => replaceProbe(b, info,index,sx,origargs)  
                 (*end case*))  
             else replaceProbe(b, info,index, sx,origargs)  
268          end          end
     | _ =>replaceProbe(b, info,index, sx,origargs)  
     (*end case*))  
   
269    
270  fun flatten []=[]      fun liftProbe((y, DstIL.EINAPP(e,args)),p ,sx)=let
271      | flatten(e1::es)=e1@(flatten es)          val _=testp["\n******* Lift Geneirc Probe ***\n"]
272            val originalb=Ein.body e
273            val params=Ein.params e
274            val index=Ein.index e
275            val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
276    
277            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
278            val fid=length(params)
279            val nid=fid+1
280            val nshift=length(dx)
281            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
282            val freshIndex=getsumshift(sx,length(index))
283    
284   (* sx-[] then move out, otherwise keep in *)          (*transform T*P*P..Ps*)
285  fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
   
     val dummy=E.Const 0  
     val sumIndex=ref []  
   
     (*b-current body, info-original ein op, data-new assigments*)  
     fun rewriteBody(b,info)= let  
286    
287          fun callfn(c1,body)=let          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
288              val ref x=sumIndex          val einApp0=mkEinApp(ein0,[PArg,FArg])
289              val c'=[c1]@x          val rtn0=(case splitvar
290              val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))              of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
291              val ref s=sumIndex              | _      => let
292              val z=hd(s)                   val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
293              val e'=( case bodyK                   in Split.splitEinApp bind3
                 of E.Const _ =>bodyK  
                 | _ => E.Sum(z,bodyK)  
                 (*end case*))  
             in  
                 (sumIndex:=tl(s);(e',infoK,dataK))  
294              end              end
295                (*end case*))
296    
297          fun filter es=let          (*lifted probe*)
298              fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
299              | filterApply(B::es, doneA, infoA,dataA)= let          val freshIndex'= length(sizes)
300                  val (bodyB, infoB,dataB)= rewriteBody(B,infoA)  
301            val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
302            val ein1=mkEin(params',sizes,body')
303            val einApp1=mkEinApp(ein1,args')
304            val rtn1=(FArg,einApp1)
305            val rtn=code@[rtn1]@rtn0
306            val _= List.map toStringBind ([rtn1]@rtn0)
307             val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
308                  in                  in
309                      filterApply(es, doneA@[bodyB], infoB,dataA@dataB)              rtn
310                  end                  end
311              in filterApply(es, [],info,[])  
312              end      (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
313          in (case b      (* scans dx for contant
314              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let       * arg:(1,code1, body1,[])
315                  val ref sx=sumIndex       *)
316                  in (case sx      fun reconstruction([],arg)= replaceProbe arg
317                      of   [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)       | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
318                        | [i]=> checkSum(i,b, info,index,origargs)          of (true,true) => liftProbe arg
319            | (_,false)    => replaceProbe arg
320                        | _ => let                        | _ => let
321                          val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)              fun fConst [] = liftProbe arg
322                          in (E.Sum(c,b),m,code)              | fConst (E.C _::_) = replaceProbe arg
323                          end              | fConst (_ ::es)= fConst es
324                in fConst dx end
325            (* end case*))
326    
327        (* **************************************************** Index Tensor **************************************************** *)
328        (*Push constant indices to tensor replacement*)
329        fun getF (e,fieldset,dim,newvx)= let
330            val (y, DstIL.EINAPP(ein,args))=e
331            val index0=Ein.index ein
332            val index1 = index0@dim
333            val b=Ein.body ein
334    
335            val (c1,dx,body1)=(case b
336                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
337                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
338                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
339                    in (c1,dx,b) end
340                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
341                    val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
342                    (* clean to get body indices in order *)
343                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
344                    in (c1,dx,body1) end
345                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
346                   val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
347                   val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
348                   in (c1,dx,body1) end
349                (*end case*))
350    
351            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
352            val ein1 = mkEin(Ein.params ein,index1,body1)
353            val code1= (lhs1,mkEinApp(ein1,args))
354    
355            val (_,(lhs0,codeAll))= (case valnumflag
356                of false    => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
357                | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
358                    of (fieldset,NONE)     => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
359                    | (fieldset,SOME m)   =>  (fieldset,(m,[]))
360                  (* end case*))                  (* end case*))
             end  
         | E.Probe(E.Conv _, E.Tensor _) =>let  
             val ref sx=sumIndex  
             in (case sx  
                 of []=> liftProbe(b, info,index, [],origargs)  
                 | [i]=> checkSum(i,b, info,index,origargs)  
                 | _ => replaceProbe(b, info,index, flatten sx,origargs)  
361               (* end case*))               (* end case*))
362              end  
363          | E.Probe _=> (dummy,info,[])          (*Probe that tensor at a constant position  c1*)
364          | E.Conv _=>  (dummy,info,[])          val param0 = [E.TEN(1,index1)]
365          | E.Lift _=> (dummy,info,[])          val nx=List.tabulate(newvx,fn n=>E.V n)
366          | E.Field _ => (dummy,info,[])          val body0 =  (case b
367          | E.Apply _ => (dummy,info,[])              of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
368          | E.Neg e=> let              | _ => E.Tensor(0,[c1]@nx)
369              val (body',info',data')=rewriteBody(e,info)              (*end case*))
370            val ein0 = mkEin(param0,index0,body0)
371            val einApp0 = mkEinApp(ein0,[lhs0])
372            val code0 = (y,einApp0)
373            val _= toStringBind code0
374              in              in
375                  (E.Neg(body'),info',data')              codeAll@[code0]
             end  
         | E.Sum (c,e)=> callfn(c,e)  
         | E.Sub(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB, dataB)= rewriteBody(b,infoA)  
             in   (E.Sub(bodyA, bodyB),infoB,dataA@dataB)  
             end  
         | E.Div(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB,dataB)= rewriteBody(b,infoA)  
             in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end  
         | E.Add es=> let  
             val (done, info',data')= filter es  
             val (_, e)=F.mkAdd done  
             in (e, info',data')  
376              end              end
377          | E.Prod es=> let      (* **************************************************** General Fn **************************************************** *)
378              val (done, info',data')= filter es      (* expandEinOp: code->  code list
379              val (_, e)=F.mkProd done      * A this point we only have simple ein ops
380              in (e, info',data')      * Looks to see if the expression has a probe. If so, replaces it.
381              end      * Note how we keeps eps expressions so only generate pieces that are used
382          | _=>  (b,info,[])      *)
383        fun expandEinOp( e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
384            fun rewriteBody(e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
385                of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
386                | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
387                | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
388                | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
389                | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
390                | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
391                | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
392                | _                                          => reconstruction(dx,(e,p,[]))
393                (*end case*))
394            | rewriteBody(e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
395                of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
396                | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
397                | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
398                | (_,_,_,[])                                => replaceProbe(e,p, sx)  (*no dx*)
399                | (_,_,[],_)                                => reconstruction(dx,(e,p,sx))
400                | _                                         => replaceProbe(e,p, sx)
401                (* end case *))
402            | rewriteBody(e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(e,E.Probe p,sx)
403            | rewriteBody (e,_)  = [e]
404    
405            val b=Ein.body ein
406            fun matchField()=(case b
407                of E.Probe _ => 1
408                | E.Sum (_, E.Probe _)=>1
409                | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
410                | _ =>0
411                (*end case*))
412            val (fieldset,varset,code,flag) = (case valnumflag
413                of true => (case (einVarSet.rtnVarN(fieldset,e0))
414                   of  (fldset,NONE)     => (fldset,varset,rewriteBody(e0,b),0)
415                  | (fldset,SOME v)    => (fldset,varset,[(y,DstIL.VAR v)],1)
416                     (*of (fldset, NONE)      => (fldset,varset,rewriteBody((y,DstIL.EINAPP(ein,List.map (fn a=>einVarSet.replaceArg(varset,a)) args)), b),0)
417                     | (fldset,SOME v)    => (fldset,einVarSet.VarSet.add(varset,einVarSet.VAR(v,y)),[],1)*)
418          (* end case *))          (* end case *))
419          end              | _     => (fieldset,varset,rewriteBody(e0, b),0)
   
      val empty =fn key =>NONE  
      val _ =(case testing  
         of 0 => 1  
         | _ => (print "\n ************************** \n Starting Expand";1)  
420          (*end case*))          (*end case*))
421            val m=matchField()
422      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))          in  (code,fieldset,varset,m,flag) end
     val e'=Ein.EIN{params=params', index=index, body=body'}  
     (*val _ =(case testing  
         of 0 => 1  
         | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)  
         (*end case*))*)  
     in  
         ((e',args'),newbies)  
     end  
423    
424    end; (* local *)    end; (* local *)
425    

Legend:
Removed from v.2613  
changed lines
  Added in v.3503

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0