Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-il/move-sums.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-il/move-sums.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2608 - (download) (annotate)
Fri May 2 18:04:54 2014 UTC (5 years, 3 months ago) by cchiw
File size: 4619 byte(s)
Quotient rule
(* Funtions push summations downt to necessary expressions *)

structure SummationEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter

    in

    fun rewriteProd A=(case A
        of [A]=> A
        | A => E.Prod A
        (*end case*))

    fun rewriteSum(c,p)= E.Sum(c, rewriteProd p)

    fun embed(A, c1, B, c2, C)=let
        val C'=rewriteSum(c2,C)
        in (A, c1,B@[C'])
        end


    (* return pre, outer Sum, post*)
    fun splitMultipleSum(c1,c2,pre,post)=(case (pre,post)
        of (A, [])  => let
            val (pre,post)= F.splitSum(c1,A)
            in (pre,[c1],post)
            end
        | ([],B)    => (case F.splitSum(c1,B)
            of ([],D)  => ([],[c1]@c2,D)
            | (C,[])   => ([],c2,C)
            | (C,D)    => embed([],c2,C,[c1],D)
            (*end case*))
        | (A,B)     => (case (F.splitSum(c1,A),F.splitSum(c1,B))
            of ((C,[]),(E,[]))  => (C ,c2,E)
            | ((C,D),(E,[]))    => (C@[rewriteSum([c1],D)], c2,E)
            | ((C,[]),(E,F))    => embed(C,c2,E,[c1],F)
            |  ((C,D),_)=> embed(C,[c1],D,c2,B)
            (*end case*))
        (*end case*))

    fun shiftSum(sx,e)= let
        val sx'=List.rev(sx)
        val c2=List.hd(sx')
        val (A,B)=F.splitSum(c2,e)

        fun double([], outer, pre, post)=  rewriteProd(pre@[rewriteSum(outer,post)])
        | double(c1::cs, outer, pre, post)= let
            val (pre',outer', post')=splitMultipleSum(c1,outer,pre,post)
            in double(cs, outer',pre',post')
            end
        in double(List.tl(sx'),[c2],A,B)
        end

(*Move Summation indices around *)
fun cleanSummation (Ein.EIN{params, index, body}) = let
    fun rewriteBody body =(case body
        of E.Const _      => body
        | E.Tensor _      => body
        | E.Field _       => body
        | E.Delta _       => body
        | E.Epsilon _     => body
        | E.Conv _        => body
        | E.Partial _     => body
        | E.Krn _         => raise Fail"Krn before Expand"
        | E.Img _         => raise Fail"Img before Expand"
        | E.Value _       => raise Fail"Value before Expand"
        | E.Neg e         => E.Neg(rewriteBody  e)
        | E.Add es        => E.Add(List.map (fn e=> rewriteBody e) es)
        | E.Sub(e1,e2)    => E.Sub(rewriteBody e1, rewriteBody e2)
        | E.Prod es       => E.Prod(List.map (fn e=> rewriteBody e) es)
        | E.Div(e1,e2)    => E.Div(rewriteBody e1, rewriteBody e2)
        | E.Apply(e1,e2)  => E.Apply(rewriteBody e1, rewriteBody e2)
        | E.Probe(e1,e2)  => E.Probe(e1, rewriteBody e2)
        | E.Lift e        => E.Lift(rewriteBody e)
        | E.Sum(sx,E.Prod e)        =>shiftSum(sx,e)
        | E.Sum _         => body
        (* end case *))

        val b=rewriteBody body 
    in (Ein.EIN{params=params, index=index, body=b})
    end

(*
  fun tester e=print( String.concat["tester \n",P.printerE(e),"===>",P.printerE(cleanSummation2 e)])
val v0=E.V 0
val v1=E.V 1
val v2=E.V 2
val vv0=(v0,0,0)
val vv1=(v1,0,0)
val vv2=(v2,0,0)

val t0=E.Tensor(0,[v0])
val t1=E.Tensor(0,[v1])
val t2=E.Tensor(0,[v2])

val t01=E.Tensor(0,[v0,v1])
val t12=E.Tensor(0,[v0,v1,v2])

val A= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1],E.Prod[t0])}

val B= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1],E.Prod[t1])}

val C= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1],E.Prod[t0,t1,t2])}

val D= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1],E.Prod[t0,t01])}

val E= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1],E.Prod[t1,t01])}

val F= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1],E.Prod[t0,t1,t2,t01])}

val G= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t01,t2])}

val H= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t01,t2])}

val I= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t01])}


val J= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t01,t2,t12])}

val K= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t01,t2,t12])}

val L= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t01,t12])}



val M= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t2,t12])}

val N= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t2,t12])}

val O= E.EIN{params = [], index = [],
body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t12])}




fun Y _=List.map tester [A,B,C,D,E,F,G,H,I,J,K,L,M,N,O ]

fun cleanSummation e=(print "pre";Y 1;cleanSummation2 e)
*)
end



end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0