Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3311, Fri Oct 16 20:09:14 2015 UTC revision 3362, Sun Nov 1 18:26:02 2015 UTC
# Line 1  Line 1 
1  (* Expands probe ein  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4     *
5     * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
# Line 38  Line 40 
40      *)      *)
41    
42      val testing=0      val testing=0
43      val testlift=0      val testlift=1
44      val detflag =false      val detflag =true
45        val fieldliftflag=true
46        val valnumflag=true
47    
48    
49      val cnt = ref 0      val cnt = ref 0
# Line 156  Line 160 
160    
161      (* silly change in order of the product to match vis branch WorldtoSpace functions*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
162      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
163        (*
164          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))
165          *)
166        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
167    
168    
# Line 235  Line 242 
242                  (*end case*))                  (*end case*))
243              (*end case*))              (*end case*))
244    
245            val _ =(case splitvar
246            of true=> (String.concat["splitvar is true", P.printbody body])
247            | _ => (String.concat["splitvar is false",P.printbody body])
248            (*end case*))
249    
250    
251          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
252          in          in
253              (splitvar,ein0,sizes,dx,alpha')              (splitvar,ein0,sizes,dx,alpha')
# Line 259  Line 272 
272          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
273          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
274          val rtn0=(case splitvar          val rtn0=(case splitvar
275              of false => [(y,einApp0)]              of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
276              | _      => Split.splitEinApp (y,einApp0)              | _      => let
277                     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
278                     in Split.splitEinApp(bind3,0)
279                     end
280              (*end case*))              (*end case*))
281    
282          (*lifted probe*)          (*lifted probe*)
# Line 276  Line 292 
292              rtn              rtn
293          end          end
294    
     (* expandEinOp: code->  code list  
     * A this point we only have simple ein ops  
     * Looks to see if the expression has a probe. If so, replaces it.  
     * Note how we keeps eps expressions so only generate pieces that are used  
     *)  
    fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let  
295    
296          fun checkConst ([],a) = liftProbe a      fun liftFieldMat(newvx,e)=
         | checkConst ((E.C _::_),a) = replaceProbe a  
         | checkConst ((_ ::es),a)= checkConst(es,a)  
   
         fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=  
297              let              let
298                val (y, DstIL.EINAPP(ein,args))=e
299                  val _= toStringBind e              val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
300                  val index0=Ein.index ein                  val index0=Ein.index ein
301                  val index1 = index0@[3]                  val index1 = index0@[3]
302                  val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
303                   (* clean to get body indices in order *)                   (* clean to get body indices in order *)
304                  val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])                  val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
305                  val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]                  val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
# Line 306  Line 312 
312                      | _ =>liftProbe(1,code1,body1,[])                      | _ =>liftProbe(1,code1,body1,[])
313                  (*end case*))                  (*end case*))
314    
315                  (*Probe that tensor at a constant position E.C c1*)              (*Probe that tensor at a constant position  c1*)
316                  val param0 = [E.TEN(1,index1)]                  val param0 = [E.TEN(1,index1)]
317                  val nx=List.tabulate(length(dx)+1,fn n=>E.V n)                  val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
318                  val body0 =  E.Tensor(0,[E.C c1]@nx)              val body0 =  E.Tensor(0,[c1]@nx)
319                  val ein0 = mkEin(param0,index0,body0)                  val ein0 = mkEin(param0,index0,body0)
320                  val einApp0 = mkEinApp(ein0,[lhs1])                  val einApp0 = mkEinApp(ein0,[lhs1])
321                  val code0 = (y,einApp0)                  val code0 = (y,einApp0)
# Line 318  Line 324 
324                  codeAll@[code0]                  codeAll@[code0]
325              end              end
326    
327        fun liftFieldSum e =
328        let
329            val _=print"\n*************************************\n"
330            val (y, DstIL.EINAPP(ein,args))=e
331            val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
332            val index0=Ein.index ein
333            val index1 = index0@[3]@[3]
334            val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
335            val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
336    
337    
338            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
339            val ein1 = mkEin(Ein.params ein,index1,body1)
340            val code1= (lhs1,mkEinApp(ein1,args))
341            val codeAll= (case dx
342            of []=> replaceProbe(1,code1,body1,[])
343            | _ =>liftProbe(1,code1,body1,[])
344            (*end case*))
345    
346            (*Probe that tensor at a constant position  c1*)
347            val param0 = [E.TEN(1,index1)]
348            val nx=List.tabulate(length(dx),fn n=>E.V n)
349            val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
350            val ein0 = mkEin(param0,index0,body0)
351            val einApp0 = mkEinApp(ein0,[lhs1])
352            val code0 = (y,einApp0)
353            val _= toStringBind  e
354            val _ =toStringBind code0
355           val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
356           val _ =(String.concat(List.map toStringBind (codeAll@[code0])))
357                   val _=print"\n*************************************\n"
358            in
359            codeAll@[code0]
360        end
361    
362    
363        (* expandEinOp: code->  code list
364        * A this point we only have simple ein ops
365        * Looks to see if the expression has a probe. If so, replaces it.
366        * Note how we keeps eps expressions so only generate pieces that are used
367        *)
368       fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
369    
370            fun checkConst ([],a) =
371                (case fieldliftflag
372                    of true => liftProbe a
373                    | _ => replaceProbe a
374                (*end case*))
375            | checkConst ((E.C _::_),a) = replaceProbe a
376            | checkConst ((_ ::es),a)= checkConst(es,a)
377    
378          fun rewriteBody b=(case (detflag,b)          fun rewriteBody b=(case (detflag,b)
379              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))              of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
380                  => liftFieldMat (1,b)                  => liftFieldMat (1,e)
381              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
382                  => liftFieldMat (2,b)                  => liftFieldMat (2,e)
383              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))              | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
384                  => liftFieldMat (3,b)                  => liftFieldMat (3,e)
385                | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
386                    => liftFieldSum e
387                | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
388                    => liftFieldSum e
389                | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
390                    => liftFieldSum e
391    
392    
393              | (_,E.Probe(E.Conv(_,_,_,[]),_))              | (_,E.Probe(E.Conv(_,_,_,[]),_))
394                  => replaceProbe(0,e,b,[])                  => replaceProbe(0,e,b,[])
395              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
# Line 340  Line 405 
405              | (_,_) => [e]              | (_,_) => [e]
406              (* end case *))              (* end case *))
407    
408          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))          val (fieldset,var) = (case valnumflag
409                of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
410                | _     => (fieldset,NONE)
411            (*end case*))
412    
413          fun matchField b=(case b          fun matchField b=(case b
414              of E.Probe _ => 1              of E.Probe _ => 1
# Line 348  Line 416 
416              | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1              | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
417              | _ =>0              | _ =>0
418              (*end case*))              (*end case*))
419            fun toStrField b=(case b
420                of E.Probe _ => print (P.printbody b)
421                | E.Sum (_, E.Probe _)=>print (P.printbody b)
422                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>print (P.printbody b)
423                | _ =>print ""
424                (*end case*))
425                val b=Ein.body ein
426    (*
427            val _=  toStrField b
428      *)
429          in  (case var          in  (case var
430              of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))              of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
431              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
432              (*end case*))              (*end case*))
433          end          end

Legend:
Removed from v.3311  
changed lines
  Added in v.3362

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0