Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/step3.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/step3.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2845 - (download) (annotate)
Fri Dec 12 06:46:23 2014 UTC (4 years, 7 months ago) by cchiw
File size: 8657 byte(s)
added norm
(*Helper function gen-ein*)
structure step3 = struct
    local

    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure SrcIL = MidIL
    structure SrcOp = MidOps
    structure E = Ein
    structure LowToS= LowToString
    
    in

    val testing=0
    val bV= ref 0
    fun err str=raise Fail(str)
    val iTy=DstTy.IntTy
    val Sca=DstTy.TensorTy []
    val addR=DstOp.addSca
    fun lookup k d = d k
    fun q e1=Int.toString e1

    fun testp n= (case testing
        of 0=> 0
        | _ =>((print (String.concat n));1)
        (*end case*))

    fun insert (key, value) d =fn s =>
        if s = key then (testp[Int.toString(key),"=>",Int.toString(value)];SOME value)
        else d s

    (*Get kernel and Image bindings*)
    fun getKernel x  = (case SrcIL.Var.binding x
        of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _ ) ,_ ))=> h
        | vb => (err (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
            (* end case *))

    fun getImageSrc x  = (case SrcIL.Var.binding x
        of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage img, _ )) => img
        | vb => (err (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
        (* end case *))

    (*Make assignment*)
    fun aaV(opss,args,pre,ty)=let
        val a=DstIL.Var.new(pre ,ty)
        val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
        val _ =testp[LowToS.toStringAll(ty,code)]
        in
            (a,[code])
        end

    fun mkReal n=let
        val a=DstIL.Var.new("Real" ,Sca)
        val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
        val _ =testp[LowToS.toStringAll(Sca,code)]
        in
            (a,[code])
        end

    fun mkInt n=let
        val a=DstIL.Var.new("Int" ,iTy)
        val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
        val _ =testp[LowToS.toStringAll(iTy,code)]
        in
            (a,[code])
        end

    (*mk Multiple, Add Ids on list1*)
    fun mkMultiple((lhs,_,_,_),list1,rator,ty)=let
        fun add([],_,_)           = err"no element in mkMultiple"
        | add([e1],_,_)         = (e1,[])
        | add([e1,e2],code,_)     = let
                val (vA,A)=aaV(rator,[e1,e2],lhs^"_2",ty)
                in  (vA,code@A) end
        | add(e1::e2::es,code,count)  = let
                val (vA,A)=aaV(rator,[e1,e2],lhs^"_"^Int.toString count,ty)
                in  add(vA::es,code@A,count-1)
                end
        in
            add(list1,[],List.length list1)
        end

    fun mapIndex(e1,mapp)=(case e1
        of E.V e => (case (lookup e mapp)
            of NONE=> err("Outside Bound:"^Int.toString(e))
            |SOME s => s)
        | E.C c=> c
        (*end case*))

    (*Integer, or Generic Tensor*)
    fun getTensorTy(params, id)=(case List.nth(params,id)
        of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
        | E.TEN(_,shape)=> DstTy.TensorTy shape
        |_=> err"NONE Tensor Param")

    fun q e=Int.toString(e)

    fun mkSca(mapp,(id, [],(lhs,params,_,args))) = (List.nth(args,id),[])
      | mkSca(mapp,(id,ix,(lhs,params,_,args)))= let
        val nU=List.nth(args,id)
        val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
        val ix'=DstTy.indexTy ixx
        val argTy=getTensorTy(params,id)
        val opp=DstOp.IndexTensor(id,ix',argTy)
        in
        aaV(opp,[nU],lhs^"R"^Int.toString(id),Sca)
        end

    fun mkIntAsn(_,(id, [],(_,_,_,args))) = (List.nth(args,id),[])
        | mkIntAsn(mapp,(id,ix,(lhs,params,_,args)))= let
        val nU=List.nth(args,id)
        val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
        val ix'=DstTy.indexTy ixx
        val argTy=getTensorTy(params,id)
        val opp=DstOp.IndexTensor(id,ix',argTy)
        in
            aaV(opp,[nU],lhs^"I"^Int.toString(id),iTy)
        end

    (*eval Epsilon*)
    fun evalEps(mapp,a,b,c)=let
        val i=mapIndex(E.V a,mapp)
        val j=mapIndex(E.V b,mapp)
        val k=mapIndex(E.V c,mapp)
        in
            if(i=j orelse j=k orelse i=k) then mkReal 0
            else
                if(j>i) then
                    if(j>k andalso k>i) then mkReal ~1 else mkReal 1
                else if(i>k andalso k>j) then mkReal 1 else mkReal ~1

        end

    (*eval Epsilon*)
    fun evalEps2(mapp,a,b)=let
        val i=mapIndex(E.V a,mapp)
        val j=mapIndex(E.V b,mapp)
        in
            if(i=j) then mkReal 0
            else
                if(j>i) then mkReal 1
                else mkReal ~1
        end

    (*eval Delta*)
    fun evalDelta2(mapp,a,b)= let
        val i=mapIndex(a,mapp)
        val j=mapIndex(b,mapp)
        in
            if(i=j) then mkReal 1  else mkReal 0
        end

    fun evalDels(mapp,dels)=let
        fun m(a,b)=if(a=b) then 1 else 0
        fun ij(i,j)=(case (i,j)
            of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
            | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
            | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
            | (E.C a, E.C b)=>m(i,j)
            (*end case*))
        val dels'=List.map ij dels
        in
            List.foldl(fn(x,y)=>x+y) 0 dels'
        end

    (*--------------------Vectorization Helper Functions--------------------*)
    (*val nextfnArgs=(body,params,args,origargs)*)
    fun mkVec(mapp,(id,[],vecIX,(lhs,params,_,args)))= let
        val nU=List.nth(args,id)
        in (nU,[]) end
      | mkVec(mapp,(id,ix,vecIX,(lhs,params,_,args)))= let
        val nU=List.nth(args,id)
        val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
        val argTy= getTensorTy(params,id)
        val vecTy=DstTy.TensorTy [vecIX]
        val opp=DstOp.ProjectTensor(id,vecIX,ix',argTy)
        in
            aaV(opp,[nU],lhs^"V"^Int.toString(id),vecTy)
        end

    (*product of -1 and 1 projection*)
    fun mkNegV(mapp,((vA,id,ix),vecIX,info as (lhs,_,_,_)))=let
        val (vB, B)= mkVec(mapp,(id,ix,vecIX,info))
        val (vD, D)=aaV(DstOp.prodScaV vecIX,[vA, vB],lhs^"prodScaV",DstTy.TensorTy [vecIX])
        in
            (vD,B@D)
        end

    (* Vector Subtraction*)
    fun mksubVec(mapp,(id1,ix1,id2,ix2,vecIX,info as (lhs,_,_,_)))= let
        val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
        val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
        val (vD,D)= aaV(DstOp.subVec vecIX ,[vA, vB],lhs^"subVec",DstTy.TensorTy [vecIX])
        in
            (vD, A@B@D)
        end

    (*Vector Addition *)
    fun handleAddVec(mapp,(es,vecIX,info))=let
        fun add([],rest,code)=(rest,code)
        | add((id1,ix1)::es,rest,code)=let
            val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
            in add(es,rest@[vA],code@A)
            end
        val (rest,code)=add(es,[],[])
        val (vA,A)=mkMultiple(info,rest,DstOp.addVec vecIX,DstTy.TensorTy([vecIX]))
        in
            (vA,code@A)
        end

    (*Vector Scaling*)
    fun mkprodScaV(mapp,(id1,ix1,id2,ix2,vecIX,info as (lhs,_,_,_)))=let
        val (vA,A)= mkSca(mapp,(id1,ix1,info))
        val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
        val (vD,D)= aaV(DstOp.prodScaV vecIX,[vA, vB],lhs^"prodScaV",DstTy.TensorTy([vecIX]))
        in
            (vD,A@B@D)
        end

    (*Vector Product*)
    fun mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info as (lhs,_,_,_)))= let
        val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
        val (vB, B)= mkVec(mapp,(id2,ix2,vecIX,info))
        val (vD, D)=aaV(DstOp.prodVec vecIX,[vA, vB],lhs^"prodV",DstTy.TensorTy([vecIX]))
        in
            (vD, A@B@D)
        end

    (*Sum of Vector Product*)
    fun mkprodSumVec(mapp,(id1,ix1,id2,ix2,vecIX, info as (lhs,_,_,_)))=let
        val (vD,D)=mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))
        val (vE, E)=aaV(DstOp.sumVec vecIX,[vD], lhs^"sumVec",DstTy.realTy)
        in
            (vE, D @E)
        end

    (*Dot Product like summation *)
    fun sumDot(mapp, ((E.V v,lb,ub),t as (_,_,_,_,_,info) ))=let

        fun sumI(a,0,rest,code)=let
            val mapp =insert(v, 0) a
            val (vE, E)=mkprodSumVec(mapp,t)
            val rest'=[vE]@rest
            val (vF, F)=mkMultiple(info,rest',addR,Sca)
            in
                (vF,E@code@F)
            end
        | sumI(a,sx,rest',code')=let
            val mapp =insert(v, (sx+lb)) a
            val (vE, E)=mkprodSumVec(mapp,t)
            in
                sumI(a,sx-1,[vE]@rest',E@code')
            end
        in
            sumI(mapp, (ub-lb), [],[])
        end
        | sumDot _= raise Fail "Non-variable index in summation"

end


end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0