Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2671, Fri Jul 18 18:57:06 2014 UTC revision 3092, Tue Mar 17 20:02:38 2015 UTC
# Line 1  Line 1 
1  (* Currently under construction  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4   * All rights reserved.   * All rights reserved.
5   *)   *)
6    
   
 (*  
 A couple of different approaches.  
 One approach is to find all the Probe(Conv). Gerenerate exp for it  
 Then use Subst function to sub in. That takes care for index matching and  
   
 *)  
   
 (*This approach creates probe expanded terms, and adds params to the end. *)  
   
   
7  structure ProbeEin = struct  structure ProbeEin = struct
8    
9      local      local
10    
11      structure E = Ein      structure E = Ein
     structure mk= mkOperators  
     structure SrcIL = HighIL  
     structure SrcTy = HighILTypes  
     structure SrcOp = HighOps  
     structure SrcSV = SrcIL.StateVar  
     structure VTbl = SrcIL.Var.Tbl  
12      structure DstIL = MidIL      structure DstIL = MidIL
     structure DstTy = MidILTypes  
13      structure DstOp = MidOps      structure DstOp = MidOps
     structure DstV = DstIL.Var  
     structure SrcV = SrcIL.Var  
14      structure P=Printer      structure P=Printer
     structure shift=ShiftEin  
     structure split=SplitEin  
     structure F=Filter  
15      structure T=TransformEin      structure T=TransformEin
16        structure MidToS = MidToString
17        structure DstV = DstIL.Var
18        structure DstTy = MidILTypes
19    
 val testing=0  
   
 datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  
 datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int  
20      in      in
21    
22        (* This file expands probed fields
23        * Take a look at ProbeEin tex file for examples
24        *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
25        * Param_ids are used to note the placement of the argument in the midIL.var list
26        * Index_ids  keep track of the shape of an Image or differentiation.
27        * Mu  bind Index_id
28        * Generally, we will refer to the following
29        *dim:dimension of field V
30        * s: support of kernel H
31        * alpha: The alpha in <V_alpha * H^(deltas)>
32        * deltas: The deltas in <V_alpha * H^(deltas)>
33        * Vid:param_id for V
34        * hid:param_id for H
35        * nid: integer position param_id
36        * fid :fractional position param_id
37        *img-imginfo about V
38        *)
39    
40  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))      val testing=0
41  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))      val testlift=1
42        val cnt = ref 0
43    
44  fun getRHS x  = (case SrcIL.Var.binding x      fun printEINAPP e=MidToString.printEINAPP e
45      of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)      fun transformToIndexSpace e=T.transformToIndexSpace e
46      | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'      fun transformToImgSpace  e=T.transformToImgSpace  e
     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)  
     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)  
     | SrcIL.VB_NONE=>(S2 2,[])  
     | vb => raise Fail(concat[  
     "expected rhs operator for ", SrcIL.Var.toString x,  
     "but found ", SrcIL.vbToString vb])  
     (* end case *))  
47    
48        fun transitionToString(testreplace,a,b)=(case testreplace
49            of 0=> 1
50            | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)
51            (*end case*))
52        fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
53        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
54        fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body
55        fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=
56                (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))
57    
58        fun testp n=(case testing
59            of 0=> 1
60            | _ =>(print(String.concat n);1)
61            (*end case*))
62        fun  einapptostring (body,a,b)=(case testlift
63            of 0=>1
64            | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)
65            (*end case*))
66    
 (*Create fractional, and integer position vectors*)  
 fun transformToImgSpace  (dim,v,posx)=let  
67    
68      val translate=DstOp.Translate v      fun getRHSDst x  = (case DstIL.Var.binding x
69      val transform=DstOp.Transform v          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
70      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
71            | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
72        (* end case *))
73    
     val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)  
     val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)  
     val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)  
     val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)  
     val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)  
     val PosToImgSpace=mk.transform(dim,dim)  
     val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)  
74    
75      val code=[      (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
76          assign(M, transform, []),          uses the Param_ids for the image, kernel,
77          assign(T, translate, []),          and position tensor to get the Mid-IL arguments
78          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)      returns the support of ther kernel, and image
79          assign(nd, DstOp.Floor dim, [x]),   (*nd *)      *)
80          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)      fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
81          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
82          assignEin(P, mk.transpose([dim,dim]), [M])              in
83          ]                ((Kernel.support h) ,img,ImageInfo.dim img)
     in ([n,f],P,code)  
84      end      end
85            |  _ => raise Fail "Expected Image and kernel arguments"
   
 fun replaceH(kvar, place,args)=let  
     val l1=List.take(args, place)  
     val l2=List.drop(args,place+1)  
     in l1@[kvar]@l2 end  
   
   
 (*Get Img, and Kern Args*)  
 fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)  
     of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let  
         val hvar=DstV.new ("KNL", DstTy.KernelTy)  
         val imgvar=DstV.new ("IMG", DstTy.ImageTy img)  
         val argsVK= (case lift  
             of 0=> let  
                 val argsN=replaceH(hvar, hid,args)  
                 in replaceH(imgvar, V,argsN) end  
             | _ => [imgvar, hvar]  
86          (* end case *))          (* end case *))
         val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]  
87    
88    
89        (*handleArgs():int*int*int*Mid IL.Var list
90            ->int*Mid.ILVars list* code*int* low-il-var
91            * uses the Param_ids for the image, kernel, and tensor
92            * and gets the mid-IL vars for each.
93            *Transforms the position to index space
94            *P is the mid-il var for the (transformation matrix)transpose
95        *)
96        fun handleArgs(Vid,hid,tid,args)=let
97            val imgArg=List.nth(args,Vid)
98            val hArg=List.nth(args,hid)
99            val newposArg=List.nth(args,tid)
100            val (s,img,dim) =getArgsDst(hArg,imgArg,args)
101            val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
102          in          in
103              (Kernel.support h ,img, assigments,argsVK)              (dim,args@argsT,code, s,P)
104          end          end
     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"  
     |  _ => raise Fail "Not a kernel argument"  
   
105    
106  fun handleArgs(V,h,t,(params,args),origargs,lift)=let      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
107      val E.IMG(dim)=List.nth(params,V)      * expands the body for the probed field
108      val kArg=List.nth(origargs,h)      *)
109      val imgArg=List.nth(origargs,V)      fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
110      val newposArg=List.nth(args, t)          (*1-d fields*)
111      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)          fun createKRND1 ()=let
112      val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)              val sum=sx
113      in (dim,argsVH@argsT,argcode@code', s,P)              val dels=List.map (fn e=>(E.C 0,e)) deltas
114                val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
115                val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
116                in
117                    E.Prod [E.Img(Vid,alpha,pos),rest]
118      end      end
   
   
 (*createDels=> creates the kronecker deltas for each Kernel*)  
 fun createDels([],_)= []  
     | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)  
   
 (*Created new body for probe*)  
 fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let  
   
     (*sumIndex creating summaiton Index for body*)  
     fun sumIndex(0)=[]  
     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]  
   
119      (*createKRN Image field and kernels *)      (*createKRN Image field and kernels *)
120      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
121      | createKRN(dim,imgpos,rest)=let      | createKRN(dim,imgpos,rest)=let
122          val dim'=dim-1          val dim'=dim-1
123          val sum=sx+dim'          val sum=sx+dim'
124          val dels=createDels(deltas,dim')              val dels=List.map (fn e=>(E.C dim',e)) deltas
125          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
126          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
127          in          in
128              createKRN(dim',pos@imgpos,[rest']@rest)              createKRN(dim',pos@imgpos,[rest']@rest)
129          end          end
130            val exp=(case dim
131      val exp=createKRN(dim, [],[])              of 1 => createKRND1()
132      val esum=sumIndex (dim)              | _=> createKRN(dim, [],[])
133      in E.Sum(esum, exp)              (*end case*))
134      end          (*sumIndex creating summaiton Index for body*)
135            val slb=1-s
136            val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
 fun ShapeConv([],n)=[]  
     | ShapeConv(E.C c::es, n)=ShapeConv(es, n)  
     | ShapeConv(E.V v::es, n)=  
         if(n>v) then [E.V v] @ ShapeConv(es, n)  
         else ShapeConv(es,n)  
   
   
 fun mapIndex([],_)=[]  
     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)  
     | mapIndex(E.C c::es,index) = mapIndex(es,index)  
   
   
 (*Lift probe and Multiply by P*)  
 fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let  
     val _ =print "Lift Probe"  
   
     val n=length(index)  
     val ns=length sumIndex  
     val nshift=length(dx)  
     val np=length(params)  
     val nsumshift =(case ns  
         of 0=>   n  
         |_=>let  val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
137              in              in
138                  v+1          E.Sum(esum, exp)
139              end              end
         (* end case *))  
   
   
     (*Outer Index-id Of Probe*)  
     val VShape=ShapeConv(alpha, n)  
     val HShape=ShapeConv(dx, n)  
     val shape=VShape@HShape  
   
     (* Bindings for Shape*)  
     val shapebind= mapIndex(shape,index)  
     val Vshapebind= mapIndex(VShape,index)  
   
     (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)  
     val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)  
   
     (*New transformations:params, sx, rest, will be empty if no transformation is made*)  
     val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)  
140    
141      (*rewriteBody*)      (*getsumshift:sum_indexid list* int list-> int
142      val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)      *get fresh/unused index_id, returns int
143        *)
144      val sx=sumIndex@sxT      fun getsumshift(sx,index) =let
145      val body'=(case sx          val nsumshift= (case sx
146          of [] =>E.Prod(restT@[bodyExpanded])              of []=> length(index)
147          | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))              | _=>let
148          (*end case*))                  val (E.V v,_,_)=List.hd(List.rev sx)
149                    in v+1
150      (*create new EIN OPerator*)                  end
     val _ =print("Found this many args ")  
     val _ =print(Int.toString(length(args')))  
   
     val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT  
     val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])  
     val newbie'=Ein.EIN{params=p', index=i', body=b'}  
     val data=assignEin (oldArg, newbie', a')  
   
     val _ = (case testing  
         of 0 => 1  
         | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)  
151          (*end case *))          (*end case *))
152            val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
153            val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
154                "\nThink nshift is ", Int.toString nsumshift]
155      in      in
156          (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)              nsumshift
157      end      end
  |liftProbe _ =raise Fail"Incorrect body for Probe"  
158    
159        (*formBody:ein_exp->ein_exp
160        *just does a quick rewrite
161        *)
162        fun formBody(E.Sum([],e))=formBody e
163        | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
164        | formBody(E.Prod [e])=e
165        | formBody e=e
166    
167        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
168        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
169          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
170    
171        (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
172                -> ein_exp* *code
173        * Transforms position to world space
174        * transforms result back to index_space
175        * rewrites body
176        * replace probe with expanded version
177        *)
178    (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
179    
180         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
181            =let
182            val originalb=Ein.body e
183            val params=Ein.params e
184            val index=Ein.index e
185    
 (*Does not yet do transformation*)  
  (* Expand probe in place *)  
  fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let  
186    
187      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
188      val fid=length(params)      val fid=length(params)
189      val nid=fid+1      val nid=fid+1
190      val n=length(index)          val Pid=nid+1
     val ns=length sumIndex  
191      val nshift=length(dx)      val nshift=length(dx)
192      val nsumshift =(case ns          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
193          of 0=> n          val freshIndex=getsumshift(sx,index)
194          | _=>let          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
             val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
             in v+1  
             end  
     (* end case *))  
   
     (*Outer Index-id Of Probe*)  
     val VShape=ShapeConv(alpha, n)  
     val HShape=ShapeConv(dx, n)  
     val shape=VShape@HShape  
     (* Bindings for Shape*)  
     val shapebind= mapIndex(shape,index)  
     val Vshapebind= mapIndex(VShape,index)  
   
   
     val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)  
     val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)  
   
195      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
196      val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
197      val body' =(case nshift          val body' = multiPs(Ps,newsx1,body')
         of 0=> body''  
         | _ => E.Sum(sxT, E.Prod(restT@[body'']))  
         (*end case*))  
     val args'=argsA@[PArg]  
     val _ =(case testing  
         of 0=> 1  
         | _ =>  let  
             val subexp=Ein.EIN{params=params', index=index, body=body'}  
             val _= print(String.concat["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])  
             in 1 end  
         (* end case *))  
   
198    
199      in (body',(params',args') ,code)          val body'=(case originalb
200      end              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
201                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
202                | _                                  => body'
203                (*end case*))
204            val _=transitionToString(testN,originalb,body')
205    
206  (*Checks if (1) Summation variable occurs just once (2) it matches n.          val args'=argsA@[PArg]
207  Then we lift otherwise expand in place *)          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
 fun checkSum(sx,b,info,index,origargs)=(case sx  
     of [(E.V i,lb,ub)]=>  let  
         val E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta))=b  
         val n=length(index)  
         val _=(case testing  
             of 1=> (print(String.concat["in check Sum\n " ,P.printbody(E.Sum([(E.V i,lb,ub)],b))]);1)  
             |_ => 1)  
208          in          in
209              if (i=n) then (case F.countSx(sx,b)              code@[einapp]
                 of (1,ixx) => liftProbe(b,info,index@[ub], [],origargs)  
                 | _ => replaceProbe(b, info,index,sx,origargs)  
                 (*end case*))  
             else replaceProbe(b, info,index, sx,origargs)  
210          end          end
     | _ =>replaceProbe(b, info,index, sx,origargs)  
     (*end case*))  
   
   
 fun flatten []=[]  
     | flatten(e1::es)=e1@(flatten es)  
211    
212    
213   (* sx-[] then move out, otherwise keep in *)      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
214  fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let          val Pid=0
215            val tid=1
216    
217      val dummy=E.Const 0          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
     val sumIndex=ref []  
218    
219      (*b-current body, info-original ein op, data-new assigments*)          (*need to rewrite dx*)
220      fun rewriteBody(b,info)= let          val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx
221                of []=> ([],index,E.Conv(9,alpha,7,newdx))
222                | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
223                (*end case*))
224    
225          fun callfn(c1,body)=let          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
226              val ref x=sumIndex          val tshape=alpha@newdx
227              val c'=[c1]@x          val t=E.Tensor(tid,tshape)
228              val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))          val exp = multiPs(Ps,newsx,t)
229              val ref s=sumIndex          val body=(case originalb
230              val z=hd(s)              of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)
231              val e'=( case bodyK              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])
232                  of E.Const _ =>bodyK              | _                                  => exp
                 | _ => E.Sum(z,bodyK)  
233                  (*end case*))                  (*end case*))
234    
235            val ein0=mkEin(params,index,body)
236              in              in
237                  (sumIndex:=tl(s);(e',infoK,dataK))              (ein0,sizes,dx)
238              end              end
239    
240        fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let
241            val originalb=Ein.body e
242            val params=Ein.params e
243            val index=Ein.index e
244    
245          (*Nothing liftProbe and checkSum are commented out.          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
246              Some mistake underestimating size of dimension*)          val fid=length(params)
247            val nid=fid+1
248            val nshift=length(dx)
249            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
250            val freshIndex=getsumshift(sx,index)
251    
252          fun filter es=let  
253              fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)          (*transform T*P*P..Ps*)
254              | filterApply(B::es, doneA, infoA,dataA)= let          val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
255                  val (bodyB, infoB,dataB)= rewriteBody(B,infoA)          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
256                  in          val einApp0=mkEinApp(ein0,[PArg,FArg])
257                      filterApply(es, doneA@[bodyB], infoB,dataA@dataB)          val rtn0=(y,einApp0)
258                  end  
259              in filterApply(es, [],info,[])          (*lifted probe*)
260              end          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
261          in (case b          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
262              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let          val ein1=mkEin(params',sizes,body')
263                  val ref sx=sumIndex          val einApp1=mkEinApp(ein1,args')
264                  in (case sx          val rtn1=(FArg,einApp1)
265                      of  (* [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)          val rtn=code@[rtn1,rtn0]
266                        | [i]=> checkSum(i,b, info,index,origargs)          val _= einapptostring (p,rtn1,rtn0)
                       |*) _ => let  
                         val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)  
                         in (E.Sum(c,b),m,code)  
                         end  
                 (* end case*))  
             end  
         | E.Probe(E.Conv _, E.Tensor _) =>let  
             val ref sx=sumIndex  
             in (case sx  
                 of (* []=> liftProbe(b, info,index, [],origargs)  
                 | [i]=> checkSum(i,b, info,index,origargs)  
                 |*) _ => replaceProbe(b, info,index, flatten sx,origargs)  
              (* end case*))  
             end  
         | E.Probe _=> (dummy,info,[])  
         | E.Conv _=>  (dummy,info,[])  
         | E.Lift _=> (dummy,info,[])  
         | E.Field _ => (dummy,info,[])  
         | E.Apply _ => (dummy,info,[])  
         | E.Neg e=> let  
             val (body',info',data')=rewriteBody(e,info)  
267              in              in
268                  (E.Neg(body'),info',data')              rtn
             end  
         | E.Sum (c,e)=> callfn(c,e)  
         | E.Sub(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB, dataB)= rewriteBody(b,infoA)  
             in   (E.Sub(bodyA, bodyB),infoB,dataA@dataB)  
             end  
         | E.Div(a,b)=>let  
             val (bodyA,infoA,dataA)= rewriteBody(a,info)  
             val (bodyB, infoB,dataB)= rewriteBody(b,infoA)  
             in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end  
         | E.Add es=> let  
             val (done, info',data')= filter es  
             val (_, e)=F.mkAdd done  
             in (e, info',data')  
             end  
         | E.Prod es=> let  
             val (done, info',data')= filter es  
             val (_, e)=F.mkProd done  
             in (e, info',data')  
             end  
         | _=>  (b,info,[])  
         (* end case *))  
269          end          end
270    
      val empty =fn key =>NONE  
      val _ =(case testing  
         of 0 => 1  
         | _ => (print "\n ************************** \n Starting Expand";1)  
         (*end case*))  
271    
272      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))      (* expandEinOp: code->  code list
273      val e'=Ein.EIN{params=params', index=index, body=body'}      *A this point we only have simple ein ops
274      (*val _ =(case testing      *Looks to see if the expression has a probe. If so, replaces it.
275          of 0 => 1      * Note how we keeps eps expressions so only generate pieces that are used
276          | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)      *)
277          (*end case*))*)      fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let
278            fun checkConst ([],a) = liftProbe a
279            | checkConst ((E.C _::_),a) =replaceProbe a
280            | checkConst ((_ ::es),a)=checkConst(es,a)
281            fun rewriteBody b=(case b
282                of E.Probe(E.Conv(_,_,_,[]),_)
283                    => replaceProbe(1,e,b, [])
284                | E.Probe(E.Conv (_,alpha,_,dx),_)
285                    => checkConst(alpha@dx,(0,e,b,[]))
286                | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
287                    => replaceProbe(1,e,p, sx)
288                | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
289                    => checkConst(dx,(0,e,p,sx))
290                | E.Sum(sx,E.Probe p)
291                    => replaceProbe(1,e,E.Probe p, sx)
292                | E.Sum(sx,E.Prod[eps,E.Probe p])
293                    => replaceProbe(1,e,E.Probe p,sx)
294                | _ => [e]
295                (* end case *))
296      in      in
297          ((e',args'),newbies)              rewriteBody (Ein.body ein)
298      end      end
299    
300    end; (* local *)    end; (* local *)

Legend:
Removed from v.2671  
changed lines
  Added in v.3092

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0