Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

View of /branches/ein16/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3557 - (download) (annotate)
Fri Jan 8 19:54:58 2016 UTC (3 years, 7 months ago) by cchiw
Original Path: branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
File size: 11712 byte(s)
added hard limit to float size
structure NormalizeEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter
    structure G=EpsHelpers
    structure Eq=EqualEin
    structure R=RationalEin

    in

    val testing=0
    fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
    fun mkProd e= F.mkProd e
    fun filterSca e=F.filterSca e
    fun mkAdd e=F.mkAdd e
    fun filterGreek e=F.filterGreek e
    fun mkapply e= derivativeEin.mkapply e
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
    (*end case*))

    val zero=E.B(E.Const 0)
    fun setConst e = E.setConst e
    fun setNeg e  =  E.setNeg e
    fun setExp e  =  E.setExp e
    fun setDiv e= E.setDiv e
    fun setSub e= E.setSub e
    fun setProd e= E.setProd e
    fun setAdd e= E.setAdd e

    (*mkSum:sum_indexid list * ein_exp->int *ein_exp
    *distribute summation expression 
    *)
    fun mkSum(sx1,b)=(case b
        of E.Lift e         => (1,E.Lift(E.Sum(sx1,e)))
        | E.Tensor(_,[])    => (1,b)
        | E.B _             => (1,b)
        | E.Opn(E.Prod, es)   => filterSca(sx1,es)
        | _                 => (0,E.Sum(sx1,b))
        (*end case*))

    (*mkprobe:ein_exp* ein_exp-> int ein_exp
    *rewritten probe
    *)
    fun mkprobe(b,x)=let
        val (c,rtn)=(case b
            of (E.B _)              => (0,b)
            | E.Tensor _            => err("Tensor without Lift")
            | E.G _                 => (0,b)
            | E.Field _             => (0,E.Probe(b,x))
            | E.Lift e1             => (1,e1)
            | E.Conv _              => (0,E.Probe(b,x))
            | E.Partial _           => err("Probe Partial")
            | E.Apply _             => (0,E.Probe(b,x))
            | E.Probe _             => err("Probe of a Probe")
            | E.Value _             => err("Value used before expand")
            | E.Img _               => err("Probe used before expand")
            | E.Krn _               => err("Krn used before expand")
            | E.Sum(sx1,e1)         => (1,E.Sum(sx1,E.Probe(e1,x)))
            | E.Op1(op1, e1)        => (1,E.Op1(op1, E.Probe(e1,x)))
            | E.Op2(op2, e1,e2)     => (1,E.Op2(op2, E.Probe(e1,x), E.Probe(e2,x)))
            | E.Opn(opn, [])        => err("Probe of empty operator")
            | E.Opn(opn, es)        => (1,E.Opn(opn, List.map(fn e1=> E.Probe(e1,x)) es))
            (*end case*))
        in
            (c,rtn)
        end

    (* normalize: EIN->EIN
    * rewrite body of EIN
    * note "c" keeps track if ein_exp is changed 
    *)
    fun normalize (ee as Ein.EIN{params, index, body},args) = let
      val changed = ref false
      fun rewrite body =(case body
        of E.B _                                => body
        | E.Tensor _                            => body
        | E.G _                                 => body
                (************** Field Terms **************)
        | E.Field _                             => body
        | E.Lift e1                             => E.Lift(rewrite e1)
        | E.Conv _                              => body
        | E.Partial _                           => body
        | E.Apply(E.Partial [],e1)              => e1
        | E.Apply(E.Partial d1, e1)             =>
            let
            val e2 = rewrite e1
                val (c,e3)=mkapply(E.Partial d1,e2)
                val _= testp["\nafter apply:",P.printbody body,"-->",P.printbody e3]
                in
                    (case c of 1=>(changed:=true;e3)| _ =>e3 (*end case*))
                end
        | E.Apply _                             => err " Not well-formed Apply expression"
        | E.Probe(e1,e2)              =>
            let
                val (c',b')=mkprobe(rewrite e1,rewrite e2)
            in (case c'
                of 1=> (changed:=true;b')
                | _ => b'
                (*end case*))
            end
            (************** Field Terms **************)
        | E.Value _                             => err "Value before Expand"
        | E.Img _                               => err "Img before Expand"
        | E.Krn _                               => err "Krn before Expand"
            (************** Sum **************)

        | E.Sum([],e1)                           => (changed:=true;rewrite e1)
        | E.Sum(sx1,e1)                            => let
                val e2=rewrite e1
                val (c,e')=mkSum(sx1,e2)
                val _= testp["\nafter mksum:\n\t",P.printbody body,"\n\t-->",P.printbody e2,"\n\t-->",P.printbody e']
                in
                (case c of 0 => e'|_ => (changed:=true;e'))
                end
            (*************Algebraic Rewrites Op1 **************)

        | E.Op1(E.Neg,e1)                       => (case e1
            of E.Op1(E.Neg,e2)                  => rewrite e2
            | E.B(E.Const 0)                    =>(changed:=true;zero)
            | _                                 => E.Op1(E.Neg,rewrite e1)
            (*end case*))
        | E.Op1(op1,e1)                         => E.Op1(op1,rewrite e1)
            (*************Algebraic Rewrites Op2 **************)
        | E.Op2(E.Sub,e1,e2)                        => (case (e1,e2)
            of (E.B(E.Const 0),_)                   => (changed:=true;setNeg(rewrite e2))
            | (_,E.B(E.Const 0))                     => (changed:=true;rewrite e1)
            | _                                 => setSub(rewrite e1, rewrite e2)
            (*end case*))
        | E.Op2(E.Div,e1,e2)                        =>(case (e1,e2)
            of (E.B(E.Const 0),_)                    => (changed:=true;zero)
            |(E.Op2(E.Div,a,b), E.Op2(E.Div,c,d))   => rewrite(setDiv(setProd[a,d],setProd[b,c]))
            |(E.Op2(E.Div,a,b), c)   =>  rewrite (setDiv(a, setProd[b,c]))
            | (a,E.Op2(E.Div,b,c))                   => rewrite (setDiv(setProd[a,c],b))
            |  _                        => setDiv(rewrite e1, rewrite e2)
            (*end case*))
            (*************Algebraic Rewrites Opn **************)
        | E.Opn(E.Add,es)          => let
            val (change,body')= mkAdd(List.map rewrite es)
            in if (change=1) then ( changed:=true;body') else body' end

        (*************Product**************)
        | E.Opn(E.Prod,[])                                 => err "missing elements in product"
        | E.Opn(E.Prod,[e1])                               => rewrite e1
        | E.Opn(E.Prod,[E.Op1(E.Sqrt,s1),E.Op1(E.Sqrt,s2)])=>
            if(Eq.isBodyEq(s1,s2)) then (changed :=true;s1)
            else let
                val a=rewrite (E.Op1(E.Sqrt,s1))
                val b=rewrite (E.Op1(E.Sqrt,s2))
                val  (_,d)=mkProd ([a,b])
                in d 
                end
        (*************Product EPS **************)
        | E.Opn(E.Prod,(E.G(E.Epsilon e1)::ps))=> let
            val E.G(E.Epsilon(i,j,k))=E.G(E.Epsilon e1)
            val eps1=E.G(E.Epsilon(i,j,k))
            val p1=List.hd(ps)
            in (case ps
                of (E.Apply(E.Partial d,e)::es)=>let
                    val change= G.matchEps(0,d,[],[i,j,k])
                    in case (change,es)
                        of (1,_)    => (changed:=true; zero)
                        | (_,[])    => setProd[eps1,rewrite p1]
                        | (_,_)     => let
                            val a=rewrite(setProd([p1]@es))
                            val (_,b)=mkProd [eps1,a]
                            in b end
                    end
                | (E.Conv(V,alpha, h, d)::es)=>let
                    val change= G.matchEps(0,d,[],[i,j,k])
                    in case (change,es)
                        of (1,_)    => (changed:=true; E.Lift zero )
                        | (_,[])    => setProd[eps1,p1]
                        | (_,_)     => let
                            val a=rewrite(setProd([p1]@es))
                            val (_,b) = mkProd [eps1,a]
                            in b end
                    end
                | [E.Tensor(_,[E.V i1,E.V i2])] =>
                    if(j=i1 andalso k=i2) then (changed :=true;zero) else body
                | _  => (case (G.epsToDels(eps1::ps))
                    of (1,e,[],_,_)       => (changed:=true;e)(* Changed to Deltas*)
                    | (1,e,sx,_,_)        => (changed:=true;E.Sum(sx,e))
                    | (_,_,_,_,[])        =>  body
                    | (_,_,_,epsAll,rest) => let
                        val p'=rewrite(setProd rest)
                        val(_,b)= mkProd(epsAll@[p'])
                        in b end
                    (*end case*))
                (*end case*))
            end
        | E.Opn(E.Prod,E.Sum(c1,e1)::E.Sum(c2,e2)::es)=>(case (e1,e2,es)
            of (E.Opn(E.Prod,E.G(E.Epsilon e1)::es1),E.Opn(E.Prod,E.G(E.Epsilon e2)::es2),_) =>
                (case G.epsToDels([E.G(E.Epsilon e1), E.G(E.Epsilon e2)]@es1@es2@es)
                    of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e))
                    | (_,_,_,_,_)=>let
                        val eA=rewrite(E.Sum(c1,setProd(E.G(E.Epsilon e1)::es1)))
                        val eB=rewrite(setProd(E.Sum(c2,setProd(E.G(E.Epsilon e2)::es2))::es))
                        val (_,e)=mkProd([eA,eB])
                        in  e end
                (*end case*))
            | (_,_,[]) =>let
                val (_,b)=mkProd[rewrite(E.Sum(c1,e1)), rewrite(E.Sum(c2,e2))]
                in b end
            |  _ =>let
                val e'=rewrite (E.Sum(c1,e1))
                val e2=rewrite(E.Opn(E.Prod,E.Sum(c2,e2)::es))
                val(_,b)=(case e2
                    of E.Opn(E.Prod, p')=> mkProd([e']@p')
                    | _ =>mkProd [e',e2])
                in b end
            (*end case*))
        | E.Opn(E.Prod,E.G(E.Delta d)::es)=> (case es
            of [E.Op1(E.Neg, e1)]=> (changed:=true;setNeg(setProd[E.G(E.Delta d), e1]))
            | _=>   let
                val (pre',eps, dels,post)= filterGreek(E.G(E.Delta d)::es)
                val _= testp["\n\n Reduce delta--",P.printbody(body)]
                val (change,a)=G.reduceDelta(eps, dels, post)
                val _= testp["\n\n ---delta moved--",P.printbody(a)]
                in (case (change,a)
                    of (0, _)=> setProd [E.G(E.Delta d),rewrite(setProd es)]
                    | (_, E.Opn(E.Prod, p))=>let
                        val (_, p') = mkProd p
                        in (changed:=true;p') end
                    | _ => (changed:=true;a )
                    (*end case*))
                end
            (*end case*))
      | E.Opn(E.Prod,[e1,e2])=> let
            val (_,b)=mkProd[rewrite e1, rewrite e2]
            in b end
      | E.Opn(E.Prod,e1::es)=>let
            val e'=rewrite e1
            val e2=rewrite(setProd es)
            val(_,b)=(case e2
                of E.Opn(Prod, p')=> mkProd([e']@p')
                |_=>mkProd [e',e2])
            in b end
    (*end case*))


    val _=testp["\n******** Start Normalize: \n",P.printerE ee,"\n*****\n"]
    fun loop(body ,count) = let
        val _= (concat["\n N =>",Int.toString(count)])
        val body' = rewrite body
        val _=(EqualEin.boolToString(EqualEin.isBodyEq(body,body')))
        in
            if !changed
            then  (changed := false ;loop(body',count+1))
            else (body',count)
        end

    val (b,count) = loop(body,0)
    val _ =testp["\n Out of normalize \n",P.printbody(b),
        "\n    Final CounterXX:",Int.toString(count),"\n\n"]
    in
        (Ein.EIN{params=params, index=index, body=b},count)
    end
end
                

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0