Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-il/move-sums.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-il/move-sums.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2844 - (download) (annotate)
Tue Dec 9 18:05:29 2014 UTC (4 years, 11 months ago) by cchiw
File size: 5652 byte(s)
code cleanup
(* Funtions push summations down to necessary expressions
* Normalize rewrites body without rewriting summation
*)

structure SummationEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter

    in

    fun rewriteProd e=F.rewriteProd e
    fun rewriteSum e =F.rewriteSum e
    fun rewriteProdSum e=F.rewriteProdSum e
    fun findIndex(v,searchspace )=List.find (fn x => x = v) searchspace

    (*foundSx: sum_indexid list*ein_exp->option index_id
    *Is c in e?. Simple lookup in the searchspace
    *)
    fun foundSx(c,e)=let
    fun sort []= NONE
    | sort(e1::es)= (case foundSx(c, e1)
        of NONE => sort(es)
        |SOME s => SOME s
        (*end case *))
    in (case e
        of E.Krn _                  => raise Fail"Krn used pre expansion"
        | E.Img _                   => raise Fail"Img used pre expansion"
        | E.Value _                 => NONE
        | E.Const _                 => NONE
        | E.Tensor(id,[])           => NONE
        | E.Conv(v,[],h,[])         => NONE
        | E.Conv(_ , alpha, _ , dx) => findIndex(c, alpha@dx)
        | E.Probe(E.Conv(v,[],h,[]),E.Tensor(t,[])) =>  NONE
        | E.Probe(e1,e2)            => sort([e1,e2])
        | E.Tensor(id,shape)        => findIndex(c,shape)
        | E.Field(id,shape)         => findIndex(c,shape)
        | E.Delta(i,j)              => findIndex(c, [i,j])
        | E.Epsilon (i,j,k)         => findIndex(c,[E.V i,E.V j,E.V k])
        | E.Eps2 (i,j)              => findIndex(c,[E.V i,E.V j])
        | E.Partial (shape)         => findIndex(c,shape)
        | E.Neg a                   => foundSx(c,a)
        | E.Lift a                  => foundSx(c,a)
        | E.Sum(_,a)                => foundSx(c,a)
        | E.Apply(e1,e2)            => sort([e1,e2])
        | E.Sub(e1,e2)              => sort([e1,e2])
        | E.Div(e1,e2)              => sort([e1,e2])
        | E.Prod a                  => sort a
        | E.Add a                   => sort a
    (*end case*))
    end


    (*splitSum: sum_index_id* ein_exp list -> ein_exp list * ein_exp list
    *filters ein_exp list by which ones have sum_index_id c
    *)
    fun splitSum(c,p)= let
        val (v, lb, ub)=c
        fun filter(s,keep,[])=(s,keep)
        | filter(s,keep,e1::es)= (case e1
            of E.Prod p=> filter(s,keep, p@es)
            | _ =>(case foundSx(v,e1)
                of NONE => filter(s@[e1],keep,es)
                | SOME _  => filter(s,keep@[e1],es)
                (*end case*))
        (*end case*))
    in
            filter([],[],p)
       
    end

    (*splitMultipleSum:sum_index_id list *sum_index_id list *ein_exp list *ein_exp
    * Two summation indices sorts what the binding.
    * Check Tex file for clarity
    * Sum([c1,c2],pre*post)
    * pre  Σ_c2 post 
    *return pre, outer Sum, post
    *)
    fun splitMultipleSum(c1,c2,pre,post)=(case (pre,post)
        of (A, [])  => let
            val (pre,post)= splitSum(c1,A)
            (* pre* Σ_c1(post) *)
            in
                (pre,[c1],post)
            end
        | ([],B)    => (case splitSum(c1,B)
            of ([],D)  => ([],[c1]@c2,D)        (* Σ_(c1,c2) D *)
            | (C,[])   => ([],c2,C)             (* Σ_(c2) D *)
            | (C,D)    => ([],c2,C@[rewriteSum([c1],D)]) (* Σ_(c2) C * Σ_(c1,c2) D  *)
            (*end case*))
        | (A,B)     => (case (splitSum(c1,A),splitSum(c1,B))
            of ((C,[]),(E,[]))  => (C ,c2,E)
            | ((C,D),(E,[]))    => (C@[rewriteSum([c1],D)], c2,E)
            | ((C,[]),(E,F))    => (C,c2,(E@[rewriteSum([c1],F)]))
            |  ((C,D),_)=> (C,[c1],D@[rewriteSum(c2,B)])
            (*end case*))
        (*end case*))

    (* shiftSum:sum_index_e* ein_exp->ein_exp
    * rewrites embedded summation
    *)
    fun shiftSum(sx,e)= let
        val sx'=List.rev(sx)
        val c2=List.hd(sx')
        val (A,B)=splitSum(c2,e)

        fun double([], outer, pre, post)=  rewriteProdSum(pre,outer,post)
        | double(c1::cs, outer, pre, post)= let
            val (pre',outer', post')=splitMultipleSum(c1,outer,pre,post)
            in double(cs, outer',pre',post')
            end
        in double(List.tl(sx'),[c2],A,B)
        end

(* cleanSummation:EIN->EIN
* Rewrites body by moving summation indices around 
*)
fun cleanSummation (Ein.EIN{params, index, body}) = let
    fun rewriteBody body =(case body
        of E.Const _      => body
        | E.Tensor _      => body
        | E.Field _       => body
        | E.Delta _       => body
        | E.Epsilon _     => body
        | E.Eps2 _        => body
        | E.Conv _        => body
        | E.Partial _     => body
        | E.Krn _         => raise Fail"Krn before Expand"
        | E.Img _         => raise Fail"Img before Expand"
        | E.Value _       => raise Fail"Value before Expand"
        | E.Neg e         => E.Neg(rewriteBody  e)
        | E.Add es        => E.Add(List.map (fn e=> rewriteBody e) es)
        | E.Sub(e1,e2)    => E.Sub(rewriteBody e1, rewriteBody e2)
        | E.Prod es       => E.Prod(List.map (fn e=> rewriteBody e) es)
        | E.Div(e1,e2)    => E.Div(rewriteBody e1, rewriteBody e2)
        | E.Apply(e1,e2)  => E.Apply(rewriteBody e1, rewriteBody e2)
        | E.Probe(e1,e2)  => E.Probe(e1, rewriteBody e2)
        | E.Lift e        => E.Lift(rewriteBody e)
        | E.Sum (sx,E.Prod[e]) => shiftSum(sx,[e])
        | E.Sum(sx,E.Prod e)   => shiftSum(sx,e)
        | E.Sum (sx,e)         => shiftSum(sx,[e])
        (* end case *))

        val b=rewriteBody body 
    in (Ein.EIN{params=params, index=index, body=b})
    end
end

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0