Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2611, Mon May 5 21:21:12 2014 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3325, Tue Oct 20 17:04:54 2015 UTC
# Line 1  Line 1 
1  (* Currently under construction  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4   * All rights reserved.   * All rights reserved.
5   *)   *)
6    
   
 (*  
 A couple of different approaches.  
 One approach is to find all the Probe(Conv). Gerenerate exp for it  
 Then use Subst function to sub in. That takes care for index matching and  
   
 *)  
   
 (*This approach creates probe expanded terms, and adds params to the end. *)  
   
   
7  structure ProbeEin = struct  structure ProbeEin = struct
8    
9      local      local
10    
11      structure E = Ein      structure E = Ein
     structure mk= mkOperators  
     structure SrcIL = HighIL  
     structure SrcTy = HighILTypes  
     structure SrcOp = HighOps  
     structure SrcSV = SrcIL.StateVar  
     structure VTbl = SrcIL.Var.Tbl  
12      structure DstIL = MidIL      structure DstIL = MidIL
     structure DstTy = MidILTypes  
13      structure DstOp = MidOps      structure DstOp = MidOps
     structure DstV = DstIL.Var  
     structure SrcV = SrcIL.Var  
14      structure P=Printer      structure P=Printer
     structure shift=ShiftEin  
     structure split=SplitEin  
     structure F=Filter  
15      structure T=TransformEin      structure T=TransformEin
16        structure MidToS = MidToString
17        structure DstV = DstIL.Var
18        structure DstTy = MidILTypes
19    
 val testing=1  
   
 datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  
 datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int  
20      in      in
21    
22        (* This file expands probed fields
23        * Take a look at ProbeEin tex file for examples
24        *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
25        * Param_ids are used to note the placement of the argument in the midIL.var list
26        * Index_ids  keep track of the shape of an Image or differentiation.
27        * Mu  bind Index_id
28        * Generally, we will refer to the following
29        *dim:dimension of field V
30        * s: support of kernel H
31        * alpha: The alpha in <V_alpha * H^(deltas)>
32        * deltas: The deltas in <V_alpha * H^(deltas)>
33        * Vid:param_id for V
34        * hid:param_id for H
35        * nid: integer position param_id
36        * fid :fractional position param_id
37        * img-imginfo about V
38        *)
39    
40  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))      val testing=0
41  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))      val testlift=1
42        val detflag =true
43        val fieldliftflag=true
44        val valnumflag=true
45    
 fun getRHS x  = (case SrcIL.Var.binding x  
     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)  
     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'  
     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)  
     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)  
     | SrcIL.VB_NONE=>(S2 2,[])  
     | vb => raise Fail(concat[  
     "expected rhs operator for ", SrcIL.Var.toString x,  
     "but found ", SrcIL.vbToString vb])  
     (* end case *))  
46    
47        val cnt = ref 0
48        fun transformToIndexSpace e=T.transformToIndexSpace e
49        fun transformToImgSpace  e=T.transformToImgSpace  e
50        fun toStringBind e=(MidToString.toStringBind e)
51        fun mkEin e=Ein.mkEin e
52        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
53    
54        fun testp n=(case testing
55            of 0=> 1
56            | _ =>(print(String.concat n);1)
57            (*end case*))
58    
 (*Create fractional, and integer position vectors*)  
 fun transformToImgSpace  (dim,v,posx)=let  
59    
60      val translate=DstOp.Translate v      fun getRHSDst x  = (case DstIL.Var.binding x
61      val transform=DstOp.Transform v          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
62      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
63            | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
64        (* end case *))
65    
     val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)  
     val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)  
     val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)  
     val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)  
     val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)  
     val PosToImgSpace=mk.transform(dim,dim)  
     val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)  
66    
67      val code=[      (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
68          assign(M, transform, []),          uses the Param_ids for the image, kernel,
69          assign(T, translate, []),          and position tensor to get the Mid-IL arguments
70          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)      returns the support of ther kernel, and image
71          assign(nd, DstOp.Floor dim, [x]),   (*nd *)      *)
72          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)      fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
73          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
74          assignEin(P, mk.transpose([dim,dim]), [M])              in
75          ]                ((Kernel.support h) ,img,ImageInfo.dim img)
     in ([n,f],P,code)  
76      end      end
77            |  _ => raise Fail "Expected Image and kernel arguments"
   
 fun replaceH(kvar, place,args)=let  
     val l1=List.take(args, place)  
     val l2=List.drop(args,place+1)  
     in l1@[kvar]@l2 end  
   
   
 (*Get Img, and Kern Args*)  
 fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)  
     of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let  
         val hvar=DstV.new ("KNL", DstTy.KernelTy)  
         val imgvar=DstV.new ("IMG", DstTy.ImageTy img)  
         val argsVK= (case lift  
             of 0=> let  
                 val argsN=replaceH(hvar, hid,args)  
                 in replaceH(imgvar, V,argsN) end  
             | _ => [imgvar, hvar]  
78          (* end case *))          (* end case *))
         val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]  
79    
80    
81        (*handleArgs():int*int*int*Mid IL.Var list
82            ->int*Mid.ILVars list* code*int* low-il-var
83            * uses the Param_ids for the image, kernel, and tensor
84            * and gets the mid-IL vars for each.
85            *Transforms the position to index space
86            *P is the mid-il var for the (transformation matrix)transpose
87        *)
88        fun handleArgs(Vid,hid,tid,args)=let
89            val imgArg=List.nth(args,Vid)
90            val hArg=List.nth(args,hid)
91            val newposArg=List.nth(args,tid)
92            val (s,img,dim) =getArgsDst(hArg,imgArg,args)
93            val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
94          in          in
95              (Kernel.support h ,img, assigments,argsVK)              (dim,args@argsT,code, s,P)
96          end          end
     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"  
     |  _ => raise Fail "Not a kernel argument"  
   
97    
98  fun handleArgs(V,h,t,(params,args),origargs,lift)=let      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
99      val E.IMG(dim)=List.nth(params,V)      * expands the body for the probed field
100      val kArg=List.nth(origargs,h)      *)
101      val imgArg=List.nth(origargs,V)      fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
102      val newposArg=List.nth(args, t)          (*1-d fields*)
103      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)          fun createKRND1 ()=let
104      val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)              val sum=sx
105      in (dim,argsVH@argsT,argcode@code', s,P)              val dels=List.map (fn e=>(E.C 0,e)) deltas
106                val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
107                val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
108                in
109                    E.Prod [E.Img(Vid,alpha,pos),rest]
110      end      end
   
   
 (*createDels=> creates the kronecker deltas for each Kernel*)  
 fun createDels([],_)= []  
     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)  
   
 (*Created new body for probe*)  
 fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let  
   
     (*sumIndex creating summaiton Index for body*)  
     fun sumIndex(0)=[]  
     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]  
   
111      (*createKRN Image field and kernels *)      (*createKRN Image field and kernels *)
112      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
113      | createKRN(dim,imgpos,rest)=let      | createKRN(dim,imgpos,rest)=let
114          val dim'=dim-1          val dim'=dim-1
115          val sum=sx+dim'          val sum=sx+dim'
116          val dels=createDels(deltas,dim')              val dels=List.map (fn e=>(E.C dim',e)) deltas
117          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
118          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
119          in          in
120              createKRN(dim',pos@imgpos,[rest']@rest)              createKRN(dim',pos@imgpos,[rest']@rest)
121          end          end
122            val exp=(case dim
123      val exp=createKRN(dim, [],[])              of 1 => createKRND1()
124      val esum=sumIndex (dim)              | _=> createKRN(dim, [],[])
125      in E.Sum(esum, exp)              (*end case*))
126            (*sumIndex creating summaiton Index for body*)
127            val slb=1-s
128            val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
129        in
130            E.Sum(esum, exp)
131      end      end
132    
133        (*getsumshift:sum_indexid list* int list-> int
134        *get fresh/unused index_id, returns int
135        *)
136        fun getsumshift(sx,index) =let
137            val nsumshift= (case sx
138                of []=> length(index)
139                | _=>let
140                    val (E.V v,_,_)=List.hd(List.rev sx)
141                    in v+1
142                    end
143                (* end case *))
144            val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
145            val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
146                "\nThink nshift is ", Int.toString nsumshift]
147            in
148                nsumshift
149            end
150    
151  fun ShapeConv([],n)=[]      (*formBody:ein_exp->ein_exp
152      | ShapeConv(E.C c::es, n)=ShapeConv(es, n)      *just does a quick rewrite
153      | ShapeConv(E.V v::es, n)=      *)
154          if(n>v) then [E.V v] @ ShapeConv(es, n)      fun formBody(E.Sum([],e))=formBody e
155          else ShapeConv(es,n)      | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
156        | formBody(E.Prod [e])=e
157        | formBody e=e
158    
159        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
160        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
161          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))
162          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
163    
164    
165        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
166          | multiMergePs e=multiPs e
167    
168    
169        (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
170                -> ein_exp* *code
171        * Transforms position to world space
172        * transforms result back to index_space
173        * rewrites body
174        * replace probe with expanded version
175        *)
176    (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
177    
178         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
179            =let
180            val originalb=Ein.body e
181            val params=Ein.params e
182            val index=Ein.index e
183            val _ = testp["\n***************** \n Replace ************ \n"]
184            val _=  toStringBind (y, DstIL.EINAPP(e,args))
185    
186  fun mapIndex([],_)=[]          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
187      | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)          val fid=length(params)
188      | mapIndex(E.C c::es,index) = mapIndex(es,index)          val nid=fid+1
189            val Pid=nid+1
190            val nshift=length(dx)
191            val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
192            val freshIndex=getsumshift(sx,index)
193            val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
194            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
195            val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
196            val body' = multiPs(Ps,newsx1,body')
197    
198            val body'=(case originalb
199                of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
200                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
201                | _                                  => body'
202                (*end case*))
203    
 (*Lift probe and Multiply by P*)  
 fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let  
     val _ =print "Lift Probe"  
204    
205      val n=length(index)          val args'=argsA@[PArg]
206      val ns=length sumIndex          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
     val nshift=length(dx)  
     val np=length(params)  
     val nsumshift =(case ns  
         of 0=>   n  
         |_=>let  val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
207              in              in
208                  v+1              code@[einapp]
209              end              end
         (* end case *))  
   
   
     (*Outer Index-id Of Probe*)  
     val VShape=ShapeConv(alpha, n)  
     val HShape=ShapeConv(dx, n)  
     val shape=VShape@HShape  
210    
211      (* Bindings for Shape*)      val tsplitvar=true
212      val shapebind= mapIndex(shape,index)      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
213      val Vshapebind= mapIndex(VShape,index)          val Pid=0
214            val tid=1
215    
216      (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)          (*Assumes body is already clean*)
217      val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
218    
219      (*New transformations:params, sx, rest, will be empty if no transformation is made*)          (*need to rewrite dx*)
220      val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
221                of []=> ([],index,E.Conv(9,alpha,7,newdx))
222                | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
223                (*end case*))
224    
225      (*rewriteBody*)          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
226      val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)          fun filterAlpha []=[]
227              | filterAlpha(E.C _::es)= filterAlpha es
228              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
229    
230      val sx=sumIndex@sxT          val tshape=filterAlpha(alpha')@newdx
231      val body'=(case sx          val t=E.Tensor(tid,tshape)
232          of [] =>E.Prod(restT@[bodyExpanded])          val (splitvar,body)=(case originalb
233          | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
234                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
235                | _                                  => (case tsplitvar
236                    of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
237                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
238                    (*end case*))
239          (*end case*))          (*end case*))
240    
241      (*create new EIN OPerator*)          val _ =(case splitvar
242      val _ =print("Found this many args ")          of true=> (String.concat["splitvar is true", P.printbody body])
243      val _ =print(Int.toString(length(args')))          | _ => (String.concat["splitvar is false",P.printbody body])
244            (*end case*))
245    
     val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT  
     val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])  
     val newbie'=Ein.EIN{params=p', index=i', body=b'}  
     val data=assignEin (oldArg, newbie', a')  
246    
247      val _ = (case testing          val ein0=mkEin(params,index,body)
         of 0 => 1  
         | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)  
         (*end case *))  
248      in      in
249          (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)              (splitvar,ein0,sizes,dx,alpha')
250      end      end
251    
252        fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
253            val _=testp["\n******* Lift ******** \n"]
254            val originalb=Ein.body e
255            val params=Ein.params e
256            val index=Ein.index e
257            val _=  toStringBind (y, DstIL.EINAPP(e,args))
258    
259            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
 (*Does not yet do transformation*)  
  (* Expand probe in place *)  
  fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let  
   
     val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b  
260      val fid=length(params)      val fid=length(params)
261      val nid=fid+1      val nid=fid+1
     val n=length(index)  
     val ns=length sumIndex  
262      val nshift=length(dx)      val nshift=length(dx)
263      val nsumshift =(case ns          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
264          of 0=> n          val freshIndex=getsumshift(sx,index)
         | _=>let  
             val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
             in v+1  
             end  
     (* end case *))  
   
     (*Outer Index-id Of Probe*)  
     val VShape=ShapeConv(alpha, n)  
     val HShape=ShapeConv(dx, n)  
     val shape=VShape@HShape  
     (* Bindings for Shape*)  
     val shapebind= mapIndex(shape,index)  
     val Vshapebind= mapIndex(VShape,index)  
   
265    
266      val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)          (*transform T*P*P..Ps*)
267      val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
268            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
269      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val einApp0=mkEinApp(ein0,[PArg,FArg])
270      val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)          val rtn0=(case splitvar
271      val body' =(case nshift              of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
         of 0=> body''  
         | _ => E.Sum(sxT, E.Prod(restT@[body'']))  
         (*end case*))  
     val args'=argsA@[PArg]  
     val _ =(case testing  
         of 0=> 1  
272          | _ =>  let          | _ =>  let
273              val subexp=Ein.EIN{params=params', index=index, body=body'}                   val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
274              val _= print(String.concat["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])                   in Split.splitEinApp bind3
275              in 1 end                   end
276          (* end case *))          (* end case *))
277    
278            (*lifted probe*)
279            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
280            val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
281            val ein1=mkEin(params',sizes,body')
282            val einApp1=mkEinApp(ein1,args')
283            val rtn1=(FArg,einApp1)
284            val rtn=code@[rtn1]@rtn0
285            val _= List.map toStringBind ([rtn1]@rtn0)
286    
     in (body',(params',args') ,code)  
     end  
   
 (*Checks if vairable occurs just once. If it does then we can lift *)  
 fun checkSum([(E.V i,lb,ub)],E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta)),info,index,origargs)=let  
     val n=length(index)  
     val _=print "in check Sum\n "  
     val _=print(P.printbody(E.Sum([(E.V i,lb,ub)],E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)))))  
287      in      in
288       if (i=n) then (case F.findOcc(i,alpha@dx)              rtn
             of  1 => liftProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)),info,index@[ub], [],origargs)  
             | _ =>  replaceProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)), info,index, [(E.V i,lb,ub)],origargs)  
             (*end case*))  
     else replaceProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)), info,index, [(E.V i,lb,ub)],origargs)  
289      end      end
290    
 fun flatten []=[]  
     | flatten(e1::es)=e1@(flatten es)  
291    
292        fun liftFieldMat(newvx,e)=
293   (* sx-[] then move out, otherwise keep in *)          let
294  fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let              val (y, DstIL.EINAPP(ein,args))=e
295                val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
296      val dummy=E.Const 0              val index0=Ein.index ein
297      val sumIndex=ref []              val index1 = index0@[3]
298                val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
299      (*b-current body, info-original ein op, data-new assigments*)              (* clean to get body indices in order *)
300      fun rewriteBody(b,info)= let              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
301                val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
302          fun callfn(c1,body)=let  
303              val ref x=sumIndex              val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
304              val c'=[c1]@x              val ein1 = mkEin(Ein.params ein,index1,body1)
305              val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))              val code1= (lhs1,mkEinApp(ein1,args))
306              val ref s=sumIndex              val codeAll= (case dx
307              val z=hd(s)              of []=> replaceProbe(1,code1,body1,[])
308              val e'=( case bodyK              | _ =>liftProbe(1,code1,body1,[])
309                  of E.Const _ =>bodyK              (*end case*))
310                  | _ => E.Sum(z,bodyK)  
311                  (*end case*))              (*Probe that tensor at a constant position  c1*)
312                val param0 = [E.TEN(1,index1)]
313                val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
314                val body0 =  E.Tensor(0,[c1]@nx)
315                val ein0 = mkEin(param0,index0,body0)
316                val einApp0 = mkEinApp(ein0,[lhs1])
317                val code0 = (y,einApp0)
318                val _= toStringBind code0
319              in              in
320                  (sumIndex:=tl(s);(e',infoK,dataK))              codeAll@[code0]
321              end              end
322    
323          fun filter es=let      fun liftFieldSum e =
324              fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)      let
325              | filterApply(B::es, doneA, infoA,dataA)= let          val _=print"\n*************************************\n"
326                  val (bodyB, infoB,dataB)= rewriteBody(B,infoA)          val (y, DstIL.EINAPP(ein,args))=e
327                  in          val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
328                      filterApply(es, doneA@[bodyB], infoB,dataA@dataB)          val index0=Ein.index ein
329                  end          val index1 = index0@[3]@[3]
330              in filterApply(es, [],info,[])          val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
331              end          val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
332          in (case b  
333              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let  
334                  val ref sx=sumIndex          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
335                  in (case sx          val ein1 = mkEin(Ein.params ein,index1,body1)
336                      of   [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)          val code1= (lhs1,mkEinApp(ein1,args))
337                        | [i]=> checkSum(i,b, info,index,origargs)          val codeAll= (case dx
338                       | _ => let          of []=> replaceProbe(1,code1,body1,[])
339                          val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)          | _ =>liftProbe(1,code1,body1,[])
340                          in (E.Sum(c,b),m,code)          (*end case*))
341                          end  
342                  (* end case*))          (*Probe that tensor at a constant position  c1*)
343              end          val param0 = [E.TEN(1,index1)]
344          | E.Probe(E.Conv _, E.Tensor _) =>let          val nx=List.tabulate(length(dx),fn n=>E.V n)
345              val ref sx=sumIndex          val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
346              in (case sx          val ein0 = mkEin(param0,index0,body0)
347                  of []=> liftProbe(b, info,index, [],origargs)          val einApp0 = mkEinApp(ein0,[lhs1])
348                  | [i]=> checkSum(i,b, info,index,origargs)          val code0 = (y,einApp0)
349                  | _ => replaceProbe(b, info,index, flatten sx,origargs)          val _= toStringBind  e
350               (* end case*))          val _ =toStringBind code0
351              end         val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
352          | E.Probe _=> (dummy,info,[])         val _ =(String.concat(List.map toStringBind (codeAll@[code0])))
353          | E.Conv _=>  (dummy,info,[])                 val _=print"\n*************************************\n"
         | E.Lift _=> (dummy,info,[])  
         | E.Field _ => (dummy,info,[])  
         | E.Apply _ => (dummy,info,[])  
         | E.Neg e=> let  
             val (body',info',data')=rewriteBody(e,info)  
354              in              in
355                  (E.Neg(body'),info',data')          codeAll@[code0]
             end  
         | E.Sum (c,e)=> callfn(c,e)  
         | E.Sub(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB, dataB)= rewriteBody(b,infoA)  
             in   (E.Sub(bodyA, bodyB),infoB,dataA@dataB)  
             end  
         | E.Div(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB,dataB)= rewriteBody(b,infoA)  
             in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end  
         | E.Add es=> let  
             val (done, info',data')= filter es  
             val (_, e)=F.mkAdd done  
             in (e, info',data')  
             end  
         | E.Prod es=> let  
             val (done, info',data')= filter es  
             val (_, e)=F.mkProd done  
             in (e, info',data')  
             end  
         | _=>  (b,info,[])  
         (* end case *))  
356          end          end
357    
      val empty =fn key =>NONE  
      val _ =(case testing  
         of 0 => 1  
         | _ => (print "\n ************************** \n Starting Expand";1)  
         (*end case*))  
358    
359      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))      (* expandEinOp: code->  code list
360      val e'=Ein.EIN{params=params', index=index, body=body'}      * A this point we only have simple ein ops
361      (*val _ =(case testing      * Looks to see if the expression has a probe. If so, replaces it.
362          of 0 => 1      * Note how we keeps eps expressions so only generate pieces that are used
363          | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)      *)
364          (*end case*))*)     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
365      in  
366          ((e',args'),newbies)          fun checkConst ([],a) =
367                (case fieldliftflag
368                    of true => liftProbe a
369                    | _ => replaceProbe a
370                (*end case*))
371            | checkConst ((E.C _::_),a) = replaceProbe a
372            | checkConst ((_ ::es),a)= checkConst(es,a)
373    
374            fun rewriteBody b=(case (detflag,b)
375                of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
376                    => liftFieldMat (1,e)
377                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
378                    => liftFieldMat (2,e)
379                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
380                    => liftFieldMat (3,e)
381                | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
382                    => liftFieldSum e
383                | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
384                    => liftFieldSum e
385                | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
386                    => liftFieldSum e
387    
388    
389                | (_,E.Probe(E.Conv(_,_,_,[]),_))
390                    => replaceProbe(0,e,b,[])
391                | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
392                    => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
393                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
394                    => replaceProbe(0,e,p, sx)  (*no dx*)
395                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
396                    => checkConst(dx,(0,e,p,sx)) (*scalar field*)
397                | (_,E.Sum(sx,E.Probe p))
398                    => replaceProbe(0,e,E.Probe p, sx)
399                | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
400                    => replaceProbe(0,e,E.Probe p,sx)
401                | (_,_) => [e]
402                (* end case *))
403    
404            val (fieldset,var) = (case valnumflag
405                of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
406                | _     => (fieldset,NONE)
407            (*end case*))
408    
409            fun matchField b=(case b
410                of E.Probe _ => 1
411                | E.Sum (_, E.Probe _)=>1
412                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
413                | _ =>0
414                (*end case*))
415    
416            in  (case var
417                of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
418                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
419                (*end case*))
420      end      end
421    
422    end; (* local *)    end; (* local *)

Legend:
Removed from v.2611  
changed lines
  Added in v.3325

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0