Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3551, Wed Jan 6 16:01:32 2016 UTC revision 3569, Mon Jan 11 05:47:54 2016 UTC
# Line 40  Line 40 
40      val valnumflag = true      val valnumflag = true
41      val tsplitvar = true      val tsplitvar = true
42      val fieldliftflag = true      val fieldliftflag = true
     val constflag = false  
43      val detflag = true      val detflag = true
44    
45      fun transformToIndexSpace e = T.transformToIndexSpace e      fun transformToIndexSpace e = T.transformToIndexSpace e
46      fun transformToImgSpace e = T.transformToImgSpace  e      fun transformToImgSpace e = T.transformToImgSpace  e
47      fun toStringBind e = (MidToString.toStringBind e)  
48      fun mkEin e = Ein.mkEin e      fun mkEin e = Ein.mkEin e
49      fun mkEinApp (rator, args) = DstIL.EINAPP(rator, args)      fun mkEinApp (rator, args) = DstIL.EINAPP(rator, args)
50      fun setConst e = E.setConst e      fun setConst e = E.setConst e
# Line 125  Line 124 
124              (* end case *))              (* end case *))
125          (*sumIndex creating summaiton Index for body*)          (*sumIndex creating summaiton Index for body*)
126          val slb=1-s          val slb=1-s
         val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))  
127          val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))          val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
128      in      in
129          E.Sum(esum, exp)          E.Sum(esum, exp)
# Line 134  Line 132 
132      (*getsumshift:sum_indexid list* int list-> int      (*getsumshift:sum_indexid list* int list-> int
133      *get fresh/unused index_id, returns int      *get fresh/unused index_id, returns int
134      *)      *)
135        fun getsumshift ([], n) = n
136      fun getsumshift (sx, n) =let      fun getsumshift (sx, n) =let
         val nsumshift= (case sx  
             of []=> n  
             | _=>let  
137                  val (E.V v,_,_)=List.hd(List.rev sx)                  val (E.V v,_,_)=List.hd(List.rev sx)
                 in v+1  
                 end  
             (* end case *))  
   
         val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx  
         val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),  
         "\n\t Index length:",Int.toString n,  
         "\n\t Freshindex: ", Int.toString nsumshift])  
138          in          in
139              nsumshift              v+1
140          end          end
141    
142      (*formBody:ein_exp->ein_exp      (*formBody:ein_exp->ein_exp
     *just does a quick rewrite  
143      *)      *)
144      fun formBody (E.Sum([],e))=formBody e      fun formBody (E.Sum([],e))=formBody e
145        | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)        | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)
# Line 179  Line 166 
166      * rewrites body      * rewrites body
167      * replace probe with expanded version      * replace probe with expanded version
168      *)      *)
169  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)       fun replaceProbe (testN, (y, DstIL.EINAPP(e,args)), p, sx) = let
   
      fun replaceProbe (testN, (y, DstIL.EINAPP(e,args)), p, sx)  
         =let  
         val originalb=Ein.body e  
170          val params=Ein.params e          val params=Ein.params e
         val index=Ein.index e  
         val _ = testp["\n***************** \n Replace ************ \n"]  
         val _=  toStringBind (y, DstIL.EINAPP(e,args))  
   
         val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p  
171          val fid=length(params)          val fid=length(params)
172          val nid=fid+1          val nid=fid+1
173          val Pid=nid+1          val Pid=nid+1
174            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p
175          val nshift=length(dx)          val nshift=length(dx)
176          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
177          val freshIndex=getsumshift(sx,length(index))          val freshIndex = getsumshift(sx,length(Ein.index e))
178          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
179          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
180          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
181          val body' = multiPs(Ps,newsx1,body')          val body' = multiPs(Ps,newsx1,body')
182            val body'= (case (Ein.body e)
         val body'=(case originalb  
183              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
184              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
185              | _                                  => body'              | _                                  => body'
186              (* end case *))              (* end case *))
187            val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))
   
         val args'=argsA@[PArg]  
         val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))  
188          in          in
189              code@[einapp]              code@[einapp]
190          end          end
191    
   
192      fun createEinApp (originalb, alpha, index, freshIndex, dim, dx, sx) = let      fun createEinApp (originalb, alpha, index, freshIndex, dim, dx, sx) = let
193          val Pid=0          val Pid=0
194          val tid=1          val tid=1
# Line 227  Line 201 
201              of []=> ([],index,E.Conv(9,alpha,7,newdx))              of []=> ([],index,E.Conv(9,alpha,7,newdx))
202              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
203              (* end case *))              (* end case *))
   
204          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
205          fun filterAlpha []=[]          fun filterAlpha []=[]
206            | filterAlpha(E.C _::es)= filterAlpha es            | filterAlpha(E.C _::es)= filterAlpha es
207            | filterAlpha(e1::es)=[e1]@(filterAlpha es)            | filterAlpha(e1::es)=[e1]@(filterAlpha es)
   
208          val tshape=filterAlpha(alpha')@newdx          val tshape=filterAlpha(alpha')@newdx
209          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
   
210          val (splitvar,body)=(case originalb          val (splitvar,body)=(case originalb
211              of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))              of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
212              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
213              | _                                  => (case tsplitvar              | _  => (true, multiPs(Ps, newsx, t))
               of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)  
                 | false*) _ =>   (true,multiPs(Ps,newsx,t))  
                 (* end case *))  
214          (* end case *))          (* end case *))
215    
         val _ =(case splitvar  
         of true=> (String.concat["splitvar is true", P.printbody body])  
         | _ => (String.concat["splitvar is false",P.printbody body])  
         (* end case *))  
   
   
216          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
217          in          in
218              (splitvar,ein0,sizes,dx,alpha')              (splitvar,ein0,sizes,dx,alpha')
219          end          end
220    
221      fun liftProbe (printStrings, (y, DstIL.EINAPP(e, args)), p, sx) = let      fun liftProbe (printStrings, (y, DstIL.EINAPP(e, args)), p, sx) = let
222          val _=testp["\n******* Lift Geneirc Probe ***\n"]  
         val originalb=Ein.body e  
223          val params=Ein.params e          val params=Ein.params e
224          val index=Ein.index e          val index=Ein.index e
         val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))  
   
         val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p  
225          val fid=length(params)          val fid=length(params)
226          val nid=fid+1          val nid=fid+1
227            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p
228          val nshift=length(dx)          val nshift=length(dx)
229          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
230          val freshIndex=getsumshift(sx,length(index))          val freshIndex=getsumshift(sx,length(index))
231    
232          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
233          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar, ein0, sizes, dx, alpha') = createEinApp(Ein.body e, alpha, index, freshIndex, dim, dx, sx)
   
234          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
235          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
236          val rtn0=(case splitvar          val rtn0=(case splitvar
237              of false => [(y,mkEinApp(ein0,[PArg,FArg]))]              of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
238              | _      => let              | _      => let
239                   val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))                   val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
240                   in Split.splitEinApp bind3                  in
241                        Split.splitEinApp bind3
242                   end                   end
243              (* end case *))              (* end case *))
244    
245          (*lifted probe*)          (*lifted probe*)
246          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
247          val freshIndex'= length(sizes)          val freshIndex'= length(sizes)
   
248          val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
249          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
250          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
251          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
         val rtn=code@[rtn1]@rtn0  
         val _= List.map toStringBind ([rtn1]@rtn0)  
          val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])  
252          in          in
253              rtn              code@[rtn1]@rtn0
         end  
   
     fun searchFullField (fieldset,code1,body1,dx)=let  
         val (lhs,_)=code1  
         fun continueReconstruction ()=let  
             val _=print"Tash:don't replaced"  
             in (case dx  
                 of []=> (lhs,replaceProbe(1,code1,body1,[]))  
                 | _ =>(lhs,liftProbe(1,code1,body1,[]))  
                 (* end case *))  
              end  
         in  (case valnumflag  
             of false    => (fieldset,continueReconstruction())  
             | true      => (case  (einSet.rtnVarN(fieldset,code1))  
                 of (fieldset,NONE)     => (fieldset,continueReconstruction())  
                  | (fieldset,SOME m)   =>(print"TASH:replaced"; (fieldset,(m,[])))  
                 (* end case *))  
             (* end case *))  
         end  
   
     fun liftFieldMat(newvx,e)=  
         let  
             val _=testp[ "\n ***************************** start FieldMat\n"]  
             val (y, DstIL.EINAPP(ein,args))=e  
             val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein  
             val index0=Ein.index ein  
             val index1 = index0@[3]  
             val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)  
             (* clean to get body indices in order *)  
             val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])  
             val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]  
   
             val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
             val ein1 = mkEin(Ein.params ein,index1,body1)  
             val code1= (lhs1,mkEinApp(ein1,args))  
             val codeAll= (case dx  
             of []=> replaceProbe(1,code1,body1,[])  
             | _ =>liftProbe(1,code1,body1,[])  
             (* end case *))  
   
             (*Probe that tensor at a constant position  c1*)  
             val param0 = [E.TEN(1,index1)]  
             val nx=List.tabulate(length(dx)+1,fn n=>E.V n)  
             val body0 =  E.Tensor(0,[c1]@nx)  
             val ein0 = mkEin(param0,index0,body0)  
             val einApp0 = mkEinApp(ein0,[lhs1])  
             val code0 = (y,einApp0)  
             val _= toStringBind code0  
             val _=testp["\n end FieldMat *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
     fun liftFieldVec(newvx,e,fieldset)=  
     let  
         val _=testp[ "\n ***************************** start FieldVec\n"]  
         val (y, DstIL.EINAPP(ein,args))=e  
         val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]  
         val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)  
         (* clean to get body indices in order *)  
         val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])  
   
   
         val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
         val ein1 = mkEin(Ein.params ein,index1,body1)  
         val code1= (lhs1,mkEinApp(ein1,args))  
         val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)  
   
         (*Probe that tensor at a constant position  c1*)  
         val param0 = [E.TEN(1,index1)]  
         val nx=List.tabulate(length(dx),fn n=>E.V n)  
         val body0 =  E.Tensor(0,[c1]@nx)  
         val ein0 = mkEin(param0,index0,body0)  
         val einApp0 = mkEinApp(ein0,[lhs0])  
         val code0 = (y,einApp0)  
   
         val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]  
         val _ = (toStringBind code0)  
         val _ = testp[ "\n end FieldVec *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
   
   
     fun liftFieldSum e =  
     let  
         val _=testp[ "\n************************************* Start Lift Field Sum\n"]  
         val (y, DstIL.EINAPP(ein,args))=e  
         val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]@[3]  
         val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))  
         val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)  
   
   
         val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
         val ein1 = mkEin(Ein.params ein,index1,body1)  
         val code1= (lhs1,mkEinApp(ein1,args))  
         val codeAll= (case dx  
             of []   => replaceProbe(1,code1,body1,[])  
             | _     =>liftProbe(1,code1,body1,[])  
             (* end case *))  
   
         (*Probe that tensor at a constant position  c1*)  
         val param0 = [E.TEN(1,index1)]  
         val nx=List.tabulate(length(dx),fn n=>E.V n)  
         val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))  
         val ein0 = mkEin(param0,index0,body0)  
         val einApp0 = mkEinApp(ein0,[lhs1])  
         val code0 = (y,einApp0)  
         val _ = toStringBind  e  
         val _ = toStringBind code0  
         val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])  
         val _  =((List.map toStringBind (codeAll@[code0])))  
         val _ = testp["\n*** end Field Sum*************************************\n"]  
         in  
         codeAll@[code0]  
254      end      end
255    
256    
# Line 425  Line 260 
260      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
261      *)      *)
262     fun expandEinOp (e as (y, DstIL.EINAPP(ein, args)), fieldset) = let     fun expandEinOp (e as (y, DstIL.EINAPP(ein, args)), fieldset) = let
263            fun rewriteBody b=(case b
     fun checkConst(es,a)=(case constflag  
         of true => liftProbe a  
         | _ => let  
             fun fConst ([],a) =  
                 (case fieldliftflag  
                     of true => liftProbe a  
                     | _ => replaceProbe a  
                 (* end case *))  
             | fConst ((E.C _::_),a) = replaceProbe a  
             | fConst ((_ ::es),a)= checkConst(es,a)  
             in fConst(es,a) end  
       (* end case*))  
         fun rewriteBodyB b=(case b  
264              of  (E.Probe(E.Conv(_,_,_,[]),_))              of  (E.Probe(E.Conv(_,_,_,[]),_))
265                  => replaceProbe(0,e,b,[])                  => replaceProbe(0,e,b,[])
266              | (E.Probe(E.Conv (_,alpha,_,dx),_))              | (E.Probe(E.Conv (_,alpha,_,dx),_))
267                  => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)                  => liftProbe (0,e,b,[]) (*scans dx for contant*)
268              | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))              | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
269                  => replaceProbe(0,e,p, sx)  (*no dx*)                  => replaceProbe(0,e,p, sx)  (*no dx*)
270              | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))              | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
271                  => checkConst(dx,(0,e,p,sx)) (*scalar field*)                  => liftProbe (0,e,p,sx) (*scalar field*)
272              | (E.Sum(sx,E.Probe p))              | (E.Sum(sx,E.Probe p))
273                  => replaceProbe(0,e,E.Probe p, sx)                  => replaceProbe(0,e,E.Probe p, sx)
274              | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))              | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
275                  => replaceProbe(0,e,E.Probe p,sx)                  => replaceProbe(0,e,E.Probe p,sx)
276              | _ => [e]              | _ => [e]
277              (* end case *))              (* end case *))
         fun rewriteBody b=(case detflag  
             of true => (case b  
                 of (E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))  
                     => liftFieldMat (1,e)  
                 | (E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))  
                     => liftFieldMat (2,e)  
                 | (E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2]),pos))  
                     => liftFieldMat (3,e)  
                 | _   => rewriteBodyB b  
                 (* end case *))  
             | _   => rewriteBodyB b  
             (* end case *))  
278          val (fieldset,var) = (case valnumflag          val (fieldset,var) = (case valnumflag
279              of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))              of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
280              | _     => (fieldset,NONE)              | _     => (fieldset,NONE)

Legend:
Removed from v.3551  
changed lines
  Added in v.3569

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0