Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/helper.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/helper.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3138 - (download) (annotate)
Thu Mar 26 16:27:35 2015 UTC (4 years, 5 months ago) by cchiw
File size: 8272 byte(s)
lifted sine,cosine,arcsine,arccosine
(*Helper functions
*)
structure Helper = struct
    local

    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure SrcIL = MidIL
    structure SrcOp = MidOps
    structure E = Ein
    structure LowToS= LowToString
    structure FL=FloatLit
    structure IMap = IntRedBlackMap
    in

    val testing=0
    val bV= ref 0
    fun err str=raise Fail(str)
    val realTy=DstTy.TensorTy []
    val intTy=DstTy.intTy 
    fun iTos e1=Int.toString e1
    fun iToss es=String.concat(List.map iTos es)
    fun testp n= (case testing
        of 0 => 0
        | _ =>((print (String.concat n));1)
        (*end case*))

    val empty = IMap.empty
    fun lookup k d = IMap.find(d, k)
    fun insert (k, v) d = IMap.insert(d, k, v)
    fun find (v, mapp) = (case IMap.find(mapp, v)
	   of NONE => raise Fail(concat["Outside Bound(", Int.toString v, ")"])
	    | SOME s => s
	  (* end *))

    (*mapIndex:E.mu * dict-> int
    * lookup
    *)
    fun mapIndex (e1, mapp)=(case e1
        of E.V e => find(e,mapp)
         | E.C c => c
        (*end case*))


    (* intToReal:int->Var*LowIL.ASSGN list
    *)
    fun intToReal n=let
        val a=DstIL.Var.new("real" ,realTy)
        val b=DstIL.Var.new("cast" ,realTy)
        val code=[DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n))),
                  DstIL.ASSGN (b,DstIL.OP(DstOp.IntToReal,[a]))]
        in
            (b,code)
        end



    (* mkINt:int->Var*LowIL.ASSGN list
    *)
    fun mkInt n=let
        val a=DstIL.Var.new("Int" ,intTy)
        val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
        val _ =testp[LowToS.toStringAll(intTy,code)]
        in
            (a,[code])
    end

    (*assgnCons int list * Var list->Var*LowIL.ASSGN list
    * cons elements on list
    *)
    fun assgnCons(pre,shape, args)=let
        val ty=DstTy.TensorTy shape
        val a=DstIL.Var.new("cons"^"_" ,ty)
        val code=DstIL.ASSGN (a,DstIL.CONS(ty ,args))
        val _ =testp[LowToS.toStringAll(ty,code)]
        in
            (a, [code])
        end

    (*LowOps.Op.op * var list*string*LowIL.Ty
    * -> Var*LowIL.ASSGN list
    * Make lowIL assignment
    *)
    fun assgn(opss,args,pre,ty)=let
        val a=DstIL.Var.new(pre ,ty)
        val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
        val _ =testp[LowToS.toStringAll(ty,code)]
        in
            (a,[code])
        end

    (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
    * Integer, or Generic Tensor
    *)
    fun getTensorTy(params, id)=(case List.nth(params,id)
        of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
        | E.TEN(_,shape)=> DstTy.TensorTy shape
        |_=> err"NONE Tensor Param"
        (*end case*))

    (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
    * ->Var*LowIL.ASSGN list
    * Index Tensor at specific indices to give a scalar result
    *)
    fun indexTensor(_,(_,_,args,id, [],ty)) =
        (List.nth(args,id),[])
      | indexTensor(mapp,(lhs,params,args,id,ix,ty))= let
        val nU=List.nth(args,id)
        val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
        val ix'=DstTy.indexTy ixx
        val argTy=getTensorTy(params,id)
        val opp=DstOp.IndexTensor(id,ix',argTy)
        val name=String.concat["Indx_",iToss ixx,"_"]
        in
            assgn(opp,[nU],name,ty)
        end

    (*projTensor:dict*(string* E.params*Var list*int*E.tensor_id*E.alpha)->Var*LowIL.ASSGN list
    * projects tensor to a vector
    *just used by EintoVecOps but made sense to keep it here
    *)
    fun projTensor(_,(_,_,args,_,id,[]))= (List.nth(args,id),[])
    | projTensor(mapp,(lhs,params,args,vecIX,id,ix))= let
        val nU=List.nth(args,id)
        val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
        val ix'=DstTy.indexTy ixx
        val argTy= getTensorTy(params,id)
        val vecTy=DstTy.TensorTy [vecIX]
        val opp=DstOp.ProjectTensor(id,vecIX,ix',argTy)
        val name=String.concat["Proj_",iToss ixx,"_"]
        in
            assgn(opp,[nU],name,vecTy)
        end

    fun mkSqrt(nU,code)= let
        val opp=DstOp.Sqrt
        val name=String.concat["_Sqrt_"]
        val (vA,A)=assgn(opp,[nU],name,realTy)
        in
            (vA,code@A)
        end

    fun mkCosine(nU,code)= let
        val opp=DstOp.Cosine
        val name=String.concat["_Cosine_"]
        val (vA,A)=assgn(opp,[nU],name,realTy)
        in
            (vA,code@A)
        end

    fun mkArcCosine(nU,code)= let
        val opp=DstOp.ArcCosine
        val name=String.concat["_ArcCosine_"]
        val (vA,A)=assgn(opp,[nU],name,realTy)
        in
        (vA,code@A)
        end

    fun mkSine(nU,code)= let
        val opp=DstOp.Sine
        val name=String.concat["_Sine_"]
        val (vA,A)=assgn(opp,[nU],name,realTy)
        in
            (vA,code@A)
        end

    fun mkArcSine(nU,code)= let
        val opp=DstOp.ArcSine
        val name=String.concat["_ArcSine_"]
        val (vA,A)=assgn(opp,[nU],name,realTy)
        in
            (vA,code@A)
        end


    fun mkPowInt((nU,code),n)= let
        val opp=DstOp.powInt
        val name=String.concat["_PowInt_"]
        val (r,rcode)=mkInt n 
        val (vA,A)=assgn(opp,[nU,r],name,realTy)
        in
            (vA,code@rcode@A)
        end

    fun mkPowRat((nU,code),rat)= let
        val opp=DstOp.powRat(DstTy.R rat)
        val name=String.concat["_PowRat_"]

        val (vA,A)=assgn(opp,[nU],name,realTy)
        in
            (vA,code@A)
        end






    (* Some shortcuts. Arguements are Low-IL variables already indexed/projected
    * string*Var list ->Var*LowIL.ASSGN list
    *)
    fun mkAddSca(lhs,args)= assgn(DstOp.addSca,args,"addSca",realTy)
    fun mkAddInt(lhs,args)= assgn(DstOp.addSca,args,"addInt",intTy)
    fun mkAddPtr(lhs,args,ty)= assgn(DstOp.addSca,args,"addPtr",ty)
    fun mkAddVec(lhs,vecIX,args)=assgn(DstOp.addVec vecIX,args,"addV",DstTy.TensorTy([vecIX]))
    fun mkSubSca(lhs,args)= assgn(DstOp.subSca,args,"subSca",realTy)
    fun mkProdSca(lhs,args)=assgn(DstOp.prodSca,args,"prodSca",realTy)
    fun mkProdInt(lhs,args)=assgn(DstOp.prodSca,args,"prodInt",intTy)
    fun mkProdVec(lhs,vecIX,args)=assgn(DstOp.prodVec vecIX,args,"prodV",DstTy.TensorTy([vecIX]))
    fun mkDivSca(lhs,args)= assgn(DstOp.divSca,args,"divSca",realTy)
    fun mkSumVec(lhs,vecIX,args)= assgn(DstOp.sumVec vecIX,args,"sumVec",realTy)
    fun mkDotVec(lhs,vecIX,args)=let
        val (vD, D)=mkProdVec("",vecIX,args)
        val (vE, E)=mkSumVec("",vecIX,[vD])
        in (vE,D@E) end


    (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*LowIL.ASSGN list
    *apply rator between each items on list1
    *)
    fun mkMultiple(lhs,list1,rator,ty)=let
        fun add([],_,_)         = err"no element in mkMultiple"
        | add([e1],_,_)         = (e1,[])
        | add([e1,e2],code,_)     = let
            val (vA,A)=assgn(rator,[e1,e2],"mult_2",ty)
            in  (vA,code@A) end
        | add(e1::e2::es,code,count)  = let
            val (vA,A)=assgn(rator,[e1,e2],String.concat["mult_",iTos count],ty)
            in  add(vA::es,code@A,count-1)
            end
        in
            add(list1,[],List.length list1)
        end

    (* deltaToInt:dict*E.mu*E.mu->int
    * delta function
    *)
    fun deltaToInt(mapp,a,b)= let
        val i=mapIndex(a,mapp)
        val j=mapIndex(b,mapp)
        in
            if(i=j) then 1 else  0
        end

    fun evalDelta(mapp,a,b)= intToReal(deltaToInt(mapp,a,b))

    (*eval Epsilon-2d*)
    fun evalEps2(mapp,a,b)=let
        val i=mapIndex(E.V a,mapp)
        val j=mapIndex(E.V b,mapp)
        in
            if(i=j) then intToReal 0
            else
                if(j>i) then intToReal 1
                else intToReal ~1
        end

    (*eval Epsilon-3d*)
    fun evalEps3(mapp,a,b,c)=let
        val i=mapIndex(E.V a,mapp)
        val j=mapIndex(E.V b,mapp)
        val k=mapIndex(E.V c,mapp)
        in
            if(i=j orelse j=k orelse i=k) then intToReal 0
            else
                if(j>i) then
                    if(j>k andalso k>i) then intToReal ~1 else intToReal 1
                else if(i>k andalso k>j) then intToReal 1 else intToReal ~1

        end

 end

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0