Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/handleEin.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/handleEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3189 - (download) (annotate)
Thu Apr 2 18:49:21 2015 UTC (4 years, 5 months ago) by cchiw
Original Path: branches/charisee/src/compiler/high-to-mid/handleEin.sml
File size: 7807 byte(s)
lift curl expression
(* Expands probe ein
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure handleEin = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure P=Printer
    structure T=TransformEin
    structure MidToS=MidToString
    structure DstV  = DstIL.Var
    in

    val testinitial=0
    val testing=0
    fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
    fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
    fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
      fun iTos i =Int.toString i 
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))
    fun test0 n=(case testinitial
        of 0=>1
        | _ => (print(String.concat n);1))
    val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
    fun setEinZero y=  (y,einappzero)
    fun filterSca e=Filter.filterSca e
    fun printEINAPP e=MidToString.printEINAPP e

    fun z e=String.concat["\n\n Found: ",P.printbody e,"=> 0\n"]
    fun sweep e= (case e
        of E.Tensor _          => e
        | E.Add es             => E.Add(List.map sweep es)
        | E.Sub(e1,e2)         => E.Sub(sweep e1,sweep e2)
        | E.Div(e1,e2)         => E.Div(sweep e1,sweep e2)
        | E.Sum(c ,e)          => E.Sum(c, sweep e)
        | E.Prod es            => E.Prod(List.map sweep es)
        | E.Neg e              => E.Neg(sweep e)
        | E.Probe(E.Conv _,_)  => e
        | E.Sqrt e             => E.Sqrt(sweep e)
        | E.Cosine e           => E.Cosine(sweep e)
        | E.ArcCosine e        => E.ArcCosine(sweep e)
        | E.Sine e             => E.Sine(sweep e)
        | E.ArcSine e          => E.ArcSine(sweep e)
        | E.Const _            => e
        | E.ConstR _           => e
        | E.Delta _            => e
        | E.Epsilon _          => e
        | E.Eps2 _             => e
        | E.Field _            => (z e;E.Const 0)
        | E.Partial _          => (z e;E.Const 0)
        | E.Apply _            => (z e;E.Const 0)
        | E.Lift _             => (z e;E.Const 0)
        | E.Conv _             => (z e;E.Const 0)
        | E.PowInt(e ,n)       => E.PowInt(sweep e ,n)
        | E.PowReal(e,n)       => E.PowReal(sweep e ,n)
        | E.Probe _            =>
        raise Fail (String.concat["\n Incorrect probe, substition was not made. Is the Field in an if statement? :",P.printbody e])
    (*end case*))


    (*Distribute summation if needed*)
    fun distributeSummation(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
        val changed = ref false
        fun rewrite b=(case b
            of E.Sum(sx,E.Tensor(id,[]))    => (changed:=true;E.Tensor(id,[]))
            | E.Sum(sx,E.Const c)           => (changed:=true;E.Const c)
            | E.Sum(sx,E.ConstR r)          => (changed:=true;E.ConstR r)
            | E.Sum(sx,E.Neg n)             => (changed:=true;(E.Neg(E.Sum(sx,n))))
            | E.Sum(sx,E.Add a)             =>
            (changed:=true;(E.Add(List.map (fn e=> E.Sum(sx,e)) a)))
            | E.Sum(sx,E.Sub (e1,e2))       =>
                (changed:=true;(E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))))
            | E.Sum(sx,E.Div(E.Const c,e2)) =>
                (changed:=true;(E.Div(E.Const c, E.Sum(sx,e2))))
            | E.Sum(sx,E.Div(e1,e2))        => (changed:=true;
                (E.Sum(sx,E.Prod[e1,E.Div(E.Const 1,rewrite e2)])))
            | E.Sum(sx,E.Lift e )           => (changed:=true;(E.Lift(E.Sum(sx,e))))
            | E.Sum(sx,E.PowReal(e,n1))     => (changed:=true;(E.PowReal(E.Sum(sx,e),n1)))
            | E.Sum(sx,E.Sqrt e)            => (changed:=true;(E.Sqrt(E.Sum(sx,e))))
            | E.Sum(sx,E.Sum (c2,e))        => (changed:=true; (E.Sum (sx@c2,e)))
            | E.Sum(sx,E.Prod p)            => let
                val p'=List.map (fn e=> rewrite e) p
                val (c,e)=filterSca(sx,p')
                in (case c of 1=> (changed:=true; e) | _=> e ) end
            | E.Div(e1,e2)                  => E.Div(rewrite e1, rewrite e2)
            | E.Sub(e1,E.Const 0)           => (changed:=true; rewrite e1)
            | E.Sub(e1,e2)                  => E.Sub(rewrite e1, rewrite e2)
            | E.Add es                      => E.Add(List.map rewrite es)
            | E.Prod es                     => E.Prod(List.map rewrite es)
            | E.Neg e                       => E.Neg(rewrite e)
            | E.Sqrt e                      => E.Sqrt(rewrite e)
            | E.Cosine e                    => E.Cosine(rewrite e)
            | E.ArcCosine e                 => E.ArcCosine(rewrite e)
            | E.Sine e                      => E.Sine(rewrite e)
            | E.ArcSine e                 => E.ArcSine(rewrite e)
            | E.Probe(e1,e2)                => E.Probe(rewrite e1, rewrite e2)
            | _                             => b
            (*end case*))
        fun loop body  = let
            val body' = rewrite body
            in
                if !changed then  (changed := false ;loop body') else  body'
            end

        val  b = loop body
        val _ =testp["\nAfter distributeSummation \n",P.printbody b]
        val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=b})
        val b = loop(Ein.body ein)
        val ein=Ein.EIN{params=Ein.params ein, index=Ein.index ein, body=b}
        val einapp2= (y,DstIL.EINAPP(ein,args))
        in
            einapp2
        end
    |distributeSummation(y,app) =(y,app)


 
    fun expandEinOp einapp00=let
       
        
        val star="************"
        val _ =test0[star,"\n Original EinApp",star,"\n\n","start get test",printEINAPP einapp00]
        val (y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=einapp00
      

        (* ************* Sweep for 0's'*********** *)
        val bodysweep=sweep body
        val ein1=Ein.EIN{params=params, index=index, body=bodysweep}
        val _=testp["\nPresweep\n",P.printbody body,"\n\n Sweep\n",P.printbody bodysweep,"\n"]

        (* **************Clean Summation*********** *)
        val ein2=SummationEin.cleanSummation(ein1)
        val einapp2=(y,DstIL.EINAPP(ein2, args))
        val _ =testp["\n\n******* after clean summation**",Int.toString (0)," ***** \n \t==>\n",printEINAPP(einapp2)]

        (* ************** distribute Summation*********** *)
        val einapp3 = distributeSummation einapp2

        (* **************** split phase ************* *)
        val (newbies5)= Split.splitEinApp einapp3
          val _ =testp["\n\n Returning \n\n =>",
            String.concatWith",\n\t"(List.map printEINAPP newbies5)]
           (* val _ =print(String.concat[
            "\n",star,"Number of pieces: ",Int.toString (List.length(newbies5))])
*)
         (* ************** ProbeEIN *********** *)

          (*val code=List.map (fn e=>ProbeEin.expandEinOp e) (newbies5)*)

        (*gives probeEIN the set. Could split into a sepearate step*)
        val fieldset= einSet.EinSet.empty
        fun iter([],_)=[]
         | iter(e1::es,fieldset)= let
            val (e2,fieldset) = ProbeEin.expandEinOp(e1,fieldset)
            in [e2]@(iter(es,fieldset))
            end 
        val code=iter(newbies5, fieldset)
        (*val _ =print(String.concat[
            "\n",star,"Number of pieces now: ",Int.toString (List.length(code))])
        *)
        val flatcode= List.foldr op@ [] code
        (*val _=List.map (fn(_,DstIL.EINAPP(e,_))=>checkEin.checkEIN e | _=> 1) flatcode*)

        in
            List.map (fn (y,rator)=> DstIL.ASSGN(y,rator)) flatcode
        end


  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0