Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/tree-il/lowOp-to-treeOp.sml
ViewVC logotype

View of /branches/ein16/src/compiler/tree-il/lowOp-to-treeOp.sml

Parent Directory Parent Directory | Revision Log Revision Log

Revision 3682 - (download) (annotate)
Thu Feb 18 20:13:18 2016 UTC (4 years, 5 months ago) by cchiw
File size: 16550 byte(s)
creating stable branch that represents ein ir
(*This function transitions low-il operators to tree-il operators
    When there is a LowIL vector op then it breaks it into HW supported TreeIL vector operation
    i.e. A6+B6=> Mux[A4+B4, A2+B2]
    The following variables are used
    isAligned/A:bool-Is the array aligned
    isFill:bool-Is the vector filled with zeros i.e. length 3 vectors represented with length 4
    nSize: int -The Size of the vector operation, (4)
    oSize:int-The Size of the orginal arguments in vector operation if less than new Size.  (3)
    pieces: sizes of vector operations. i.e.2->[2], 6->[4,2]   3->[4]

(* FIXME: add signature!!! *)
signature lott= sig
val  vecToTree: (LowOps.rator)  * LowIL.var-> TreeIL.stm list * TreeIL.stm list *LowIL.var *(LowIL.var*TreeIL.exp * TreeIL.var) list 

structure LowOpToTreeOp =
      structure Src = LowIL
      structure SrcOp = LowOps
      structure SrcTy = LowILTypes
      structure SrcV = Src.Var
      structure DstOp = TreeOps
      structure DstTy = TreeILTypes
      structure Dst = TreeIL
      structure DstSV=Dst.StateVar
      structure DstV = Dst.Var
      structure Ty=TreeIL.Ty

    val testing=0
    fun testp t=(case testing
        of 1=>(print(String.concat t);1)
         |_ =>1
    (*end case*))

    val newVar = Dst.Var.new
    val cnt = ref 0
    fun genName prefix = let
        val n = !cnt
            cnt := n+1;
            String.concat[prefix, "_", Int.toString n]
    fun iTos n =Int.toString n
    fun newLocalWithTy (name,n)= Dst.Var.new(genName("l_"^iTos n^name), Dst.VK_Local,Ty.TensorTy [n])

    (* isAlignedLoad: bool* ty->bool
    * Do we Load this array assuming it's aligned?
    * Decides if isAligned is true when creating E_load
    * Currently, it is only true when the vector argument can be represent in the HW
    * Fix here
    fun isAlignedLoad(isFill,Ty.TensorTy [_]) =(case isFill
        of true => false
        | _     => true
        (*end case*))
    | isAlignedLoad _ =false

    (* isAlignedStore: bool*int->bool
    * Do we Store this array assuming it's aligned?
    * Decides if isAligned is true when creating S_Store and E_Mux
    * Currently, it is false always 
    * Fix here
    fun isAlignedStore(isFill,_) =(case isFill
        of true => false
        | _     => true
    (*end case*))
    (*|isAlignedStore _ =false
    (*mkStmt:TreeIL.Var* bool*int*ty*TreeIL.Exp list ->TreeIL.stmt
    * makes final TreeIL.Stmt
    * When TreeIL.Op is a sumVec is then all ops are added
    * If lhs is a global then S_Store is used
    * Otherwise uses S_Assign
    fun mkStmt(lhs,isFill,oSize,pieces,ops)= let
        val alignedStore= isAlignedStore(isFill,length pieces)
        val splitTy= pieces
        in  (case (DstV.kind lhs)
            of TreeIL.VK_Local
                =>  Dst.S_Assign([lhs], Dst.E_Mux(alignedStore,isFill,oSize,splitTy,ops))
            | _
                => Dst.S_StoreVec(Dst.E_Var lhs,0,alignedStore,isFill,oSize,Ty.TensorTy[oSize],splitTy, ops)
            (*end case*))

    *  Gets the argument for the operation 
    *  if the argument is a mux then we get the piece needed for the operation
    *  if the argument is a local var then we assume it is the right size
    *  if the argument is a global var or state var then we load the array with offset and sizes
    *  otherwise the arg was not loaded probably and should produce an error
    fun getArg(isFill,lhsKind,t,count, nSize, oSize,offset)=(case t
        of (Dst.E_Var a) =>(case (DstV.kind a,oSize)
            of (TreeIL.VK_Local,_) =>t
            | (_,_)  => Dst.E_LoadArr(isAlignedLoad(isFill,DstV.ty a) ,nSize,oSize,t, Dst.E_Lit(Literal.Int offset))
            (*end case*))
        | (Dst.E_State a)=>
                Dst.E_LoadArr(isAlignedLoad(isFill,DstSV.ty a) ,nSize,oSize,t, Dst.E_Lit(Literal.Int offset))
        | (Dst.E_Mux(_,_,_,_,ops))=> List.nth(ops,count)
        | a1 =>( print(String.concat["Warning argument to vector operation is: ",Dst.toString a1]);a1)
        (*end case*))

    fun Bbind(srct,t,lhs,n,expld)=
        if(2>(SrcV.useCount srct))
        then (expld,[],[],[]) (*don't bind*)
        else let
            val vs=newLocalWithTy((SrcV.name srct)^"ldarr",n)
            val expvar=Dst.E_Var vs
            val stmtn=[Dst.S_Assign([vs] ,expld)]
            val needtobind=[(srct,expvar,vs)]         (*needs to be added to env*)


    fun getArg2(srct,dstt,lhs,isFill,count, nSize, oSize,offset)=let
        val offsetlit=Dst.E_Lit(Literal.Int offset)
            val _= ("\n pre-getArg2: Exp-"^Dst.toString dstt)
        fun ldArr vTy=  let
            val allignedld=isAlignedLoad(isFill,vTy)
            val expld=Dst.E_LoadArr(allignedld ,nSize,oSize,dstt, offsetlit)
            val (expvar,stmtn,vs,needtobind)=Bbind(srct,dstt,lhs,nSize,expld)
        fun sort e =(e,[],[],[])

        val y=(case dstt
            of  (Dst.E_Var a) =>( "\n*getArg E_Var"; case (DstV.kind a,oSize)
                of (TreeIL.VK_Local,_) => sort dstt
                | (TreeIL.VK_Input ,_)  =>
                    sort (Dst.E_LoadArr(isAlignedLoad(isFill,DstV.ty a) ,nSize,oSize,dstt,offsetlit))
                | (_,_)  => (*ldArr(DstV.ty a)-problem with illust-vr*)
                    sort (Dst.E_LoadArr(isAlignedLoad(isFill,DstV.ty a) ,nSize,oSize,dstt,offsetlit))
                (*end case*))
            | (Dst.E_State a)=>(*ldArr(DstSV.ty a)*)
                sort (Dst.E_LoadArr(isAlignedLoad(isFill,DstSV.ty a) ,nSize,oSize,dstt,offsetlit))
            | (Dst.E_Mux(_,_,_,_,ops))=> sort(List.nth(ops,count))
            | a1 =>( print(String.concat["Warning argument to vector operation is: ",Dst.toString a1]); sort a1)
            (*end case*))


    fun sortArg([(srca,dsta)],x) =
            val _="\n sortArg1"
            val (lhs,isFill,count, nSize, oSize, offset)=x
            val (exp0,stmt0,v0,bind0)  = getArg2(srca,dsta,lhs,isFill,count, nSize, oSize, offset)
            ([exp0], stmt0, v0,bind0)
      | sortArg([(srca,dsta),(srcb,dstb)],x) =
            val _= "\n sortArg2"
            val (lhs,isFill,count, nSize, oSize, offset)=x
            val (exp0,stmt0,v0,bind0)  = getArg2(srca,dsta,lhs,isFill,count, nSize, oSize, offset)
            val (exp1,stmt1,v1,bind1)  = getArg2(srcb,dstb,lhs,isFill,count, nSize, oSize, offset)
                ([exp0,exp1], stmt0@stmt1, v0@v1,bind0@bind1)



    (*Special functions needed to transform specific operators*)
    * ProjectLast is an op that slices a vector from a higher order tensor
    * That vec could be in pieces
    * So we load each one and the next step wraps a mux around it
    * We can assume the higher order tensor is an array
    * This rewrites the ProjectLast op as an E_LoadArr.
    *Otherwise hand off to above function
    fun transformProjectLast(origrator,argVec,x) = let
        val (lhs,oSize,_,(isFill,_,pieces))=x
        val SrcOp.ProjectLast(_,_,indTy,argTy) = origrator
        val alignedLd=isAlignedLoad(isFill,argTy)
        (*shift based on type of argument*)
        val shift=(case (indTy,argTy)
            of ([i],Ty.TensorTy[_,m])=> i*m
            | ([i,j],Ty.TensorTy[_,m,n])=> (i*n*m)+(j*n)
            | ([i,j,k],Ty.TensorTy[_,m,n,p])=> (i*m*n*p)+(j*p*n)+(k*p)
            |  _ =>raise Fail "ProjectLast Tensor of a unhandled size" (*add this later*)
            (*end case *))
        fun mkLoad ([], _, code) = code
          | mkLoad (nSize::es, offset, code)=
                mkLoad (es, offset + IntInf.fromInt nSize,
                code@[Dst.E_LoadArr(alignedLd, nSize,oSize,argVec, Dst.E_Lit(Literal.Int offset))])
        val ops=  mkLoad (pieces, IntInf.fromInt shift, [])
        val stmt0= mkStmt(lhs,isFill,oSize,pieces,ops)
        val stmtn=[]
        val needtobind=[]

    fun transformProjectFirst(origrator,argVec,x)=  let
        val (lhs,oSize,_,(isFill,_,pieces))=x
        val SrcOp.ProjectFirst(_,vecIX,[i],Ty.TensorTy[argTy,argTyX])= origrator
        fun f cnt = Dst.E_Op(DstOp.IndexTensor(false,[cnt,i],Ty.TensorTy[argTy,argTyX]),[argVec])
        val indops=List.tabulate(vecIX, fn e=> f e)
        val ops=[Dst.E_Cons(oSize, oSize,indops)]
        val stmt0= mkStmt(lhs,isFill,oSize,pieces,ops)
        val stmtn=[]
        val needtobind=[]

    fun transformSumVec(argVec,x)= let
        val (lhs,oSize,_,(isFill,_,pieces))=x

        val lhsKind=DstV.kind lhs
        val ops=(case pieces
            of [nSize]  => let
                val op2=DstOp.sumVec([nSize],oSize)
                val arg2=getArg(isFill,lhsKind,argVec,0, nSize, oSize, 0)
                    [Dst.E_Op(op2, [arg2])]
            | _ => let
                (*createOps:int list*int*int*DstIL.exp->DstIL.exp list
                *  Gets all the Arguments in order
                *   i.e. nsize=[4,2]=>  [[A4,B4],[A2,B2]]
                fun createOps ([], _,_, code) = code
                | createOps (nSize::es, count,offset, code)=let
                    val argsLd= getArg(isFill,lhsKind,argVec,count, nSize, oSize, offset)
                    val exp = (nSize,[argsLd])
                        createOps (es, count+1,offset + IntInf.fromInt nSize, code@[exp])
                val indexAt=IntInf.fromInt 0
                val code=createOps (pieces, 0,indexAt, [])
                val args=List.foldr op@ [] (List.map (fn(_,args)=>args) code)
            (*end case*))
        fun addSca [e1]=e1
         | addSca(e1::e2::es)= addSca([Dst.E_Op(DstOp.addSca,[e1,e2])]@es)

        val stmt0=  Dst.S_Assign([lhs], addSca ops)
        val stmtn=[]
        val needtobind=[]

    (* mkGenericOps:
    TreeIL.Var*TreeILExp list *TreeILExp list * bool * int * int list *LowILOP ->TreeIL.stmt
    Take the original vector op and breaks it "pieces" which are HW supported sizes
    gets arguments with getArg().
    and then puts them inside TreeIL op ("rator")
    special case sumVecOp, because it maintains new and old size.

    fun transformGenVec(argsV,argsS,x) =let

        val (lhs,oSize,SOME dstrator,(isFill,_,pieces))=x

        val _= "\n pre-transformGenVec"
        fun createOps ([], _,_, code,stmtn,vs,needtobind) = (code,stmtn,vs,needtobind)
          | createOps (nSize::es, count,offset, code,stmtn,vs,needtobind)=
                val x1 =(lhs,isFill,count, nSize, oSize, offset)
                val (exp0, stmtn0,v0,needtobind0)=   sortArg(argsV,x1)
                val code0 = (nSize, argsS@exp0)
                createOps (es, count+1,offset + IntInf.fromInt nSize, code@[code0],stmtn@stmtn0,vs@v0,needtobind@needtobind0)
        val indexAt=IntInf.fromInt 0
        val (code,stmtn,vs,needtobind)=createOps (pieces, 0,indexAt, [],[],[],[])
         val _= "\n post-transformGenVec"
        val ops= List.map (fn(nSize,args)=> Dst.E_Op(dstrator nSize,args)) code
        val stmt0= mkStmt(lhs,isFill,oSize,pieces,ops)


    fun vecToTree(e1,x) =(case e1
        of (SOME(SrcOp.ProjectLast e),[(_,argVec)])              => transformProjectLast(SrcOp.ProjectLast e,argVec,x)
        |  (SOME(SrcOp.ProjectFirst e),[(_,argVec)])             => transformProjectFirst(SrcOp.ProjectFirst e,argVec,x)
        |  (SOME(SrcOp.sumVec _),[(_,argVec)])                   => transformSumVec(argVec,x)
        |  (SOME(SrcOp.prodScaV _),(_,argSca)::argVecs)          => transformGenVec(argVecs, [argSca],x)
        |  (SOME(SrcOp.Lerp(Ty.TensorTy[_])),[a,b,(_,argSca)])   => transformGenVec([a,b],[argSca],x)
        |  (NONE,argVecs)                                        => transformGenVec(argVecs,[],x)
        (*end case *))

    (* consVecToTree:int*int*int list*TreeIL.Exp* bool->TreeIL.Exp
     * Takes Cons of a vector and returns TreeIL.exp inside E_Mux.
     *  When isFill is true, creates zeros
    fun consVecToTree(_,oSize,[nSize],args,true)= let
        val nArg=length(args)
        val n=nSize-nArg
        val newArgs=List.tabulate(n, (fn _=>Dst.E_Lit(Literal.Int 0)))
        val op1=Dst.E_Cons(nSize, oSize,args@newArgs)
        val aligned= isAlignedStore(true,1)
        val splitTy= [nSize]
    | consVecToTree(_,_,_,_,true)= raise Fail"In ConsVecToTree-isFill with more than 1 piece"
    | consVecToTree(nSize,oSize,pieces,args,isFill)= let
        val aligned= isAlignedStore(isFill,length pieces)
        val splitTy= pieces
        fun createOps ([], _,_,_) = []
        | createOps (nSize::es, offset,arg, code)=
            [Dst.E_Cons(nSize, nSize,List.take(arg,nSize))]@
            createOps (es, offset + IntInf.fromInt nSize, List.drop(arg,nSize),code)
        val ops= createOps (pieces, 0,args, [])

    (*Low-IL operators to Tree-IL operators*)
    fun expandOp rator=(case rator
        of SrcOp.IAdd  =>    DstOp.IAdd
        | SrcOp.ISub  =>    DstOp.ISub
        | SrcOp.IMul  =>    DstOp.IMul
        | SrcOp.IDiv  =>    DstOp.IDiv
        | SrcOp.INeg  =>    DstOp.INeg
        | SrcOp.Abs ty =>    DstOp.Abs ty
        | SrcOp.LT ty =>    DstOp.LT  ty
        | SrcOp.LTE ty =>    DstOp.LTE  ty
        | SrcOp.EQ ty =>    DstOp.EQ  ty
        | SrcOp.NEQ ty =>    DstOp.NEQ  ty
        | SrcOp.GT ty =>    DstOp.GT  ty
        | SrcOp.GTE ty =>    DstOp.GTE  ty
        | SrcOp.Not =>    DstOp.Not
        | SrcOp.Max =>    DstOp.Max
        | SrcOp.Min =>    DstOp.Min
        | SrcOp.Clamp ty =>    DstOp.Clamp ty
        | SrcOp.Lerp ty =>    DstOp.Lerp  ty
        | SrcOp.Sqrt=>      DstOp.Sqrt
        | SrcOp.Cosine=>      DstOp.Cosine
        | SrcOp.ArcCosine=>      DstOp.ArcCosine
        | SrcOp.Sine=>      DstOp.Sine
        | SrcOp.ArcSine=>      DstOp.ArcSine
        | SrcOp.Tangent =>      DstOp.Tangent
        | SrcOp.ArcTangent =>      DstOp.ArcTangent
        | SrcOp.Exp =>      DstOp.Exp
        | SrcOp.Zero ty =>    DstOp.Zero  ty
        | SrcOp.PrincipleEvec ty =>    DstOp.PrincipleEvec ty
        | SrcOp.EigenVals2x2 =>    DstOp.EigenVals2x2
        | SrcOp.EigenVals3x3 =>    DstOp.EigenVals3x3
        | SrcOp.EigenVecs2x2 =>    DstOp.EigenVecs2x2
        | SrcOp.EigenVecs3x3 =>    DstOp.EigenVecs3x3
        | SrcOp.Select(ty as SrcTy.TupleTy tys, i)  =>    DstOp.Select( ty, i)
        | SrcOp.Index (ty, i ) =>    DstOp.Index ( ty, i)
        | SrcOp.Subscript ty =>    DstOp.Subscript  ty
        | SrcOp.Ceiling d =>    DstOp.Ceiling d
        | SrcOp.Floor d =>    DstOp.Floor d
        | SrcOp.Round d =>    DstOp.Round d
        | SrcOp.Trunc d =>    DstOp.Trunc d
        | SrcOp.IntToReal =>    DstOp.IntToReal
        | SrcOp.RealToInt d =>    DstOp.RealToInt d
        | SrcOp.Inside info =>    DstOp.Inside info
        | SrcOp.Translate V=>  DstOp.Translate V
        | SrcOp.addSca =>DstOp.addSca
        | SrcOp.subSca => DstOp.subSca
        | SrcOp.prodSca => DstOp.prodSca
        | SrcOp.divSca => DstOp.divSca
	    | SrcOp.powRat ty => raise Fail"PowRat"
	    | SrcOp.powInt => DstOp.powInt
          (*| SrcOp.Norm ty =>  (DstOp.Norm ty)*)
        | SrcOp.Normalize d =>(DstOp.Normalize d)
        | SrcOp.baseAddr V => DstOp.baseAddr V
        (*| SrcOp.Input e=>*)
	    | SrcOp.LoadImage args => DstOp.LoadImage args
	    | SrcOp.Input inp => DstOp.Input inp
	    | SrcOp.Print tys => DstOp.Print tys
        | rator => raise Fail ("bogus operator " ^ SrcOp.toString rator)
      (* end case *))


ViewVC Help
Powered by ViewVC 1.0.0