Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/getShape.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/getShape.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2870 - (download) (annotate)
Wed Feb 25 21:47:43 2015 UTC (4 years, 7 months ago) by cchiw
File size: 6043 byte(s)
added sqrt,pow, and examples
(*
aShape(mu list): all the indices mentioned in body
eShape(mu list): possible shape of tensor replacement
tShape(mu list): actual shape of tensor replacement
*)
structure getShape= struct

    local
   
    structure E = Ein
    structure P=Printer

    in

    val testing=0
    fun iTos i =Int.toString i
    fun err str=raise Fail str
    fun flat xs = List.foldr op@ [] xs
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))

    (* ashape:ein_exp -> mu list
    * get all indices used in expression
    *  Σ_3 e. 3 is in aShape  even when 3 doesn't appear in e.
    *)
    fun aShape e=let
        fun iterList es=flat(List.map (fn e1=>aShape   e1) es)
        fun iterSx sx=List.map (fn (v,_,_)=> v) sx
        in (case e
            of E.Tensor(_,alpha)   => alpha
            | E.Add e1             => iterList e1
            | E.Sub(e1,e2)         => iterList[e1,e2]
            | E.Div(e1,e2)         => iterList[e1,e2]
            | E.Sum(sx ,e1)        =>(iterSx sx)@(aShape  e1)
            | E.Prod e1            => iterList e1
            | E.Delta(i,j)         => [i,j]
            | E.Epsilon(i,j,k)     => [E.V i, E.V j, E.V k]
            | E.Eps2(i,j)          => [E.V i, E.V j]
            | E.Neg e1             => aShape e1
            | E.Conv(_,alpha,_,dx) => alpha@dx
            | E.Probe(e1,_)        => aShape e1
            | E.Sqrt e             => aShape e
            | E.PowInt(e1,_)          => aShape e1
            | E.PowReal(e1,_)         => aShape e1
            | E.Const _            => []
            | E.Field _            => []
            | E.Partial _          => []
            | E.Apply _            => []
            | E.Lift _             => []
            | E.Img _              => err  "should not be here"
            | E.Value e1           => err "Should not be here"
            | E.Krn _              => err "Should not be here"
            | E.ConstR _           => []
            (*end case*))
        end

    (*
    * potential tshape for e
    * list of index-ids with potential to be in $\beta$. $\rho \in \gamma$.
    * Right now the eshape/$\rho$ of einexpression e. case e
    *   of $A_\alpha \longrightarrow \rho=\alpha$
    *   |  $e1 +.. \longrightarrow \rho(e1)$.
    *   |  $e1 /.. \longrightarrow \rho(e1)$.
    *   |  $e1 *e2 \longrightarrow$. 
    *       $\rho=\rho(e1) $ and $ \beta=\rho(e2).
    *       \forall i \in \beta $  if $(i \not \in \rho)$ then $ i \in \rho$
    *)
    fun eShape e=let
        val _ =print(String.concat["\n eshape",P.printbody e])
        fun iterList list=let
            val llist=(List.map (fn e1=>eShape  e1) list)
            fun f ([],rest)=rest
              | f(E.C _::es,rest)=f(es,rest)
              | f(e1::es,rest)=(case (List.find (fn x =>  x = e1) rest)
                    of NONE=> f(es,rest@[e1])
                    | SOME _ =>f(es,rest)
                    (*end case*))
            in 
                foldl f (List.hd(llist)) llist
            end
            
        in (case e
            of E.Tensor(_,alpha)   => alpha
            | E.Add (e1::es)       => eShape e1
            | E.Sub(e1,e2)         => eShape e1
            | E.Div(e1,e2)         => eShape e1
            | E.Sum(_ ,e)          => eShape e
            | E.PowInt(e1,_)          => eShape e1
            | E.PowReal(e1,_)         => eShape e1
            | E.Prod e             => iterList e
            | E.Delta(i,j)         => [i,j]
            | E.Epsilon(i,j,k)     => [E.V i, E.V j, E.V k]
            | E.Eps2(i,j)          => [E.V i, E.V j]
            | E.Neg e              => eShape  e
            | E.Probe(E.Conv(_,alpha,_,dx),_) =>  (alpha@dx)
            | E.Sqrt e             => eShape e 
            | E.Const _            => []
            | E.Field _            => []
            | E.Partial _          => []
            | E.Apply _            => []
            | E.Lift _             => []
            | E.Probe (e,_)        => eShape e
            | E.Img _              => err  "should not be here"
            | E.Value e1           => err "Should not be here"
            | E.Krn _              => err "Should not be here"
            | E.ConstR e1          => []
            (*end case*))
        end

    (* tShape: ->mu list
    * get shape of tensor replacement
    * outerAlpha= List of indices supported by original EIN. Created with index and sx.
    * Simply,
    * For every index i in eShape:
    if i in outerAlpha then it is in tShape
    otherwise it must be an index supported by the subexpression alone
    and not in tShape.
    *)
    fun tShape(index,sx,e)=let
        val _=print("\n\n\n Tshape"^P.printbody e)
        val eshape= eShape e
        val outerAlpha=List.map (fn ( v,_,_)=>v) sx
        val n'=(length index)
        val outerAlpha=(case index
            of []=>outerAlpha
            | _=>(List.tabulate(n',fn e=>E.V e))@outerAlpha
            (*end case*))

        val removedup=false
            fun getT([],rest) =rest
            | getT((E.C _)::es,rest)= getT(es,rest)
            | getT( (e1 as E.V v)::es,rest)= (case (List.find (fn x =>  x = e1) outerAlpha)
                of SOME _ =>
                    if (removedup)
                    then (case (List.find (fn x =>  x = e1) rest)
                        of NONE=> getT(es,rest@[e1]) (*remove duplicates*)
                        | SOME _ => getT(es,rest)
                        (*end case*))
                    else getT(es,rest@[e1])
                | NONE =>getT(es,rest)
                (*end case*))
        in
            getT (eshape,[])
        end


    fun tester(aShape,eShape,tShape)=let
        fun getShape(str,es)=print(String.concat[str," : ",
            String.concatWith"," (List.map (fn E.V v=>iTos v|E.C c=>"const"^iTos c) es)])
        val a =getShape("\neShape",eShape)
        val c =getShape("\naShape",aShape)
        val b =getShape("\ntShape",tShape)
        in
            1
        end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0