Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3440, Mon Nov 16 19:16:54 2015 UTC revision 3441, Wed Nov 18 00:24:04 2015 UTC
# Line 52  Line 52 
52      fun toStringBind e=(MidToString.toStringBind e)      fun toStringBind e=(MidToString.toStringBind e)
53      fun mkEin e=Ein.mkEin e      fun mkEin e=Ein.mkEin e
54      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
55        fun setConst e = E.setConst e
56        fun setNeg e  =  E.setNeg e
57        fun setExp e  =  E.setExp e
58        fun setDiv e= E.setDiv e
59        fun setSub e= E.setSub e
60        fun setProd e= E.setProd e
61        fun setAdd e= E.setAdd e
62    
63      fun testp n=(case testing      fun testp n=(case testing
64          of 0=> 1          of 0=> 1
# Line 76  Line 83 
83              in              in
84                ((Kernel.support h) ,img,ImageInfo.dim img)                ((Kernel.support h) ,img,ImageInfo.dim img)
85              end              end
86          |  _ => raise Fail "Expected Image and kernel arguments"              |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
87          (*end case*))          (*end case*))
88    
89    
# Line 105  Line 112 
112          fun createKRND1 ()=let          fun createKRND1 ()=let
113              val sum=sx              val sum=sx
114              val dels=List.map (fn e=>(E.C 0,e)) deltas              val dels=List.map (fn e=>(E.C 0,e)) deltas
115              val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
116              val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
117              in              in
118                  E.Prod [E.Img(Vid,alpha,pos),rest]                 setProd[E.Img(Vid,alpha,pos),rest]
119              end              end
120          (*createKRN Image field and kernels *)          (*createKRN Image field and kernels *)
121          fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)
122          | createKRN(dim,imgpos,rest)=let          | createKRN(dim,imgpos,rest)=let
123              val dim'=dim-1              val dim'=dim-1
124              val sum=sx+dim'              val sum=sx+dim'
125              val dels=List.map (fn e=>(E.C dim',e)) deltas              val dels=List.map (fn e=>(E.C dim',e)) deltas
126              val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
127              val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
128              in              in
129                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
130              end              end
# Line 158  Line 165 
165      *)      *)
166      fun formBody(E.Sum([],e))=formBody e      fun formBody(E.Sum([],e))=formBody e
167      | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)      | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
168      | formBody(E.Prod [e])=e      | formBody(E.Opn(E.Prod, [e]))=e
169      | formBody e=e      | formBody e=e
170    
171      (* silly change in order of the product to match vis branch WorldtoSpace functions*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
172      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
173      (*      (*
174        | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))        | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
175        *)        *)
176        | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,P3,body])))        | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
177        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
178    
179    
180      fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])      fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
181        | multiMergePs e=multiPs e        | multiMergePs e=multiPs e
182    
183    
# Line 205  Line 212 
212    
213          val body'=(case originalb          val body'=(case originalb
214              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
215              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
216              | _                                  => body'              | _                                  => body'
217              (*end case*))              (*end case*))
218    
# Line 237  Line 244 
244    
245          val tshape=filterAlpha(alpha')@newdx          val tshape=filterAlpha(alpha')@newdx
246          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
247    
248          val (splitvar,body)=(case originalb          val (splitvar,body)=(case originalb
249              of E.Sum(sx, E.Probe _)              => (*(false,E.Sum(sx,multiPs(Ps,newsx,t)))*) (true,multiPs(Ps,sx@newsx,t))              of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
250              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
251              | _                                  => (case tsplitvar              | _                                  => (case tsplitvar
252                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
253                  | false*) _ =>   (true,multiPs(Ps,newsx,t))                  | false*) _ =>   (true,multiPs(Ps,newsx,t))
# Line 258  Line 266 
266          end          end
267    
268      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
269          val _=(String.concat["\n******* Lift Geneirc Probe ***\n"])          val _=testp["\n******* Lift Geneirc Probe ***\n"]
270          val originalb=Ein.body e          val originalb=Ein.body e
271          val params=Ein.params e          val params=Ein.params e
272          val index=Ein.index e          val index=Ein.index e
# Line 273  Line 281 
281    
282          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
283          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
284    
285          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
286          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
287          val rtn0=(case splitvar          val rtn0=(case splitvar
# Line 318  Line 327 
327    
328      fun liftFieldMat(newvx,e)=      fun liftFieldMat(newvx,e)=
329          let          let
330              val _=print "\n ***************************** start FieldMat\n"              val _=testp[ "\n ***************************** start FieldMat\n"]
331              val (y, DstIL.EINAPP(ein,args))=e              val (y, DstIL.EINAPP(ein,args))=e
332              val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein              val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
333              val index0=Ein.index ein              val index0=Ein.index ein
# Line 344  Line 353 
353              val einApp0 = mkEinApp(ein0,[lhs1])              val einApp0 = mkEinApp(ein0,[lhs1])
354              val code0 = (y,einApp0)              val code0 = (y,einApp0)
355              val _= toStringBind code0              val _= toStringBind code0
356                      val _=print "\n end FieldMat *****************************\n "              val _=testp["\n end FieldMat *****************************\n "]
357          in          in
358              codeAll@[code0]              codeAll@[code0]
359      end      end
360    
361      fun liftFieldVec(newvx,e,fieldset)=      fun liftFieldVec(newvx,e,fieldset)=
362      let      let
363          val _=print "\n ***************************** start FieldVec\n"          val _=testp[ "\n ***************************** start FieldVec\n"]
364          val (y, DstIL.EINAPP(ein,args))=e          val (y, DstIL.EINAPP(ein,args))=e
365          val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein          val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
366          val index0=Ein.index ein          val index0=Ein.index ein
# Line 374  Line 383 
383          val einApp0 = mkEinApp(ein0,[lhs0])          val einApp0 = mkEinApp(ein0,[lhs0])
384          val code0 = (y,einApp0)          val code0 = (y,einApp0)
385    
386          val _ = (String.concat ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1])          val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
387          val _=  (toStringBind code0)          val _=  (toStringBind code0)
388          val _=print "\n end FieldVec *****************************\n "          val _ = testp[ "\n end FieldVec *****************************\n "]
389          in          in
390              codeAll@[code0]              codeAll@[code0]
391      end      end
# Line 385  Line 394 
394    
395      fun liftFieldSum e =      fun liftFieldSum e =
396      let      let
397          val _=print "\n************************************* Start Lift Field Sum\n"          val _=testp[ "\n************************************* Start Lift Field Sum\n"]
398          val (y, DstIL.EINAPP(ein,args))=e          val (y, DstIL.EINAPP(ein,args))=e
399          val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein          val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
400          val index0=Ein.index ein          val index0=Ein.index ein
# Line 413  Line 422 
422          val _ = toStringBind code0          val _ = toStringBind code0
423          val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])          val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
424          val _  =((List.map toStringBind (codeAll@[code0])))          val _  =((List.map toStringBind (codeAll@[code0])))
425          val _ = print "\n*** end Field Sum*************************************\n"          val _ = testp["\n*** end Field Sum*************************************\n"]
426          in          in
427          codeAll@[code0]          codeAll@[code0]
428      end      end
# Line 426  Line 435 
435      *)      *)
436     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
437    
438          fun (*checkConst ([],a) =          fun checkConst ([],a) =
439              (case fieldliftflag              (case fieldliftflag
440                  of true => liftProbe a                  of true => liftProbe a
441                  | _ => replaceProbe a                  | _ => replaceProbe a
442              (*end case*))              (*end case*))
443          | checkConst ((E.C _::_),a) = replaceProbe a          | checkConst ((E.C _::_),a) = replaceProbe a
444          | checkConst ((_ ::es),a)= checkConst(es,a)          | checkConst ((_ ::es),a)= checkConst(es,a)
         *)  
         checkConst (_,a) = liftProbe a  
445          fun rewriteBody b=(case (detflag,b)          fun rewriteBody b=(case (detflag,b)
446              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
447                  => liftFieldMat (1,e)                  => liftFieldMat (1,e)
# Line 466  Line 473 
473                  => checkConst(dx,(0,e,p,sx)) (*scalar field*)                  => checkConst(dx,(0,e,p,sx)) (*scalar field*)
474              | (_,E.Sum(sx,E.Probe p))              | (_,E.Sum(sx,E.Probe p))
475                  => replaceProbe(0,e,E.Probe p, sx)                  => replaceProbe(0,e,E.Probe p, sx)
476              | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))              | (_, E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
477                  => replaceProbe(0,e,E.Probe p,sx)                  => replaceProbe(0,e,E.Probe p,sx)
478              | (_,_) => [e]              | (_,_) => [e]
479              (* end case *))              (* end case *))
# Line 479  Line 486 
486          fun matchField b=(case b          fun matchField b=(case b
487              of E.Probe _ => 1              of E.Probe _ => 1
488              | E.Sum (_, E.Probe _)=>1              | E.Sum (_, E.Probe _)=>1
489              | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
490              | _ =>0              | _ =>0
491              (*end case*))              (*end case*))
492          fun toStrField b=(case b          fun toStrField b=(case b
493              of E.Probe _ => print("\n"^(P.printbody b))              of E.Probe _ => print("\n"^(P.printbody b))
494              | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))              | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))
495              | E.Sum(_, E.Prod[ _ ,E.Probe _])=>print("\n"^ (P.printbody b))              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>print("\n"^ (P.printbody b))
496              | _ => print""              | _ => print""
497              (*end case*))              (*end case*))
498              val b=Ein.body ein              val b=Ein.body ein
499    
         val _=  toStrField b  
   
500          in  (case var          in  (case var
501          of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))          of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
502              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))

Legend:
Removed from v.3440  
changed lines
  Added in v.3441

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0