Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/step3.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/step3.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2680 - (download) (annotate)
Wed Aug 6 00:51:53 2014 UTC (4 years, 11 months ago) by cchiw
File size: 7498 byte(s)
update
(*Helper function gen-ein*)
structure step3 = struct
    local

    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure SrcIL = MidIL
    structure SrcOp = MidOps
    structure E = Ein
    structure tS= toStringEin

    
    in

val testing=1
val bV= ref 0
fun err str=raise Fail(str)
val Sca=DstTy.TensorTy []
val addR=DstOp.addSca

fun lookup k d = d k
fun q e1=Int.toString e1

fun insert (key, value) d =fn s =>
    if s = key then (print(String.concat[Int.toString(key),"=>",Int.toString(value)]);SOME value)
    else d s


(*Get kernel and Image bindings*)
fun getKernel x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _ ) ,_ ))=> h
    | vb => (err (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
        (* end case *))


fun getImageSrc x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage img, _ )) => img
    | vb => (err (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
    (* end case *))

fun getImageDst x  = (case DstIL.Var.binding x
    of DstIL.VB_RHS(DstIL.OP(DstOp.LoadImage _ ,[ivar])) =>ivar
    | vb => (err (String.concat["\n -- Not an image, ", DstIL.Var.toString x," found ", DstIL.vbToString vb,"\n"]))
    (* end case *))

(*Make assignment*)
fun aaV(opss,args,pre,ty)=let
    val a=DstIL.Var.new(pre ,ty)
    val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
    val _ =(case testing
        of 0=> 1
        | _ => (print(tS.toStringAll(ty,code)); 1)
        (* end case *))
    in
        (a,[code])
    end


fun mkInt n=let
    val a=DstIL.Var.new("Int" ,Sca)
    val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
    val _ =(case testing
        of 0=> 1
        | _ => (print(tS.toStringAll(Sca,code));1)
        (*end case*))
    in
        (a,[code])
    end





(*mk Multiple, Add Ids on list1*)
fun mkMultiple(list1,rator,ty)=let
    fun add(e,code)=(case e
        of []           => err"no element in mkMultiple"
        | [e1]          => (e1,[])
        | [e1,e2]       => let
            val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
            in  (vA,code@A) end
        | (e1::e2::es)  => let
            val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
            in  add(vA::es,code@A)
            end
        (*end case*))
    in
        add(list1,[])
    end


fun mapIndex(e1,mapp)=(case e1
    of E.V e => (case (lookup e mapp)
        of NONE=> err("Outside Bound:"^Int.toString(e))
        |SOME s => s)
    | E.C c=> c
    (*end case*))


(*Integer, or Generic Tensor*)
fun getTensorTy(params, id)=(case List.nth(params,id)
    of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
    | E.TEN(_,shape)=> DstTy.TensorTy shape
    |_=> err"NONE Tensor Param")

fun q e=Int.toString(e)

(*Just added Index options*)
fun mkSca(mapp,(id, [],(params,args)))=let

      val nU=List.nth(args,id)
     in
        (nU,[])
    end
  | mkSca(mapp,(id,ix,(params,args)))= let
    val nU=List.nth(args,id)
    val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
    val ix'=DstTy.indexTy ixx

    val argTy=getTensorTy(params,id)
    (*DstOp.S(id,ix',argTy)*)
    val opp=DstOp.IndexTensor(id,Sca,ix',argTy)
    in
    aaV(opp,[nU],"S"^Int.toString(id),Sca)
    end

(*eval Epsilon*)
fun evalEps(mapp,a,b,c)=let
    val i=mapIndex(E.V a,mapp)
    val j=mapIndex(E.V b,mapp)
    val k=mapIndex(E.V c,mapp)
    in
        if(i=j orelse j=k orelse i=k) then mkInt 0
        else
            if(j>i) then
                if(j>k andalso k>i) then mkInt ~1 else mkInt 1
            else if(i>k andalso k>j) then mkInt 1 else mkInt ~1

    end

(*eval Delta*)
fun evalDelta2(mapp,a,b)= let
    val i=mapIndex(a,mapp)
    val j=mapIndex(b,mapp)
    in
        if(i=j) then mkInt 1  else mkInt 0
    end


fun evalDels(mapp,dels)=let
    fun m(a,b)=if(a=b) then 1 else 0
    fun ij(i,j)=(case (i,j)
        of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
        | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
        | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
        | (E.C a, E.C b)=>m(i,j)
        (*end case*))
    val dels'=List.map ij dels
    in
        List.foldl(fn(x,y)=>x+y) 0 dels'
    end


(*--------------------Vectorization Helper Functions--------------------*)
(*val nextfnArgs=(body,params,args,origargs)*)

fun mkVec(mapp,(id,[],vecIX,(params,args)))= let
    val nU=List.nth(args,id)
    in (nU,[]) end
  | mkVec(mapp,(id,ix,vecIX,(params,args)))= let
    val nU=List.nth(args,id)
    val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
    val argTy= getTensorTy(params,id)
    val vecTy=DstTy.TensorTy [vecIX]
    val opp=DstOp.IndexTensor(id,vecTy,ix',argTy)
    in
        aaV(opp,[nU],"V"^Int.toString(id),vecTy)
    end


(*product of -1 and 1 projection*)
fun mkNegV(mapp,((vA,id,ix),vecIX,info))=let
    val (vB, B)= mkVec(mapp,(id,ix,vecIX,info))
    val (vD, D)=aaV(DstOp.prodScaV vecIX,[vA, vB],"prodScaV",DstTy.TensorTy [vecIX])
    in
        (vD,B@D)
    end

(* Vector Subtraction*)
fun mksubVec(mapp,(id1,ix1,id2,ix2,vecIX,info))= let
    val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
    val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
    val (vD,D)= aaV(DstOp.subVec vecIX ,[vA, vB],"subVec",DstTy.TensorTy [vecIX])
    in
        (vD, A@B@D)
    end

(*Vector Addition *)
fun handleAddVec(mapp,(es,vecIX,info))=let
    fun add([],rest,code)=(rest,code)
    | add((id1,ix1)::es,rest,code)=let
        val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
        in add(es,rest@[vA],code@A)
        end
    val (rest,code)=add(es,[],[])
    val (vA,A)=mkMultiple(rest,DstOp.addVec vecIX,DstTy.TensorTy([vecIX]))
    in
        (vA,code@A)
    end

(*Vector Scaling*)
fun mkprodScaV(mapp,(id1,ix1,id2,ix2,vecIX,info))=let
    val (vA,A)= mkSca(mapp,(id1,ix1,info))
    val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
    val (vD,D)= aaV(DstOp.prodScaV vecIX,[vA, vB],"prodScaV",DstTy.TensorTy([vecIX]))
    in
        (vD,A@B@D)
    end

(*Vector Product*)
fun mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))= let
    val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
    val (vB, B)= mkVec(mapp,(id2,ix2,vecIX,info))
    val (vD, D)=aaV(DstOp.prodVec(vecIX,1),[vA, vB],"prodV",DstTy.TensorTy([vecIX]))
    in
        (vD, A@B@D)
    end

(*Sum of Vector Product*)
val dotVec=0

fun mkprodSumVec(mapp,(id1,ix1,id2,ix2,vecIX, info))=(case dotVec
    of 0 =>let
        val (vD,D)=mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))
        val (vE, E)=aaV(DstOp.sumVec(vecIX,1),[vD],"sumVec",DstTy.realTy)
        in
            (vE, D @E)
        end
    | _ => let 
        val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
        val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
        val (vE,E)=aaV(DstOp.dotVec(vecIX,1),[vA,vB],"dotVec",DstTy.realTy)
        in
            (vE, E)
        end
    (*end case*))


(*Dot Product like summation *)
fun sumDot(mapp, ((E.V v,lb,ub),t))=let

    fun sumI(a,0,rest,code)=let
        val mapp =insert(v, 0) a
        val (vE, E)=mkprodSumVec(mapp,t)
        val rest'=[vE]@rest
        val (vF, F)=mkMultiple(rest',addR,Sca)
        in
            (vF,E@code@F)
        end
    | sumI(a,sx,rest',code')=let
        val mapp =insert(v, (sx+lb)) a
        val (vE, E)=mkprodSumVec(mapp,t)
        in
            sumI(a,sx-1,[vE]@rest',E@code')
        end
    in
        sumI(mapp, (ub-lb), [],[])
    end
    | sumDot _= raise Fail "Non-variable index in summation"

end


end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0