Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/tree-il/low-to-tree-fn.sml
ViewVC logotype

View of /branches/charisee/src/compiler/tree-il/low-to-tree-fn.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3544 - (download) (annotate)
Tue Jan 5 00:01:44 2016 UTC (3 years, 6 months ago) by cchiw
File size: 45542 byte(s)
code cleanup
(* low-to-tree-fn.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *
 * This module translates the LowIL representation of a program (i.e., a pure CFG) to
 * a block-structured AST with nested expressions.
 *
 * NOTE: this translation is pretty dumb about variable coalescing (i.e., it doesn't do any).
 *)

functor LowToTreeFn (Target : sig

    val supportsPrinting : unit -> bool (* does the target support the Print op? *)

  (* tests for whether various expression forms can appear inline *)
    val inlineCons : int -> bool        (* can n'th-order tensor construction appear inline *)
    val inlineMatrixExp : bool          (* can matrix-valued expressions appear inline? *)
(* FIXME: isHwVec, isVecTy, and getPieces do not appear to be used *)
    val isHwVec   : int -> bool
    val isVecTy   : int -> bool
    val getPieces : int -> int list
(* FIXME: what does this function do?  what are its results? *)
    val getVecTy  : int -> bool * int * int list

  end) : sig

    val translate : LowIL.program -> TreeIL.program

  end = struct

    structure Src = LowIL
    structure SrcOp = LowOps
    structure SrcV = LowIL.Var
    structure SrcSV = LowIL.StateVar
    structure VA = VarAnalysis
    structure InP = Inputs
    structure Ty = LowILTypes
    structure Nd = LowIL.Node
    structure CFG = LowIL.CFG
    structure LowOpToTreeOp = LowOpToTreeOp
    structure Dst = TreeIL
    structure DstOp = TreeOps
    structure DstSV = Dst.StateVar
    structure SrcTy = LowILTypes
    structure DstTy = TreeILTypes
    structure DstV = Dst.Var
    structure TreeToOpr = TreeToOpr
    structure Fnc = TreeFunc
    structure TySet = Fnc.TySet
    structure OprSet = Fnc.OprSet
    structure LowToS = LowToString

  (* create new tree IL variables *)
    local
      val newVar = Dst.Var.new
      val cnt = ref 0
      fun genName prefix = let
            val n = !cnt
            in
              cnt := n+1;
              String.concat[prefix, "_", Int.toString n]
            end
    in
    val testing=false
    fun testp n =if (testing) then (print(String.concat n);1) else 1
    fun iTos n =Int.toString n
    fun newGlobal x = newVar (genName("G_" ^ SrcV.name x), Dst.VK_Global, SrcV.ty x)
    fun newParam x = newVar (genName("p_" ^ SrcV.name x), Dst.VK_Local, SrcV.ty x)
    fun newLocal x = newVar (genName("l_" ^ SrcV.name x), Dst.VK_Local, SrcV.ty x)
    fun newIter x = newVar (genName("i_" ^ SrcV.name x), Dst.VK_Local, SrcV.ty x)
    fun newTmp (x,n) = newVar (genName("l_" ^  iTos n^SrcV.name x), Dst.VK_Local, SrcV.ty x)
    fun newLocalPtrTy (name,ty)= newVar(genName("l_rp_"^name), Dst.VK_Local,ty)
    fun newLocalWithTy (name,n)= newVar(genName("l_"^iTos n^name), Dst.VK_Local,Ty.TensorTy [n])
    fun newGlobalWithTy (name,n)= newVar(genName("G_"^iTos n^name), Dst.VK_Global,Ty.TensorTy [n])
 
    end

  (* associate Tree IL state variables with Low IL variables using properties *)
    local
      fun mkStateVar x = Dst.SV{
              name = SrcSV.name x,
              id = Stamp.new(),
              ty = SrcSV.ty x,
              varying = VA.isVarying x,
              output = SrcSV.isOutput x
            }
    in
    val {getFn = getStateVar, ...} = SrcSV.newProp mkStateVar
    end

    fun mkBlock stms = Dst.Block{locals=[], body=stms}
    fun mkIf (x, stms, []) = Dst.S_IfThen(x, mkBlock stms)
      | mkIf (x, stms1, stms2) = Dst.S_IfThenElse(x, mkBlock stms1, mkBlock stms2)

  (* an environment that tracks bindings of variables to target expressions and the list
   * of locals that have been defined.
   *)
    local
      structure VT = SrcV.Tbl
      fun decCount ( Src.V{useCnt, ...}) = let
            val n = !useCnt - 1
            in
              useCnt := n;  (n <= 0)
            end

    datatype target_binding
      = GLOB of Dst.var         (* variable is global *)
      | TREE of Dst.exp         (* variable bound to target expression tree *)
      | DEF of Dst.exp          (* either a target variable or constant for a defined variable *)

    structure ListSetOfInts = ListSetFn (struct
        type ord_key = int
        val compare = Int.compare
        end)
    
    (*changed env to add sets Tys and Oprs*)
      datatype env = E of {
          tbl : target_binding VT.hash_table,
(* FIXME: perhaps the types and functs should be set refs, since they are global? *)
          types : TySet.set,
	  functs : OprSet.set,
          locals : Dst.var list
        }
        
    in

      fun peelEnv (E{tbl, types, functs, locals}) = (types,functs)
      fun peelEnvLoc (E{tbl, types, functs, locals}) = locals
      fun setEnv (E{tbl, types,functs,locals}, types1, functs1) =
	    E{tbl=tbl, types=types1, functs= functs1 ,locals=locals}
        
     (*addOprFromExp: env* TreeIL.Exp-> env
      * get new opr and type set and store it into the environment
      *)
      fun addOprFromExp(env, exp)=let
	    val t1 = peelEnv env
	    val (ty2, opr2) = TreeToOpr.expToOpr (t1,exp)
	    in
	      setEnv(env, ty2,opr2)
	    end

    (*addOprFromStmt: env* TreeIL.Stmt-> env
    * get new opr and type set and store it into the environment
    *)
    fun addOprFromStmt (env, stms)=let
	  val t1=peelEnv(env)
	  val (ty2,opr2)=  TreeToOpr.stmtsToOpr ( t1 ,stms)
	  in
	    setEnv(env, ty2,opr2)
	  end

(* DEBUG *)
    fun bindToString binding = (case binding
	   of GLOB y => "GLOB " ^ Dst.Var.name y
	    | TREE e => "TREE"
	    | DEF(Dst.E_Var y) => "DEFVar " ^ Dst.Var.name y
	    | DEF e => "DEF"^Dst.toString  e 
	  (* end case *))

    fun dumpEnv (E{tbl, ...}) = let
	  fun prEntry (x, binding) =
		testp["  ", Src.Var.toString x, " --> ", bindToString binding, "\n"]
	  in
	   (* print "\n *** dump environment\n";
	    VT.appi prEntry tbl;
	    print "***\n"*) print ""
	  end
(* DEBUG *)

    fun newEnv () = E{tbl = VT.mkTable (512, Fail "tbl"),  types=TySet.empty, functs=OprSet.empty, locals=[]}

  (* use a variable.  If it is a pending expression, we remove it from the table *)
    fun peek (env as E{tbl, ...}) x =  (case (VT.find tbl x)
	    of NONE=>"none"
	    | SOME e=> bindToString e
	(*end case *))
    fun useVar (env as E{tbl, ...}) x = (case VT.find tbl x
        of SOME(GLOB x') => Dst.E_Var x'
        | SOME(TREE e) => (
            ignore(VT.remove tbl x);
            e)
        | SOME(DEF e) => (
            (* if this is the last use of x, then remove it from the table *)
            (*if (decCount x) then ignore(VT.remove tbl x) else ();*)
            e)
        | NONE => (dumpEnv env;raise Fail(concat ["useVar(", SrcV.toString x, ")"]))
          (* end case *))

    (* record a local variable *)
    fun addLocal (E{tbl, types,functs,locals}, x) = E{tbl=tbl,types=types, functs=functs,locals=x::locals}
    fun addLocals (E{tbl, types,functs,locals}, x) =
                    (E{tbl=tbl,types=types, functs=functs,locals=x@locals})
    fun global (E{tbl, ...}, x, x') = (VT.insert tbl (x, GLOB x'))

  (* insert a pending expression into the table.  Note that x should only be used once! *)
    fun insert (env as E{tbl, ...}, x, exp) = (
          VT.insert tbl (x, TREE exp);
          env)
    fun rename (env as E{tbl, ...}, x, x') = (
          VT.insert tbl (x, DEF(Dst.E_Var x'));
          env)
    fun renameGlob (env as E{tbl, ...}, x, x') = (
        VT.insert tbl (x, GLOB( x'));
        env)
    fun renameExp (env as E{tbl, ...}, x, x') = (
                  VT.insert tbl (x, DEF( x'));
                  env)
    fun peekGlobal (E{tbl, ...}, x) = (case VT.find tbl x
        of SOME(GLOB x') => SOME x'
        | SOME e         => NONE
        | NONE           => NONE
    (* end case *))
                  
    (*bindLocal: env*SrcV*Dst.Exp-> env*Dst.S list
    * if lhs variable is used once then it is inserted
    * else if lhs is used multiple times then new local variables are created
    * when exp is a mux more ops are used
    *)
    fun bindLocal (env, lhs, rhs) =
        (case (SrcV.useCount lhs,rhs)
            of (0,_) => (env,[])
            | (1,_)  =>   (insert(addOprFromExp(env,rhs), lhs, rhs), [])
            | (_,Dst.E_Mux(A,isFill, nOrig, tys,exps))=> let
                val name=SrcV.name lhs
                val vs=List.map (fn n=>  newLocalWithTy(name,n) ) tys (*vector length*)
                val rhs=Dst.E_Mux(A, isFill,nOrig,tys,List.map (fn v=>Dst.E_Var v) vs)
                val stmts=ListPair.map  (fn(x,e)=>Dst.S_Assign([x],e)) (vs,exps)
                in
                    (renameExp(addLocals(env,vs),lhs,rhs),stmts)
                end
            |(_,_)=> let
                val t = newLocal lhs
                in
                  (rename(addLocal(env, t), lhs, t), [Dst.S_Assign([t], rhs)])
                end
            (*end case*))
            
    fun bind (env, lhs, rhs) =(case peekGlobal (env, lhs)
        of SOME x =>((env, [Dst.S_Assign([x], rhs)]))
        | NONE =>  (bindLocal (env, lhs, rhs))
    (* end case *))

  (* set the definition of a variable, where the RHS is either a literal constant or a variable *)
    fun bindSimple (env as E{tbl, ...}, lhs, rhs) =(case peekGlobal (env, lhs)
        of SOME x => (env, [Dst.S_Assign([x], rhs)])
        | NONE => (VT.insert tbl (lhs, DEF rhs); (env, []))
    (* end case *))

  (* at the end of a block, we need to assign any pending expressions to locals.  The
   * blkStms list and the resulting statement list are in reverse order.
   *)
    fun flushPending (E{tbl,types, functs,locals}, blkStms) = let
          fun doVar (x, TREE e, (locals, stms)) = let
                val t = newLocal x
                in
                  VT.insert tbl (x, DEF(Dst.E_Var t));
                  (t::locals, Dst.S_Assign([t], e)::stms)
                end
            | doVar (_, _, acc) = acc
          val (locals, stms) = VT.foldi doVar (locals, blkStms) tbl
          in
            (E{tbl=tbl, types=types,functs=functs,locals=locals}, stms)
          end

    fun doPhi ((lhs, rhs), (env, predBlks : Dst.stm list list)) = let
        (* t will be the variable in the continuation of the JOIN *)
          val t = newLocal lhs
          val predBlks = ListPair.map
                (fn (x, stms) => Dst.S_Assign([t], useVar env x)::stms)
                  (rhs, predBlks)
          in
            (rename (addLocal(env, t), lhs, t), predBlks)
          end
    fun endScope (env, stms) = let
        val env'=addOprFromStmt(env,  stms)
        val (types,opr)=peelEnv(env')
        in    Dst.BlockWithOpr{
              locals= List.rev(peelEnvLoc env),
              types= types,
              opr=opr,
              body = stms
              }
        end 
    end

  (* Certain IL operators cannot be compiled to inline expressions.  Return
   * false for those and true for all others.
   *)
                
    fun isInlineOp rator = let
          fun chkTensorTy (Ty.TensorTy[]) = true
            | chkTensorTy (Ty.TensorTy[_]) = true
            | chkTensorTy (Ty.TensorTy[_, _]) = Target.inlineMatrixExp
            | chkTensorTy _ = false
        
          in
           case rator
             of SrcOp.EigenVecs2x2 => false
              |  SrcOp.EigenVecs3x3 => false
              |  SrcOp.EigenVals2x2 => false
              |  SrcOp.EigenVals3x3 => false
              (* | SrcOp.Zero _ => Target.inlineMatrixExp*)
             | _ => true (*when true calls binding *)
            (* end case *)
          end
    fun isInlineCons ty = (*(case ty
           of Ty.SeqTy(Ty.IntTy, _) => true
            | Ty.TensorTy dd => Target.inlineCons(List.length dd)
            | Ty.SeqTy _ => false
           (* | Ty.DynSeqTy ty => false*)
            | _ => raise Fail(concat["invalid CONS<", Ty.toString ty, ">"])
          (* end case *))*) false

  (* translate a LowIL assignment to a list of zero or more target statements in reverse
   * order.
   *)

  (* translate input-variable initialization to a TreeIL expression *)
    fun trInitialization (InP.String s) = ([], Dst.E_Lit(Literal.String s))
      | trInitialization (InP.Int n) = ([], Dst.E_Lit(Literal.Int n))
      | trInitialization (InP.Real f) = ([], Dst.E_Lit(Literal.Float f))
      | trInitialization (InP.Bool b) = ([], Dst.E_Lit(Literal.Bool b))
      | trInitialization (InP.Tensor([d], vs)) = let
            (* make a literal expression for i'th initializer *)
          fun mk i = Dst.E_Lit(Literal.Float(Vector.sub(vs, i)))
          (* get representation of the the vector type *)
          val (isPadded, wid, pieces) = Target.getVecTy d
          val exp = LowOpToTreeOp.consVecToTree(wid, d, pieces, List.tabulate(Vector.length vs, mk), isPadded)
          in
            ([], exp)
        end
      | trInitialization (InP.Tensor _) = raise Fail "trInitialization: Tensor"
      | trInitialization (InP.Seq vs) = raise Fail "trInitialization: Seq"
      | trInitialization _ = raise Fail "trInitialization: impossible"

  (* translate a LowIL assignment to a list of zero or more target statements in reverse
   * order.
   *)
    fun doAssign (env, (lhs, rhs)) = let
          fun doLHS () = (case peekGlobal(env, lhs)
		 of SOME lhs' => (env, lhs')
		  | NONE => let
		      val t = newLocal lhs
		      in
			(rename (addLocal(env, t), lhs, t), t)
		      end
		(* end case *))
        (* for expressions that are going to be compiled to a call statement *)
          fun assignExp (env, exp) = let
              (* operations that return matrices may not be supported inline *)
                val (env, t) = doLHS()
                in
                  (env, [Dst.S_Assign([t], exp)])
                end
        (* force an argument to be stored in something that will be mapped to an l-value *)
          fun bindVar (env, x) = (case useVar env x
                of x' as Dst.E_State _  =>(env, x', [])
                | x' as Dst.E_Var _     => (env, x', [])
                | e                     => let
                      val x' = newLocal x
                      in
                        (addLocal(env, x'), Dst.E_Var x', [Dst.S_Assign([x'], e)])
                      end
                (* end case *))
          in
            case rhs
                of Src.STATE x  =>let
                    (* Hmm, what to do with State nodes.
                    * They get represented like globals
                    * but their other operations register their kind as local
                    * Leads to trouble in their representation
                    *  Fix me, once we change representation of state and local vars
                    * currently we load 1 piece when it is a local var.
                    *  fun iter([],_)=[]
                    * | iter(e1::es,counter)=[Dst.E_ (isFill ,e1,oSize, t , Dst.E_Lit(Literal.Int counter))]@
                    * | iter(es,counter+ IntInf.fromInt e1)
                    * val ops=iter(pieces,0)
                    *)
                    val (env, vt) = doLHS()
                    val t=Dst.E_State(getStateVar x)
                    val exp=(case (DstV.kind vt, DstV.ty vt)
                        of (Dst.VK_Local,DstTy.TensorTy [6])=>let
                            val (isFill,nSize,pieces)=Target.getVecTy 6
                            val op1= Dst.E_LoadArr(false ,4,6, t , Dst.E_Lit(Literal.Int 0))
                            val op2= Dst.E_LoadArr(false ,2,6, t , Dst.E_Lit(Literal.Int 4))
                            val splitTy=pieces
                            in    Dst.E_Mux(false,false,6,splitTy,[op1,op2])
                            end
                        | (Dst.VK_Local,DstTy.TensorTy [oSize])=>let
                                val (isFill,nSize,pieces)=Target.getVecTy oSize
                                val op1= Dst.E_LoadArr(false ,nSize,oSize, t , Dst.E_Lit(Literal.Int 0))
                                  val splitTy= pieces
                                in    Dst.E_Mux(false,isFill,oSize,splitTy,[op1])
                                end
                        | _ => t
                        (*end case *))
                    val (env,stmt)= bindSimple (env, lhs,exp)
                    (*adds loadArr function to env*)
                    val env'=addOprFromExp(env, exp)
                    in (env',stmt)
                    end
                | Src.VAR x     => bindSimple (env, lhs, useVar env x)
                | Src.LIT lit   => bindSimple (env, lhs, Dst.E_Lit lit)
                | Src.OP(SrcOp.Kernel _, _)         => (env, [])
                | Src.OP(SrcOp.LoadImage(ty, nrrd, info), []) => let
		    val (env, t) = doLHS()
		    in
		      (env, [Dst.S_LoadNrrd(t, ty, nrrd)])
		    end
		| Src.OP(SrcOp.Input(InP.INP{ty=Ty.ImageTy _, name, desc, init}), []) => let
		    val (env, t) = doLHS()
		    in
		      case init
		       of SOME(InP.Proxy(nrrd, _)) => (env, [Dst.S_InputNrrd(t, name, desc, SOME nrrd)])
			| SOME(InP.Image _) => (env, [Dst.S_InputNrrd(t, name, desc, NONE)])
			| _ => raise Fail "bogus initialization for image"
		      (* end case *)
		    end
		| Src.OP(SrcOp.Input(InP.INP{ty, name, desc, init=NONE}), []) => let
		    val (env, t) = doLHS()
		    in
		      (env, [Dst.S_Input(t, name, desc, NONE)])
		    end
		| Src.OP(SrcOp.Input(InP.INP{ty, name, desc, init=SOME init}), []) => let
		    val (env, t) = doLHS()
		    val (stms, exp) = trInitialization init
		    in
		      (env, stms@[Dst.S_Input(t, name, desc, SOME exp)])
		    end
                | Src.OP(SrcOp.Inside(info, s), args) =>    let
                    val [a,b]=List.map (useVar env) args
                    val size=(case (ImageInfo.dim info)
                        of 1 => raise Fail"Inside of 1-D dimension"
                         | 2 => 2
                         | 3 => 4
                        (*end case*))
                    (*separated the position to make it look cleaner*)
                    val k=newLocalWithTy ("Pos_"^SrcV.name lhs ,size)
                    val s1= Dst.S_Assign([k],a)
                    val rhs=Dst.E_Op(DstOp.Inside(info,s),[Dst.E_Var k,b])
                    val (env,s)=bind(addLocal(env,k),lhs,rhs)
                    in
                        (env,[s1]@s)
                    end
                | Src.OP(SrcOp.Translate v, [a])=> let
                    (*Result is a vector so we have to use Mux*)
                    val dim = ImageInfo.dim v
                    val splitTy= [dim]
                    val op1= Dst.E_Op(DstOp.Translate v,[(useVar env a)])
                    val exp= (case dim
                        of 1=> op1
                        | 2 =>  Dst.E_Mux(true,false,dim,splitTy,[op1])
                        | 3=>   Dst.E_Mux(false,false,dim,splitTy,[op1])
                        (*end case*))
                    in
                        bind(env,lhs,exp)
                    end
                | Src.OP(SrcOp.Transform v,args) => let
                    (*Result is an array so we have to use Store*)
                    val (env2, t) = doLHS()
                    val V=Dst.E_Var t
                    val dim = ImageInfo.dim v
                    val ty=DstTy.TensorTy [dim,dim]
                    val args'=List.map (useVar env) args
                    val a=List.tabulate(dim,(fn n=> Dst.E_Op(DstOp.Transform(v,n),args')))
                    val (env2,stmt)= (case dim
                        of 1=>   bind(env2,lhs,Dst.E_Op(DstOp.Transform (v,1),args'))
                        | 2 =>(env2,[Dst.S_StoreVec(V,0,true,false,dim,ty,[2,2],a)])
                        | 3 =>(env2,[Dst.S_StoreVec(V,0,false,true,dim,ty, [4,4,4] ,a)])
                        (*end case*))
                    in
                        (env2,stmt)
                    end
                | Src.OP(SrcOp.imgLoad(info,dim,oSize),[a])=>let
                    (*create ptr variable vp and index it with stride*)
                    val vp=  newLocalPtrTy(SrcV.name lhs,DstTy.AddrTy info)
                    val stmt=Dst.S_Assign([vp], useVar env a)
                    val stride = ImageInfo.stride info
                    val IndexArgs= List.tabulate(oSize,
                        fn n=> Dst.E_Op(DstOp.IndexTensor(false, [n*stride],Ty.TensorTy [oSize]),[Dst.E_Var vp]))
                    (*create cons expressions*)
                    val (isFill,nSize,pieces)=Target.getVecTy oSize
                    val exp=LowOpToTreeOp.consVecToTree(nSize,oSize,pieces,IndexArgs,isFill)
                    (*increase use count so it is easier to read c-file*)
                    (*
                    fun incUse (Src.V{useCnt, ...}) = (useCnt := !useCnt + 1)
                    val _ = incUse lhs*)
                    val (env,stmt2)=bind (addLocal(env, vp), lhs,exp)
                    in
                        (env, List.rev(stmt::stmt2))
                    end
                | Src.OP(SrcOp.IndexTensor e,[a])=> let
                    (*IndexTensor operation is int*ty*ty
                    * The first ty is the list of indexed position and second ty is the type of the argument
                    * When the rhs is a mux(matrix.. or larger arg) then we look for the right argument to index
                    * Otherwise we just pass the variable kind to the tree-il op and let c-util decide how/where to index
                    * The kind of variable decides if there is cast in the c-code
                    *)
                    val a'=(useVar env a)
                    val exp=(case ((SrcOp.IndexTensor e),a')
                        of (SrcOp.IndexTensor(_ , [i],_),Dst.E_Mux(_,_,_, pieces,ops))=>
                            let
                                fun findLocal(c,i,indexAt,v::vs,a1::args)=let
                                    val newsize=c+v
                                    in
                                        if(newsize>i)
                                        then Dst.E_Op(DstOp.IndexTensor(true, [indexAt],Ty.TensorTy [v]), [a1])
                                        else findLocal(newsize,i,indexAt-v,vs,args)
                                    end
                            in
                                findLocal(0,i,i,pieces,ops)
                            end
                        | (SrcOp.IndexTensor( _ , indTy,argTy),Dst.E_Var v) => (case (DstV.kind v)
                            of TreeIL.VK_Local=>Dst.E_Op(DstOp.IndexTensor(true ,indTy,argTy),[a'])
                            | _ =>Dst.E_Op(DstOp.IndexTensor(false ,indTy,argTy),[a'])
                            (*end case*))
                        | (SrcOp.IndexTensor( _ , indTy,argTy),_) => Dst.E_Op(DstOp.IndexTensor(true ,indTy,argTy),[a'])
                        (*end case*))
                    in
                        bind (env, lhs, exp)
                    end
                | Src.OP(rator,args) => let
                    val args'=List.map (useVar env) args
                    fun foundVec(n,srcrator,dstrator)= let
                        val (env0, lhst) = doLHS()
                        val m=ListPair.map  (fn(srcv,dstv)=>(srcv,dstv)) (args,args')
                        val stmt0  = LowOpToTreeOp.vecToTree((srcrator,m),(lhst,n,dstrator,Target.getVecTy n))
                        val (env2,stmts)=(case stmt0
                            of Dst.S_Assign(_,exp)=> bind (env0, lhs, exp)
                            | stmt=> (env0,[stmt0])
                            (*end case*))
                        in
                            (env2,stmts)
                        end
                   in case (rator,args')
                     of (SrcOp.addVec n,_)                           =>  foundVec(n,NONE,SOME DstOp.addVec)
                     | (SrcOp.subVec n,_)                            =>  foundVec(n,NONE,SOME DstOp.subVec)
                     | (SrcOp.prodVec n,_)                           =>  foundVec(n,NONE,SOME DstOp.prodVec)
                     | (SrcOp.Floor n,_)                             =>  foundVec(n,NONE,SOME DstOp.Floor)
                     | (SrcOp.Clamp(Ty.TensorTy[n]),_)               =>  foundVec(n,NONE,SOME DstOp.clampVec)
                     | (SrcOp.Normalize n,_)                         =>  foundVec(n,NONE,SOME DstOp.Normalize)
                     | (SrcOp.ProjectLast(_,n,_,_),_)                =>  foundVec(n,SOME rator,NONE)
                     | (SrcOp.ProjectFirst(_,n,_,_),_)               =>  foundVec(n,SOME rator,NONE)
                     | (SrcOp.sumVec n,_)                            =>  foundVec(n,SOME rator,NONE)
                     | (SrcOp.prodScaV n,_)                          =>  foundVec(n,SOME rator,SOME DstOp.prodScaV)
                     | (SrcOp.Lerp(Ty.TensorTy[n]),_)                =>  foundVec(n,SOME rator,SOME DstOp.lerpVec)
                     | (SrcOp.addSca ,[a,Dst.E_Lit (Literal.Int 0)]) => assignExp (env,a)
                     | (SrcOp.addSca ,[Dst.E_Lit (Literal.Int 0),a]) => assignExp (env,a)
                     | (SrcOp.subSca ,[a,Dst.E_Lit (Literal.Int 0)]) => assignExp (env,a)
                     | _                          => let
                        val Trator = LowOpToTreeOp.expandOp rator
                        val exp = Dst.E_Op(Trator, args')
                        in
                            if isInlineOp rator then (bind (env, lhs, exp))
                            else (assignExp (env, exp))
                        end
                    (*end case *)
                 end                
              | Src.APPLY(f, args) =>
                  bind (env, lhs, Dst.E_Apply(f, List.map (useVar env) args))
              | Src.CONS(ty as Ty.TensorTy[oSize], args) => let
                (* CONS of a vector with real arguments
                    * If lhs is a local var
                    * then we assume lhs will be represented with vectors
                    *      and  we use Mux and E_ConsVec on pieces, much like a vector op
                    * Otherwise, assume it's an array and use S_Cons
                    *
                    *)
                    val args' = List.map (useVar env) args
                    val (env2, t) = doLHS()
                    in case DstV.kind t
                        of TreeIL.VK_Local=> let
                            val (isFill,nSize,pieces)=Target.getVecTy oSize
                            val exp = LowOpToTreeOp.consVecToTree(nSize,oSize,pieces,args',isFill)
                            in
                                bind (env2, lhs, exp)
                            end
                        | _ => (env2, [Dst.S_Cons(t, oSize, args')])
                        (*end case*)
                    end
                | Src.CONS(ty as Ty.TensorTy [_,j], args) =>let
                    (* Cons is a matrix  with vector arguments 
                    * Each Vector Arg could be a global, local or state variable
                    * Which means their representation could be different
                    * when it is a global or state then we copy it S_Copy
                    * when it is a local or other then we use S_Store
                    *)
                    val args' =  List.map (useVar env) args
                    val (env2, t) = doLHS()
                
                    (*Vector params for last matrix index. Retrieved in case we use S_Store*)
                    val (isFill,nSize,pieces)=Target.getVecTy j
                    val splitTy=(*vector Length*)  pieces
                    val n=length pieces
                    val A =LowOpToTreeOp.isAlignedStore(isFill,n)
                    fun f ([], _ ) = []
                    | f (e1::es,count)=let
                        val t=(case e1
                            of Dst.E_State v    => Dst.S_Copy(Dst.E_Var t, e1, count,j)
                            | Dst.E_Var v       =>(case (DstV.kind v)
                                of TreeIL.VK_Global =>
                                    Dst.S_Copy(Dst.E_Var t, e1, count,j)
                                | _                 =>
                                    Dst.S_StoreVec(Dst.E_Var t,count,A,isFill,j,ty,splitTy, [e1])
                                (*end case*))
                            | _                 => Dst.S_StoreVec(Dst.E_Var t,count,A,isFill,j,ty,splitTy, [e1])
                        (*end case*))
                        in
                            t::f(es,count+j)
                        end
                    val stmts=f (args',0)
                    in
                        (env2, List.rev stmts)
                    end
                | Src.CONS(ty, args) =>let
                    (* CONS is larger tensor with non-vector arguments.
                    * We can assume everything is an array and S_Copy everything
                    *)
                    val args' =  List.map (useVar env) args
                    val (env2, t) = doLHS()
                    (*New row index shift *)
                    val shift=(case ty
                        of Ty.TensorTy [_,i,j]=> j*i
                        | Ty.TensorTy [_,i,j,k]=> j*i*k
                        |_ =>   raise Fail"CONS unsupported"
                        (*end case*))
                    fun f ([], _ ) = []
                    | f (e1::es,count)= Dst.S_Copy(Dst.E_Var t, e1, count,shift)::f(es,count+shift)
                    val stmts=f (args',0)
                    in
                        (env2, List.rev stmts)
                    end
             | Src.EINAPP (e,a)=> raise Fail "EINAPP in Low-IL to Tree-IL"
            (* end case *)
          end

  (* In order to reconstruct the block-structure from the CFG, we keep a stack of open ifs.
   * the items on this stack distinguish between when we are processing the then and else
   * branches of the if.
   *)
    datatype open_if
    (* working on the "then" branch.  The fields are statments that preceed the if, the condition,
     * and the else-branch node.
     *)
      = THEN_BR of Dst.stm list * Dst.exp * Src.node
    (* working on the "else" branch.  The fields are statments that preceed the if, the condition,
     * the "then" branch statements, and the node that terminated the "then" branch (will be
     * a JOIN, DIE, or STABILIZE).
     *)
      | ELSE_BR of Dst.stm list * Dst.exp * Dst.stm list * Src.node_kind


    fun mkBlockOrig(Dst.BlockWithOpr{ locals ,types,opr,body})= Dst.Block{locals=locals ,body=body}
                
    fun peelBlockOrig(env,Dst.BlockWithOpr{ locals ,types,opr,body})=let
        val env= setEnv(env,types,opr)
        in
	  (env,Dst.Block{locals=locals ,body=body})
        end
                
    fun decCount ( Src.V{useCnt, ...}) = let
        val n = !useCnt - 1
        in
	  useCnt := n;  (0 >= n)
        end
                
    fun trCFG (env, prefix, finish, cfg) = let
          fun join (env, [], _, Src.JOIN _) = raise Fail "JOIN with no open if"
                | join (env, [], stms, _) = let
                    val env'=addOprFromStmt(env,  stms)
                    in endScope (env', prefix @ List.rev stms) end
            | join (env, THEN_BR(stms1, cond, elseBr)::stk, thenBlk, k) = let
              
                val (env, thenBlk) = flushPending (env, thenBlk)
                val env'=addOprFromStmt(env,  stms1)
                in
                  doNode (env', ELSE_BR(stms1, cond, thenBlk, k)::stk, [], elseBr)
                end
            | join (env, ELSE_BR(stms, cond, thenBlk, k1)::stk, elseBlk, k2) = let

                val (env, elseBlk) = flushPending (env, elseBlk)
                in
                  case (k1, k2)
                   of ( Src.JOIN{phis, succ, ...}, Src.JOIN _) => let
                        val (env, [thenBlk, elseBlk]) =
                              List.foldl doPhi (env, [thenBlk, elseBlk]) (!phis)
                        val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk)
        
                        val env'=addOprFromStmt(env,  stm::stms)
                        in
                          doNode (env', stk, stm::stms, !succ)
                        end
                    | ( Src.JOIN{phis, succ, ...}, _) => let
                        val (env, [thenBlk]) = List.foldl doPhi (env, [thenBlk]) (!phis)
                        val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk)
                                                 in
                          doNode (addOprFromStmt(env, [stm]), stk, stm::stms, !succ)
                        end
                    | (_, Src.JOIN{phis, succ, ...}) => let
                        val (env, [elseBlk]) = List.foldl doPhi (env, [elseBlk]) (!phis)
                        val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk)
           
         
                        in
                          doNode (addOprFromStmt(env, [stm]), stk, stm::stms, !succ)
                        end
                    | (_, _) => raise Fail "no path to exit unimplemented" (* FIXME *)
                  (* end case *)
                end
                and doNode (env, ifStk : open_if list, stms, nd) = (
                case Nd.kind nd
                 of Src.NULL => raise Fail "unexpected NULL"
                  | Src.ENTRY{succ} => doNode (env, ifStk, stms, !succ)
                  | k as Src.JOIN{phis, succ, ...} =>  (join (env, ifStk, stms, k))
                  | Src.COND{cond, trueBranch, falseBranch, ...} => let
                      val cond = useVar env cond
                      val (env, stms) = flushPending (env, stms)
                      in
                        doNode (env, THEN_BR(stms, cond, !falseBranch)::ifStk, [], !trueBranch)
                      end
                  | Src.COM {text, succ, ...} =>
                      doNode (env, ifStk, Dst.S_Comment text :: stms, !succ)
                  | Src.ASSIGN{stm, succ, ...} => let
                      val (env, stms') = doAssign (env, stm)
                      in
                            doNode (addOprFromStmt(env, stms')  , ifStk, stms' @ stms, !succ)
                      end
                  | Src.MASSIGN{stm=(ys, rator, xs), succ, ...} => let
                    
                      fun doit () = let
                            fun doLHSVar (y, (env, ys)) = (case peekGlobal(env, y)
                                of SOME y' => ((env, y'::ys))
                                | NONE => let
                                        val t = newLocal y
                                        in
                                          (rename (addLocal(env, t), y, t), t::ys)
                                        end
                                  (* end case *))
                            val (env, ys) = List.foldr doLHSVar (env, []) ys
                             val Trator =  LowOpToTreeOp.expandOp rator
                            val exp = Dst.E_Op(Trator, List.map (useVar env) xs)
                            val stm = Dst.S_Assign(ys, exp)
                            in
                              doNode (env, ifStk, stm :: stms, !succ)
                            end
                      in
                        case rator
                         of SrcOp.Print _ => if Target.supportsPrinting()
                              then doit ()
                              else doNode (env, ifStk, stms, !succ)
                          | _ => doit()
                        (* end case *)
                      end
                  | Src.NEW{strand, args, succ, ...} => raise Fail "NEW unimplemented"
                  | Src.SAVE{lhs, rhs, succ, ...} => let
                    (* There is a Save and lhs is an array,
                    *  Stmt depends on how rhs exp is stored
                    *  If rhs is stored as a vector then use S_StoreVec(Mux, or Local Vector)
                    *  If rhs is an array uses S_copy (higher order tensors)
                    *  otherwise regular save
                    *)
                    val x=getStateVar lhs
                    val rhs2=useVar env rhs
                    fun size n=foldl (fn (a,b) => b*a) 1 n
                    val stm=(case  rhs2
                        of Dst.E_Mux(A,isFill, oSize,splitTy,args) => let
                            val varstatex=Dst.E_State x
                            val stmt1=(case args
                                of [Dst.E_LoadArr(_,nSize,oSize,t,Dst.E_Lit(Literal.Int 0))] => Dst.S_Copy(varstatex, t,0,oSize)
                                | _ => Dst.S_StoreVec(varstatex,0, A,isFill, oSize,Ty.TensorTy [oSize],splitTy,args)
                                (*end case*))
                            in
                                (decCount rhs ; stmt1)
                            end 
                        |  Dst.E_Var rhs3               => (case (DstV.kind rhs3,DstV.rTy rhs3)
                            of ( _ ,Ty.TensorTy [])              => Dst.S_Save([x], rhs2)
                            | (Dst.VK_Local,Ty.TensorTy [oSize]) =>let
                                val (isFill,nSize,pieces)=Target.getVecTy oSize
                                val tyP=pieces
                                val tyO=Ty.TensorTy [oSize]
                                in  Dst.S_StoreVec(Dst.E_State x,0, false,isFill, oSize,tyO,tyP,[rhs2])
                                end
                            | (_,Ty.TensorTy xs)              => Dst.S_Copy( Dst.E_State x, rhs2,0,size xs)
                            |(_ ,Ty.SeqTy(Ty.TensorTy xs,j))  => Dst.S_Copy( Dst.E_State x, rhs2,0,size([j]@xs))
                            | _                               => Dst.S_Save([x], rhs2)
                        (*end case*))
                        |  Dst.E_State rhs3               => (case (DstSV.ty rhs3)
                            of Ty.TensorTy xs                 => Dst.S_Copy( Dst.E_State x, rhs2,0,size xs)
                            | _                               => Dst.S_Save([x], rhs2)
                            (*end case*))
                        |  _ => Dst.S_Save([x], rhs2)
                        (*end case*))
                        val stmts=stm::stms
                      in
                        doNode (addOprFromStmt(env, stmts), ifStk, stmts, !succ)
                      end
                  | k as Src.EXIT{kind, live, ...} => (case kind
                       of ExitKind.FRAGMENT =>
                            endScope (env, prefix @ List.revAppend(stms, finish env))
                        | ExitKind.SINIT => let
(* FIXME: we should probably call flushPending here! *)
                            val suffix = finish env @ [Dst.S_Exit[]]
                            in
                              endScope (env, prefix @ List.revAppend(stms, suffix))
                            end
                        | ExitKind.RETURN => let
(* FIXME: we should probably call flushPending here! *)
                            val suffix = finish env @ [Dst.S_Exit(List.map (useVar env) live)]
                            in
                              endScope (env, prefix @ List.revAppend(stms, suffix))
                            end
                        | ExitKind.ACTIVE => let
(* FIXME: we should probably call flushPending here! *)
                            val suffix = finish env @ [Dst.S_Active]
                            in
                              endScope (env, prefix @ List.revAppend(stms, suffix))
                            end
                        | ExitKind.STABILIZE => let
(* FIXME: we should probably call flushPending here! *)
                            val stms = Dst.S_Stabilize :: stms
                            in
(* FIXME: we should probably call flushPending here! *)
                              (join (env, ifStk, stms, k))
                            end
                        | ExitKind.DIE => (join (env, ifStk, Dst.S_Die :: stms, k))
                      (* end case *))
                (* end case *))

          in
            doNode (env, [], [], CFG.entry cfg)
          end

    fun trInitially (env, Src.Initially{isArray, rangeInit, iters, create=(createInit, strand, args)}) =
          let
          val (env,iterPrefix) = peelBlockOrig(env,trCFG (env, [], fn _ => [], rangeInit))
        (*  val (env2,iterPrefix) = peelBlockOrig(env,trCFG (env, [], fn _ => [], rangeInit))*)
         (*val (iterPrefix) = mkBlockOrig(trCFG (env, [], fn _ => [], rangeInit))*)
                
          fun cvtIter ((param, lo, hi), (env, iters)) = let
                val param' = newIter param
                val env = rename (env, param, param')
                in
                  (env, (param', useVar env lo, useVar env hi)::iters)
                end
          val (env, iters) = List.foldr cvtIter (env, []) iters
          val (env,createPrefix) = peelBlockOrig(env,trCFG (env, [], fn _ => [], createInit))
          in (env,{
            isArray = isArray,
            iterPrefix = iterPrefix,
            iters = iters,
            createPrefix = createPrefix,
            strand = strand,
            args = List.map (useVar env) args
          }) end

    fun trMethod (env, Src.Method{name, body}) = let
          val (env, blk) = peelBlockOrig(env, trCFG (env, [], fn _ => [], body))
          (*val (blk)=mkBlockOrig(trCFG (env, [], fn _ => [], body))*)
	  in
	    (env, Dst.Method{name = name, body = blk})
	  end

    fun trStrands (env, strands) = let
	  fun tr (Src.Strand{name, params, state, stateInit, methods}, (env, strands')) = let
		val params' = List.map newParam params
		val env = ListPair.foldlEq (fn (x, x', env) => rename(env, x, x')) env (params, params')
		val (env', sInit) = peelBlockOrig(env,trCFG (env, [], fn _ => [], stateInit))
		fun callmethod (env, [], M) = (env,M)
		  | callmethod (env, b::es, Ms) = let
		      val (env2,M1) = trMethod(env,b)
		      in callmethod(env2, es, [M1]@Ms) end
		val (env', methods) = callmethod(env',methods,[])
		val strand' = Dst.Strand{
			name = name,
			params = params',
			state = List.map getStateVar state,
			stateInit =sInit,
			methods = methods
		      }
		in
		  (env', strand'::strands')
		end
	  val (env', strands') = List.foldl tr (env, []) strands
	  in
	    (env', List.rev strands')
	  end

  (* split the globalInit into the part that specifies the inputs and the rest of
   * the global initialization.
   *)
    fun splitGlobalInit globalInit = let
(* FIXME: can split as soon as we see a non-Input statement! *)
          fun walk (nd, lastInput, live) = (case Nd.kind nd
                 of Src.ENTRY{succ} => walk (!succ, lastInput, live)
                  | Src.COM{succ, ...} => walk (!succ, lastInput, live)
                  | Src.ASSIGN{stm=(lhs, rhs), succ, ...} => (case rhs
                       of Src.OP(SrcOp.Input _, _) => walk (!succ, nd, lhs::live)
                        | _ => walk (!succ, lastInput, live)
                      (* end case *))
                  | _ => if Nd.isNULL lastInput
                      then let (* no inputs *)
                        val entry = Nd.mkENTRY()
                        val exit = Nd.mkEXIT(ExitKind.RETURN, [])
                        in
                          Nd.addEdge (entry, exit);
                          {inputInit = Src.CFG{entry=entry, exit=exit}, globalInit = globalInit}
                        end
                      else let (* split at lastInput *)
                        val inputExit = Nd.mkEXIT(ExitKind.RETURN, live)
                        val globalEntry = Nd.mkENTRY()
                        val [gFirst] = Nd.succs lastInput
                        in
                          Nd.replaceInEdge {src = lastInput, oldDst = gFirst, dst = inputExit};
                          Nd.replaceOutEdge {oldSrc = lastInput, src = globalEntry, dst = gFirst};
                          {
                            inputInit = Src.CFG{entry = Src.CFG.entry globalInit, exit = inputExit},
                            globalInit = Src.CFG{entry = globalEntry, exit = Src.CFG.exit globalInit}
                          }
                        end
                (* end case *))
          in
            walk ( Src.CFG.entry globalInit, Nd.dummy, [])
          end

    fun getInfo(env,Init)=let
        val inputInit' = trCFG (env, [], fn _ => [], Init)
        in
            peelBlockOrig(env,inputInit')
        end
                
    fun translate prog = let
        (* first we do a variable analysis pass on the Low IL *)
          val prog as Src.Program{props, globalInit, initially, strands} = VA.optimize prog
(* FIXME: here we should do a contraction pass to eliminate unused variables that VA may have created *)
          val _ = (* DEBUG *)
                LowPP.output (Log.logFile(), "LowIL after variable analysis", prog)
          val envOrig = newEnv()
          val globals = List.map
                (fn x => let val x' = newGlobal x in global(envOrig, x, x'); x' end)
                  ( Src.CFG.liveAtExit globalInit)
          val {inputInit, globalInit} = splitGlobalInit globalInit
          val (env, inputInit) = getInfo(envOrig,inputInit)
          val (env, globalInit) = getInfo(env, globalInit)
          val (env, strands) = trStrands (env, strands)
          val (env, initially) = trInitially (env, initially)        
          val (typs, opr) = peelEnv(env)
          val typsList = TySet.listItems(typs);
          val oprList = OprSet.listItems(opr);
          
          in
	    Dst.Program{
		props = props,
		types=typsList,
		operations = oprList,
		globals = globals,
		inputInit = inputInit,
		globalInit = globalInit,
		strands = strands,
		initially = initially
	      }
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0