Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/staging/src/compiler/tree-il/low-to-tree-fn.sml
ViewVC logotype

View of /branches/staging/src/compiler/tree-il/low-to-tree-fn.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1942 - (download) (annotate)
Tue Jul 3 15:22:53 2012 UTC (7 years ago) by jhr
File size: 18649 byte(s)
  changes being staged from vis12 branch
(* low-to-tree-fn.sml
 *
 * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * This module translates the LowIL representation of a program (i.e., a pure CFG) to
 * a block-structured AST with nested expressions.
 *
 * NOTE: this translation is pretty dumb about variable coalescing (i.e., it doesn't do any).
 *)

functor LowToTreeFn (Target : sig

    val supportsPrinting : unit -> bool (* does the target support the Print op? *)

  (* tests for whether various expression forms can appear inline *)
    val inlineCons : int -> bool	(* can n'th-order tensor construction appear inline *)
    val inlineMatrixExp : bool		(* can matrix-valued expressions appear inline? *)

  end) : sig

    val translate : LowIL.program -> TreeIL.program

  end = struct

    structure IL = LowIL
    structure Ty = LowILTypes
    structure V = LowIL.Var
    structure StV = LowIL.StateVar
    structure Op = LowOps
    structure Nd = LowIL.Node
    structure CFG = LowIL.CFG
    structure T = TreeIL
    structure VA = VarAnalysis

  (* create new tree IL variables *)
    local
      fun newVar (name, kind, ty) = T.V{
	      name = name,
	      id = Stamp.new(),
	      kind = kind,
	      ty = ty
	    }
      val cnt = ref 0
      fun genName prefix = let
	    val n = !cnt
	    in
	      cnt := n+1;
	      String.concat[prefix, "_", Int.toString n]
	    end
    in
    fun newInput x = newVar (V.name x, T.VK_Input, V.ty x)
    fun newGlobal x = newVar (V.name x, T.VK_Global, V.ty x)
    fun newParam x = newVar (genName("p_" ^ V.name x), T.VK_Local, V.ty x)
    fun newLocal x = newVar (genName("l_" ^ V.name x), T.VK_Local, V.ty x)
    fun newIter x = newVar (genName("i_" ^ V.name x), T.VK_Local, V.ty x)
    end

  (* associate Tree IL state variables with Low IL variables using properties *)
    local
      fun mkStateVar x = T.SV{
              name = StV.name x,
              id = Stamp.new(),
              ty = StV.ty x,
              varying = VA.isVarying x,
              output = StV.isOutput x
            }
    in
    val {getFn = getStateVar, ...} = StV.newProp mkStateVar
    end

    fun mkBlock stms = T.Block{locals=[], body=stms}
    fun mkIf (x, stms, []) = T.S_IfThen(x, mkBlock stms)
      | mkIf (x, stms1, stms2) = T.S_IfThenElse(x, mkBlock stms1, mkBlock stms2)

  (* an environment that tracks bindings of variables to target expressions and the list
   * of locals that have been defined.
   *)
    local
      structure VT = V.Tbl
      fun decCount (IL.V{useCnt, ...}) = let
	    val n = !useCnt - 1
	    in
	      useCnt := n;  (n <= 0)
	    end
      datatype target_binding
	= GLOB of T.var		(* variable is global *)
	| TREE of T.exp		(* variable bound to target expression tree *)
	| DEF of T.exp		(* either a target variable or constant for a defined variable *)
      datatype env = E of {
	  tbl : target_binding VT.hash_table,
	  locals : T.var list
	}
    in
(* DEBUG *)
fun bindToString binding = (case binding
       of GLOB y => "GLOB " ^ T.Var.name y
	| TREE e => "TREE"
	| DEF(T.E_Var y) => "DEF " ^ T.Var.name y
	| DEF e => "DEF"
      (* end case *))
fun dumpEnv (E{tbl, ...}) = let
      fun prEntry (x, binding) =
	    print(concat["  ", IL.Var.toString x, " --> ", bindToString binding, "\n"])
      in
	print "*** dump environment\n";
	VT.appi prEntry tbl;
	print "***\n"
      end
(* DEBUG *)

    fun newEnv () = E{tbl = VT.mkTable (512, Fail "tbl"), locals=[]}

  (* use a variable.  If it is a pending expression, we remove it from the table *)
    fun useVar (env as E{tbl, ...}) x = (case VT.find tbl x
	   of SOME(GLOB x') => T.E_Var x'
	    | SOME(TREE e) => (
(*print(concat["useVar ", V.toString x, " ==> TREE\n"]);*)
		ignore(VT.remove tbl x);
		e)
	    | SOME(DEF e) => (
(*print(concat["useVar ", V.toString x, " ==> ", bindToString(DEF e), "; use count = ", Int.toString(V.useCount x), "\n"]);*)
	      (* if this is the last use of x, then remove it from the table *)
		if (decCount x) then ignore(VT.remove tbl x) else ();
		e)
	    | NONE => (
dumpEnv env;
raise Fail(concat ["useVar(", V.toString x, ")"])
)
	  (* end case *))

  (* record a local variable *)
    fun addLocal (E{tbl, locals}, x) = E{tbl=tbl, locals=x::locals}

    fun global (E{tbl, ...}, x, x') = VT.insert tbl (x, GLOB x')

  (* insert a pending expression into the table.  Note that x should only be used once! *)
    fun insert (env as E{tbl, ...}, x, exp) = (
	  VT.insert tbl (x, TREE exp);
	  env)

    fun rename (env as E{tbl, ...}, x, x') = (
	  VT.insert tbl (x, DEF(T.E_Var x'));
	  env)

    fun peekGlobal (E{tbl, ...}, x) = (case VT.find tbl x
	   of SOME(GLOB x') => SOME x'
	    | _ => NONE
	  (* end case *))

    fun bindLocal (env, lhs, rhs) = if (V.useCount lhs = 1)
	  then (insert(env, lhs, rhs), [])
	  else let
	    val t = newLocal lhs
	    in
	      (rename(addLocal(env, t), lhs, t), [T.S_Assign([t], rhs)])
	    end

    fun bind (env, lhs, rhs) = (case peekGlobal (env, lhs)
	   of SOME x => (env, [T.S_Assign([x], rhs)])
	    | NONE => bindLocal (env, lhs, rhs)
	  (* end case *))

  (* set the definition of a variable, where the RHS is either a literal constant or a variable *)
    fun bindSimple (env as E{tbl, ...}, lhs, rhs) = (
	  case peekGlobal (env, lhs)
	   of SOME x => (env, [T.S_Assign([x], rhs)])
	    | NONE => (VT.insert tbl (lhs, DEF rhs); (env, []))
	  (* end case *))

  (* at the end of a block, we need to assign any pending expressions to locals.  The
   * blkStms list and the resulting statement list are in reverse order.
   *)
    fun flushPending (E{tbl, locals}, blkStms) = let
	  fun doVar (x, TREE e, (locals, stms)) = let
		val t = newLocal x
		in
		  VT.insert tbl (x, DEF(T.E_Var t));
		  (t::locals, T.S_Assign([t], e)::stms)
		end
	    | doVar (_, _, acc) = acc
	  val (locals, stms) = VT.foldi doVar (locals, blkStms) tbl
	  in
	    (E{tbl=tbl, locals=locals}, stms)
	  end

    fun doPhi ((lhs, rhs), (env, predBlks : T.stm list list)) = let
	(* t will be the variable in the continuation of the JOIN *)
	  val t = newLocal lhs
	  val predBlks = ListPair.map
		(fn (x, stms) => T.S_Assign([t], useVar env x)::stms)
		  (rhs, predBlks)
	  in
	    (rename (addLocal(env, t), lhs, t), predBlks)
	  end

    fun endScope (E{locals, ...}, stms) = T.Block{
	    locals = List.rev locals,
	    body = stms
	  }

    end

  (* Certain IL operators cannot be compiled to inline expressions.  Return
   * false for those and true for all others.
   *)
    fun isInlineOp rator = let
	  fun chkTensorTy (Ty.TensorTy[]) = true
	    | chkTensorTy (Ty.TensorTy[_]) = true
	    | chkTensorTy (Ty.TensorTy _) = Target.inlineMatrixExp
	    | chkTensorTy _ = true
	  in
	   case rator
	     of Op.LoadVoxels(_, 1) => true
	      | Op.LoadVoxels _ => false
	      | Op.Add ty => chkTensorTy ty
	      | Op.Sub ty => chkTensorTy ty
	      | Op.Neg ty => chkTensorTy ty
	      | Op.Scale ty => chkTensorTy ty
	      | Op.MulMatMat _ => Target.inlineMatrixExp
	      | Op.MulVecTen3 _ => false
	      | Op.MulTen3Vec _ => false
              | Op.EigenVecs2x2 => false
              | Op.EigenVecs3x3 => false
              | Op.EigenVals2x2 => false
              | Op.EigenVals3x3 => false
	      | Op.Identity _ => Target.inlineMatrixExp
	      | Op.Zero _ => Target.inlineMatrixExp
	      | Op.TensorToWorldSpace(_, ty) => chkTensorTy ty
	      | _ => true
	    (* end case *)
	  end

  (* translate a LowIL assignment to a list of zero or more target statements *)
    fun doAssign (env, (lhs, rhs)) = let
	  fun doLHS () = (case peekGlobal(env, lhs)
		 of SOME lhs' => (env, lhs')
		  | NONE => let
		      val t = newLocal lhs
		      in
			(rename (addLocal(env, t), lhs, t), t)
		      end
		(* end case *))
	(* for expressions that are going to be compiled to a call statement *)
	  fun assignExp (env, exp) = let
	      (* operations that return matrices may not be supported inline *)
		val (env, t) = doLHS()
		in
		  (env, [T.S_Assign([t], exp)])
		end
	  in
	    case rhs
	     of IL.STATE x => bindSimple (env, lhs, T.E_State(getStateVar x))
              | IL.VAR x => bindSimple (env, lhs, useVar env x)
	      | IL.LIT lit => bindSimple (env, lhs, T.E_Lit lit)
	      | IL.OP(Op.LoadImage info, [a]) => let
		  val (env, t) = doLHS()
		  in
		    (env, [T.S_LoadImage(t, ImageInfo.dim info, useVar env a)])
		  end
	      | IL.OP(Op.Input(ty, name, desc), []) => let
		  val (env, t) = doLHS()
		  in
		    (env, [T.S_Input(t, name, desc, NONE)])
		  end
	      | IL.OP(Op.InputWithDefault(ty, name, desc), [a]) => let
		  val (env, t) = doLHS()
		  in
		    (env, [T.S_Input(t, name, desc, SOME(useVar env a))])
		  end
	      | IL.OP(rator, args) => let
		  val exp = T.E_Op(rator, List.map (useVar env) args)
		  in
		    if isInlineOp rator
		      then bind (env, lhs, exp)
		      else assignExp (env, exp)
		  end
	      | IL.APPLY(f, args) =>
		  bind (env, lhs, T.E_Apply(f, List.map (useVar env) args))
	      | IL.CONS(ty, args) => let
		  val inline = (case ty
			 of Ty.SeqTy(Ty.IntTy, _) => true
			  | Ty.TensorTy dd => Target.inlineCons(List.length dd)
                          | Ty.SeqTy _ => false
                          | _ => raise Fail(concat["invalid CONS<", Ty.toString ty, ">"])
			(* end case *))
		  val exp = T.E_Cons(ty, List.map (useVar env) args)
		  in
		    if inline
		      then bind (env, lhs, exp)
		      else assignExp (env, exp)
		  end
	    (* end case *)
	  end

  (* In order to reconstruct the block-structure from the CFG, we keep a stack of open ifs.
   * the items on this stack distinguish between when we are processing the then and else
   * branches of the if.
   *)
    datatype open_if
    (* working on the "then" branch.  The fields are statments that preceed the if, the condition,
     * and the else-branch node.
     *)
      = THEN_BR of T.stm list * T.exp * IL.node
    (* working on the "else" branch.  The fields are statments that preceed the if, the condition,
     * the "then" branch statements, and the node that terminated the "then" branch (will be
     * a JOIN, DIE, or STABILIZE).
     *)
      | ELSE_BR of T.stm list * T.exp * T.stm list * IL.node_kind

    fun trCFG (env, prefix, finish, cfg) = let
	  fun join (env, [], _, IL.JOIN _) = raise Fail "JOIN with no open if"
	    | join (env, [], stms, _) = endScope (env, prefix @ List.rev stms)
	    | join (env, THEN_BR(stms1, cond, elseBr)::stk, thenBlk, k) = let
		val (env, thenBlk) = flushPending (env, thenBlk)
		in
		  doNode (env, ELSE_BR(stms1, cond, thenBlk, k)::stk, [], elseBr)
		end
	    | join (env, ELSE_BR(stms, cond, thenBlk, k1)::stk, elseBlk, k2) = let
		val (env, elseBlk) = flushPending (env, elseBlk)
		in
		  case (k1, k2)
		   of (IL.JOIN{phis, succ, ...}, IL.JOIN _) => let
			val (env, [thenBlk, elseBlk]) =
			      List.foldl doPhi (env, [thenBlk, elseBlk]) (!phis)
			val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk)
			in
			  doNode (env, stk, stm::stms, !succ)
			end
		    | (IL.JOIN{phis, succ, ...}, _) => let
			val (env, [thenBlk]) = List.foldl doPhi (env, [thenBlk]) (!phis)
			val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk)
			in
			  doNode (env, stk, stm::stms, !succ)
			end
		    | (_, IL.JOIN{phis, succ, ...}) => let
			val (env, [elseBlk]) = List.foldl doPhi (env, [elseBlk]) (!phis)
			val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk)
			in
			  doNode (env, stk, stm::stms, !succ)
			end
		    | (_, _) => raise Fail "no path to exit unimplemented" (* FIXME *)
		  (* end case *)
		end
	  and doNode (env, ifStk : open_if list, stms, nd) = (
		case Nd.kind nd
		 of IL.NULL => raise Fail "unexpected NULL"
		  | IL.ENTRY{succ} => doNode (env, ifStk, stms, !succ)
		  | k as IL.JOIN{phis, succ, ...} => join (env, ifStk, stms, k)
		  | IL.COND{cond, trueBranch, falseBranch, ...} => let
		      val cond = useVar env cond
		      val (env, stms) = flushPending (env, stms)
		      in
			doNode (env, THEN_BR(stms, cond, !falseBranch)::ifStk, [], !trueBranch)
		      end
		  | IL.COM {text, succ, ...} =>
		      doNode (env, ifStk, T.S_Comment text :: stms, !succ)
		  | IL.ASSIGN{stm, succ, ...} => let
		      val (env, stms') = doAssign (env, stm)
		      in
			doNode (env, ifStk, stms' @ stms, !succ)
		      end
                  | IL.MASSIGN{stm=(ys, rator, xs), succ, ...} => let
                      fun doit () = let
                            fun doLHSVar (y, (env, ys)) = (case peekGlobal(env, y)
                                   of SOME y' => (env, y'::ys)
                                    | NONE => let
                                        val t = newLocal y
                                        in
                                          (rename (addLocal(env, t), y, t), t::ys)
                                        end
                                  (* end case *))
                            val (env, ys) = List.foldr doLHSVar (env, []) ys
                            val exp = T.E_Op(rator, List.map (useVar env) xs)
                            val stm = T.S_Assign(ys, exp)
                            in
                              doNode (env, ifStk, stm :: stms, !succ)
                            end
                      in
                        case rator
                         of Op.Print _ => if Target.supportsPrinting()
                              then doit ()
                              else doNode (env, ifStk, stms, !succ)
                          | _ => doit()
                        (* end case *)
                      end
		  | IL.NEW{strand, args, succ, ...} => raise Fail "NEW unimplemented"
                  | IL.SAVE{lhs, rhs, succ, ...} => let
                      val stm = T.S_Save([getStateVar lhs], useVar env rhs)
                      in
                        doNode (env, ifStk, stm::stms, !succ)
                      end
		  | k as IL.EXIT{kind, live, ...} => (case kind
		       of ExitKind.FRAGMENT =>
			    endScope (env, prefix @ List.revAppend(stms, finish env))
			| ExitKind.SINIT => let
(* FIXME: we should probably call flushPending here! *)
			    val suffix = finish env @ [T.S_Exit[]]
			    in
			      endScope (env, prefix @ List.revAppend(stms, suffix))
			    end
			| ExitKind.RETURN => let
(* FIXME: we should probably call flushPending here! *)
			    val suffix = finish env @ [T.S_Exit(List.map (useVar env) live)]
			    in
			      endScope (env, prefix @ List.revAppend(stms, suffix))
			    end
			| ExitKind.ACTIVE => let
(* FIXME: we should probably call flushPending here! *)
			    val suffix = finish env @ [T.S_Active]
			    in
			      endScope (env, prefix @ List.revAppend(stms, suffix))
			    end
			| ExitKind.STABILIZE => let
(* FIXME: we should probably call flushPending here! *)
			    val stms = T.S_Stabilize :: stms
			    in
(* FIXME: we should probably call flushPending here! *)
			      join (env, ifStk, stms, k)
			    end
			| ExitKind.DIE => join (env, ifStk, T.S_Die :: stms, k)
		      (* end case *))
		(* end case *))
	  in
	    doNode (env, [], [], CFG.entry cfg)
	  end

    fun trInitially (env, IL.Initially{isArray, rangeInit, iters, create=(createInit, strand, args)}) =
	  let
	  val iterPrefix = trCFG (env, [], fn _ => [], rangeInit)
	  fun cvtIter ((param, lo, hi), (env, iters)) = let
		val param' = newIter param
		val env = rename (env, param, param')
		in
		  (env, (param', useVar env lo, useVar env hi)::iters)
		end
	  val (env, iters) = List.foldr cvtIter (env, []) iters
	  val createPrefix = trCFG (env, [], fn _ => [], createInit)
	  in {
	    isArray = isArray,
	    iterPrefix = iterPrefix,
	    iters = iters,
	    createPrefix = createPrefix,
	    strand = strand,
	    args = List.map (useVar env) args
	  } end

    fun trMethod env (IL.Method{name, body}) = T.Method{
            name = name,
            body = trCFG (env, [], fn _ => [], body)
          }

    fun trStrand globalEnv (IL.Strand{name, params, state, stateInit, methods}) = let
	  val params' = List.map newParam params
	  val env = ListPair.foldlEq (fn (x, x', env) => rename(env, x, x')) globalEnv (params, params')
	  in
	    T.Strand{
		name = name,
		params = params',
		state = List.map getStateVar state,
		stateInit = trCFG (env, [], fn _ => [], stateInit),
		methods = List.map (trMethod env) methods
	      }
	  end

  (* split the globalInit into the part that specifies the inputs and the rest of
   * the global initialization.
   *)
    fun splitGlobalInit globalInit = let
	  fun walk (nd, lastInput, live) = (case Nd.kind nd
	         of IL.ENTRY{succ} => walk (!succ, lastInput, live)
		  | IL.COM{succ, ...} => walk (!succ, lastInput, live)
		  | IL.ASSIGN{stm=(lhs, rhs), succ, ...} => (case rhs
		       of IL.OP(Op.Input _, _) => walk (!succ, nd, lhs::live)
			| IL.OP(Op.InputWithDefault _, _) => walk (!succ, nd, lhs::live)
			| _ => walk (!succ, lastInput, live)
		      (* end case *))
		  | _ => if Nd.isNULL lastInput
		      then let (* no inputs *)
			val entry = Nd.mkENTRY()
			val exit = Nd.mkEXIT(ExitKind.RETURN, [])
			in
			  Nd.addEdge (entry, exit);
			  {inputInit = IL.CFG{entry=entry, exit=exit}, globalInit = globalInit}
			end
		      else let (* split at lastInput *)
			val inputExit = Nd.mkEXIT(ExitKind.RETURN, live)
			val globalEntry = Nd.mkENTRY()
			val [gFirst] = Nd.succs lastInput
			in
			  Nd.replaceInEdge {src = lastInput, oldDst = gFirst, dst = inputExit};
			  Nd.replaceOutEdge {oldSrc = lastInput, src = globalEntry, dst = gFirst};
			  {
			    inputInit = IL.CFG{entry = IL.CFG.entry globalInit, exit = inputExit},
			    globalInit = IL.CFG{entry = globalEntry, exit = IL.CFG.exit globalInit}
			  }
			end
		(* end case *))
	  in
	    walk (IL.CFG.entry globalInit, Nd.dummy, [])
	  end

    fun translate prog = let
	(* first we do a variable analysis pass on the Low IL *)
	  val prog as IL.Program{props, globalInit, initially, strands} = VA.optimize prog
(* FIXME: here we should do a contraction pass to eliminate unused variables that VA may have created *)
	  val _ = (* DEBUG *)
		LowPP.output (Log.logFile(), "LowIL after variable analysis", prog)
	  val env = newEnv()
	  val globals = List.map
		(fn x => let val x' = newGlobal x in global(env, x, x'); x' end)
		  (IL.CFG.liveAtExit globalInit)
	  val {inputInit, globalInit} = splitGlobalInit globalInit
	  val strands = List.map (trStrand env) strands
	  in
	    T.Program{
		props = props,
		globals = globals,
		inputInit = trCFG (env, [], fn _ => [], inputInit),
		globalInit = trCFG (env, [], fn _ => [], globalInit),
		strands = strands,
		initially = trInitially (env, initially)
	      }
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0