Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/tree-il/lowOp-to-treeOp.sml
ViewVC logotype

View of /branches/charisee/src/compiler/tree-il/lowOp-to-treeOp.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3263 - (download) (annotate)
Mon Sep 28 18:01:40 2015 UTC (3 years, 11 months ago) by cchiw
File size: 16925 byte(s)
project on i*j*k*l
(*This function transitions low-il operators to tree-il operators
    When there is a LowIL vector op then it breaks it into HW supported TreeIL vector operation
    i.e. A6+B6=> Mux[A4+B4, A2+B2]
    The following variables are used
    isAligned/A:bool-Is the array aligned
    isFill:bool-Is the vector filled with zeros i.e. length 3 vectors represented with length 4
    nSize: int -The Size of the vector operation, (4)
    oSize:int-The Size of the orginal arguments in vector operation if less than new Size.  (3)
    pieces: sizes of vector operations. i.e.2->[2], 6->[4,2]   3->[4]
*)

(* FIXME: add signature!!! *)
signature lott= sig
val  vecToTree: (LowOps.rator)  * LowIL.var-> TreeIL.stm list * TreeIL.stm list *LowIL.var *(LowIL.var*TreeIL.exp * TreeIL.var) list 
end


structure LowOpToTreeOp =
  struct
    local
      structure Src = LowIL
      structure SrcOp = LowOps
      structure SrcTy = LowILTypes
      structure SrcV = Src.Var
      structure DstOp = TreeOps
      structure DstTy = TreeILTypes
      structure Dst = TreeIL
      structure DstSV=Dst.StateVar
      structure DstV = Dst.Var
      structure Ty=TreeIL.Ty
    in

    val testing=0
    fun testp t=(case testing
        of 1=>(print(String.concat t);1)
         |_ =>1
    (*end case*))

  (**************************************)
    val newVar = Dst.Var.new
    val cnt = ref 0
    fun genName prefix = let
        val n = !cnt
        in
            cnt := n+1;
            String.concat[prefix, "_", Int.toString n]
    end
    fun iTos n =Int.toString n
    fun newLocalWithTy (name,n)= Dst.Var.new(genName("l_"^iTos n^name), Dst.VK_Local,Ty.TensorTy [n])

    (**************************************)
    (* isAlignedLoad: bool* ty->bool
    * Do we Load this array assuming it's aligned?
    * Decides if isAligned is true when creating E_load
    * Currently, it is only true when the vector argument can be represent in the HW
    * Fix here
    *)
    fun isAlignedLoad(isFill,Ty.TensorTy [_]) =(case isFill
        of true => false
        | _     => true
        (*end case*))
    | isAlignedLoad _ =false

    (* isAlignedStore: bool*int->bool
    * Do we Store this array assuming it's aligned?
    * Decides if isAligned is true when creating S_Store and E_Mux
    * Currently, it is false always 
    * Fix here
    *)
    fun isAlignedStore(isFill,_) =(case isFill
        of true => false
        | _     => true
    (*end case*))
    (*|isAlignedStore _ =false
*)
    (**************************************)
    (*mkStmt:TreeIL.Var* bool*int*ty*TreeIL.Exp list ->TreeIL.stmt
    * makes final TreeIL.Stmt
    * When TreeIL.Op is a sumVec is then all ops are added
    * If lhs is a global then S_Store is used
    * Otherwise uses S_Assign
    *)
    fun mkStmt(lhs,isFill,oSize,pieces,ops)= let
        val alignedStore= isAlignedStore(isFill,length pieces)
        val splitTy=DstTy.vectorLength pieces
        in  (case (DstV.kind lhs)
            of TreeIL.VK_Local
                =>  Dst.S_Assign([lhs], Dst.E_Mux(alignedStore,isFill,oSize,splitTy,ops))
            | _
                => Dst.S_StoreVec(Dst.E_Var lhs,0,alignedStore,isFill,oSize,Ty.TensorTy[oSize],splitTy, ops)
            (*end case*))
        end 


    (*getArg:->TreeIL.Exp
    *  Gets the argument for the operation 
    *  if the argument is a mux then we get the piece needed for the operation
    *  if the argument is a local var then we assume it is the right size
    *  if the argument is a global var or state var then we load the array with offset and sizes
    *  otherwise the arg was not loaded probably and should produce an error
    *)
    fun getArg(isFill,lhsKind,t,count, nSize, oSize,offset)=(case t
        of (Dst.E_Var a) =>(case (DstV.kind a,oSize)
            of (TreeIL.VK_Local,_) =>t
            | (_,_)  => Dst.E_LoadArr(isAlignedLoad(isFill,DstV.ty a) ,nSize,oSize,t, Dst.E_Lit(Literal.Int offset))
            (*end case*))
        | (Dst.E_State a)=>
                Dst.E_LoadArr(isAlignedLoad(isFill,DstSV.ty a) ,nSize,oSize,t, Dst.E_Lit(Literal.Int offset))
        | (Dst.E_Mux(_,_,_,_,ops))=> List.nth(ops,count)
        | a1 =>( print(String.concat["Warning argument to vector operation is: ",Dst.toString a1]);a1)
        (*end case*))

    fun Bbind(srct,t,lhs,n,expld)=
        if(2>(SrcV.useCount srct))
        then (expld,[],[],[]) (*don't bind*)
        else let
            val vs=newLocalWithTy((SrcV.name srct)^"ldarr",n)
            val expvar=Dst.E_Var vs
            val stmtn=[Dst.S_Assign([vs] ,expld)]
            val needtobind=[(srct,expvar,vs)]         (*needs to be added to env*)

            in
                (expvar,stmtn,[vs],needtobind)
            end

    fun getArg2(srct,dstt,lhs,isFill,count, nSize, oSize,offset)=let
        val offsetlit=Dst.E_Lit(Literal.Int offset)
            val _= ("\n pre-getArg2: Exp-"^Dst.toString dstt)
        fun ldArr vTy=  let
            val allignedld=isAlignedLoad(isFill,vTy)
            val expld=Dst.E_LoadArr(allignedld ,nSize,oSize,dstt, offsetlit)
            val (expvar,stmtn,vs,needtobind)=Bbind(srct,dstt,lhs,nSize,expld)
            in
                (expvar,stmtn,vs,needtobind)
            end
        fun sort e =(e,[],[],[])

        val y=(case dstt
            of  (Dst.E_Var a) =>( "\n*getArg E_Var"; case (DstV.kind a,oSize)
                of (TreeIL.VK_Local,_) => sort dstt
                | (TreeIL.VK_Input ,_)  =>
                    sort (Dst.E_LoadArr(isAlignedLoad(isFill,DstV.ty a) ,nSize,oSize,dstt,offsetlit))
                | (_,_)  => (*ldArr(DstV.ty a)-problem with illust-vr*)
                    sort (Dst.E_LoadArr(isAlignedLoad(isFill,DstV.ty a) ,nSize,oSize,dstt,offsetlit))
                (*end case*))
            | (Dst.E_State a)=>(*ldArr(DstSV.ty a)*)
                sort (Dst.E_LoadArr(isAlignedLoad(isFill,DstSV.ty a) ,nSize,oSize,dstt,offsetlit))
            | (Dst.E_Mux(_,_,_,_,ops))=> sort(List.nth(ops,count))
            | a1 =>( print(String.concat["Warning argument to vector operation is: ",Dst.toString a1]); sort a1)
            (*end case*))

        in
            y
        end 


    fun sortArg([(srca,dsta)],x) =
        let
            val _="\n sortArg1"
            val (lhs,isFill,count, nSize, oSize, offset)=x
            val (exp0,stmt0,v0,bind0)  = getArg2(srca,dsta,lhs,isFill,count, nSize, oSize, offset)
        in
            ([exp0], stmt0, v0,bind0)
        end
      | sortArg([(srca,dsta),(srcb,dstb)],x) =
        let
            val _= "\n sortArg2"
            val (lhs,isFill,count, nSize, oSize, offset)=x
            val (exp0,stmt0,v0,bind0)  = getArg2(srca,dsta,lhs,isFill,count, nSize, oSize, offset)
            val (exp1,stmt1,v1,bind1)  = getArg2(srcb,dstb,lhs,isFill,count, nSize, oSize, offset)
        in
                ([exp0,exp1], stmt0@stmt1, v0@v1,bind0@bind1)
        end

    

 (**************************************)

    (*Special functions needed to transform specific operators*)
    (*
    * ProjectLast is an op that slices a vector from a higher order tensor
    * That vec could be in pieces
    * So we load each one and the next step wraps a mux around it
    * We can assume the higher order tensor is an array
    * This rewrites the ProjectLast op as an E_LoadArr.
    *Otherwise hand off to above function
    *)
    fun transformProjectLast(origrator,argVec,x) = let
        val (lhs,oSize,_,(isFill,_,pieces))=x
        val SrcOp.ProjectLast(_,_,indTy,argTy) = origrator
        val alignedLd=isAlignedLoad(isFill,argTy)
        (*shift based on type of argument*)
        val shift=(case (indTy,argTy)
            of (Ty.indexTy[i],Ty.TensorTy[_,m])=> i*m
            | (Ty.indexTy[i,j],Ty.TensorTy[_,m,n])=> (i*n*m)+(j*n)
            | (Ty.indexTy[i,j,k],Ty.TensorTy[_,m,n,p])=> (i*m*n*p)+(j*p*n)+(k*p)
            |  _ =>raise Fail "ProjectLast Tensor of a unhandled size" (*add this later*)
            (*end case *))
        fun mkLoad ([], _, code) = code
          | mkLoad (nSize::es, offset, code)=
                mkLoad (es, offset + IntInf.fromInt nSize,
                code@[Dst.E_LoadArr(alignedLd, nSize,oSize,argVec, Dst.E_Lit(Literal.Int offset))])
        val ops=  mkLoad (pieces, IntInf.fromInt shift, [])
        val stmt0= mkStmt(lhs,isFill,oSize,pieces,ops)
        val stmtn=[]
        val needtobind=[]
        in
            (stmt0,stmtn,[],needtobind)
        end

    fun transformProjectFirst(origrator,argVec,x)=  let
  
        val (lhs,oSize,_,(isFill,_,pieces))=x
        val SrcOp.ProjectFirst(_,vecIX,Ty.indexTy[i],Ty.TensorTy[argTy,argTyX])= origrator
        fun f cnt = Dst.E_Op(DstOp.IndexTensor(false,Ty.indexTy [cnt,i],Ty.TensorTy[argTy,argTyX]),[argVec])
        val indops=List.tabulate(vecIX, fn e=> f e)
        val ops=[Dst.E_Cons(oSize, oSize,indops)]
        val stmt0= mkStmt(lhs,isFill,oSize,pieces,ops)
        val stmtn=[]
        val needtobind=[]
        in
            (stmt0,stmtn,[],needtobind)
        end

    fun transformSumVec(argVec,x)= let
        val (lhs,oSize,_,(isFill,_,pieces))=x

        val lhsKind=DstV.kind lhs
        val ops=(case pieces
            of [nSize]  => let
                val op2=DstOp.sumVec([nSize],oSize)
                val arg2=getArg(isFill,lhsKind,argVec,0, nSize, oSize, 0)
                in
                    [Dst.E_Op(op2, [arg2])]
                end
            | _ => let
                (*createOps:int list*int*int*DstIL.exp->DstIL.exp list
                *  Gets all the Arguments in order
                *   i.e. nsize=[4,2]=>  [[A4,B4],[A2,B2]]
                *)
                fun createOps ([], _,_, code) = code
                | createOps (nSize::es, count,offset, code)=let
                    val argsLd= getArg(isFill,lhsKind,argVec,count, nSize, oSize, offset)
                    val exp = (nSize,[argsLd])
                    in
                        createOps (es, count+1,offset + IntInf.fromInt nSize, code@[exp])
                    end
                val indexAt=IntInf.fromInt 0
                val code=createOps (pieces, 0,indexAt, [])
                val args=List.foldr op@ [] (List.map (fn(_,args)=>args) code)
                in
                    [Dst.E_Op(DstOp.sumVec(pieces,oSize),args)]
                end
            (*end case*))
        fun addSca [e1]=e1
         | addSca(e1::e2::es)= addSca([Dst.E_Op(DstOp.addSca,[e1,e2])]@es)

        val stmt0=  Dst.S_Assign([lhs], addSca ops)
        val stmtn=[]
        val needtobind=[]
        in
            (stmt0,stmtn,[],needtobind)
        end


    (* mkGenericOps:
    TreeIL.Var*TreeILExp list *TreeILExp list * bool * int * int list *LowILOP ->TreeIL.stmt
    Take the original vector op and breaks it "pieces" which are HW supported sizes
    gets arguments with getArg().
    and then puts them inside TreeIL op ("rator")
    special case sumVecOp, because it maintains new and old size.
    *)

    fun transformGenVec(argsV,argsS,x) =let

        val (lhs,oSize,SOME dstrator,(isFill,_,pieces))=x

        val _= "\n pre-transformGenVec"
        fun createOps ([], _,_, code,stmtn,vs,needtobind) = (code,stmtn,vs,needtobind)
          | createOps (nSize::es, count,offset, code,stmtn,vs,needtobind)=
            let
                val x1 =(lhs,isFill,count, nSize, oSize, offset)
                val (exp0, stmtn0,v0,needtobind0)=   sortArg(argsV,x1)
                val code0 = (nSize, argsS@exp0)
            in
                createOps (es, count+1,offset + IntInf.fromInt nSize, code@[code0],stmtn@stmtn0,vs@v0,needtobind@needtobind0)
            end
        val indexAt=IntInf.fromInt 0
        val (code,stmtn,vs,needtobind)=createOps (pieces, 0,indexAt, [],[],[],[])
         val _= "\n post-transformGenVec"
        val ops= List.map (fn(nSize,args)=> Dst.E_Op(dstrator nSize,args)) code
        val stmt0= mkStmt(lhs,isFill,oSize,pieces,ops)

        in
            (stmt0,stmtn,vs,needtobind)
        end

    (**************************************)
    fun vecToTree(e1,x) =(case e1
        of (SOME(SrcOp.ProjectLast e),[(_,argVec)])              => transformProjectLast(SrcOp.ProjectLast e,argVec,x)
        |  (SOME(SrcOp.ProjectFirst e),[(_,argVec)])             => transformProjectFirst(SrcOp.ProjectFirst e,argVec,x)
        |  (SOME(SrcOp.sumVec _),[(_,argVec)])                   => transformSumVec(argVec,x)
        |  (SOME(SrcOp.prodScaV _),(_,argSca)::argVecs)          => transformGenVec(argVecs, [argSca],x)
        |  (SOME(SrcOp.Lerp(Ty.TensorTy[_])),[a,b,(_,argSca)])   => transformGenVec([a,b],[argSca],x)
        |  (NONE,argVecs)                                        => transformGenVec(argVecs,[],x)
        (*end case *))

    (* consVecToTree:int*int*int list*TreeIL.Exp* bool->TreeIL.Exp
     * Takes Cons of a vector and returns TreeIL.exp inside E_Mux.
     *  When isFill is true, creates zeros
     *)
    fun consVecToTree(_,oSize,[nSize],args,true)= let
        val nArg=length(args)
        val n=nSize-nArg
        val newArgs=List.tabulate(n, (fn _=>Dst.E_Lit(Literal.Int 0)))
        val op1=Dst.E_Cons(nSize, oSize,args@newArgs)
        val aligned= isAlignedStore(true,1)
        val splitTy=DstTy.vectorLength [nSize]
        in
            Dst.E_Mux(aligned,true,oSize,splitTy,[op1])
        end
    | consVecToTree(_,_,_,_,true)= raise Fail"In ConsVecToTree-isFill with more than 1 piece"
    | consVecToTree(nSize,oSize,pieces,args,isFill)= let
        val aligned= isAlignedStore(isFill,length pieces)
        val splitTy=DstTy.vectorLength pieces
        fun createOps ([], _,_,_) = []
        | createOps (nSize::es, offset,arg, code)=
            [Dst.E_Cons(nSize, nSize,List.take(arg,nSize))]@
            createOps (es, offset + IntInf.fromInt nSize, List.drop(arg,nSize),code)
        val ops= createOps (pieces, 0,args, [])
        in
            Dst.E_Mux(aligned,isFill,oSize,splitTy,ops)
        end
(***************************************************************************)




    (*Low-IL operators to Tree-IL operators*)
    fun expandOp rator=(case rator
        of SrcOp.IAdd  =>    DstOp.IAdd
        | SrcOp.ISub  =>    DstOp.ISub
        | SrcOp.IMul  =>    DstOp.IMul
        | SrcOp.IDiv  =>    DstOp.IDiv
        | SrcOp.INeg  =>    DstOp.INeg
        | SrcOp.Abs ty =>    DstOp.Abs ty
        | SrcOp.LT ty =>    DstOp.LT  ty
        | SrcOp.LTE ty =>    DstOp.LTE  ty
        | SrcOp.EQ ty =>    DstOp.EQ  ty
        | SrcOp.NEQ ty =>    DstOp.NEQ  ty
        | SrcOp.GT ty =>    DstOp.GT  ty
        | SrcOp.GTE ty =>    DstOp.GTE  ty
        | SrcOp.Not =>    DstOp.Not
        | SrcOp.Max =>    DstOp.Max
        | SrcOp.Min =>    DstOp.Min
        | SrcOp.Clamp ty =>    DstOp.Clamp ty
        | SrcOp.Lerp ty =>    DstOp.Lerp  ty
        | SrcOp.Sqrt=>      DstOp.Sqrt
        | SrcOp.Cosine=>      DstOp.Cosine
        | SrcOp.ArcCosine=>      DstOp.ArcCosine
        | SrcOp.Sine=>      DstOp.Sine
        | SrcOp.Zero ty =>    DstOp.Zero  ty
        | SrcOp.PrincipleEvec ty =>    DstOp.PrincipleEvec ty
        | SrcOp.EigenVals2x2 =>    DstOp.EigenVals2x2
        | SrcOp.EigenVals3x3 =>    DstOp.EigenVals3x3
        | SrcOp.EigenVecs2x2 =>    DstOp.EigenVecs2x2
        | SrcOp.EigenVecs3x3 =>    DstOp.EigenVecs3x3
        | SrcOp.Select(ty as SrcTy.TupleTy tys, i)  =>    DstOp.Select( ty, i)
        | SrcOp.Index (ty, i ) =>    DstOp.Index ( ty, i)
        | SrcOp.Subscript ty =>    DstOp.Subscript  ty
        | SrcOp.Ceiling d =>    DstOp.Ceiling d
        | SrcOp.Floor d =>    DstOp.Floor d
        | SrcOp.Round d =>    DstOp.Round d
        | SrcOp.Trunc d =>    DstOp.Trunc d
        | SrcOp.IntToReal =>    DstOp.IntToReal
	| SrcOp.R r => DstOp.R r 
        | SrcOp.RealToInt d =>    DstOp.RealToInt d
        (*| SrcOp.VoxelAddress( info, offset)  => expandVoxelAddress  (y, info, offset, args')
        | SrcOp.LoadVoxels (rty, d ) =>    DstOp.LoadVoxels( rty, d)*)
        (*Maybe should delete?*)
        (*  | SrcOp.LoadImage info =>    DstOp.LoadImage info*)
        | SrcOp.Inside info =>    DstOp.Inside info
        | SrcOp.Translate V=>  DstOp.Translate V
        | SrcOp.addSca =>DstOp.addSca
        | SrcOp.subSca => DstOp.subSca
        | SrcOp.prodSca => DstOp.prodSca
        | SrcOp.divSca => DstOp.divSca
	| SrcOp.powRat ty => raise Fail"PowRat"
	| SrcOp.powInt => raise Fail"PowInt"
        (*| SrcOp.Norm ty =>  (DstOp.Norm ty)*)
        | SrcOp.Normalize d =>(DstOp.Normalize d)
        | SrcOp.imgAddr(v,indexAt, dim)=>DstOp.imgAddr(v, indexAt, dim)
        | SrcOp.baseAddr V => DstOp.baseAddr V
        (*EigenVecs,mkDynamic, Append,Prepend, Concat,Length,ImageAddress,LoadVoxel,Inputs, and Pritns*)
        (*| SrcOp.Input e=>*)
	| SrcOp.LoadImage args => DstOp.LoadImage args
	| SrcOp.Input inp => DstOp.Input inp
	| SrcOp.Print tys => DstOp.Print tys
        | rator => raise Fail ("bogus operator " ^ SrcOp.toString rator)
      (* end case *))
    end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0