Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3362, Sun Nov 1 18:26:02 2015 UTC revision 3383, Mon Nov 9 02:39:26 2015 UTC
# Line 55  Line 55 
55    
56      fun testp n=(case testing      fun testp n=(case testing
57          of 0=> 1          of 0=> 1
58          | _ =>(print(String.concat n);1)          | _ =>((String.concat n);1)
59          (*end case*))          (*end case*))
60    
61    
# Line 127  Line 127 
127              (*end case*))              (*end case*))
128          (*sumIndex creating summaiton Index for body*)          (*sumIndex creating summaiton Index for body*)
129          val slb=1-s          val slb=1-s
130            val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
131          val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))          val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
132      in      in
133          E.Sum(esum, exp)          E.Sum(esum, exp)
# Line 135  Line 136 
136      (*getsumshift:sum_indexid list* int list-> int      (*getsumshift:sum_indexid list* int list-> int
137      *get fresh/unused index_id, returns int      *get fresh/unused index_id, returns int
138      *)      *)
139      fun getsumshift(sx,index) =let      fun getsumshift(sx,n) =let
140          val nsumshift= (case sx          val nsumshift= (case sx
141              of []=> length(index)              of []=> n
142              | _=>let              | _=>let
143                  val (E.V v,_,_)=List.hd(List.rev sx)                  val (E.V v,_,_)=List.hd(List.rev sx)
144                  in v+1                  in v+1
145                  end                  end
146              (* end case *))              (* end case *))
147    
148          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
149          val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),          val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
150              "\nThink nshift is ", Int.toString nsumshift]          "\n\t Index length:",Int.toString n,
151            "\n\t Freshindex: ", Int.toString nsumshift])
152          in          in
153              nsumshift              nsumshift
154          end          end
# Line 193  Line 196 
196          val Pid=nid+1          val Pid=nid+1
197          val nshift=length(dx)          val nshift=length(dx)
198          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
199          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,length(index))
200          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
201          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
202          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
# Line 254  Line 257 
257          end          end
258    
259      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
260          val _=testp["\n******* Lift ******** \n"]          val _=(String.concat["\n******* Lift Geneirc Probe ***\n"])
261          val originalb=Ein.body e          val originalb=Ein.body e
262          val params=Ein.params e          val params=Ein.params e
263          val index=Ein.index e          val index=Ein.index e
264          val _=  toStringBind (y, DstIL.EINAPP(e,args))          val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
265    
266          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
267          val fid=length(params)          val fid=length(params)
268          val nid=fid+1          val nid=fid+1
269          val nshift=length(dx)          val nshift=length(dx)
270          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
271          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,length(index))
272    
273          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
274          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
# Line 281  Line 284 
284    
285          (*lifted probe*)          (*lifted probe*)
286          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
287          val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)          val freshIndex'= length(sizes)
288    
289            val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
290          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
291          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
292          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
293          val rtn=code@[rtn1]@rtn0          val rtn=code@[rtn1]@rtn0
294          val _= List.map toStringBind ([rtn1]@rtn0)          val _= List.map toStringBind ([rtn1]@rtn0)
295             val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
296          in          in
297              rtn              rtn
298          end          end
299    
300        fun searchFullField (fieldset,code1,body1,dx)=let
301            val (lhs,_)=code1
302            fun continueReconstruction ()=let
303                val _=print"Tash:don't replaced"
304                in (case dx
305                    of []=> (lhs,replaceProbe(1,code1,body1,[]))
306                    | _ =>(lhs,liftProbe(1,code1,body1,[]))
307                    (*end case*))
308                 end
309            in  (case valnumflag
310                of false => (fieldset,continueReconstruction())
311                | true => (case  (einSet.rtnVarN(fieldset,code1))
312                    of (fieldset,NONE)     => (fieldset,continueReconstruction())
313                     | (fieldset,SOME m)   =>(print"TASH:replaced"; (fieldset,(m,[])))
314                    (*end case*))
315                (*end case*))
316            end
317    
318      fun liftFieldMat(newvx,e)=      fun liftFieldMat(newvx,e)=
319          let          let
320                val _=print "\n ***************************** start FieldMat\n"
321              val (y, DstIL.EINAPP(ein,args))=e              val (y, DstIL.EINAPP(ein,args))=e
322              val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein              val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
323              val index0=Ein.index ein              val index0=Ein.index ein
# Line 320  Line 343 
343              val einApp0 = mkEinApp(ein0,[lhs1])              val einApp0 = mkEinApp(ein0,[lhs1])
344              val code0 = (y,einApp0)              val code0 = (y,einApp0)
345              val _= toStringBind code0              val _= toStringBind code0
346                        val _=print "\n end FieldMat *****************************\n "
347            in
348                codeAll@[code0]
349        end
350    
351        fun liftFieldVec(newvx,e,fieldset)=
352        let
353            val _=print "\n ***************************** start FieldVec\n"
354            val (y, DstIL.EINAPP(ein,args))=e
355            val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
356            val index0=Ein.index ein
357            val index1 = index0@[3]
358            val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
359            (* clean to get body indices in order *)
360            val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
361    
362    
363            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
364            val ein1 = mkEin(Ein.params ein,index1,body1)
365            val code1= (lhs1,mkEinApp(ein1,args))
366            val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)
367    
368            (*Probe that tensor at a constant position  c1*)
369            val param0 = [E.TEN(1,index1)]
370            val nx=List.tabulate(length(dx),fn n=>E.V n)
371            val body0 =  E.Tensor(0,[c1]@nx)
372            val ein0 = mkEin(param0,index0,body0)
373            val einApp0 = mkEinApp(ein0,[lhs0])
374            val code0 = (y,einApp0)
375    
376            val _ = (String.concat ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1])
377            val _=  (toStringBind code0)
378            val _=print "\n end FieldVec *****************************\n "
379          in          in
380              codeAll@[code0]              codeAll@[code0]
381      end      end
382    
383    
384    
385      fun liftFieldSum e =      fun liftFieldSum e =
386      let      let
387          val _=print"\n*************************************\n"          val _=print "\n************************************* Start Lift Field Sum\n"
388          val (y, DstIL.EINAPP(ein,args))=e          val (y, DstIL.EINAPP(ein,args))=e
389          val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein          val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
390          val index0=Ein.index ein          val index0=Ein.index ein
# Line 353  Line 411 
411          val _= toStringBind  e          val _= toStringBind  e
412          val _ =toStringBind code0          val _ =toStringBind code0
413         val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])         val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
414         val _ =(String.concat(List.map toStringBind (codeAll@[code0])))          val _  =((List.map toStringBind (codeAll@[code0])))
415                 val _=print"\n*************************************\n"          val _ = print "\n*** end Field Sum*************************************\n"
416          in          in
417          codeAll@[code0]          codeAll@[code0]
418      end      end
# Line 367  Line 425 
425      *)      *)
426     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
427    
428          fun checkConst ([],a) =          fun (*checkConst ([],a) =
429              (case fieldliftflag              (case fieldliftflag
430                  of true => liftProbe a                  of true => liftProbe a
431                  | _ => replaceProbe a                  | _ => replaceProbe a
432              (*end case*))              (*end case*))
433          | checkConst ((E.C _::_),a) = replaceProbe a          | checkConst ((E.C _::_),a) = replaceProbe a
434          | checkConst ((_ ::es),a)= checkConst(es,a)          | checkConst ((_ ::es),a)= checkConst(es,a)
435            *)
436            checkConst (_,a) = liftProbe a
437          fun rewriteBody b=(case (detflag,b)          fun rewriteBody b=(case (detflag,b)
438              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
439                  => liftFieldMat (1,e)                  => liftFieldMat (1,e)
# Line 388  Line 447 
447                  => liftFieldSum e                  => liftFieldSum e
448              | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))              | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
449                  => liftFieldSum e                  => liftFieldSum e
450                | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))
451                    => liftFieldVec (0,e,fieldset)
452                | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))
453                    => liftFieldVec (1,e,fieldset)
454                | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))
455                    => liftFieldVec (2,e,fieldset)
456                | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))
457                    => liftFieldVec (3,e,fieldset)
458              | (_,E.Probe(E.Conv(_,_,_,[]),_))              | (_,E.Probe(E.Conv(_,_,_,[]),_))
459                  => replaceProbe(0,e,b,[])                  => replaceProbe(0,e,b,[])
460              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
# Line 417  Line 482 
482              | _ =>0              | _ =>0
483              (*end case*))              (*end case*))
484          fun toStrField b=(case b          fun toStrField b=(case b
485              of E.Probe _ => print (P.printbody b)              of E.Probe _ => print("\n"^(P.printbody b))
486              | E.Sum (_, E.Probe _)=>print (P.printbody b)              | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))
487              | E.Sum(_, E.Prod[ _ ,E.Probe _])=>print (P.printbody b)              | E.Sum(_, E.Prod[ _ ,E.Probe _])=>print("\n"^ (P.printbody b))
488              | _ =>print ""              | _ =>print ""
489              (*end case*))              (*end case*))
490              val b=Ein.body ein              val b=Ein.body ein
491  (*  (*
492          val _=  toStrField b          val _=  toStrField b*)
493    *)  
494          in  (case var          in  (case var
495              of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))              of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
496              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))

Legend:
Removed from v.3362  
changed lines
  Added in v.3383

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0