Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2608, Fri May 2 18:04:54 2014 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3655, Thu Feb 4 04:12:40 2016 UTC
# Line 1  Line 1 
1  (* Currently under construction  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4     *
5     * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
   
 (*  
 A couple of different approaches.  
 One approach is to find all the Probe(Conv). Gerenerate exp for it  
 Then use Subst function to sub in. That takes care for index matching and  
   
 *)  
   
 (*This approach creates probe expanded terms, and adds params to the end. *)  
   
   
9  structure ProbeEin = struct  structure ProbeEin = struct
10    
11      local      local
12    
13      structure E = Ein      structure E = Ein
     structure mk= mkOperators  
     structure SrcIL = HighIL  
     structure SrcTy = HighILTypes  
     structure SrcOp = HighOps  
     structure SrcSV = SrcIL.StateVar  
     structure VTbl = SrcIL.Var.Tbl  
14      structure DstIL = MidIL      structure DstIL = MidIL
     structure DstTy = MidILTypes  
15      structure DstOp = MidOps      structure DstOp = MidOps
     structure DstV = DstIL.Var  
     structure SrcV = SrcIL.Var  
16      structure P=Printer      structure P=Printer
17      structure shift=ShiftEin      structure T = TransformEin
18      structure split=SplitEin      structure MidToS = MidToString
19      structure F=Filter      structure DstV = DstIL.Var
20        structure DstTy = MidILTypes
 val testing=1  
21    
 datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  
 datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int  
22      in      in
23    
24        (* This file expands probed fields
25        * Take a look at ProbeEin tex file for examples
26        *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27        * Param_ids are used to note the placement of the argument in the midIL.var list
28        * Index_ids  keep track of the shape of an Image or differentiation.
29        * Mu  bind Index_id
30        * Generally, we will refer to the following
31        *dim:dimension of field V
32        * s: support of kernel H
33        * alpha: The alpha in <V_alpha * H^(deltas)>
34        * deltas: The deltas in <V_alpha * H^(deltas)>
35        * Vid:param_id for V
36        * hid:param_id for H
37        * nid: integer position param_id
38        * fid :fractional position param_id
39        * img-imginfo about V
40        *)
41    
42  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))      val testing=0
43  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))      val valnumflag= true
44        val tsplitvar = true
45        val fieldliftflag=  true
46        val constflag =  true
47        val detflag = true
48        val detsumflag= true
49        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50        fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51    
52        val liftimgflag =false
53        val pullKrn= false
54    
55        val cnt = ref 0
56        fun transformToIndexSpace e=T.transformToIndexSpace e
57        fun transformToImgSpace  e=T.transformToImgSpace  e
58        fun transformToImgSpaceF  e=T.transformToImgSpaceF  e
59        fun toStringBind e=(MidToString.toStringBind e)
60        fun toStringBindp e=(MidToString.toStringBind e)
61        fun mkEin e=Ein.mkEin e
62        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
63        fun setConst e = E.setConst e
64        fun setNeg e  =  E.setNeg e
65        fun setExp e  =  E.setExp e
66        fun setDiv e= E.setDiv e
67        fun setSub e= E.setSub e
68        fun setProd e= E.setProd e
69        fun setAdd e= E.setAdd e
70        fun mkCx es =List.map (fn c => E.C (c,true)) es
71        fun mkCxSingle c = E.C (c,true)
72    
73  fun getRHS x  = (case SrcIL.Var.binding x      fun testp n=(case testing
74      of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)          of 0=> 1
75      | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'          | _ =>(print(String.concat n);1)
     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)  
     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)  
     | SrcIL.VB_NONE=>(S2 2,[])  
     | vb => raise Fail(concat[  
     "expected rhs operator for ", SrcIL.Var.toString x,  
     "but found ", SrcIL.vbToString vb])  
76      (* end case *))      (* end case *))
77    
 (*Transform differentiation index to world-space*)  
 (*Returns new deltas, summations, and List of Tensor Products*)  
 fun ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let  
   
     val dim'=dim-1  
     fun setMatrix(imgix,wrdix) = E.Tensor(1,[ E.V imgix, E.V wrdix])  
78    
79      fun changeDel([],n,newdels,sx,rest)=(newdels,sx,rest)      fun getRHSDst x  = (case DstIL.Var.binding x
80       | changeDel((E.C _)::es,_,_,_,_)= raise Fail "unsure what to do with constant differentiation"          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
81       | changeDel((E.V v)::es,n,newdels,sx,rest)= let          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
82          val P=setMatrix(v, n)          | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
83          val n'=E.V n      (* end case *))
         in  
             changeDel(es,n+1,newdels@[n'],sx@[(n',0,dim')],rest@[P])  
         end  
84    
     val n=length(outerShape)  
     val (newdels,sx,rest)=changeDel(dels,n,[],[],[])  
     in  
         (newdels,sx,rest)  
     end  
85    
86  (*Transformation is Lifted*)      (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
87  fun ImgtoWorldSpaceLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let          uses the Param_ids for the image, kernel,
88      val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)          and position tensor to get the Mid-IL arguments
89      val tshape=List.tabulate((length(alpha)),fn v=> E.V v)      returns the support of ther kernel, and image
90        *)
91        fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
92            of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> ((Kernel.support h) ,img,ImageInfo.dim img)
93            |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
94            (*end case*))
95    
     val newbie'=Ein.EIN{  
             params=[E.TEN(1,outerShape), E.TEN(1,[dim,dim])],  
             index=outerShape,  
             body=E.Sum(sx,E.Prod([E.Tensor(0,tshape@newdels)]@rest))  
         }  
96    
97      val _ = print(String.concat["\n Transform \n ",(split.printA (newArg,newbie',[oldArg,PArg])) ,"\n"])      (*handleArgs():int*int*int*Mid IL.Var list
98      val data=assignEin (newArg, newbie', [oldArg,PArg])          ->int*Mid.ILVars list* code*int* low-il-var
99      val ix=List.tabulate((length(dels)),fn _=> dim)          * uses the Param_ids for the image, kernel, and tensor
100            * and gets the mid-IL vars for each.
101            *Transforms the position to index space
102            *P is the mid-il var for the (transformation matrix)transpose
103        *)
104        fun handleArgs(Vid,hid,tid,args)=let
105            val imgArg=List.nth(args,Vid)
106            val hArg=List.nth(args,hid)
107            val newposArg=List.nth(args,tid)
108            val (s,img,dim) =getArgsDst(hArg,imgArg,args)
109            val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
110      in      in
111          (newdels,ix,[data],[],[],[])              (dim,args@argsT,code, s,P)
112      end      end
113    
114        fun handleArgsF(fieldset,Vid,hid,tid,args)=let
115  fun ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,id)= let          val imgArg=List.nth(args,Vid)
116      val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)          val hArg=List.nth(args,hid)
117      val params=[E.TEN(id,[dim,dim])]          val newposArg=List.nth(args,tid)
118            val (s,img,dim) =getArgsDst(hArg,imgArg,args)
119            val (fieldset,argsT,P,code)=transformToImgSpaceF(fieldset,dim,img,newposArg,imgArg)
120      in      in
121          (newdels,[],[],params,sx,rest)              (fieldset,dim,args@argsT,code, s,P)
122      end      end
123    
124    
 (*E.Sum(sx,rest) *)  
   
125    
126    
127  (*returns final argumentVar, new dels, and assignments*)      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
128  fun decideIfTransform(dels,outerShape,alpha,dim,PArg,Pid,lift)=let      * expands the body for the probed field
129      val oldArg = DstV.new ("ProbeResult", DstTy.tensorTy outerShape)      *)
130      in (case (dels,lift)      fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
131          of ([],_) => (oldArg,oldArg,dels,[],[],[],[],[])          (*1-d fields*)
132          | (_,0) => let  (*need to transform to world-space*)          fun createKRND1 ()=let
133              val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)              val sum=sx
134              val (dels',ix',assigments,_,_,_)=ImgtoWorldSpaceLift(dels,outerShape,alpha,dim,oldArg,newArg,PArg)              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
135                val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
136                val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
137              in              in
138                  (oldArg,newArg,dels', ix',assigments,[],[],[])                 setProd[E.Img(Vid,alpha,pos),rest]
             end  
         | ( _,_) =>  let  
                 val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)  
                val (dels',_,_,params,sx,rest)=ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,Pid)  
                 in (oldArg, oldArg, dels',[],[],params,sx,rest)  
                 end  
         (*end case*))  
139      end      end
140    
141            fun mkImg(imgpos)=E.Img(Vid,alpha,imgpos)
142    
143  (*Create fractional, and integer position vectors*)          (*createKRN Image field and kernels *)
144  fun transformToImgSpace  (dim,v,posx)=let          fun createKRN(0,imgpos,rest)=setProd ([mkImg(imgpos)] @rest)
145            | createKRN(dim,imgpos,rest)=let
146      val translate=DstOp.Translate v              val dim'=dim-1
147      val transform=DstOp.Transform v              val sum=sx+dim'
148      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
149                val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
150                val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
151                in
152                    createKRN(dim',pos@imgpos,[rest']@rest)
153                end
154            val exp=(case dim
155                of 1 => createKRND1()
156                | _=> createKRN(dim, [],[])
157                (*end case*))
158            (*sumIndex creating summaiton Index for body*)
159            val slb=1-s
160            val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
161            val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
162        in
163            E.Sum(esum, exp)
164        end
165    
166    
167    
168        (* build position *)
169        fun buildPos (dir,dim,argsA,hid,nid,s) =let
170            val vA  = DstV.new ("kernel_pos", DstTy.TensorTy([]))
171            val p=[E.KRN,E.TEN(1,[dim])]
172            val pos=setSub(E.Tensor(1,[mkCxSingle dir]),E.Value(0))
173            val exp= E.BuildPos(s,pos)
174            (*val exp = E.Sum([(E.V 0, slb, s)],E.Krn(0,[],pos))*)
175            val a=[List.nth(argsA,hid),List.nth(argsA,nid)]
176            val A=(vA,mkEinApp(mkEin(p,[],exp),a))
177            in (vA,A) end
178    
179        (* apply differentiation *)
180        fun getKrn1Del(dx,dim,args,slb,s)= let
181            val n=Int.toString(dx)
182            val vA  = DstV.new ("kernel_del"^n, DstTy.TensorTy([]))
183            val p=[E.KRN,E.TEN(1,[dim])]
184            val exp = E.EvalKrn dx
185            val A = (vA,mkEinApp(mkEin(p,[],exp),args))
186            in (vA,A) end
187    
188        (*create holder expression*)
189        fun mkHolder(dim,args) =let
190            val n=List.length(args)
191            val vA  = DstV.new ("kernel_cons", DstTy.TensorTy([n]))
192            val p=[E.KRN,E.TEN(1,[dim])]
193            val A= (vA,mkEinApp(mkEin(p,[],E.Holder n),args))
194            in (vA,A) end
195    
196        (*lifted Kernel expressions*)
197        fun liftKrn(dx,dir,dim,argsA,hid,nid,slb,s)=let
198            val  (vA,A)=buildPos(dir,dim,argsA,hid,nid,s)
199            val args=[List.nth(argsA,hid),vA]
200            fun iter(0,vBs,Bs)=let
201                val (vA,A)=getKrn1Del(0,dim,args,slb,s)
202                in (vA::vBs,A::Bs) end
203              | iter (n,vBs,Bs)= let
204                val (vA,A)=getKrn1Del(n,dim,args,slb,s)
205                in iter(n-1,vA::vBs,A::Bs) end
206            val (vBs,Bs)=iter(length(dx),[],[])
207            val (vC,C)  =mkHolder(dim,vBs)
208            in (vC,(A::Bs)@[C]) end
209    
210    
211    
212        fun createBody2(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=let
213            (*1-d fields*)
214            val slb=1-s
215    
216    
217            (*making image*)
218            val tid=(case liftimgflag
219                of true => length(params)-1
220                | _     => length(params)-1
221            (*end case*))
222            fun mkImg imgpos =(case liftimgflag
223                of true=>(E.Tensor(Vid,alpha),SOME(E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim),slb,s))),E.Img(Vid,alpha,imgpos))))
224                |  _ =>let
225                    val imgpos= List.tabulate(dim,fn e=> setAdd[E.Tensor(fid,[mkCxSingle e]),E.Value(e+sx)])
226                    in (E.Img(Vid,alpha,imgpos),NONE) end
227            (*end case*))
228    
229      val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)          fun createKRND1 ()=let
230      val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)              val sum=sx
231      val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
232      val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)              val imgpos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
233      val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
234      val PosToImgSpace=mk.transform(dim,dim)              val (talpha,iexp)= mkImg imgpos
235      val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)              in (setProd[talpha,rest],iexp,NONE,NONE)end
236    
237      val code=[          (*createKRN Image field and kernels *)
238          assign(M, transform, []),          fun createKRN(0,orig,imgpos,vAs,krnpos)= let
239          assign(T, translate, []),              val (talpha,iexp)= mkImg imgpos
240          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)              in  (setProd ([talpha]@orig),iexp,SOME vAs,SOME krnpos) end
241          assign(nd, DstOp.Floor dim, [x]),   (*nd *)          | createKRN(d,orig,imgpos,vAs,krnpos)=let
242          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)              val dim'=d-1
243          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)              val sum=sx+dim'
244          assignEin(P, mk.transpose([dim,dim]), [M])              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
245          ]              val ipos=setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(dim')]
246      in ([n,f],P,code)              val opos= E.Krn(hid,dels,E.Tensor(tid+d,[]))
247                val (vA,A)= liftKrn(dels,dim',dim,argsA,hid,nid,slb,s)
248                in
249                    createKRN(dim',[opos]@orig,[ipos]@imgpos,[vA]@vAs,A@krnpos)
250      end      end
251    
252            val (oexp,iexp,vAs,keinapp)=(case dim
253                of 1 => createKRND1()
254                | _=> createKRN(dim, [],[],[],[])
255            (*end case*))
256    
257  fun replaceH(kvar, place,args)=let          val oexp=E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s))), oexp)
258      val l1=List.take(args, place)          in (oexp,iexp,vAs,keinapp) end
     val l2=List.drop(args,place+1)  
     in l1@[kvar]@l2 end  
259    
260        fun createBody3(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)=
261                createBody2(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)
262        |   createBody3(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=
263                (createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid),NONE,NONE,NONE)
264    
265  (*Get Img, and Kern Args*)      (*getsumshift:sum_indexid list* int list-> int
266  fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)      *get fresh/unused index_id, returns int
267      of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let      *)
268          val hvar=DstV.new ("KNL", DstTy.KernelTy)      fun getsumshift(sx,n) =let
269          val imgvar=DstV.new ("IMG", DstTy.ImageTy img)          val nsumshift= (case sx
270          val argsVK= (case lift              of []=> n
271              of 0=> let              | _=>let
272                  val argsN=replaceH(hvar, hid,args)                  val (E.V v,_,_)=List.hd(List.rev sx)
273                  in replaceH(imgvar, V,argsN) end                  in v+1
274              | _ => [imgvar, hvar]                  end
275          (* end case *))          (* end case *))
         val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]  
276    
277            val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
278            val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
279            "\n\t Index length:",Int.toString n,
280            "\n\t Freshindex: ", Int.toString nsumshift])
281          in          in
282              (Kernel.support h ,img, assigments,argsVK)              nsumshift
283          end          end
     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"  
     |  _ => raise Fail "Not a kernel argument"  
284    
285        (*formBody:ein_exp->ein_exp
286        *just does a quick rewrite
287        *)
288        fun formBody(E.Sum([],e))=formBody e
289        | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
290        | formBody(E.Opn(E.Prod, [e]))=e
291        | formBody e=e
292    
293  fun handleArgs(V,h,t,(params,args),origargs,lift)=let      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
294      val E.IMG(dim)=List.nth(params,V)      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
295      val kArg=List.nth(origargs,h)      (*
296      val imgArg=List.nth(origargs,V)        | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
297      val newposArg=List.nth(args, t)        *)
298      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)        | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
299      val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)        | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
     in (dim,argsVH@argsT,argcode@code', s,P)  
     end  
300    
301    
302  (*createDels=> creates the kronecker deltas for each Kernel*)      fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
303  fun createDels([],_)= []        | multiMergePs e=multiPs e
     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)  
304    
305  (*Created new body for probe*)      (* *******************************************  setImage *******************************************  *)
306  fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let      fun replaceImgA(es,vid,newbie)=List.take(es,vid)@[newbie]@List.drop(es,vid+1)
307        fun setImage(params',argsA,code,vexp2,index,alpha,paraminstant,Vid,s)=
308            (case vexp2
309                of NONE =>(params',argsA,code)
310                | SOME vexp  =>    let
311                    val iArg  = DstV.new ("Img", DstTy.TensorTy([]))
312                    val alphax=List.map (fn (E.V i)=>List.nth(index,i)) alpha
313                    val ieinapp=(iArg,mkEinApp(mkEin(paraminstant,alphax,vexp),argsA))
314                    (*
315                     val _ =print(String.concat["\n****\n Image (",Int.toString(length(argsA)),")"])
316                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
317                    val _ =print(String.concat["\n replace at ",Int.toString Vid ," with " , DstIL.Var.toString iArg ,"\n"])*)
318                    val argsA=replaceImgA(argsA,Vid,iArg)
319                    val params'=replaceImgA(params',Vid,E.TEN(2,[(s-(1-s)+1)*(s-(1-s)+1),(s-(1-s)+1)]))
320                    val code=code@[ieinapp]
321                    (*
322                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
323                    val _ =print(String.concat["\n****\n Image(",Int.toString(length(argsA)),")"])*)
324                    in (params',argsA,code) end
325            (*end case*))
326    
327      (*sumIndex creating summaiton Index for body*)        (*kernels*)
328      fun sumIndex(0)=[]        fun setKernel(params',args',code,vAs2,keinapp2,dim)=
329      |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]          (case (vAs2,keinapp2)
330                of (NONE,NONE)=> (params',args',code)
331                | (SOME vAs,SOME keinapp) => let
332                (*
333                    val _ =print"\n****\n Kernels\n"
334                    val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])
335                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))*)
336                    val args'=   args'@vAs
337                    val params'= params'@(List.tabulate(dim,fn _=> E.TEN(2,[])))
338                    val code=code@keinapp
339                    (*
340                    val _ =print"\n"
341                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))
342                    val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])*)
343                    in (params',args',code) end
344            (*end case*))
345    
346          fun setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)=let
347            val (params',args',code)=setImage(params',args',code,vexp2,index,alpha,paraminstant,Vid,s)
348            in setKernel(params',args',code,vAs2,keinapp2,dim) end
349    
350    
351        (* *******************************************  Replace probe *******************************************  *)
352        (* replaceProbe
353        * Transforms position to world space
354        * transforms result back to index_space
355        * rewrites body
356        * replace probe with expanded version
357        *)
358         fun replaceProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)
359            =let
360            val originalb=Ein.body e
361            val params=Ein.params e
362            val index=Ein.index e
363            val _ = (String.concat["\n***************** \n Replace ************ \n"])
364            val _=  toStringBindp (y, DstIL.EINAPP(e,args))
365    
366      (*createKRN Image field and kernels *)          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
367      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)          val fid=length(params)
368      | createKRN(dim,imgpos,rest)=let          val nid=fid+1
369          val dim'=dim-1          val Pid=nid+1
370          val sum=sx+dim'          val nshift=length(dx)
371          val dels=createDels(deltas,dim')          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
372          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]          val freshIndex=getsumshift(sx,length(index))
373          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
374            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
375            val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
376            val body' = multiPs(Ps,newsx1,body')
377    
378            val body'=(case originalb
379                of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
380                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
381                | _                                  => body'
382                (*end case*))
383    
384            val args'=argsA@[PArg]
385            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
386            val _= List.map toStringBindp(code@[einapp])
387          in          in
388              createKRN(dim',pos@imgpos,[rest']@rest)              (fieldset,code@[einapp])
389          end          end
390    
     val exp=createKRN(dim, [],[])  
     val esum=sumIndex (dim)  
     in E.Sum(esum, exp)  
     end  
391    
392    
393  fun ShapeConv([],n)=[]      fun replaceProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx) = let
394      | ShapeConv(E.C c::es, n)=ShapeConv(es, n)          val originalb=Ein.body e
395      | ShapeConv(E.V v::es, n)=          val params=Ein.params e
396          if(n>v) then [E.V v] @ ShapeConv(es, n)          val index=Ein.index e
397          else ShapeConv(es,n)          val _ = testp["\n***************** \n Replace ************ \n"]
398            val _=  toStringBind (y, DstIL.EINAPP(e,args))
399    
400            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
401            val fid=length(params)
402            val nid=fid+1
403            val Pid=nid+1
404            val nshift=length(dx)
405            val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
406            val freshIndex=getsumshift(sx,length(index))
407            val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
408    
409  fun mapIndex([],_)=[]          val paraminstant=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
410      | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)          val params'=paraminstant@[E.TEN(1,[dim,dim])]
     | mapIndex(E.C c::es,index) = mapIndex(es,index)  
411    
412    
413    
414  (*Currently have three different Functions. One that Lifts Field and Transormation. One that just one lift, and the other does it all in place*)          val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid,paraminstant,argsA)
415  (*Lift probe and does transformation Lifted*)          val body' = multiPs(Ps,newsx1,body')
 fun liftProbeTester(b,(params,args),index, sumIndex,origargs)=let  
416    
417      val E.Probe(E.Conv(V,alpha,H,dels),E.Tensor(t,_))=b          val body'=(case originalb
418      val newId=length(params)              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
419      val n=length(index)              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
420                | _                                  => body'
421                (*end case*))
422    
423      (*Create new tensor replacement*)          (*images and kernels*)
424      val shape=ShapeConv(alpha@dels, n)          val (params',argsA,code)=setImageKernel(params',argsA,code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)
     val newB=E.Tensor(newId,shape)  
425    
     (* Create new Param*)  
     (*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)  
     val shape'= mapIndex(shape,index)  
     val newP= E.TEN(1,shape')  
426    
427            (*replace term*)
428            val args'=argsA@[PArg]
429            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
430            val _= List.map toStringBindp(code@[einapp])
431            in
432                (fieldset,code@[einapp])
433            end
434    
     val A1=ShapeConv(alpha, n)  
     val alpha1= mapIndex(A1,index)  
435    
436      (*Expand Probe*)      (* ******************************************* Lift probe *******************************************  *)
437      val ns=length sumIndex      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
438            val Pid=0
439            val tid=1
440    
441      val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)          (*Assumes body is already clean*)
442      val (oldArg,newArg,dx, ix,assigments,paramsA,sxA,restA) = decideIfTransform(dels,shape',alpha1,dim,PArg,n+1,ns)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
443    
444      val body' =(case ns          (*need to rewrite dx*)
445          of 0=>    createBody(dim, s,n+length(dx),alpha,dx,0, 1, 3, 2)          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
446          |_=>let              of []=> ([],index,E.Conv(9,alpha,7,newdx))
447              val (E.V v,_,_)=List.nth(sumIndex, ns-1)              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
             val v'=v+length(dx) (*Shifted because of the swapped indices *)  
             val body'=createBody(dim, s,v'+2,alpha,dx,0, 1, 3, 2)  
             in  E.Sum(sumIndex ,body')  
             end  
448      (* end case *))      (* end case *))
449    
450      val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
451      val (p',i',b',a')=shift.clean(params'@paramsA, index@ix,body', args')          fun filterAlpha []=[]
452      val newbie'=Ein.EIN{params=p', index=i', body=b'}            | filterAlpha(E.C _::es)= filterAlpha es
453              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
454    
455      val data=assignEin (oldArg, newbie', a')          val tshape=filterAlpha(alpha')@newdx
456            val t=E.Tensor(tid,tshape)
457    
458      val _ = (case testing          val (splitvar,body)=(case originalb
459          of 0 => 1              of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
460          | _ => (print(String.concat["\n Lift Probe\n", split.printA(oldArg, newbie', a'),"\n"]);1)              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
461                | _                                  => (case tsplitvar
462                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
463                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
464                    (*end case*))
465              (*end case *))              (*end case *))
     in (newB, (params@[newP],args@[newArg]) ,code@[data]@assigments)  
     end  
   
466    
467  (*Lift probe and Multiply by P*)          val _ =(case splitvar
468  fun liftProbe(b,(params,args),index, sumIndex,origargs)=let              of true=> (String.concat["splitvar is true", P.printbody body])
469                | _ => (String.concat["splitvar is false",P.printbody body])
470            (*end case*))
471    
 val E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_))=b  
 val newId=length(params)  
 val n=length(index)  
472    
473  (*Create new tensor replacement*)          val ein0=mkEin(params,index,body)
474  val shape=ShapeConv(alpha@dx, n)          in
475  val newB=E.Tensor(newId,shape)              (splitvar,ein0,sizes,dx,alpha')
476            end
477    
478  (* Create new Param*)      fun liftProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
479  (*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)          val _=(String.concat["\n******* Lift Geneirc Probe ***\n"])
480  val shape'= mapIndex(shape,index)          val originalb=Ein.body e
481  val newP= E.TEN(1,shape')          val params=Ein.params e
482            val index=Ein.index e
483            val _ =  (toStringBindp (y, DstIL.EINAPP(e,args)))
484    
485  val newArg = DstV.new ("PC", DstTy.tensorTy shape')          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
486            val fid=length(params)
487            val nid=fid+1
488            val nshift=length(dx)
489            val (fieldset,dim,args',code,s,PArg) = handleArgsF(fieldset,Vid,hid,tid,args)
490            val freshIndex=getsumshift(sx,length(index))
491    
492            (*transform T*P*P..Ps*)
493            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
494            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
495    
 (*Expand Probe*)  
 val ns=length sumIndex  
496    
 val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)  
497    
498            (*addedhere*)
499            val ein9=mkEin(params,sizes,E.Conv(Vid,alpha',hid,dx))
500            val einApp9=mkEinApp(ein9,args)
501            val rtn9=(FArg,einApp9)
502            val (fieldset,FArg,rtn1)= (case (einVarSet.rtnVarN(fieldset,rtn9))
503                of  (fieldset,SOME v)  => let
504                        val _ = (" \n did find"^toStringBind(rtn9))
505                        in (fieldset, v,[]) end
506                | (fieldset,NONE)  => let
507                        (*lifted probe*)
508                        val _ =(" \n did not find"^toStringBind(rtn9))
509                        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
510                        val freshIndex'= length(sizes)
511                        val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
512                        val ein1=mkEin(params',sizes,body')
513                        val einApp1=mkEinApp(ein1,args')
514                        val rtn1=(FArg,einApp1)
515                        in (fieldset,FArg ,[rtn1]) end
516                (*end case*))
517    
518  val body' =(case ns          val einApp0=mkEinApp(ein0,[PArg,FArg])
519  of 0=>    createBody(dim, s,n,alpha,dx,0, 1, 3, 2)          val rtn0=(case splitvar
520                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
521  |_=>let  |_=>let
522  val (E.V v,_,_)=List.nth(sumIndex, ns-1)              val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
523                    in Split.splitEinApp bind3
 val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)  
 in  E.Sum(sumIndex ,body')  
524  end  end
525  (* end case *))  (* end case *))
526    
527  val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]          val rtn=code@rtn1@rtn0
528  val (p',i',b',a')=shift.clean(params', index,body', args')          val _= List.map toStringBindp (code@rtn1)
529  val newbie'=Ein.EIN{params=p', index=i', body=b'}          val _ ="\n**** split code **\n"
530            val _= List.map toStringBindp rtn0
531            in
532  val data=assignEin (newArg, newbie', a')              (fieldset,rtn)
   
 val _ = (case testing  
 of 0 => 1  
 | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)  
 (*end case *))  
 in (newB, (params@[newP],args@[newArg]) ,code@[data])  
533  end  end
534    
535        fun liftProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
536            val _=testp["\n******* Lift Geneirc Probe ***\n"]
537            val originalb=Ein.body e
538            val params=Ein.params e
539            val index=Ein.index e
540            val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
541    
542            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
 (*Does not yet do transformation*)  
  (* Expand probe in place *)  
  fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let  
   
     val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b  
543      val fid=length(params)      val fid=length(params)
     val n=length(index)  
   
     (*Expand Probe*)  
     val ns=length sumIndex  
     val (dim,args',code,s,P) = handleArgs(V,h,t,(params,args), origargs,0)  
544      val nid=fid+1      val nid=fid+1
545      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val nshift=length(dx)
546      val body' =(case ns          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
547          of 0=> createBody(dim, s,n,alpha,dx,V, h, nid, fid)          val freshIndex=getsumshift(sx,length(index))
548    
549            (*transform T*P*P..Ps*)
550            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
551    
552            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
553            val einApp0=mkEinApp(ein0,[PArg,FArg])
554            val rtn0=(case splitvar
555                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
556          |_=>let          |_=>let
557              val (E.V v,_,_)=List.nth(sumIndex, ns-1)                  val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
558              in createBody(dim, s,v+1,alpha,dx,V, h, nid, fid)                  in Split.splitEinApp bind3
559              end              end
560          (* end case *))          (* end case *))
     val _ =(case testing  
         of 0=> 1  
         | _ =>  let  
             val subexp=Ein.EIN{params=params', index=index, body=body'}  
             val _= print(String.concat["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])  
             in 1 end  
         (* end case *))  
   
     in (body',(params',args') ,code)  
     end  
   
   
 fun flatten []=[]  
     | flatten(e1::es)=e1@(flatten es)  
561    
562            (*lifted probe*)
563            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
564            val freshIndex'= length(sizes)
565    
566   (* sx-[] then move out, otherwise keep in *)          (*val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)*)
 fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let  
567    
568      val dummy=E.Const 0          val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid,params',args')
     val sumIndex=ref []  
569    
570      (*b-current body, info-original ein op, data-new assigments*)          (*set image and kernel*)
571      fun rewriteBody(b,info)= let          val (params',args',code)=setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,params',Vid,s)
572    
573          fun callfn(c1,body)=let          val ein1=mkEin(params',sizes,body')
574              val ref x=sumIndex          val einApp1=mkEinApp(ein1,args')
575              val c'=[c1]@x          val rtn1=(FArg,einApp1)
576              val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))          val rtn=code@[rtn1]@rtn0
577              val ref s=sumIndex          val _= List.map toStringBind ([rtn1]@rtn0)
578              val z=hd(s)          val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
579              val e'=( case bodyK            val _= List.map toStringBindp(rtn)
                 of E.Const _ =>bodyK  
                 | _ => E.Sum(z,bodyK)  
                 (*end case*))  
580              in              in
581                  (sumIndex:=tl(s);(e',infoK,dataK))              (fieldset,rtn)
582              end              end
583    
584          in (case b      fun replaceProbe e= (case pullKrn
585              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let          of true=>replaceProbe3 e
586                  val ref sx=sumIndex          | false => replaceProbe0 e
                 in (case sx  
                     of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)  
                     | _ =>  replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)  
587                  (* end case*))                  (* end case*))
588              end      fun liftProbe e=(case pullKrn
589          | E.Probe(E.Conv _, E.Tensor _) =>let          of true=>liftProbe3 e
590              val ref sx=sumIndex          | false => liftProbe0 e
             in (case sx  
                 of []=> liftProbe(b, info,index, [],origargs)  
                 | _=> replaceProbe(b, info,index, flatten sx,origargs)  
591               (* end case*))               (* end case*))
             end  
         | E.Probe _=> (dummy,info,[])  
         | E.Conv _=>  (dummy,info,[])  
         | E.Lift _=> (dummy,info,[])  
         | E.Field _ => (dummy,info,[])  
         | E.Apply _ => (dummy,info,[])  
         | E.Neg e=> let  
             val (body',info',data')=rewriteBody(e,info)  
             in  
                 (E.Neg(body'),info',data')  
             end  
         | E.Sum (c,e)=> callfn(c,e)  
         | E.Sub(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB, dataB)= rewriteBody(b,infoA)  
             in   (E.Sub(bodyA, bodyB),infoB,dataA@dataB)  
             end  
         | E.Div(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB,dataB)= rewriteBody(b,infoA)  
             in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end  
         | E.Add es=> let  
             fun filter([], done, info', data)= let  
                     val (_, e)=F.mkAdd done  
                     in (e, info',data)  
                     end  
                 | filter(e::es, done, info',data)= let  
                     val (body', info'',data')= rewriteBody(e,info')  
                     in filter(es, done@[body'], info'',data@data') end  
             in filter(es, [],info,[]) end  
592    
593          | E.Prod es=> let  
594              fun filter([], done, info',data)= let      (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
595                      val (_, e)=F.mkProd done      (* scans dx for contant
596                      in  (e,info', data)       * arg:(1,code1, body1,[])
597                      end       *)
598                  | filter(e::es, done, info',data)= let      fun reconstruction([],arg)= replaceProbe arg
599                      val (body', info'',data')= rewriteBody(e, info')       | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
600                      in filter(es, done@[body'], info'',data@data') end          of (true,true) => liftProbe arg
601                  in filter(es, [],info,[]) end          | (_,false)    => replaceProbe arg
602          | _=>  (b,info,[])          | _ => let
603                fun fConst [] = liftProbe arg
604                | fConst (E.C _::_) = replaceProbe arg
605                | fConst (_ ::es)= fConst es
606                in fConst dx end
607            (* end case*))
608    
609        (* **************************************************** Index Tensor **************************************************** *)
610        (*Push constant indices to tensor replacement*)
611        fun getF (e,fieldset,dim,newvx)= let
612            val (y, DstIL.EINAPP(ein,args))=e
613            val index0=Ein.index ein
614            val index1 = index0@dim
615            val b=Ein.body ein
616    
617            val (c1,dx,body1)=(case b
618                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
619                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
620                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
621                    in (c1,dx,b) end
622                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
623                    val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
624                    (* clean to get body indices in order *)
625                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
626                    in (c1,dx,body1) end
627                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
628                   val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
629                   val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
630                   in (c1,dx,body1) end
631                (*end case*))
632    
633            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
634            val ein1 = mkEin(Ein.params ein,index1,body1)
635            val code1= (lhs1,mkEinApp(ein1,args))
636    
637            val (lhs0,(fieldset,codeAll))= (case valnumflag
638                of false    => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
639                | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
640                    of (fieldset,NONE)     => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
641                    | (fieldset,SOME m)   =>  (m,(fieldset,[]))
642                    (*end case*))
643          (* end case *))          (* end case *))
         end  
644    
645       val empty =fn key =>NONE          (*Probe that tensor at a constant position  c1*)
646       val _ =(case testing          val param0 = [E.TEN(1,index1)]
647          of 0 => 1          val nx=List.tabulate(newvx,fn n=>E.V n)
648          | _ => (print "\n ************************** \n Starting Expand";1)          val body0 =  (case b
649                of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
650                | _ => E.Tensor(0,[c1]@nx)
651                (*end case*))
652            val ein0 = mkEin(param0,index0,body0)
653            val einApp0 = mkEinApp(ein0,[lhs0])
654            val code0 = (y,einApp0)
655            val _= toStringBind code0
656            in
657                (fieldset,codeAll@[code0])
658            end
659        (* **************************************************** General Fn **************************************************** *)
660        (* expandEinOp: code->  code list
661        * A this point we only have simple ein ops
662        * Looks to see if the expression has a probe. If so, replaces it.
663        * Note how we keeps eps expressions so only generate pieces that are used
664        *)
665        fun expandEinOp(e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
666            fun rewriteBody(fieldset,e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
667                of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
668                | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
669                | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
670                | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
671                | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
672                | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
673                | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
674                | _                                          => reconstruction(dx,(fieldset,e,p,[]))
675                (*end case*))
676            | rewriteBody(fieldset,e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
677                of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
678                | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
679                | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
680                | (_,_,_,[])                                => replaceProbe(fieldset,e,p, sx)  (*no dx*)
681                | (_,_,[],_)                                => reconstruction(dx,(fieldset,e,p,sx))
682                | _                                         => replaceProbe(fieldset,e,p, sx)
683                (* end case *))
684            | rewriteBody(fieldset,e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(fieldset,e,E.Probe p,sx)
685            | rewriteBody (fieldset,e,_)  = (fieldset,[e])
686    
687            val b=Ein.body ein
688            fun pf()=("\n **************************** starting  **************************** \n"^(P.printerE(ein)))
689            fun matchField()=(case b
690                of E.Probe _ =>  (pf();1)
691                | E.Sum (_, E.Probe _)=> (pf();1)
692                | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=> (pf();1)
693                | _ =>0
694                (*end case*))
695            val m=matchField()
696            val (fieldset,varset,code,flag) = (case valnumflag
697                of true => (case (einVarSet.rtnVarN(fieldset,e0))
698                   of  (fieldset,NONE)     => let
699                        val(fieldset,code)=rewriteBody(fieldset,e0,b)
700                        in (fieldset,varset,code,0) end
701                  | (fieldset,SOME v)    => (fieldset,varset,[(y,DstIL.VAR v)],1)
702                    (*end case*))
703                    | _     => let
704                    val(fieldset,code)=rewriteBody(fieldset,e0, b)
705                    in (fieldset,varset,code,0) end
706          (*end case*))          (*end case*))
707    
708      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))          in  (code,fieldset,varset,m,flag) end
     val e'=Ein.EIN{params=params', index=index, body=body'}  
     (*val _ =(case testing  
         of 0 => 1  
         | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)  
         (*end case*))*)  
     in  
         ((e',args'),newbies)  
     end  
709    
710    end; (* local *)    end; (* local *)
711    

Legend:
Removed from v.2608  
changed lines
  Added in v.3655

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0