Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/ein/getShape.sml
ViewVC logotype

View of /branches/ein16/src/compiler/ein/getShape.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3687 - (download) (annotate)
Sun Feb 28 03:46:43 2016 UTC (4 years, 5 months ago) by cchiw
File size: 5827 byte(s)
added flags
(*
aShape(mu list): all the indices mentioned in body
eShape(mu list): possible shape of tensor replacement
tShape(mu list): actual shape of tensor replacement
*)
structure GetShape= struct

    local
   
    structure E = Ein
    structure P=Printer

    in

    val testing=0
    fun iTos i =Int.toString i
    fun err str=raise Fail str
    fun flat xs = List.foldr op@ [] xs
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))

    (* ashape:ein_exp -> mu list
    * get all indices used in expression
    *  Σ_3 e. 3 is in aShape  even when 3 doesn't appear in e.
    *)
    fun aShape b=let
       (* val _ =testp["\n ashape",P.printbody b]*)
        fun iterList es=flat(List.map (fn e1=>aShape   e1) es)
        fun iterSx sx=List.map (fn (v,_,_)=> v) sx
        in (case b
            of E.B _                    => []
            | E.Tensor(_,alpha)         => alpha
            | E.G(E.Delta(i,j))         => [i,j]
            | E.G(E.Epsilon(i,j,k))     => [E.V i, E.V j, E.V k]
            | E.G(E.Eps2(i,j))          => [E.V i, E.V j]
            | E.Field(_,alpha)          => alpha
            | E.Lift e1                 => aShape e1
            | E.Conv(_,alpha,_,dx)      => alpha@dx
            | E.Partial alpha           => alpha
            | E.Apply(E.Partial dx,e1)  => (aShape e1)@dx
            | E.Probe(e1,_)             => aShape e1
            | E.Value e1                => err "Error in Ashape"
            | E.Img _                   => err "Error in Ashape"
            | E.Krn _                   => err "Error in Ashape"
            | E.Sum(sx ,e1)             => (iterSx sx)@(aShape  e1)
            | E.Op1 (_,e1)              => aShape e1
            | E.Op2(_,e1,e2)            => iterList[e1,e2]
            | E.Opn (_,es)              => iterList es
         (*end case*))
        end
    (*
    * potential tshape for e
    * list of index-ids with potential to be in $\beta$. $\rho \in \gamma$.
    * Right now the eshape/$\rho$ of einexpression e. case e
    *   of $A_\alpha \longrightarrow \rho=\alpha$
    *   |  $e1 +.. \longrightarrow \rho(e1)$.
    *   |  $e1 /.. \longrightarrow \rho(e1)$.
    *   |  $e1 *e2 \longrightarrow$. 
    *       $\rho=\rho(e1) $ and $ \beta=\rho(e2).
    *       \forall i \in \beta $  if $(i \not \in \rho)$ then $ i \in \rho$
    *)
    fun eShape b=let
       (* val _ =testp["\n eshape",P.printbody b]*)
        fun iterList list=let
            val llist=(List.map (fn e1=>eShape  e1) list)
            fun f ([],rest)=rest
              | f(E.C _::es,rest)=f(es,rest)
              | f(e1::es,rest)=(case (List.find (fn x =>  x = e1) rest)
                    of NONE=> f(es,rest@[e1])
                    | SOME _ =>f(es,rest)
                    (*end case*))
            in 
                foldl f (List.hd(llist)) llist
            end
        in (case b
            of E.B _   => []
            | E.Tensor(_,alpha)         => alpha
            | E.G(E.Delta(i,j))         => [i,j]
            | E.G(E.Epsilon(i,j,k))     => [E.V i, E.V j, E.V k]
            | E.G(E.Eps2(i,j))          => [E.V i, E.V j]
            | E.Field(_,alpha)          => alpha
            | E.Lift e1                 => eShape e1
            | E.Conv(_,alpha,_,dx)      => alpha@dx
            | E.Partial alpha           => alpha
            | E.Apply(E.Partial dx,e1)  => (eShape e1)@dx
            | E.Probe (e1,_)            => eShape e1
            | E.Value e1                => err "Error in Eshape"
            | E.Img _                   => err "Error in Eshape"
            | E.Krn _                   => err "Error in Eshape"
            | E.Sum(_ ,e1)              => eShape  e1
            | E.Op1 (_,e1)              => eShape e1
            | E.Op2 (_,e1,e2)           => iterList[e1,e2]
            | E.Opn (_,es)              => iterList es
            (*end case*))
        end

    (* tShape: ->mu list
    * get shape of tensor replacement
    * outerAlpha= List of indices supported by original EIN. Created with index and sx.
    * Simply,
    * For every index i in eShape:
    if i in outerAlpha then it is in tShape
    otherwise it must be an index supported by the subexpression alone
    and not in tShape.
    *)
    fun tShape(index,sx,e,eshape)=let
        val outerAlpha=List.map (fn ( v,_,_)=>v) sx
        val n'=(length index)
        val outerAlpha=(case index
            of []=>outerAlpha
            | _=>(List.tabulate(n',fn e=>E.V e))@outerAlpha
            (*end case*))

        val removedup=false
            fun getT([],rest) =rest
            | getT((E.C _)::es,rest)= getT(es,rest)
            | getT( (e1 as E.V v)::es,rest)= (case (List.find (fn x =>  x = e1) outerAlpha)
                of SOME _ =>
                    if (removedup)
                    then (case (List.find (fn x =>  x = e1) rest)
                        of NONE=> getT(es,rest@[e1]) (*remove duplicates*)
                        | SOME _ => getT(es,rest)
                        (*end case*))
                    else getT(es,rest@[e1])
                | NONE =>getT(es,rest)
                (*end case*))
        in
            getT(eshape,[])
        end

    fun tester(aShape,eShape,tShape)=let
        fun getShape(str,es)=(String.concat[str," : ",
            String.concatWith"," (List.map (fn E.V v=>iTos v|E.C(c,_)=>"const"^iTos c) es)])
        val c =getShape("\naShape",aShape)
        val a =getShape("\neShape",eShape)
        val b =getShape("\ntShape",tShape)
        in
            1
        end

    fun getShapes(e,index,sx)= let
        val ashape = aShape e
        val eshape = eShape e
        val tshape = tShape(index,sx,e,eshape)
        val _ = tester(ashape,eshape,tshape)
    in
        (ashape,tshape)
    end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0