Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3443, Thu Nov 19 23:24:18 2015 UTC revision 3503, Thu Dec 17 23:13:57 2015 UTC
# Line 40  Line 40 
40      *)      *)
41    
42      val testing=0      val testing=0
   
   
   
43      val valnumflag=true      val valnumflag=true
44      val tsplitvar=false      val tsplitvar=true
45      val fieldliftflag=false      val fieldliftflag=true
46      val constflag=false      val constflag=true
47      val detflag =false      val detflag =true
48      val detsumflag=false      val detsumflag=true
49        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50        fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51    
52      val cnt = ref 0      val cnt = ref 0
53      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
# Line 64  Line 62 
62      fun setSub e= E.setSub e      fun setSub e= E.setSub e
63      fun setProd e= E.setProd e      fun setProd e= E.setProd e
64      fun setAdd e= E.setAdd e      fun setAdd e= E.setAdd e
65        fun mkCx es =List.map (fn c => E.C (c,true)) es
66        fun mkCxSingle c = E.C (c,true)
67    
68      fun testp n=(case testing      fun testp n=(case testing
69          of 0=> 1          of 0=> 1
70          | _ =>((String.concat n);1)          | _ =>(print(String.concat n);1)
71          (*end case*))          (*end case*))
72    
73    
# Line 116  Line 116 
116          (*1-d fields*)          (*1-d fields*)
117          fun createKRND1 ()=let          fun createKRND1 ()=let
118              val sum=sx              val sum=sx
119              val dels=List.map (fn e=>(E.C 0,e)) deltas              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
120              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
121              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
122              in              in
# Line 127  Line 127 
127          | createKRN(dim,imgpos,rest)=let          | createKRN(dim,imgpos,rest)=let
128              val dim'=dim-1              val dim'=dim-1
129              val sum=sx+dim'              val sum=sx+dim'
130              val dels=List.map (fn e=>(E.C dim',e)) deltas              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
131              val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
132              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
133              in              in
134                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
135              end              end
# Line 186  Line 186 
186        | multiMergePs e=multiPs e        | multiMergePs e=multiPs e
187    
188    
189      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* *******************************************  Replace probe *******************************************  *)
190              -> ein_exp* *code      (* replaceProbe
191      * Transforms position to world space      * Transforms position to world space
192      * transforms result back to index_space      * transforms result back to index_space
193      * rewrites body      * rewrites body
194      * replace probe with expanded version      * replace probe with expanded version
195      *)      *)
196  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)       fun replaceProbe((y, DstIL.EINAPP(e,args)),p ,sx)
   
      fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)  
197          =let          =let
198          val originalb=Ein.body e          val originalb=Ein.body e
199          val params=Ein.params e          val params=Ein.params e
# Line 221  Line 219 
219              | _                                  => body'              | _                                  => body'
220              (*end case*))              (*end case*))
221    
   
222          val args'=argsA@[PArg]          val args'=argsA@[PArg]
223          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
224          in          in
225              code@[einapp]              code@[einapp]
226          end          end
227    
228        (* ******************************************* Lift probe *******************************************  *)
229      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
230          val Pid=0          val Pid=0
231          val tid=1          val tid=1
# Line 270  Line 267 
267              (splitvar,ein0,sizes,dx,alpha')              (splitvar,ein0,sizes,dx,alpha')
268          end          end
269    
270      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe((y, DstIL.EINAPP(e,args)),p ,sx)=let
271          val _=testp["\n******* Lift Geneirc Probe ***\n"]          val _=testp["\n******* Lift Geneirc Probe ***\n"]
272          val originalb=Ein.body e          val originalb=Ein.body e
273          val params=Ein.params e          val params=Ein.params e
# Line 312  Line 309 
309              rtn              rtn
310          end          end
311    
312      fun searchFullField (fieldset,code1,body1,dx)=let      (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
313          val (lhs,_)=code1      (* scans dx for contant
314          fun continueReconstruction ()=let       * arg:(1,code1, body1,[])
315              val _=print"Tash:don't replaced"       *)
316              in (case dx      fun reconstruction([],arg)= replaceProbe arg
317                  of []=> (lhs,replaceProbe(1,code1,body1,[]))       | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
318                  | _ =>(lhs,liftProbe(1,code1,body1,[]))          of (true,true) => liftProbe arg
319                  (*end case*))          | (_,false)    => replaceProbe arg
320               end          | _ => let
321          in  (case valnumflag              fun fConst [] = liftProbe arg
322              of false    => (fieldset,continueReconstruction())              | fConst (E.C _::_) = replaceProbe arg
323              | true      => (case  (einSet.rtnVarN(fieldset,code1))              | fConst (_ ::es)= fConst es
324                  of (fieldset,NONE)     => (fieldset,continueReconstruction())              in fConst dx end
                  | (fieldset,SOME m)   =>(print"TASH:replaced"; (fieldset,(m,[])))  
                 (*end case*))  
325              (*end case*))              (*end case*))
         end  
326    
327      fun liftFieldMat(newvx,e)=      (* **************************************************** Index Tensor **************************************************** *)
328          let      (*Push constant indices to tensor replacement*)
329              val _=testp[ "\n ***************************** start FieldMat\n"]      fun getF (e,fieldset,dim,newvx)= let
330              val (y, DstIL.EINAPP(ein,args))=e              val (y, DstIL.EINAPP(ein,args))=e
             val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein  
331              val index0=Ein.index ein              val index0=Ein.index ein
332              val index1 = index0@[3]          val index1 = index0@dim
333            val b=Ein.body ein
334    
335            val (c1,dx,body1)=(case b
336                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
337                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
338                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
339                    in (c1,dx,b) end
340                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
341              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
342              (* clean to get body indices in order *)              (* clean to get body indices in order *)
343              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
344              val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]                  in (c1,dx,body1) end
345                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
             val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
             val ein1 = mkEin(Ein.params ein,index1,body1)  
             val code1= (lhs1,mkEinApp(ein1,args))  
             val codeAll= (case dx  
             of []=> replaceProbe(1,code1,body1,[])  
             | _ =>liftProbe(1,code1,body1,[])  
             (*end case*))  
   
             (*Probe that tensor at a constant position  c1*)  
             val param0 = [E.TEN(1,index1)]  
             val nx=List.tabulate(length(dx)+1,fn n=>E.V n)  
             val body0 =  E.Tensor(0,[c1]@nx)  
             val ein0 = mkEin(param0,index0,body0)  
             val einApp0 = mkEinApp(ein0,[lhs1])  
             val code0 = (y,einApp0)  
             val _= toStringBind code0  
             val _=testp["\n end FieldMat *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
     fun liftFieldVec(newvx,e,fieldset)=  
     let  
         val _=testp[ "\n ***************************** start FieldVec\n"]  
         val (y, DstIL.EINAPP(ein,args))=e  
         val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]  
346          val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)          val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
         (* clean to get body indices in order *)  
347          val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])          val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
348                   in (c1,dx,body1) end
349                (*end case*))
350    
351          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
352          val ein1 = mkEin(Ein.params ein,index1,body1)          val ein1 = mkEin(Ein.params ein,index1,body1)
353          val code1= (lhs1,mkEinApp(ein1,args))          val code1= (lhs1,mkEinApp(ein1,args))
         val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)  
   
         (*Probe that tensor at a constant position  c1*)  
         val param0 = [E.TEN(1,index1)]  
         val nx=List.tabulate(length(dx),fn n=>E.V n)  
         val body0 =  E.Tensor(0,[c1]@nx)  
         val ein0 = mkEin(param0,index0,body0)  
         val einApp0 = mkEinApp(ein0,[lhs0])  
         val code0 = (y,einApp0)  
   
         val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]  
         val _ = (toStringBind code0)  
         val _ = testp[ "\n end FieldVec *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
354    
355            val (_,(lhs0,codeAll))= (case valnumflag
356      fun liftFieldSum e =              of false    => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
357      let              | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
358          val _=testp[ "\n************************************* Start Lift Field Sum\n"]                  of (fieldset,NONE)     => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
359          val (y, DstIL.EINAPP(ein,args))=e                  | (fieldset,SOME m)   =>  (fieldset,(m,[]))
360          val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein                  (*end case*))
         val index0=Ein.index ein  
         val index1 = index0@[3]@[3]  
         val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))  
         val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)  
   
   
         val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
         val ein1 = mkEin(Ein.params ein,index1,body1)  
         val code1= (lhs1,mkEinApp(ein1,args))  
         val codeAll= (case dx  
             of []   => replaceProbe(1,code1,body1,[])  
             | _     =>liftProbe(1,code1,body1,[])  
361              (*end case*))              (*end case*))
362    
363          (*Probe that tensor at a constant position  c1*)          (*Probe that tensor at a constant position  c1*)
364          val param0 = [E.TEN(1,index1)]          val param0 = [E.TEN(1,index1)]
365          val nx=List.tabulate(length(dx),fn n=>E.V n)          val nx=List.tabulate(newvx,fn n=>E.V n)
366          val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))          val body0 =  (case b
367                of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
368                | _ => E.Tensor(0,[c1]@nx)
369                (*end case*))
370          val ein0 = mkEin(param0,index0,body0)          val ein0 = mkEin(param0,index0,body0)
371          val einApp0 = mkEinApp(ein0,[lhs1])          val einApp0 = mkEinApp(ein0,[lhs0])
372          val code0 = (y,einApp0)          val code0 = (y,einApp0)
         val _ = toStringBind  e  
373          val _ = toStringBind code0          val _ = toStringBind code0
         val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])  
         val _  =((List.map toStringBind (codeAll@[code0])))  
         val _ = testp["\n*** end Field Sum*************************************\n"]  
374          in          in
375          codeAll@[code0]          codeAll@[code0]
376      end      end
377        (* **************************************************** General Fn **************************************************** *)
   
378      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
379      * A this point we only have simple ein ops      * A this point we only have simple ein ops
380      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.
381      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
382      *)      *)
383     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let      fun expandEinOp( e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
384            fun rewriteBody(e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
385      fun checkConst(es,a)=(case constflag              of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
386          of true => liftProbe a              | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
387          | _ => let              | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
388              fun fConst ([],a) =              | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
389                  (case fieldliftflag              | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
390                      of true => liftProbe a              | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
391                      | _ => replaceProbe a              | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
392                  (*end case*))              | _                                          => reconstruction(dx,(e,p,[]))
393              | fConst ((E.C _::_),a) = replaceProbe a              (*end case*))
394              | fConst ((_ ::es),a)= checkConst(es,a)          | rewriteBody(e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
395              in fConst(es,a) end              of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
396        (* end case*))              | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
397                | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
398                | (_,_,_,[])                                => replaceProbe(e,p, sx)  (*no dx*)
399                | (_,_,[],_)                                => reconstruction(dx,(e,p,sx))
400      fun rewriteBodyB b=(case b              | _                                         => replaceProbe(e,p, sx)
         of  (E.Probe(E.Conv(_,_,_,[]),_))  
             => replaceProbe(0,e,b,[])  
         | (E.Probe(E.Conv (_,alpha,_,dx),_))  
             => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)  
         | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))  
             => replaceProbe(0,e,p, sx)  (*no dx*)  
         | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))  
             => checkConst(dx,(0,e,p,sx)) (*scalar field*)  
         | (E.Sum(sx,E.Probe p))  
             => replaceProbe(0,e,E.Probe p, sx)  
         | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))  
             => replaceProbe(0,e,E.Probe p,sx)  
         | _ => [e]  
         (* end case *))  
   
   
         fun rewriteBody b=(case detflag  
             of true => (case (detsumflag,b)  
                 of (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))  
                     => liftFieldMat (1,e)  
                 | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))  
                     => liftFieldMat (2,e)  
                 | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))  
                     => liftFieldMat (3,e)  
                 | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))  
                     => liftFieldSum e  
                 | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))  
                     => liftFieldSum e  
                 | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))  
                     => liftFieldSum e  
                 | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))  
                     => liftFieldVec (0,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))  
                     => liftFieldVec (1,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))  
                     => liftFieldVec (2,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))  
                         => liftFieldVec (3,e,fieldset)  
                 | _   => rewriteBodyB b  
                 (* end case *))  
             | _   => rewriteBodyB b  
             (* end case *))  
   
         val (fieldset,var) = (case valnumflag  
             of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))  
             | _     => (fieldset,NONE)  
401          (*end case*))          (*end case*))
402            | rewriteBody(e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(e,E.Probe p,sx)
403            | rewriteBody (e,_)  = [e]
404    
405          fun matchField b=(case b          val b=Ein.body ein
406            fun matchField()=(case b
407              of E.Probe _ => 1              of E.Probe _ => 1
408              | E.Sum (_, E.Probe _)=>1              | E.Sum (_, E.Probe _)=>1
409              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
410              | _ =>0              | _ =>0
411              (*end case*))              (*end case*))
412          fun toStrField b=(case b          val (fieldset,varset,code,flag) = (case valnumflag
413              of E.Probe _ => print("\n"^(P.printbody b))              of true => (case (einVarSet.rtnVarN(fieldset,e0))
414              | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))                 of  (fldset,NONE)     => (fldset,varset,rewriteBody(e0,b),0)
415              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>print("\n"^ (P.printbody b))                | (fldset,SOME v)    => (fldset,varset,[(y,DstIL.VAR v)],1)
416              | _ => print""                   (*of (fldset, NONE)      => (fldset,varset,rewriteBody((y,DstIL.EINAPP(ein,List.map (fn a=>einVarSet.replaceArg(varset,a)) args)), b),0)
417                     | (fldset,SOME v)    => (fldset,einVarSet.VarSet.add(varset,einVarSet.VAR(v,y)),[],1)*)
418              (*end case*))              (*end case*))
419              val b=Ein.body ein              | _     => (fieldset,varset,rewriteBody(e0, b),0)
   
         in  (case var  
         of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))  
             | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))  
420              (*end case*))              (*end case*))
421          end          val m=matchField()
422            in  (code,fieldset,varset,m,flag) end
423    
424    end; (* local *)    end; (* local *)
425    

Legend:
Removed from v.3443  
changed lines
  Added in v.3503

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0