Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/ein/move-sums.sml
ViewVC logotype

View of /branches/ein16/src/compiler/ein/move-sums.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4196 - (download) (annotate)
Wed Jul 13 03:36:42 2016 UTC (4 years ago) by cchiw
File size: 11736 byte(s)
added to synthetic testing
(* Funtions push summations down to necessary expressions
* Normalize rewrites body without rewriting summation
*)

structure SummationEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter

    in

    val testing=0
    fun testp n=(case testing
        of 0=> 1
        | _ =>((String.concat n);1)
        (*end case*))


    fun filterSca e=Filter.filterSca e
    fun rewriteProd e=F.rewriteProd e
    fun rewriteSum e =F.rewriteSum e
    fun rewriteProdSum e=("pre rewrite prod sum";F.rewriteProdSum e)
    fun findIndex(v,searchspace )=List.find (fn x => x = v) searchspace
    fun bodyToStr e= P.printbody(E.body e)

    val zero=E.B(E.Const 0)
    fun setConst e = E.setConst e
    fun setNeg e  =  E.setNeg e
    fun setExp e  =  E.setExp e
    fun setDiv e= E.setDiv e
    fun setSub e= E.setSub e
    fun setProd e= E.setProd e
    fun setAdd e= E.setAdd e

    (*foundSx: sum_indexid list*ein_exp->option index_id
    *Is c in e?. Simple lookup in the searchspace
    *)
    fun foundSx(c,b)=let
        fun sort []= NONE
          | sort(e1::es)= (case foundSx(c, e1)
             of NONE => sort(es)
            |SOME s => SOME s
            (*end case *))
        in (case b
            of  E.B _ => NONE
            | E.Tensor(id,[])           => NONE
            | E.Tensor(id,shape)        => findIndex(c,shape)
            | E.G(E.Delta(i,j))         => findIndex(c, [i,j])
            | E.G(E.Epsilon (i,j,k))    => findIndex(c,[E.V i,E.V j,E.V k])
            | E.G(E.Eps2 (i,j))         => findIndex(c,[E.V i,E.V j])
            | E.Field(id,shape)         => findIndex(c,shape)
            | E.Lift e1                 => foundSx(c,e1)
            | E.Conv(v,[],h,[])         => NONE
            | E.Conv(_ , alpha, _ , dx) => findIndex(c, alpha@dx)
            | E.Partial (shape)         => findIndex(c,shape)
            | E.Apply(e1,e2)            => sort([e1,e2])
            | E.Probe(E.Conv(v,[],h,[]),E.Tensor(t,[])) =>  NONE
            | E.Probe(e1,e2)            => sort([e1,e2])
            | E.Value _                 => NONE
            | E.Img _                   => raise Fail"Img used pre expansion"
            | E.Krn _                   => raise Fail"Krn used pre expansion"
            | E.Sum(_,e1)               => foundSx(c,e1)
            | E.Op1(_,e1)               => foundSx(c,e1)
            | E.Op2(_,e1,e2)            => sort([e1,e2])
            | E.Opn(_,es)               => sort es
        (*end case*))
    end


    (*splitSum: sum_index_id* ein_exp list -> ein_exp list * ein_exp list
    *filters ein_exp list by which ones have sum_index_id c
    *)
    fun splitSum(c,p)= let
        val (v, lb, ub)=c
 val _ = "\n begin splitSumA";
        fun filter(s,keep,[])=(s,keep)
          | filter(s,keep,e1::es)= (case e1
            of E.Opn(E.Prod, p)=> filter(s,keep, p@es)
            | _ =>(case foundSx(v,e1)
                of NONE => filter(s@[e1],keep,es)
                | SOME _  => filter(s,keep@[e1],es)
                (*end case*))
            (*end case*))
 val _ = "\n splitSumC";
        val y= filter([],[],p)
 val _ = "\n splitSumend";
        in y end

    (*splitMultipleSum:sum_index_id list *sum_index_id list *ein_exp list *ein_exp
    * Two summation indices sorts what the binding.
    * Check Tex file for clarity
    * Sum([c1,c2],pre*post)
    * pre  Σ_c2 post 
    *return pre, outer Sum, post
    *)
fun splitMultipleSum(c1,c2,pre,post)=let
        val _ = "\n before cat \n";
        val y=(case (pre,post)
        of (A, [])  => let
            val (pre,post)= splitSum(c1,A)
            (* pre* Σ_c1(post) *)
            in
                (pre,[c1],post)
            end
        | ([],B)    => (case splitSum(c1,B)
            of ([],D)  => ([],[c1]@c2,D)        (* Σ_(c1,c2) D *)
            | (C,[])   => ([],c2,C)             (* Σ_(c2) D *)
            | (C,D)    => ([],c2,C@[rewriteSum([c1],D)]) (* Σ_(c2) C * Σ_(c1,c2) D  *)
            (*end case*))
        | (A,B)     => (case (splitSum(c1,A),splitSum(c1,B))
            of ((C,[]),(E,[]))  => (C ,c2,E)
            | ((C,D),(E,[]))    => (C@[rewriteSum([c1],D)], c2,E)
            | ((C,[]),(E,F))    => (C,c2,(E@[rewriteSum([c1],F)]))
            |  ((C,D),_)=> (C,[c1],D@[rewriteSum(c2,B)])
            (*end case*))
        (*end case*))
       val _ = "\n after cat \n";
        in y end

    (* shiftSum:sum_index_e* ein_exp->ein_exp
    * rewrites embedded summation
    *)
    fun shiftSum(sx,e)= let

        val sx'=List.rev(sx)

        val c2=List.hd(sx')

        val (A,B)=splitSum(c2,e)

        fun double([], outer, pre, post)=  ("\n sorta wolf \n";rewriteProdSum(pre,outer,post))
        | double(c1::cs, outer, pre, post)= let

            val (pre',outer', post')=splitMultipleSum(c1,outer,pre,post)

            val q=double(cs, outer',pre',post')

            in q end

        val out = double(List.tl(sx'),[c2],A,B)


        in
            out
        end

    fun merge e= let
    val _ ="\n before merge"
    fun merge2(E.Sum(sx0,E.Sum(sx1,e)))= E.Sum(sx0@sx1,merge2 e)
      | merge2(E.Opn(E.Prod, es))= E.Opn(E.Prod, (List.map merge2 es))
      | merge2(E.Sum(sx0, e))=E.Sum(sx0,merge2 e)
      | merge2 e= e

    val out=merge2 e
    in out
    end

(* cleanSummation:EIN->EIN
* Rewrites body by moving summation indices around 
*)
fun cleanSummation (Ein.EIN{params, index, body}) = let
    val _ ="\n inside clean summation"
    fun rewriteBody body =(case body
        of E.B _                => body
        | E.Tensor _            => body
        | E.G _                 => body
        | E.Field _             => body
        | E.Lift e1             => let
                val _ = "\ntash:lift0"
                val a=E.Lift(rewriteBody e1)
                val _ = "liftn"
                in a end
        | E.Conv _              => body
        | E.Partial _           => body
        | E.Apply(e1,e2)        => let
            val _ = "TASH:apply0"
            val a =E.Apply(rewriteBody e1, rewriteBody e2)
            val _ = "applyn"
            in a end
        | E.Probe(e1,e2)        => let
            val _ = "TASH:probe0"
            val a=E.Probe(e1, rewriteBody e2)
            val _ = "proben"
            in a end
        | E.Value _             => raise Fail"Value before Expand"
        | E.Img _               => raise Fail"Img before Expand"
        | E.Krn _               => raise Fail"Krn before Expand"
        | E.Sum(sx,E.Opn(E.Prod,[e1])) =>
            let
            val _ = "TASH:a0"
            val a=merge (shiftSum(sx,[rewriteBody e1]))
           val _ = "an"
            in a end
        | E.Sum(sx,E.Opn(E.Prod, [e1,e2]))    =>let
            val a=merge (shiftSum(sx, [rewriteBody e1, rewriteBody e2]))
            in a end
    | E.Sum(sx,E.Opn(E.Prod, e))    =>let
           val _ = "TASH:b0"
            val a=merge (shiftSum(sx,e))
           val _ = "bn"
            in a end
        | E.Sum (sx,e1)         => let
           val _ = "TASH:c0"
            val a=merge (shiftSum(sx,[e1]))
           val _ = "cn"
            in a end
        | E.Op1(op1, e1)        => let
           val _ = "TASH:op10"
            val a=E.Op1(op1, rewriteBody e1)
           val _ = "op1n"
            in a end

        | E.Op2(op2, e1,e2)     => let
           val _ = "TASH:op20"
            val a=E.Op2(op2, rewriteBody e1, rewriteBody e2)
           val _ = "op2n"
            in a end
        | E.Opn(opn, es)        =>let
            val _ = "TASH:opn0"
            val a=E.Opn(opn,List.map (fn e=> rewriteBody e) es)
            val _ = "opnn"
        in a end
        (* end case *))

    val b=rewriteBody body

    in
        (Ein.EIN{params=params, index=index, body=b})
    end

    (*Distribute summation if needed*)
    fun distributeSummation(Ein.EIN{params=params, index=index, body=body})=let
        val changed = ref false
        fun constant()=raise Fail "sum of constant"
        fun rewrite b=(case b
            of (E.B _)                          => b
            | E.Tensor _                        => b
            | E.G _                             => b
            | E.Field _                         => b
            | E.Lift _                          => b
            | E.Conv _                          => b
            | E.Partial _                       => b
            | E.Apply _                         => b
            | E.Probe (e1,e2)                   => E.Probe(rewrite e1, rewrite e2)
            | E.Value _                         => b
            | E.Img  _                          => b
            | E.Krn _                           => b
            | E.Sum(sx,E.Lift e1)               => (changed:=true;(E.Lift(E.Sum(sx,e1))))
            | E.Sum(sx1,E.Sum (sx2,e1))         => (changed:=true; E.Sum (sx1@sx2,e1))
            | E.Sum(sx, E.Op1(op1, e1))         => (changed:=true; E.Op1(op1,E.Sum(sx, e1)))
            | E.Sum(sx, E.Op2(E.Div, e1,e2))    =>
                let
                val e=(case e1
                    of (E.B _)=> (changed:=true; setDiv(e1, E.Sum(sx,e2)))
                    | _      => (changed:=true; E.Sum(sx,setProd [e1,setDiv(setConst 1,rewrite e2)]))
                    (* end case *))
                in (changed:=true;e)
                end
            | E.Sum(sx, E.Op2(op2, e1,e2))      => (changed:=true; E.Op2(op2, E.Sum(sx, e1),E.Sum(sx,e2)))
            | E.Sum(sx, E.Opn(E.Prod, es))      =>
                let
                val p'=List.map (fn e=> rewrite e) es
                val (c,e)=filterSca(sx,p')
                in (case c of 1=> (changed:=true; e) | _=> e ) end
            | E.Sum(sx, E.Opn(opn, es))         => (changed:=true; E.Opn(opn, List.map (fn e1=> E.Sum(sx,e1)) es))
            | E.Sum (sx,_)                     => b
            | E.Op1(op1, e1)                    => E.Op1(op1, rewrite e1)
            | E.Op2(E.Sub, e1,E.B(E.Const 0))   =>(changed:=true; rewrite e1)
            | E.Op2(op2, e1,e2)                 => E.Op2(op2, rewrite e1, rewrite e2)
            | E.Opn(E.Prod, es)                 =>
                (case es
                    of [e1,E.Opn(E.Prod, es)]          => (changed:=true;setProd(e1::es))
                    | [E.Tensor(id0,[]), E.Tensor(id1,[i1]), E.Sum([v],E.Opn(E.Prod,[E.Tensor(id2,[ix2]),E.Tensor(id3,[ix3])]))]
                        => let
                        val e1=E.Sum([v],setProd[E.Tensor(id2,[ix2]),E.Tensor(id3,[ix3])])
                        val e2=setProd[E.Tensor(id0,[]),e1]
                        val e3=setProd[e2, E.Tensor(id1,[i1])]
                        in (changed:=true; e3) end 
                    | _                         => E.Opn(E.Prod,List.map (fn e=> rewrite e) es)
                 (* end case *))
            | E.Opn(opn, es)        => E.Opn(opn, List.map (fn e=> rewrite e) es)
        (*end case*))
        fun loop body  = let
            val body' = rewrite body
            in
                if !changed then  (changed := false ;loop body') else  body'
            end

        val  b = loop body
        in
           Ein.EIN{params=params, index=index, body=b}
        end



    (*distribute and clean summation *)
    fun main ein1=  let
        val ein3 = distributeSummation ein1
        val ein4=cleanSummation ein3
        val ein5 = distributeSummation ein4
        val ein6=cleanSummation ein5
(*
        val _ =testp["\n\n **** Move-sums starts with:\n", bodyToStr ein1,
                "\n\t**Distribute summation :\n",bodyToStr  ein3, "\n\t ** cleanSummation:\n",bodyToStr  ein4,
                "\n\t**Distribute summation :\n",bodyToStr  ein5, "\n\t ** cleanSummation:\n",bodyToStr  ein6]
*)
        in
            ein6
        end


end

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0