Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/cleanParam.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/cleanParam.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2867 - (download) (annotate)
Tue Feb 10 06:52:58 2015 UTC (4 years, 5 months ago) by cchiw
File size: 6031 byte(s)
moved split around, added norm to typechecker, added sqrt to ein
(*
*cleanParam.sml cleans the parameters in an EIN expression.
*Cleaning parameters is simple.
*We keep track of all the paramids used in subexpression(getIdCount()),
*remap the param ids(mkMap)
*and choosing the mid-il args that are used, and then lastly rewrites the body.
*)
structure cleanParams= struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure P=Printer
    structure DstV = DstIL.Var

    in
    val testing=0
    fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s
    fun lookup k d = d k
    val empty =fn key =>NONE
    fun err str=raise Fail str
    fun iTos i =Int.toString i
    fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
    fun cnt(DstIL.V{useCnt, ...}) = !useCnt
    fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 2)
    fun use x = (incUse x; x)
    fun getUse e=String.concat["\n\t Varname "^DstIL.Var.name e," -use ",iTos(cnt e)]
    fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))
    
    (*dictionary to lookup mapp*)
    fun lkupIndexSingle(e1,mapp,str)=(case (lookup e1 mapp)
        of SOME l=>l
        | _=> raise Fail(str^iTos(e1))
        (*endcase*))

    (* mkMapp:dict*params*var list ->dict*params*var list
    * countmapp dictionary keeps track of which ids have been used
    * mapp id the dictionary of the new ids
    *)
    fun mkMapp(countmapp,params,args)=let
        val n=(length params)
        val ix=List.tabulate(n, (fn e=> e))
        fun m([],_,mapp,p, _, a, _)=(mapp,p,a)
          | m(i::ix,j,mapp,p, p1::params, a, a1::arg)=(case (lookup i countmapp)
            of SOME _=>let
                val _ =testp["\nInserting into dictionary",iTos i,"=>",iTos j]
                val mapp2=insert(i,j) mapp
                in
                    m(ix,j+1,mapp2,p@[p1],params, a@[a1],arg)
                end
            | _ =>  m(ix,j,mapp,p, params, a, arg)
            (*end case*))
        val _ =testp["ix created up to",iTos(n),"length of params",
                iTos(length params),"lengths of args",iTos(length args)]
        val  (mapp,Nparams,Nargs)=m(ix, 0, empty, [],params,[],args)
        in
            (mapp,Nparams,Nargs)
        end 

    (*getIdCount: ein_exp ->dict
    *rewrite ids in exp using mapp
    *)
    fun getIdCount e=let
        fun rewriteExp(e,mapp)=let 
            fun iterList([],mapp)=mapp
            | iterList(e1::es,mapp)=let
                val mapp1=rewriteExp(e1,mapp)
                in iterList(es,mapp1) end
        in (case e
            of E.Tensor(id,_)      => insert(id,1) mapp
            | E.Conv(v,_,h,_)      => insert(h,1) (insert(v,1) mapp)
            | E.Add e              => iterList (e,mapp)
            | E.Sub(e1,e2)         => iterList ([e1,e2],mapp)
            | E.Div(e1,e2)         => iterList ([e1,e2],mapp)
            | E.Sum(_ ,e1)         => rewriteExp (e1,mapp)
            | E.Prod e             => iterList (e,mapp)
            | E.Probe(e1,e2)       => iterList ([e1,e2],mapp)
            | E.Neg e1             => rewriteExp (e1,mapp)
            | E.Sqrt e1            => rewriteExp (e1,mapp)
            | E.Const _            => mapp
            | E.Field _            => mapp
            | E.Partial _          => mapp
            | E.Apply _            => mapp
            | E.Lift _             => mapp
            | E.Delta(i,j)         => mapp
            | E.Epsilon(i,j,k)     => mapp
            | E.Eps2(i,j)          => mapp
            | E.Img _              => err  "should not be here"
            | E.Value e1           => err "Should not be here"
            | E.Krn _              => err "Should not be here"
            (*end case*))
        end
    in
        rewriteExp(e,empty)
    end

    (*rewriteParam:dict*ein_exp ->ein_exp
    *rewrite ids in exp using mapp
    *)
    fun rewriteParam(mapp,e)=let
        fun getId id  = lkupIndexSingle(id,mapp,"Mapp doesn't have Param Id ")
        fun rewriteExp e=(case e
            of E.Tensor(id,alpha)  => E.Tensor(getId id, alpha)
            | E.Add e              => E.Add(List.map rewriteExp e)
            | E.Sub(e1,e2)         => E.Sub(rewriteExp e1,rewriteExp e2)
            | E.Div(e1,e2)         => E.Div(rewriteExp e1,rewriteExp e2)
            | E.Sum(c ,e)          => E.Sum(c,rewriteExp e)
            | E.Prod e             => E.Prod(List.map rewriteExp e)
            | E.Neg e              => E.Neg(rewriteExp e)
            | E.Sqrt e             => E.Sqrt(rewriteExp e)
            | E.Probe(E.Conv(v,alpha,h,dx),t)         => E.Probe(E.Conv (getId v,  alpha,getId h,dx), rewriteExp t)
            | E.Delta(i,j)         => e
            | E.Epsilon(i,j,k)     => e
            | E.Eps2(i,j)          => e
            | E.Const _            => e
            | E.Field _            => E.Const 0
            | E.Partial _          => E.Const 0
            | E.Apply _            => E.Const 0
            | E.Lift _             => E.Const 0
            | E.Probe _            => E.Const 0
            | E.Conv _             => E.Const 0
            | E.Img _              => raise Fail "should not be here"
            | E.Value e1           => raise Fail"Should not be here"
            | E.Krn _              => raise Fail"Should not be here"
            (*end case*))
        in
            rewriteExp e 
        end

    (* cleanParams:var*ein_exp*param*index* var list ->code
    *cleans params
    *)
    fun cleanParams(y,body,params,index,args)=let
        val _ =testp["\n Cleaning e ",P.printbody body]
        val countmapp=getIdCount body
        val (mapp,Nparams,Nargs)=mkMapp(countmapp,params,args)
        val Nargs =List.map (fn e=> use e) Nargs
        val Nbody=rewriteParam(mapp,body)
        val einapp= assignEinApp(y,Nparams,index,Nbody,Nargs)
        in
            einapp
        end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0