Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/cleanIndex.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/cleanIndex.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2838 - (download) (annotate)
Tue Nov 25 03:40:24 2014 UTC (4 years, 8 months ago) by cchiw
File size: 9647 byte(s)
edit split-ein

(*This approach creates probe expanded terms, and adds params to the end. *)


structure cleanIndex= struct

    local
   
    structure E = Ein
    structure SrcIL = HighIL
    structure DstIL = MidIL
    structure DstV = DstIL.Var
    structure SrcV = SrcIL.Var
    structure P=Printer
    structure cleanP=cleanParams

    in
    val testing=1
    fun itos i =Int.toString i 
    fun insert (key, value) d =(fn s =>
        if s = key then SOME value
        else d s)

    fun lookup k d = d k
    val empty =fn key =>NONE
    fun flat xs = List.foldr op@ [] xs


    fun err str=raise Fail str
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
    (*end case*))

(*dictionary to lookup mapp*)
(*val sizeN=lkupIndexSingle(e1,sizeMapp,"Could not find Size of")*)
fun lkupIndexSingle(e1,mapp,str)=(case (lookup e1 mapp)
    of SOME l=>l
    | _=> raise Fail(str^Int.toString(e1))
    (*endcase*))

fun lkupIndexV(E.V e1,mapp,str)=E.V (lkupIndexSingle(e1,mapp,str))
   |lkupIndexV(E.C e1,mapp,_)=E.C e1

fun lkupIndexSx([],mapp,str)=[]
    |lkupIndexSx((E.V e1,ub,lb)::es,mapp,str)=(case (lookup e1 mapp)
        of SOME l =>[(E.V l,ub,lb)]@lkupIndexSx(es, mapp,str)
        |_ =>[]
    (*end case*))


(* mkIndexMapp:int list, sum_id list*mu list*mu list=>dict*int list *mu list
creates mapp for indices
ashape(mu list): all the indices mentioned in body
eshape(mu list): possible shape of tensor replacement
tshape(mu list): actual shape of tensor replacement, removes duplicated and summation indices from eshape
size(int list): TensorType of tensor replacement
*)
fun mkIndexMapp(index,sx,eshape,ashape)= let


    (*create sizeMapp*)
    fun mkSizeMapp(index,sx)=let
        fun m([],mapp)=mapp
        | m((E.V v, _,ub)::es,mapp)= m(es,insert(v,ub+1) mapp)
        | m( _,_)=err"Non-V-index in sx"
        fun f(_,[],mapp)= mapp
        |f(counter,ix::es,mapp)=f(counter+1,es,insert(counter,ix) mapp)
        val mapp=f(0,index,empty)
    in
        m(sx,mapp)
    end

    (*finds max element in ashape and creates [0,1,2,....,max]*)
    fun mkIXfromMax index=let
        (*get E.V vars in shape*)
        fun peel []=[]
        | peel ((E.V v)::es)=[v]@peel es
        | peel (e1::es)=peel es
       
        val pp=peel ashape
        val max= List.foldl(fn(a,b)=> Int.max(a,b)) ((length index-1)) pp
        val ix=List.tabulate(max+1,(fn e=> e))
        fun getI e=List.map(fn i=>Int.toString i) e
        val _ =testp["\nindices from index:",String.concatWith"," (getI index),"\n Max for map " ,itos max,
            "\n ix: ",String.concatWith"," (getI ix)]
    in
            ix
    end

    (*creates map for indices s*)
    fun mkIndexMapp ix =let 
        fun m(mapp,[],_ ,nsizes)=(mapp,nsizes)
        |m(mapp,e1::es, tocounter,nsizes)=(case (List.find (fn x =>  x = E.V e1) ashape)
            of NONE=>m(mapp,es,tocounter,nsizes)
            | SOME _ =>let
                val dict=insert(e1, tocounter) mapp
                (*val _ =testp["\nInserting into dictionary",itos e1,"=>",itos tocounter]*)
                in
                    m(dict,es,tocounter+1,nsizes@[e1])
                end
            (*end case*))
    in
        m(empty,ix,0,[])
    end

    (* getTshape:mu list->mu list
    * get shape of tensor replacement
    * removed duplicates and sum_ids 
    *)
    fun getTshape(eshape)=let
        val pp=List.map (fn (E.V v,_,_)=>v) sx
        val n=(length index-1)
        val max= List.foldl(fn(a,b)=> Int.max(a,b)) n pp
        (*val _ =testp["Max ix ",Int.toString(max)]*)
        (*remove all summation indices*)
        fun getT([],rest) =rest
        | getT((E.C _)::es,rest)= getT(es,rest)
        | getT( (e1 as E.V v)::es,rest)=
            if(max>=v) then
            (case (List.find (fn x =>  x = e1) rest)
                of NONE=> getT(es,rest@[e1])
                | SOME _ => getT(es,rest)
                (*end case*))
            else  getT(es,rest)
    in
        getT (eshape,[])
    end

    (*prntShape
    *just used for testing
    *)
    fun getShape(str,es)=[str,String.concatWith"," (List.map (fn E.V v=>itos v|E.C c=>"const"^itos c) es)]

    val a =getShape("\neshape",eshape)
    val c =getShape("\nalphaa",ashape)
    val sizeMapp=mkSizeMapp(index,sx)
    val ix =mkIXfromMax index
    val (mapp,_)=mkIndexMapp ix
    val tshape=getTshape eshape
    val b =getShape("\ntshape",tshape)
    (*val _ =prntShape(eshape,tshape,ashape)*)
    val sizes=List.map(fn E.V e1=> lkupIndexSingle(e1,sizeMapp,"Could not find Size of")) tshape
    in
        (mapp, sizes,tshape)
    end

(*getOuterShape:ein_exp=>mu list 
returns shape for replacement T_\alpha
*)
fun getOuterShape e=let
    fun iterList list=flat(List.map (fn e1=>getOuterShape  e1) list)
    in (case e
        of E.Tensor(_,alpha)   => alpha
        | E.Add (e1::es)       => getOuterShape e1
        | E.Sub(e1,e2)         => getOuterShape e1
        | E.Div(e1,e2)         => getOuterShape e1
        | E.Sum(_ ,e)          => getOuterShape e
        | E.Prod e             => iterList e
        | E.Delta(i,j)         => [i,j]
        | E.Epsilon(i,j,k)     => [E.V i, E.V j, E.V k]
        | E.Neg e              => getOuterShape  e
        | E.Conv(_,alpha,_,dx) => alpha@dx
        | E.Probe(e,_)         => getOuterShape  e
        | E.Const _            => []
        | E.Field _            => []
        | E.Partial _          => []
        | E.Apply _            => []
        | E.Lift _             => []
        | E.Img _              => err  "should not be here"
        | E.Value e1           => err "Should not be here"
        | E.Krn _              => err "Should not be here"
        (*end case*))
    end


(*getAllIx:ein_exp -> mu list
*get all indices used in expression 
*)
fun getAllIx e=let

    fun iterList list=flat(List.map (fn e1=>getAllIx   e1) list)
    in (case e
        of E.Tensor(_,alpha)   => alpha
        | E.Add e1              => iterList e1
        | E.Sub(e1,e2)         => iterList[e1,e2]
        | E.Div(e1,e2)         => iterList[e1,e2]
        | E.Sum(_ ,e1)          =>  getAllIx  e1
        | E.Prod e1             => iterList e1
        | E.Delta(i,j)         => [i,j]
        | E.Epsilon(i,j,k)     => [E.V i, E.V j, E.V k]
        | E.Neg e1              =>  getAllIx e1
        | E.Conv(_,alpha,_,dx) => alpha@dx
        | E.Probe(e1,_)         =>  getAllIx   e1
        | E.Const _            => []
        | E.Field _            => []
        | E.Partial _          => []
        | E.Apply _            => []
        | E.Lift _             => []
        | E.Img _              => err  "should not be here"
        | E.Value e1           => err "Should not be here"
        | E.Krn _              => err "Should not be here"
        (*end case*))
    end


(* rewriteIndices:dict*ein_exp ->ein_exp
*  rewrites indices in e using mapp
*)
fun rewriteIndices(mapp,e)=let
    val str="Error indexMapp from expression:"^P.printbody e^"Index"
    fun getAlpha alpha = List.map (fn e=>lkupIndexV(e,mapp,str)) alpha
    fun getIx ix= lkupIndexSingle(ix,mapp,str)
    fun getVx ix= lkupIndexV (ix,mapp,str)
    fun getSx sx =lkupIndexSx(sx,mapp,str) 

    fun rewriteExp e=(case e
        of E.Tensor(id,alpha)  => E.Tensor(id, getAlpha alpha)
        | E.Add e              => E.Add(List.map rewriteExp e)
        | E.Sub(e1,e2)         => E.Sub(rewriteExp e1,rewriteExp e2)
        | E.Div(e1,e2)         => E.Div(rewriteExp e1,rewriteExp e2)
        | E.Sum(sx ,e)          => E.Sum(getSx sx ,rewriteExp e)
        | E.Prod e             => E.Prod(List.map rewriteExp e)
        | E.Neg e              => E.Neg(rewriteExp e)
        | E.Conv(v,alpha,h,dx) => E.Conv (v, getAlpha alpha,h, getAlpha dx)
        | E.Probe(e,t)         => E.Probe(rewriteExp e, rewriteExp t)
        | E.Delta(i,j)         => E.Delta(getVx i,getVx j)
        | E.Epsilon(i,j,k)     => E.Epsilon(getIx  i,getIx  j,getIx k)
        | E.Const _            => E.Const 0 
        | E.Field _            => E.Const 0
        | E.Partial _          => E.Const 0
        | E.Apply _            => E.Const 0
        | E.Lift _             => E.Const 0
        | E.Img _              => raise Fail "should not be here"
        | E.Value e1           => raise Fail"Should not be here"
        | E.Krn _              => raise Fail"Should not be here"
        (*end case*))
    in
        rewriteExp e 
    end


(* isZero:ein_exp->int
*  Checks to see if the entire body is 0. If so returns 1.
*)
fun isZero e=let
    fun iterList []= 1
    | iterList(e1::list)=(case (isZero e1)
        of 0=>0
        | _ => iterList list
        (*end case*))
    in (case e
        of E.Add e             => iterList e
        | E.Sub(e1,e2)         => iterList [e1,e2]
        | E.Div(e1,e2)         => isZero e1
        | E.Sum(_ ,e)          => isZero e
        | E.Prod e             => iterList e
        | E.Neg e              => isZero e
        | E.Probe(e,_)         => isZero e
        | E.Const 0            => 1
        | _                    => 0
    (*end case*))
    end 

(* cleanIndex:ein_exp*int list *sum_id ->mu list *int list*ein_exp
*  cleans index in body
* returns shape of replacement in terms variables, size, and rewritten body
*)
(*val (tshape,sizes,body)=cleanI.cleanIndex(e,index,sx)*)
fun cleanIndex(e,index,sx)=let
    val _ =testp["\n In clean Index einexp",P.printbody e]
    val eshape=getOuterShape e
    val ashape=getAllIx e
    val (mapp,sizes,tshape)=mkIndexMapp(index,sx,eshape,ashape) (*map shape in sizeMapp*)
    val body=rewriteIndices(mapp,e)
    in
            (tshape,sizes,body)
    end


  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0