Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3317, Sat Oct 17 02:36:54 2015 UTC revision 3325, Tue Oct 20 17:04:54 2015 UTC
# Line 38  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val testlift=0      val testlift=1
42      val detflag =false      val detflag =true
43      val fieldliftflag=false      val fieldliftflag=true
44      val valnumflag=false      val valnumflag=true
45    
46    
47      val cnt = ref 0      val cnt = ref 0
# Line 158  Line 158 
158    
159      (* silly change in order of the product to match vis branch WorldtoSpace functions*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
160      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
161          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))
162        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
163    
164    
# Line 237  Line 238 
238                  (*end case*))                  (*end case*))
239              (*end case*))              (*end case*))
240    
241            val _ =(case splitvar
242            of true=> (String.concat["splitvar is true", P.printbody body])
243            | _ => (String.concat["splitvar is false",P.printbody body])
244            (*end case*))
245    
246    
247          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
248          in          in
249              (splitvar,ein0,sizes,dx,alpha')              (splitvar,ein0,sizes,dx,alpha')
# Line 261  Line 268 
268          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
269          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
270          val rtn0=(case splitvar          val rtn0=(case splitvar
271              of false => [(y,einApp0)]              of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
272              | _      => Split.splitEinApp (y,einApp0)              | _      => let
273                     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
274                     in Split.splitEinApp bind3
275                     end
276              (*end case*))              (*end case*))
277    
278          (*lifted probe*)          (*lifted probe*)
# Line 278  Line 288 
288              rtn              rtn
289          end          end
290    
     (* expandEinOp: code->  code list  
     * A this point we only have simple ein ops  
     * Looks to see if the expression has a probe. If so, replaces it.  
     * Note how we keeps eps expressions so only generate pieces that are used  
     *)  
    fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let  
291    
292          fun checkConst ([],a) =      fun liftFieldMat(newvx,e)=
             (case fieldliftflag  
                 of true => liftProbe a  
                 | _ => replaceProbe a  
             (*end case*))  
         | checkConst ((E.C _::_),a) = replaceProbe a  
         | checkConst ((_ ::es),a)= checkConst(es,a)  
   
         fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=  
293              let              let
294                val (y, DstIL.EINAPP(ein,args))=e
295                  val _= toStringBind e              val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
296                  val index0=Ein.index ein                  val index0=Ein.index ein
297                  val index1 = index0@[3]                  val index1 = index0@[3]
298                  val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
299                   (* clean to get body indices in order *)                   (* clean to get body indices in order *)
300                  val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])                  val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
301                  val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]                  val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
# Line 312  Line 308 
308                      | _ =>liftProbe(1,code1,body1,[])                      | _ =>liftProbe(1,code1,body1,[])
309                  (*end case*))                  (*end case*))
310    
311                  (*Probe that tensor at a constant position E.C c1*)              (*Probe that tensor at a constant position  c1*)
312                  val param0 = [E.TEN(1,index1)]                  val param0 = [E.TEN(1,index1)]
313                  val nx=List.tabulate(length(dx)+1,fn n=>E.V n)                  val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
314                  val body0 =  E.Tensor(0,[E.C c1]@nx)              val body0 =  E.Tensor(0,[c1]@nx)
315                  val ein0 = mkEin(param0,index0,body0)                  val ein0 = mkEin(param0,index0,body0)
316                  val einApp0 = mkEinApp(ein0,[lhs1])                  val einApp0 = mkEinApp(ein0,[lhs1])
317                  val code0 = (y,einApp0)                  val code0 = (y,einApp0)
# Line 324  Line 320 
320                  codeAll@[code0]                  codeAll@[code0]
321              end              end
322    
323        fun liftFieldSum e =
324        let
325            val _=print"\n*************************************\n"
326            val (y, DstIL.EINAPP(ein,args))=e
327            val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
328            val index0=Ein.index ein
329            val index1 = index0@[3]@[3]
330            val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
331            val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
332    
333    
334            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
335            val ein1 = mkEin(Ein.params ein,index1,body1)
336            val code1= (lhs1,mkEinApp(ein1,args))
337            val codeAll= (case dx
338            of []=> replaceProbe(1,code1,body1,[])
339            | _ =>liftProbe(1,code1,body1,[])
340            (*end case*))
341    
342            (*Probe that tensor at a constant position  c1*)
343            val param0 = [E.TEN(1,index1)]
344            val nx=List.tabulate(length(dx),fn n=>E.V n)
345            val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
346            val ein0 = mkEin(param0,index0,body0)
347            val einApp0 = mkEinApp(ein0,[lhs1])
348            val code0 = (y,einApp0)
349            val _= toStringBind  e
350            val _ =toStringBind code0
351           val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
352           val _ =(String.concat(List.map toStringBind (codeAll@[code0])))
353                   val _=print"\n*************************************\n"
354            in
355            codeAll@[code0]
356        end
357    
358    
359        (* expandEinOp: code->  code list
360        * A this point we only have simple ein ops
361        * Looks to see if the expression has a probe. If so, replaces it.
362        * Note how we keeps eps expressions so only generate pieces that are used
363        *)
364       fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
365    
366            fun checkConst ([],a) =
367                (case fieldliftflag
368                    of true => liftProbe a
369                    | _ => replaceProbe a
370                (*end case*))
371            | checkConst ((E.C _::_),a) = replaceProbe a
372            | checkConst ((_ ::es),a)= checkConst(es,a)
373    
374          fun rewriteBody b=(case (detflag,b)          fun rewriteBody b=(case (detflag,b)
375              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
376                  => liftFieldMat (1,b)                  => liftFieldMat (1,e)
377              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
378                  => liftFieldMat (2,b)                  => liftFieldMat (2,e)
379              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
380                  => liftFieldMat (3,b)                  => liftFieldMat (3,e)
381                | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
382                    => liftFieldSum e
383                | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
384                    => liftFieldSum e
385                | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
386                    => liftFieldSum e
387    
388    
389              | (_,E.Probe(E.Conv(_,_,_,[]),_))              | (_,E.Probe(E.Conv(_,_,_,[]),_))
390                  => replaceProbe(0,e,b,[])                  => replaceProbe(0,e,b,[])
391              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))

Legend:
Removed from v.3317  
changed lines
  Added in v.3325

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0