Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/low-il/helper.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/low-il/helper.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3668 - (download) (annotate)
Sun Feb 7 16:11:21 2016 UTC (3 years, 6 months ago) by cchiw
File size: 10312 byte(s)
DVF
(*Helper functions
*)
structure HelperSet = struct
    local

    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure SrcIL = MidIL
    structure SrcOp = MidOps
    structure E = Ein
    structure LowToS= LowToString
    structure FL=FloatLit
    structure IMap = IntRedBlackMap
    in

    val testing=0
    val valnum = false
    val bV= ref 0
    fun err str=raise Fail(str)
    val realTy=DstTy.TensorTy []
    val intTy=DstTy.intTy 
    fun iTos e1=Int.toString e1
    fun iToss es=String.concat(List.map iTos es)
    fun testp n= (case testing
        of 0 => 0
        | _ =>((print (String.concat n));1)
        (*end case*))
    fun incUse (LowIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    val empty = IMap.empty
    fun lookup k d = IMap.find(d, k)
    fun insert (k, v) d = ((String.concat["\n\t",Int.toString(k),"==>",Int.toString(v)]) ;IMap.insert(d, k, v))
    fun insertP (k, v) d =  ((String.concat["\n\t",Int.toString(k),"==>",Int.toString(v)]) ;IMap.insert(d, k, v))
    fun find (v, mapp) = (case IMap.find(mapp, v)
	   of NONE => raise Fail(concat["Outside Bound(", Int.toString v, ")"])
	    | SOME s => s
	  (* end *))

    (*mapIndex:E.mu * dict-> int
    * lookup
    *)
    fun mapIndex (e1, mapp)=(case e1
        of E.V e => find(e,mapp)
         | E.C(c,_) => c
        (*end case*))

    (*     *************************** DstIL.ASSGN  ****************************  *)
        fun noplugin(opset,ty,lhs,rhs)=let
            (*val _=List.map incUse [lhs]*)
            val codeo=DstIL.ASSGN (lhs,rhs)
            val (a,code)=(lhs,[codeo])
            in
                (opset,a,code)
            end


    fun plugin(opset,ty,lhs,rhs)=(case valnum
        of false => noplugin(opset,ty,lhs,rhs)
        | _ =>
            let
                val (opset,var) = lowSet.filter(opset,(lhs,rhs))
                val _=List.map incUse [lhs] (*FIXME*)
                val codeo=DstIL.ASSGN (lhs,rhs)
                val _ = (LowToS.toStringAll(ty,codeo))
                val (a,code)=(case var
                    of  SOME v=> (testp["\n Found(",DstIL.Var.toString(v),"):",LowToS.toStringAll(ty,codeo)]; (v,[]))
                    | NONE  => (testp["\n Inserting:",LowToS.toStringAll(ty,codeo)];(lhs,[codeo]))

                    (*end case*))
            in
                (opset,a,code)
            end
        (*end case *))

    (*     *************************** DstIL.LIT ****************************  *)

    (* mkINt:int->Var*code list*)
    fun mkInt (opset,n)=let
        val lhs=DstIL.Var.new("Int" ,intTy)
        val rhs=DstIL.LIT(Literal.Int(IntInf.fromInt n))
    in
        plugin(opset,intTy,lhs,rhs)
    end

    fun mkReal (opset,n)=let
        val lhs=DstIL.Var.new("real" ,realTy)
        val rhs=DstIL.LIT(Literal.Int(IntInf.fromInt n))
    in
        plugin(opset,realTy,lhs,rhs)
    end

    (*     *************************** DstIL.CONS ****************************  *)
    fun assgnCons(opset,pre,shape, args)=let
        val ty=DstTy.TensorTy shape
        val lhs = DstIL.Var.new("cons"^"_" ,ty)
        val rhs = DstIL.CONS(ty ,args)
        in
            plugin(opset,ty,lhs,rhs)
        end

    fun assgnConsV(opset,pre,ty, args)=let
        val lhs = DstIL.Var.new("cons"^"_" ,ty)
        val rhs = DstIL.CONS(ty ,args)
        in plugin(opset,ty,lhs,rhs) end


    (*     ***************************  DstIL.OP  ****************************  *)
    fun assignOP(opset,opss,args,pre,ty)=let
        val lhs=DstIL.Var.new(pre ,ty)
        val rhs=DstIL.OP(opss,args)
    in
        plugin(opset,ty,lhs,rhs)
    end

    fun mkSingle(opp,name,(opset,nU,code))=let
        val (opset,vA,A)=assignOP(opset,opp,[nU],name,realTy)
        in
            (opset,vA,code@A)
        end

    (*     *************************** DstOp.IndexTensor ****************************  *)
    (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
    * Integer, or Generic Tensor
    *)
    fun getTensorTy(params, id)=(case List.nth(params,id)
        of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
        | E.TEN(2,shape) =>DstTy.indexTy shape
        | E.TEN(_,shape)=> DstTy.TensorTy shape
        |_=> err"NONE Tensor Param"
        (*end case*))

    (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
    * ->Var*code list
    * Index Tensor at specific indices to give a scalar result
    *)
    fun indexTensor(opset,_,(lhs,params,args,id, []  ,ty)) = (opset,List.nth(args,id),[])
      | indexTensor(_,_,(lhs,params,args,id, [_,_,_],DstTy.TensorTy [_,_,_,_] )) = raise Fail "uneven"
      | indexTensor(opset,mapp,(lhs,params,args,id,ix,ty))= let
        val nU=List.nth(args,id)
        val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
        val ix'= ixx
        val argTy=getTensorTy(params,id)
        val opp=DstOp.IndexTensor(id,ix',argTy)
        val name=String.concat["Indx_",iToss ixx,"_"]
        in
            assignOP(opset,opp,[nU],name,ty)
        end


    (*     *************************** DstOp._ Shortcuts ****************************  *)

    (* Some shortcuts. Arguements are Low-IL variables already indexed/projected
    * string*Var list ->Var*code list
    *)
    fun mkAddSca(opset,args)= assignOP(opset,DstOp.addSca,args,"addSca",realTy)
    fun mkAddInt(opset,args)= assignOP(opset,DstOp.addSca,args,"addInt",intTy)
    fun mkAddPtr(opset,args,ty)= assignOP(opset,DstOp.addSca,args,"addPtr",ty)
    fun mkAddVec(opset,vecIX,args)=assignOP(opset,DstOp.addVec vecIX,args,"addV",DstTy.TensorTy([vecIX]))
    fun mkSubSca(opset,args)= assignOP(opset,DstOp.subSca,args,"subSca",realTy)
    fun mkProdSca(opset,args)=assignOP(opset,DstOp.prodSca,args,"prodSca",realTy)
    fun mkProdInt(opset,args)=assignOP(opset,DstOp.prodSca,args,"prodInt",intTy)
    fun mkProdVec(opset,vecIX,args)=assignOP(opset,DstOp.prodVec vecIX,args,"prodV",DstTy.TensorTy([vecIX]))
    fun mkDivSca(opset,args)= assignOP(opset,DstOp.divSca,args,"divSca",realTy)
    fun mkSumVec(opset,vecIX,args)= assignOP(opset,DstOp.sumVec vecIX,args,"sumVec",realTy)


    (*     *************************** DstOp. Other ****************************  *)
    fun mkDotVec(opset,vecIX,args)=let
        val (opsetD,vD, D)=mkProdVec(opset,vecIX,args)
        val (opsetE,vE, E)=mkSumVec(opsetD,vecIX,[vD])
        in (opsetE,vE,D@E) end

    fun intToReal(setA,n)=let
        val (setC,vC,C)=mkReal(setA,n)
        val (setD,vD,D)=assignOP(setC,DstOp.IntToReal,[vC],"cast",realTy)
        in (setD,vD,C@D)end

    fun mkPowInt((opset,nU,code),nn)= let
        fun pow(1,setA)=(setA,nU,[])
        | pow(2,setA)=let
            val opp=DstOp.prodSca
            val name=String.concat["_Pow2_"]
            val (setB,vB,B)=assignOP(setA,opp,[nU,nU],name,intTy)
            in
                (setB,vB,B)
            end
        | pow(n,setA)=let
            fun half m= let
                val (setB,vB,B)= pow(m div 2,setA)
                val opp=DstOp.prodSca
                val name=String.concat["_Pow",Int.toString(m),"_"]
                val (setC,vC,C)= assignOP(setB,opp,[vB,vB],name,intTy)
                in   (setC,vC,B@C) end
            in if ((n mod 2) = 0)
                then half n
                else let
                val (setC,vC,C)=half(n-1)
                val opp=DstOp.prodSca
                val name=String.concat["_Pow",Int.toString(n),"_"]
                val (setD,vD,D)=assignOP(setC,opp,[nU,vC],name,intTy)
                in
                    (setD,vD,C@D)
                end
            end

        val (setA,vA,A)=pow(nn,opset)
        in
            (setA,vA,code@A)
        end

    fun mkOp1(E.PowInt n,e)     = mkPowInt(e,n)
      | mkOp1(t,e)=let
        val opp=(case t
            of E.Cosine         => DstOp.Cosine
            | E.ArcCosine       => DstOp.ArcCosine
            | E.Sine            => DstOp.Sine
            | E.ArcSine         => DstOp.ArcSine
            | E.Tangent         => DstOp.Tangent
            | E.ArcTangent      => DstOp.ArcTangent
            | E.Sqrt            => DstOp.Sqrt
            | E.Exp             => DstOp.Exp
            (*end case*))
        in  mkSingle(opp,"_op1_",e) end

    (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*code list
    *apply rator between each items on list1
    *)
    fun mkMultiple(opsetM,list1,rator,ty)=let
        fun add(opset,[],_,_)         = err"no element in mkMultiple"
        | add(opset,[e1],_,_)         = (opset,e1,[])
        | add(opset,[e1,e2],code,_)     = let
            val (opsetA,vA,A)=assignOP(opset,rator,[e1,e2],"mult_2",ty)
            in  (opsetA,vA,code@A) end
        | add(opset,e1::e2::es,code,count)  = let
            val (opsetA,vA,A)=assignOP(opset,rator,[e1,e2],String.concat["mult_",iTos count],ty)
            in  add(opsetA,vA::es,code@A,count-1)
            end
        in
            add(opsetM,list1,[],List.length list1)
        end


    (*     *************************** DstOp. Greek ****************************  *)
    (* deltaToInt:dict*E.mu*E.mu->int
    * delta function
    *)
    fun deltaToInt(mapp,a,b)= let
        val i=mapIndex(a,mapp)
        val j=mapIndex(b,mapp)
        in if(i=j) then 1 else  0 end

    fun evalDelta(opset,mapp,a,b)= intToReal(opset,deltaToInt(mapp,a,b))

    (*eval Epsilon-2d*)
    fun evalEps2(opset,mapp,a,b)=let
        val i=mapIndex(E.V a,mapp)
        val j=mapIndex(E.V b,mapp)
        in if(i=j) then intToReal(opset,0)
            else
                if(j>i) then intToReal(opset,1)
                else intToReal(opset, ~1)
        end

    (*eval Epsilon-3d*)
    fun evalEps3(opset,mapp,a,b,c)=let
        val i=mapIndex(E.V a,mapp)
        val j=mapIndex(E.V b,mapp)
        val k=mapIndex(E.V c,mapp)
        in
            if(i=j orelse j=k orelse i=k) then intToReal(opset, 0)
            else if(j>i)
                then if(j>k andalso k>i) then intToReal (opset, ~1) else intToReal(opset,  1)
                else if(i>k andalso k>j) then intToReal(opset,  1) else intToReal(opset,  ~1)
        end
    fun evalG(setG,mapp,b)=(case b
        of  E.Epsilon(i,j,k)    => evalEps3(setG,mapp,i,j,k)
        |   E.Eps2(i,j)         => evalEps2(setG,mapp,i,j)
        |   E.Delta(i,j)        => evalDelta(setG,mapp,i,j)
    (*end case*))
 end

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0