Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3460, Mon Nov 23 20:27:35 2015 UTC revision 3472, Tue Dec 1 18:45:25 2015 UTC
# Line 40  Line 40 
40      *)      *)
41    
42      val testing=0      val testing=0
   
   
   
43      val valnumflag=true      val valnumflag=true
44      val tsplitvar=true      val tsplitvar=true
45      val fieldliftflag=true      val fieldliftflag=true
# Line 64  Line 61 
61      fun setSub e= E.setSub e      fun setSub e= E.setSub e
62      fun setProd e= E.setProd e      fun setProd e= E.setProd e
63      fun setAdd e= E.setAdd e      fun setAdd e= E.setAdd e
64        fun mkCx es =List.map (fn c => E.C (c,true)) es
65        fun mkCxSingle c = E.C (c,true)
66    
67      fun testp n=(case testing      fun testp n=(case testing
68          of 0=> 1          of 0=> 1
69          | _ =>((String.concat n);1)          | _ =>(print(String.concat n);1)
70          (*end case*))          (*end case*))
71    
72    
# Line 116  Line 115 
115          (*1-d fields*)          (*1-d fields*)
116          fun createKRND1 ()=let          fun createKRND1 ()=let
117              val sum=sx              val sum=sx
118              val dels=List.map (fn e=>(E.C 0,e)) deltas              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
119              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
120              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
121              in              in
# Line 127  Line 126 
126          | createKRN(dim,imgpos,rest)=let          | createKRN(dim,imgpos,rest)=let
127              val dim'=dim-1              val dim'=dim-1
128              val sum=sx+dim'              val sum=sx+dim'
129              val dels=List.map (fn e=>(E.C dim',e)) deltas              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
130              val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
131              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
132              in              in
133                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
134              end              end
# Line 186  Line 185 
185        | multiMergePs e=multiPs e        | multiMergePs e=multiPs e
186    
187    
188      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* *******************************************  Replace probe *******************************************  *)
189              -> ein_exp* *code      (* replaceProbe
190      * Transforms position to world space      * Transforms position to world space
191      * transforms result back to index_space      * transforms result back to index_space
192      * rewrites body      * rewrites body
193      * replace probe with expanded version      * replace probe with expanded version
194      *)      *)
195  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)       fun replaceProbe((y, DstIL.EINAPP(e,args)),p ,sx)
   
      fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)  
196          =let          =let
197          val originalb=Ein.body e          val originalb=Ein.body e
198          val params=Ein.params e          val params=Ein.params e
# Line 221  Line 218 
218              | _                                  => body'              | _                                  => body'
219              (*end case*))              (*end case*))
220    
   
221          val args'=argsA@[PArg]          val args'=argsA@[PArg]
222          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
223          in          in
224              code@[einapp]              code@[einapp]
225          end          end
226    
227        (* ******************************************* Lift probe *******************************************  *)
228      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
229          val Pid=0          val Pid=0
230          val tid=1          val tid=1
# Line 270  Line 266 
266              (splitvar,ein0,sizes,dx,alpha')              (splitvar,ein0,sizes,dx,alpha')
267          end          end
268    
269      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe((y, DstIL.EINAPP(e,args)),p ,sx)=let
270          val _=testp["\n******* Lift Geneirc Probe ***\n"]          val _=testp["\n******* Lift Geneirc Probe ***\n"]
271          val originalb=Ein.body e          val originalb=Ein.body e
272          val params=Ein.params e          val params=Ein.params e
# Line 312  Line 308 
308              rtn              rtn
309          end          end
310    
311      fun searchFullField (fieldset,code1,body1,dx)=let      (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
312          val (lhs,_)=code1      (* scans dx for contant
313          fun continueReconstruction ()=let       * arg:(1,code1, body1,[])
314              val _=print"Tash:don't replaced"       *)
315              in (case dx      fun reconstruction([],arg)= replaceProbe arg
316                  of []=> (lhs,replaceProbe(1,code1,body1,[]))       | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
317                  | _ =>(lhs,liftProbe(1,code1,body1,[]))          of (true,true) => liftProbe arg
318                  (*end case*))          | (_,false)    => replaceProbe arg
319               end          | _ => let
320          in  (case valnumflag              fun fConst [] = liftProbe arg
321              of false    => (fieldset,continueReconstruction())              | fConst (E.C _::_) = replaceProbe arg
322              | true      => (case  (einSet.rtnVarN(fieldset,code1))              | fConst (_ ::es)= fConst es
323                  of (fieldset,NONE)     => (fieldset,continueReconstruction())              in fConst dx end
                  | (fieldset,SOME m)   =>(print"TASH:replaced"; (fieldset,(m,[])))  
                 (*end case*))  
324              (*end case*))              (*end case*))
         end  
325    
326      fun liftFieldMat(newvx,e)=      (* **************************************************** Index Tensor **************************************************** *)
327          let      (*Push constant indices to tensor replacement*)
328              val _=testp[ "\n ***************************** start FieldMat\n"]      fun getF (e,fieldset,dim,newvx)= let
329              val (y, DstIL.EINAPP(ein,args))=e              val (y, DstIL.EINAPP(ein,args))=e
             val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein  
330              val index0=Ein.index ein              val index0=Ein.index ein
331              val index1 = index0@[3]          val index1 = index0@dim
332            val b=Ein.body ein
333    
334            val (c1,dx,body1)=(case b
335                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
336                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
337                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
338                    in (c1,dx,b) end
339                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
340              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
341              (* clean to get body indices in order *)              (* clean to get body indices in order *)
342              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
343              val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]                  in (c1,dx,body1) end
344                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
             val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
             val ein1 = mkEin(Ein.params ein,index1,body1)  
             val code1= (lhs1,mkEinApp(ein1,args))  
             val codeAll= (case dx  
             of []=> replaceProbe(1,code1,body1,[])  
             | _ =>liftProbe(1,code1,body1,[])  
             (*end case*))  
   
             (*Probe that tensor at a constant position  c1*)  
             val param0 = [E.TEN(1,index1)]  
             val nx=List.tabulate(length(dx)+1,fn n=>E.V n)  
             val body0 =  E.Tensor(0,[c1]@nx)  
             val ein0 = mkEin(param0,index0,body0)  
             val einApp0 = mkEinApp(ein0,[lhs1])  
             val code0 = (y,einApp0)  
             val _= toStringBind code0  
             val _=testp["\n end FieldMat *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
     fun liftFieldVec(newvx,e,fieldset)=  
     let  
         val _=testp[ "\n ***************************** start FieldVec\n"]  
         val (y, DstIL.EINAPP(ein,args))=e  
         val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]  
345          val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)          val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
         (* clean to get body indices in order *)  
346          val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])          val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
347                   in (c1,dx,body1) end
348                (*end case*))
349    
350          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
351          val ein1 = mkEin(Ein.params ein,index1,body1)          val ein1 = mkEin(Ein.params ein,index1,body1)
352          val code1= (lhs1,mkEinApp(ein1,args))          val code1= (lhs1,mkEinApp(ein1,args))
         val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)  
353    
354          (*Probe that tensor at a constant position  c1*)          val (_,(lhs0,codeAll))= (case valnumflag
355          val param0 = [E.TEN(1,index1)]              of false    => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
356          val nx=List.tabulate(length(dx),fn n=>E.V n)              | true      => (case  (einSet.rtnVarN(fieldset,code1))
357          val body0 =  E.Tensor(0,[c1]@nx)                  of (fieldset,NONE)     => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
358          val ein0 = mkEin(param0,index0,body0)                  | (fieldset,SOME m)   =>  (fieldset,(m,[]))
359          val einApp0 = mkEinApp(ein0,[lhs0])                  (*end case*))
         val code0 = (y,einApp0)  
   
         val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]  
         val _ = (toStringBind code0)  
         val _ = testp[ "\n end FieldVec *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
   
   
     fun liftFieldSum e =  
     let  
         val _=testp[ "\n************************************* Start Lift Field Sum\n"]  
         val (y, DstIL.EINAPP(ein,args))=e  
         val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]@[3]  
         val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))  
         val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)  
   
   
         val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
         val ein1 = mkEin(Ein.params ein,index1,body1)  
         val code1= (lhs1,mkEinApp(ein1,args))  
         val codeAll= (case dx  
             of []   => replaceProbe(1,code1,body1,[])  
             | _     =>liftProbe(1,code1,body1,[])  
360              (*end case*))              (*end case*))
361    
362          (*Probe that tensor at a constant position  c1*)          (*Probe that tensor at a constant position  c1*)
363          val param0 = [E.TEN(1,index1)]          val param0 = [E.TEN(1,index1)]
364          val nx=List.tabulate(length(dx),fn n=>E.V n)          val nx=List.tabulate(newvx,fn n=>E.V n)
365          val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))          val body0 =  (case b
366                of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
367                | _ => E.Tensor(0,[c1]@nx)
368                (*end case*))
369          val ein0 = mkEin(param0,index0,body0)          val ein0 = mkEin(param0,index0,body0)
370          val einApp0 = mkEinApp(ein0,[lhs1])          val einApp0 = mkEinApp(ein0,[lhs0])
371          val code0 = (y,einApp0)          val code0 = (y,einApp0)
         val _ = toStringBind  e  
372          val _ = toStringBind code0          val _ = toStringBind code0
         val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])  
         val _  =((List.map toStringBind (codeAll@[code0])))  
         val _ = testp["\n*** end Field Sum*************************************\n"]  
373          in          in
374          codeAll@[code0]          codeAll@[code0]
375      end      end
376        (* **************************************************** General Fn **************************************************** *)
   
377      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
378      * A this point we only have simple ein ops      * A this point we only have simple ein ops
379      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.
380      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
381      *)      *)
382     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
383            fun rewriteBody(p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
384      fun checkConst(es,a)=(case constflag              of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
385          of true => liftProbe a              | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
386          | _ => let              | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
387              fun fConst ([],a) =              | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
388                  (case fieldliftflag              | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
389                      of true => liftProbe a              | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
390                      | _ => replaceProbe a              | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
391                  (*end case*))              | _                                          => reconstruction(dx,(e,p,[]))
392                  | fConst ((E.C _::_),a) = replaceProbe a              (*end case*))
393                      (*raise Fail (String.concat["\nFound it:",P.printerE(ein)])*)          | rewriteBody(E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
394              | fConst ((_ ::es),a)= checkConst(es,a)              of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
395              in fConst(es,a) end              | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
396        (* end case*))              | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
397                | (_,_,_,[])                                => replaceProbe(e,p, sx)  (*no dx*)
398                | (_,_,[],_)                                => reconstruction(dx,(e,p,sx))
399                | _                                         => replaceProbe(e,p, sx)
     fun rewriteBodyB b=(case b  
         of  (E.Probe(E.Conv(_,_,_,[]),_))  
             => replaceProbe(0,e,b,[])  
         | (E.Probe(E.Conv (_,alpha,_,dx),_))  
             => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)  
         | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))  
             => replaceProbe(0,e,p, sx)  (*no dx*)  
         | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))  
             => checkConst(dx,(0,e,p,sx)) (*scalar field*)  
         | (E.Sum(sx,E.Probe p))  
             => replaceProbe(0,e,E.Probe p, sx)  
         | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))  
             => replaceProbe(0,e,E.Probe p,sx)  
         | _ => [e]  
         (* end case *))  
   
   
         fun rewriteBody b=(case detflag  
             of true => (case (detsumflag,b)  
                 of (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))  
                     => liftFieldMat (1,e)  
                 | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))  
                     => liftFieldMat (2,e)  
                 | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))  
                     => liftFieldMat (3,e)  
                 | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))  
                     => liftFieldSum e  
                 | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))  
                     => liftFieldSum e  
                 | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))  
                     => liftFieldSum e  
                 | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))  
                     => liftFieldVec (0,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))  
                     => liftFieldVec (1,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))  
                     => liftFieldVec (2,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))  
                         => liftFieldVec (3,e,fieldset)  
                 | _   => rewriteBodyB b  
                 (* end case *))  
             | _   => rewriteBodyB b  
             (* end case *))  
   
         val (fieldset,var) = (case valnumflag  
             of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))  
             | _     => (fieldset,NONE)  
400          (*end case*))          (*end case*))
401            | rewriteBody(E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(e,E.Probe p,sx)
402            | rewriteBody _  = [e]
403    
404          fun matchField b=(case b          val b=Ein.body ein
405            fun matchField()=(case b
406              of E.Probe _ => 1              of E.Probe _ => 1
407              | E.Sum (_, E.Probe _)=>1              | E.Sum (_, E.Probe _)=>1
408              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
409              | _ =>0              | _ =>0
410              (*end case*))              (*end case*))
411          fun toStrField b=(case b          val (fieldset,code,flag) = (case valnumflag
412              of E.Probe _ => print("\n"^(P.printbody b))              of true => (case (einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args)))
413              | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))                  of (fldset,NONE)     => (fldset,rewriteBody b,0)
414              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>print("\n"^ (P.printbody b))                  | (fldset,SOME v)    => (fldset,[(y,DstIL.VAR v)],1)
             | _ => print""  
415              (*end case*))              (*end case*))
416          val b=Ein.body ein              | _     => (fieldset,rewriteBody b,0)
   
         in  (case var  
         of NONE=> (toStrField(b);(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))  
             | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))  
417              (*end case*))              (*end case*))
418          end          val m=matchField()
419            in  (code,fieldset,m,flag) end
420    
421    end; (* local *)    end; (* local *)
422    

Legend:
Removed from v.3460  
changed lines
  Added in v.3472

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0