Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

View of /branches/ein16/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3682 - (download) (annotate)
Thu Feb 18 20:13:18 2016 UTC (3 years, 6 months ago) by cchiw
File size: 11919 byte(s)
creating stable branch that represents ein ir
structure NormalizeEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter
    structure G=EpsHelpers
    structure Eq=EqualEin
    structure R=RationalEin

    in

    val testing=0
    fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
    fun mkProd e= F.mkProd e
    fun filterSca e=F.filterSca e
    fun mkAdd e=F.mkAdd e
    fun filterGreek e=F.filterGreek e
    fun mkapply e= derivativeEin.mkapply e
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
    (*end case*))

    val zero=E.B(E.Const 0)
    fun setConst e = E.setConst e
    fun setNeg e  =  E.setNeg e
    fun setExp e  =  E.setExp e
    fun setDiv e= E.setDiv e
    fun setSub e= E.setSub e
    fun setProd e= E.setProd e
    fun setAdd e= E.setAdd e

    (*mkSum:sum_indexid list * ein_exp->int *ein_exp
    *distribute summation expression 
    *)
    fun mkSum(sx1,b)=(case b
        of E.Lift e         => (1,E.Lift(E.Sum(sx1,e)))
        | E.Tensor(_,[])    => (1,b)
        | E.B _             => (1,b)
        | E.Opn(E.Prod, es)   => filterSca(sx1,es)
        | _                 => (0,E.Sum(sx1,b))
        (*end case*))

    (*mkprobe:ein_exp* ein_exp-> int ein_exp
    *rewritten probe
    *)
    fun mkprobe(b,x)=let
        val (c,rtn)=(case b
            of (E.B _)              => (0,b)
            | E.Tensor _            => err("Tensor without Lift")
            | E.G _                 => (0,b)
            | E.Field _             => (0,E.Probe(b,x))
            | E.Lift e1             => (1,e1)
            | E.Conv _              => (0,E.Probe(b,x))
            | E.Partial _           => err("Probe Partial")
            | E.Apply _             => (0,E.Probe(b,x))
            | E.Probe _             => err("Probe of a Probe")
            | E.Value _             => err("Value used before expand")
            | E.Img _               => err("Probe used before expand")
            | E.Krn _               => err("Krn used before expand")
            | E.Sum(sx1,e1)         => (1,E.Sum(sx1,E.Probe(e1,x)))
            | E.Op1(op1, e1)        => (1,E.Op1(op1, E.Probe(e1,x)))
            | E.Op2(op2, e1,e2)     => (1,E.Op2(op2, E.Probe(e1,x), E.Probe(e2,x)))
            | E.Opn(opn, [])        => err("Probe of empty operator")
            | E.Opn(opn, es)        => (1,E.Opn(opn, List.map(fn e1=> E.Probe(e1,x)) es))
            (*end case*))
        in
            (c,rtn)
        end

    (* normalize: EIN->EIN
    * rewrite body of EIN
    * note "c" keeps track if ein_exp is changed 
    *)
    fun normalize (ee as Ein.EIN{params, index, body},args) = let
      val changed = ref false
      fun rewrite body =let

        fun prod2(e1, e2,[]) =
        (case (rewrite e1, rewrite e2)
        of (E.B(E.Const 0), e2') => (changed:=true;e2')
        | (e1', E.B(E.Const 0)) => (changed:=true;e1')
        | (e1', e2') => E.Opn(E.Prod,[e1',e2']))
        | prod2(e1, e2,es) = let
        val e2= E.Opn(E.Prod,e2::es)
        in
        (case (rewrite e1, rewrite e2)
        of (E.B(E.Const 0), _) => (changed:=true;E.B(E.Const 0))
        | (_ , E.B(E.Const 0)) => (changed:=true;E.B(E.Const 0))
        | (e1', E.Opn(E.Prod, ps')) => E.Opn(E.Prod, e1'::ps')
        | (e1', e2') =>(changed:=true; E.Opn(E.Prod,[e1',e2']))
        (*end case*))
        end

        in (case body
        of E.B _                                => body
        | E.Tensor _                            => body
        | E.G _                                 => body
                (************** Field Terms **************)
        | E.Field _                             => body
        | E.Lift e1                             => E.Lift(rewrite e1)
        | E.Conv _                              => body
        | E.Partial _                           => body
        | E.Probe(e1,e2)              =>
            let
            val (c',b')=mkprobe(rewrite e1,rewrite e2)
            in (case c'
            of 1=> (changed:=true;b')
            | _ => b'
            (*end case*))
end
        | E.Apply(E.Partial [],e1)              => e1
        | E.Apply(E.Partial d1, e1)             =>
            let
            val e2 = rewrite e1
                val (c,e3)=mkapply(E.Partial d1,e2)
                val _= testp["\nafter apply:",P.printbody body,"-->",P.printbody e3]
                in
                    (case c of 1=>(changed:=true;e3)| _ =>e3 (*end case*))
                end
        | E.Apply _                             => err " Not well-formed Apply expression"

            (************** Field Terms **************)
        | E.Value _                             => err "Value before Expand"
        | E.Img _                               => err "Img before Expand"
        | E.Krn _                               => err "Krn before Expand"
            (************** Sum **************)

        | E.Sum([],e1)                           => (changed:=true;rewrite e1)
        | E.Sum(sx1,e1)                            => let
                val e2=rewrite e1
                val (c,e')=mkSum(sx1,e2)
                val _= testp["\nafter mksum:\n\t",P.printbody body,"\n\t-->",P.printbody e2,"\n\t-->",P.printbody e']
                in
                (case c of 0 => e'|_ => (changed:=true;e'))
                end
            (*************Algebraic Rewrites Op1 **************)
        | E.Op1(E.Neg,e1)                       => (case e1
            of E.Op1(E.Neg,e2)                  => rewrite e2
            | E.B(E.Const 0)                    =>(changed:=true;zero)
            | _                                 => E.Op1(E.Neg,rewrite e1)
            (*end case*))
        | E.Op1(op1,e1)                         => E.Op1(op1,rewrite e1)
            (*************Algebraic Rewrites Op2 **************)
        | E.Op2(E.Sub,e1,e2)                        => (case (e1,e2)
            of (E.B(E.Const 0),_)                   => (changed:=true;setNeg(rewrite e2))
            | (_,E.B(E.Const 0))                     => (changed:=true;rewrite e1)
            | _                                 => setSub(rewrite e1, rewrite e2)
            (*end case*))
        | E.Op2(E.Div,e1,e2)                        =>(case (e1,e2)
            of (E.B(E.Const 0),_)                    => (changed:=true;zero)
            |(E.Op2(E.Div,a,b), E.Op2(E.Div,c,d))   => rewrite(setDiv(setProd[a,d],setProd[b,c]))
            |(E.Op2(E.Div,a,b), c)   =>  rewrite (setDiv(a, setProd[b,c]))
            | (a,E.Op2(E.Div,b,c))                   => rewrite (setDiv(setProd[a,c],b))
            |  _                        => setDiv(rewrite e1, rewrite e2)
            (*end case*))
            (*************Algebraic Rewrites Opn **************)
(*
| E.Opn(E.Add,[E.Sum([(E.V 4,0,2)],E.Opn(E.Prod,[E.Tensor(0,[E.V 4,E.V 0]), E.Tensor(1,[E.V 4,E.V 1])])),E.Opn(E.Prod,[E.Tensor(2,[]),
E.Sum([(E.V 8,0,2)],E.Opn(E.Prod,[E.Tensor(0,[E.V 8,E.V 0]), E.Tensor(4,[E.V 8,E.V 1])]))])]) => let
val a = E.Tensor(0,[E.V 2,E.V 0])
val b = E.Tensor(1,[E.V 2,E.V 1])
val c = E.Tensor(4,[E.V 2,E.V 1])
val s = E.Tensor(2,[])
val add = E.Opn(E.Add,[b,E.Opn(E.Prod,[s,c])])
val prod= E.Sum([(E.V 2,0,2)],E.Opn(E.Prod,[a,add]))
val _ =print(String.concat["\nMatched"])
in prod end
*)
        | E.Opn(E.Add,es)          => let
            val (change,body')= mkAdd(List.map rewrite es)
            in if (change=1) then ( changed:=true;body') else body' end

        (*************Product**************)
        | E.Opn(E.Prod,[])                                 => err "missing elements in product"
        | E.Opn(E.Prod,[e1])                               => rewrite e1
        | E.Opn(E.Prod,[e1 as E.Op1(E.Sqrt,s1),e2 as E.Op1(E.Sqrt,s2)])=>
            if(Eq.isBodyEq(s1,s2)) then (changed :=true;s1)
            else (*let
                val a=rewrite (E.Op1(E.Sqrt,s1))
                val b=rewrite (E.Op1(E.Sqrt,s2))
                val  (_,d)=mkProd ([a,b])
                in d 
                end*)  prod2(e1, e2,[])
        (*************Product EPS **************)
        | E.Opn(E.Prod,(E.G(E.Epsilon e1)::ps))=> let
            val E.G(E.Epsilon(i,j,k))=E.G(E.Epsilon e1)
            val eps1=E.G(E.Epsilon(i,j,k))
            val p1=List.hd(ps)
            in (case ps
                of (E.Apply(E.Partial d,e)::es)=>let
                    val change= G.matchEps(0,d,[],[i,j,k])
                    in case (change,es)
                        of (1,_)    => (changed:=true; zero)
                        | _ => prod2(eps1, p1, es)
                    end
                | (E.Conv(V,alpha, h, d)::es)=>let
                    val change= G.matchEps(0,d,[],[i,j,k])
                    in case (change,es)
                        of (1,_)    => (changed:=true; E.Lift zero )
                        | (_,_)     => prod2(eps1, p1 ,es)
                    end
                | _  => (case (G.epsToDels(eps1::ps))
                    of (1,e,[],_,_)       => (changed:=true;e)(* Changed to Deltas*)
                    | (1,e,sx,_,_)        => (changed:=true;E.Sum(sx,e))
                    | (_,_,_,_,[])        =>  body
                    | (_,_,_,epsAll,[r]) =>  E.Opn(E.Prod,epsAll@[rewrite r])
                    | (_,_,_,epsAll,rest) => (case (rewrite(E.Opn(E.Prod, rest)))
                        of E.Opn(E.Prod, ps')=> E.Opn(E.Prod, epsAll@ ps')
                        | t => (changed:=true; E.Opn(E.Prod,epsAll@[t])))
                    (*end case*))
                (*end case*))
            end
        | E.Opn(E.Prod,E.Sum(c1,e1)::E.Sum(c2,e2)::es)=>(case (e1,e2,es)
            of (E.Opn(E.Prod,E.G(E.Epsilon e1)::es1),E.Opn(E.Prod,E.G(E.Epsilon e2)::es2),_) =>
                (case G.epsToDels([E.G(E.Epsilon e1), E.G(E.Epsilon e2)]@es1@es2@es)
                    of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e))
                    | (_,_,_,_,_)=>
                        let
                            val eA= E.Sum(c1,setProd(E.G(E.Epsilon e1)::es1))
                            val eB= E.Sum(c2,setProd(E.G(E.Epsilon e2)::es2))
                        in prod2(eA, eB, es) end

                    (*end case*))
                | _  => prod2(E.Sum(c1,e1),E.Sum(c2,e2),es)
            (*end case*))
        | E.Opn(E.Prod,E.G(E.Delta d)::es) => let
            val (pre',eps, dels,post)= filterGreek(E.G(E.Delta d)::es)
            val _= testp["\n\n Reduce delta--",P.printbody(body)]
            val (change,a)=G.reduceDelta(eps, dels, post)
            val _= testp["\n\n ---delta moved--",P.printbody(a)]
            in (case (change,a)
                of (0, _)=>  prod2(E.G(E.Delta d), List.hd(es) , List.tl(es))
                | (_, E.Opn(E.Prod, p))=>let
                    val (_, p') = mkProd p
                    in (changed:=true;p') end
                 | _ => raise Fail"impossible"
                (*end case*))
            end
      | E.Opn(E.Prod,[e1,e2])=> let
            val (_,b)=mkProd[rewrite e1, rewrite e2]
            in b end
      | E.Opn(E.Prod,e1::es)=>let
            val e'=rewrite e1
            val e2=rewrite(setProd es)
            val(_,b)=(case e2
                of E.Opn(Prod, p')=> mkProd([e']@p')
                |_=>mkProd [e',e2])
            in b end
    (*end case*))
end


    val _=testp["\n******** Start Normalize: \n",P.printerE ee,"\n*****\n"]
    fun loop(body ,count) = let
        val _= (concat["\n N =>",Int.toString(count)])
        val body' = rewrite body
        val _=(EqualEin.boolToString(EqualEin.isBodyEq(body,body')))
        in
            if !changed
            then  (changed := false ;loop(body',count+1))
            else (body',count)
        end

    val (b,count) = loop(body,0)
    val _ =testp["\n Out of normalize \n",P.printbody(b),
        "\n    Final CounterXX:",Int.toString(count),"\n\n"]
    in
        (Ein.EIN{params=params, index=index, body=b},count)
    end
end
                

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0