Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/staging/src/compiler/translate/translate.sml
ViewVC logotype

View of /branches/staging/src/compiler/translate/translate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2746 - (download) (annotate)
Wed Oct 1 21:08:30 2014 UTC (4 years, 9 months ago) by jhr
File size: 24087 byte(s)
  Porting changes from vis12 branch
(* translate.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * Translate Simple-AST code into the IL representation.  This translation is based on the
 * algorithm described in
 *
 *      Single-pass generation of static single assignment form for structured languages
 *      ACM TOPLAS, Nov. 1994
 *      by Brandis and MossenBock.
 *)

structure Translate : sig

    val translate : Simple.program -> HighIL.program

  end = struct

    structure S = Simple
    structure Ty = SimpleTypes
    structure VMap = SimpleVar.Map
    structure VSet = SimpleVar.Set
    structure IL = HighIL
    structure Op = HighOps
    structure DstTy = HighILTypes
    structure Census = HighILCensus

    val cvtTy = TranslateTy.tr

  (* maps from SimpleAST variables to the current corresponding SSA variable *)
    type env = IL.var VMap.map

(* +DEBUG *)
    fun prEnv (prefix, env) = let
          val wid = ref 0
          fun pr s = (print s; wid := !wid + size s)
          fun nl () = if (!wid > 0) then (print "\n"; wid := 0) else ()
          fun prElem (src, dst) = let
                val s = String.concat [
                        " ", SimpleVar.uniqueNameOf src, "->", IL.Var.toString dst
                      ]
                in
                  pr s;
                  if (!wid >= 100) then (nl(); pr " ") else ()
                end
          in
            pr prefix; pr " ENV: {"; nl(); pr " ";
            VMap.appi prElem env;
            nl(); pr "}"; nl()
          end
(* -DEBUG *)

    fun lookup env x = (case VMap.find (env, x)
           of SOME x' => x'
            | NONE => raise Fail(concat[
                  "no binding for ", SimpleVar.uniqueNameOf x, " in environment"
                ])
          (* end case *))

  (* create a new instance of a variable *)
    fun newVar x = IL.Var.new (SimpleVar.nameOf x, cvtTy(SimpleVar.typeOf x))

  (* generate fresh SSA variables and add them to the environment *)
    fun freshVars (env, xs) = let
          fun cvtVar (x, (env, xs)) = let
                val x' = newVar x
                in
                  (VMap.insert(env, x, x'), x'::xs)
                end
          val (env, xs) = List.foldl cvtVar (env, []) xs
          in
            (env, List.rev xs)
          end

  (* a pending-join node tracks the phi nodes needed to join the assignments
   * that flow into the join node.
   *)
    datatype join = JOIN of {
        env : env,                      (* the environment that was current at the conditional *)
                                        (* associated with this node. *)
        arity : int ref,                (* actual number of predecessors *)
        nd : IL.node,                   (* the CFG node for this pending join *)
        phiMap : IL.phi VMap.map ref,   (* a mapping from Simple AST variables that are assigned *)
                                        (* to their phi nodes. *)
        predKill : bool array           (* killed predecessor edges (because of DIE or STABILIZE *)
      }

  (* a stack of pending joins.  The first component specifies the path index of the current
   * path to the join.
   *)
    type pending_joins = (int * join) list

  (* create a new pending-join node *)
    fun newJoin (env, arity) = JOIN{
            env = env, arity = ref arity, nd = IL.Node.mkJOIN [], phiMap = ref VMap.empty,
            predKill = Array.array(arity, false)
          }

  (* record that a path to the top join in the stack has been killed because f DIE or STABILIZE *)
    fun killPath ((i, JOIN{arity, predKill, ...}) :: _) = (
          arity := !arity - 1;
          Array.update (predKill, i, true))
      | killPath _ = ()

  (* record an assignment to the IL variable dstVar (corresponding to the Simple AST variable
   * srcVar) in the current pending-join node.  The predIndex specifies which path into the
   * JOIN node this assignment occurs on.
   *)
    fun recordAssign ([], _, _) = ()
      | recordAssign ((predIndex, JOIN{env, phiMap, predKill, nd, ...})::_, srcVar, dstVar) = let
          val arity = Array.length predKill (* the original arity before any killPath calls *)
          val m = !phiMap
          in
            case VMap.find (env, srcVar)
             of NONE => () (* local temporary *)
              | SOME dstVar' => (case VMap.find (m, srcVar)
                   of NONE => let
                        val lhs = newVar srcVar
                        val rhs = List.tabulate (arity, fn i => if (i = predIndex) then dstVar else dstVar')
                        in
(*
print(concat["recordAssign: ", SimpleVar.uniqueNameOf srcVar, " --> ", IL.Var.toString lhs,
" @ ", IL.Node.toString nd, "\n"]);
*)
                          phiMap := VMap.insert (m, srcVar, (lhs, rhs))
                        end
                    | SOME(lhs, rhs) => let
                        fun update (i, l as x::r) = if (i = predIndex)
                              then dstVar::r
                              else x::update(i+1, r)
                          | update _ = raise Fail "invalid predecessor index"
                        in
                          phiMap := VMap.insert (m, srcVar, (lhs, update(0, rhs)))
                        end
                  (* end case *))
            (* end case *)
          end

  (* complete a pending join operation by filling in the phi nodes from the phi map and
   * updating the environment.
   *)
    fun commitJoin (joinStk, JOIN{env, arity, nd, phiMap, predKill}) = (case !arity
           of 0 => (env, NONE)
            | 1 => let
              (* there is only one path to the join, so we do not need phi nodes, but
               * we still need to propogate assignments to the next join on the stack.
               *)
                val IL.ND{kind=IL.JOIN{phis, ...}, ...} = nd
                val ix = let (* find pred of this join *)
                      fun find i = if Array.sub(predKill, i) then find(i+1) else i
                      in
                        find 0
                      end
                fun doVar (srcVar, (_, xs), env) = let
                      val dstVar = List.nth(xs, ix)
                      in
(*
print(concat["doVar (", SimpleVar.uniqueNameOf srcVar, ", ", IL.phiToString phi, ", _) @ ", IL.Node.toString nd, "\n"]);
*)
                        recordAssign (joinStk, srcVar, dstVar);
                        VMap.insert (env, srcVar, dstVar)
                      end
                val env = VMap.foldli doVar env (!phiMap)
                in
                  (env, SOME nd)
                end
            | n => if (n = Array.length predKill)
                then let
                  val IL.ND{kind=IL.JOIN{phis, ...}, ...} = nd
                  fun doVar (srcVar, phi as (dstVar, _), (env, phis)) = (
(*
print(concat["doVar (", SimpleVar.uniqueNameOf srcVar, ", ", IL.phiToString phi, ", _) @ ", IL.Node.toString nd, "\n"]);
*)
                        recordAssign (joinStk, srcVar, dstVar);
                        (VMap.insert (env, srcVar, dstVar), phi::phis))
                  val (env, phis') = VMap.foldli doVar (env, []) (!phiMap)
                  in
                    phis := phis';
                    (env, SOME nd)
                  end
                else raise Fail "FIXME: prune killed paths."
          (* end case *))

  (* expression translation *)
    fun cvtExp (env : env, lhs, exp) = (case exp
           of S.E_Var x => [IL.ASSGN(lhs, IL.VAR(lookup env x))]
            | S.E_Lit lit => [IL.ASSGN(lhs, IL.LIT lit)]
            | S.E_Tuple xs => raise Fail "E_Tuple not implemeted"
            | S.E_Apply _ => raise Fail "unexpected E_Apply"
            | S.E_Prim(f, tyArgs, args, ty) => let
                val args' = List.map (lookup env) args
                in
                  TranslateBasis.translate (lhs, f, tyArgs, args')
                end
            | S.E_Cons args => [IL.ASSGN(lhs, IL.CONS(IL.Var.ty lhs, List.map (lookup env) args))]
            | S.E_Seq(args, _) => [IL.ASSGN(lhs, IL.CONS(IL.Var.ty lhs, List.map (lookup env) args))]
            | S.E_Slice(x, indices, ty) => let
                val x = lookup env x
                val mask = List.map isSome indices
                fun cvt NONE = NONE
                  | cvt (SOME x) = SOME(lookup env x)
                val indices = List.mapPartial cvt indices
                in
                  if List.all (fn b => b) mask
                    then [IL.ASSGN(lhs, IL.OP(Op.TensorSub(IL.Var.ty x), x::indices))]
                    else [IL.ASSGN(lhs, IL.OP(Op.Slice(IL.Var.ty x, mask), x::indices))]
                end
            | S.E_Coerce{srcTy, dstTy, x} => (case (srcTy, dstTy)
                 of (Ty.T_Int, Ty.T_Tensor _) =>
                      [IL.ASSGN(lhs, IL.OP(Op.IntToReal, [lookup env x]))]
                  | (Ty.T_Field _, Ty.T_Field _) =>
                    (* change in continuity is a no-op *)
                      [IL.ASSGN(lhs, IL.VAR(lookup env x))]
                  | _ => raise Fail(concat[
                        "unsupported type coercion: ", Ty.toString srcTy,
                        " ==> ", Ty.toString dstTy
                      ])
                (* end case *))
            | S.E_LoadImage(ty, nrrd, info) => [IL.ASSGN(lhs, IL.OP(Op.LoadImage(cvtTy ty, nrrd, info), []))]
          (* end case *))

  (* add nodes to save the strand state, followed by an exit node *)
    fun saveStrandState (env, (srcState, dstState), exit) = let
          val stateOut = List.map (lookup env) srcState
          fun save (x, x', cfg) = IL.CFG.appendNode (cfg, IL.Node.mkSAVE(x, x'))
          in
            IL.CFG.appendNode (
              ListPair.foldlEq save IL.CFG.empty (dstState, stateOut),
              exit)
          end
(*DEBUG*)handle ex => raise ex

    fun cvtBlock (state, env : env, joinStk, S.Block stms) = let
          fun cvt (env : env, cfg, []) = (cfg, env)
            | cvt (env, cfg, stm::stms) = (case stm
                 of S.S_Var x => let
                      val x' = newVar x
                      in
                        cvt (VMap.insert (env, x, x'), cfg, stms)
                      end
                  | S.S_Assign(lhs, rhs) => let
                      val lhs' = newVar lhs
                      val assigns = cvtExp (env, lhs', rhs)
                      in
(*
print "doAssign\n";
*)
                        recordAssign (joinStk, lhs, lhs');
                        cvt (
                          VMap.insert(env, lhs, lhs'),
                          IL.CFG.concat(cfg, IL.CFG.mkBlock assigns),
                          stms)
                      end
                  | S.S_IfThenElse(x, b0, b1) => let
                      val x' = lookup env x
                      val join = newJoin (env, 2)
                      val (cfg0, _) = cvtBlock (state, env, (0, join)::joinStk, b0)
                      val (cfg1, _) = cvtBlock (state, env, (1, join)::joinStk, b1)
                      val cond = IL.Node.mkCOND {
                              cond = x',
                              trueBranch = IL.Node.dummy,
                              falseBranch = IL.Node.dummy
                            }
                      in
                        IL.Node.addEdge (IL.CFG.exit cfg, cond);
                        case commitJoin (joinStk, join)
                         of (env, SOME joinNd) => (
                              if IL.CFG.isEmpty cfg0
                                then (
                                  IL.Node.setTrueBranch (cond, joinNd);
                                  IL.Node.setPred (joinNd, cond))
                                else (
                                  IL.Node.setTrueBranch (cond, IL.CFG.entry cfg0);
                                  IL.Node.setPred (IL.CFG.entry cfg0, cond);
                                  IL.Node.addEdge (IL.CFG.exit cfg0, joinNd));
                              if IL.CFG.isEmpty cfg1
                                then (
                                  IL.Node.setFalseBranch (cond, joinNd);
                                  IL.Node.setPred (joinNd, cond))
                                else (
                                  IL.Node.setFalseBranch (cond, IL.CFG.entry cfg1);
                                  IL.Node.setPred (IL.CFG.entry cfg1, cond);
                                  IL.Node.addEdge (IL.CFG.exit cfg1, joinNd));
                              cvt (
                                env,
                                IL.CFG.concat (
                                  cfg,
                                  IL.CFG{entry = cond, exit = joinNd}),
                                stms))
                        (* the join node has only zero predecessors, so
                         * it was killed.
                         *)
                          | (env, NONE) => raise Fail "unimplemented" (* FIXME *)
                        (* end case *)
                      end
                  | S.S_New(strandId, args) => let
                      val nd = IL.Node.mkNEW{
                              strand = strandId,
                              args = List.map (lookup env) args
                            }
                      in
                        cvt (env, IL.CFG.appendNode (cfg, nd), stms)
                      end
                  | S.S_Die => (
                      killPath joinStk;
                      (IL.CFG.appendNode (cfg, IL.Node.mkDIE ()), env))
                  | S.S_Stabilize => (
                      killPath joinStk;
                      (IL.CFG.concat (cfg, saveStrandState (env, state, IL.Node.mkSTABILIZE())), env))
                  | S.S_Return _ => raise Fail "unexpected return"
                  | S.S_Print args => let
                      val args = List.map (lookup env) args
                      val nd = IL.Node.mkMASSIGN([], Op.Print(List.map IL.Var.ty args), args)
                      in
                        cvt (env, IL.CFG.appendNode (cfg, nd), stms)
                      end
                (* end case *))
          in
            cvt (env, IL.CFG.empty, stms)
          end
(*DEBUG*)handle ex => raise ex

    fun cvtTopLevelBlock (env, blk, mkExit) = let
          val (cfg, env) = cvtBlock (([], []), env, [], blk)
          val cfg = IL.CFG.prependNode (IL.Node.mkENTRY(), cfg)
          val cfg = IL.CFG.concat (cfg, mkExit env)
          in
            (cfg, env)
          end
(*DEBUG*)handle ex => raise ex

(* FIXME: the following function could be refactored with cvtTopLevelBlock to share code *)
    fun cvtFragmentBlock (env0, blk) = let
          val (cfg, env) = cvtBlock (([], []), env0, [], blk)
          val entry = IL.Node.mkENTRY ()
        (* the live variables out are those that were not live coming in *)
          val liveOut = VMap.foldli
                (fn (x, x', xs) => if VMap.inDomain(env0, x) then xs else x'::xs)
                  [] env
          val exit = IL.Node.mkFRAGMENT liveOut
          in
            if IL.CFG.isEmpty cfg
              then IL.Node.addEdge (entry, exit)
              else (
                IL.Node.addEdge (entry, IL.CFG.entry cfg);
                IL.Node.addEdge (IL.CFG.exit cfg, exit));
            (IL.CFG{entry = entry, exit = exit}, env)
          end
(*DEBUG*)handle ex => raise ex

    fun cvtMethod (env, name, state, svars, blk) = let
        (* load the state into fresh variables *)
          val (env, loadCFG) = let
              (* allocate shadow variables for the state variables *)
                val (env, stateIn) = freshVars (env, state)
                fun load (x, x') = IL.ASSGN(x, IL.STATE x')
                in
                  (env, IL.CFG.mkBlock (ListPair.map load (stateIn, svars)))
                end
        (* convert the body of the method *)
          val (cfg, env) = cvtBlock ((state, svars), env, [], blk)
        (* add the entry/exit nodes *)
          val entry = IL.Node.mkENTRY ()
          val loadCFG = IL.CFG.prependNode (entry, loadCFG)
          val exit = (case name
                 of StrandUtil.Update => IL.Node.mkACTIVE ()
                  | StrandUtil.Stabilize => IL.Node.mkRETURN []
                (* end case *))
          val body = IL.CFG.concat (loadCFG, cfg)
(*DEBUG**val _ = prEnv (StrandUtil.nameToString name, env);*)
(* FIXME: the following code doesn't work properly *)
          val body = if IL.Node.hasSucc(IL.CFG.exit body)
                then IL.CFG.concat (body, saveStrandState (env, (state, svars), exit))
                else IL.CFG{entry = IL.CFG.entry body, exit = exit}
          in
            IL.Method{
                name = name,
                body = body
              }
          end
(*DEBUG*)handle ex => (print(concat["error in cvtMethod(", StrandUtil.nameToString name, ", ...)\n"]); raise ex)

  (* convert the initially code *)
    fun cvtInitially (env, S.Initially{isArray, rangeInit, create, iters}) = let
          val S.C_Create{argInit, name, args} = create
          fun cvtIter ({param, lo, hi}, (env, iters)) = let
                val param' = newVar param
                val env = VMap.insert (env, param, param')
                val iter = (param', lookup env lo, lookup env hi)
                in
                  (env, iter::iters)
                end
          val (cfg, env) = cvtFragmentBlock (env, rangeInit)
          val (env, iters) = List.foldl cvtIter (env, []) iters
          val (argInitCFG, env) = cvtFragmentBlock (env, argInit)
          in
            IL.Initially{
                isArray = isArray,
                rangeInit = cfg,
                iters = List.rev iters,
                create = (argInitCFG, name, List.map (lookup env) args)
              }
          end

  (* check strands for properties *)
    fun checkProps strands = let
          val hasDie = ref false
          val hasNew = ref false
          fun chkStm e = (case e
                 of S.S_IfThenElse(_, b1, b2) => (chkBlk b1; chkBlk b2)
                  | S.S_New _ => (hasNew := true)
                  | S.S_Die => (hasDie := true)
                  | _ => ()
              (* end case *))
          and chkBlk (S.Block body) = List.app chkStm body
          fun chkStrand (S.Strand{stateInit, methods, ...}) = let
                fun chkMeth (S.Method(_, body)) = chkBlk body
                in
                  chkBlk stateInit;
                  List.app chkMeth methods
                end
          fun condCons (x, v, l) = if !x then v::l else l
          in
            List.app chkStrand strands;
            condCons (hasDie, StrandUtil.StrandsMayDie,
            condCons (hasNew, StrandUtil.NewStrands, []))
          end

    fun cvtInputs inputs = let
          fun cvt ((x, inp), (env, stms)) = let
                val x' = newVar x
                val stm = IL.ASSGN(x', IL.OP(Op.Input(Inputs.map cvtTy inp), []))
                in
                  (VMap.insert(env, x, x'), stm::stms)
                end
          val (env, stms) = List.foldr cvt (VMap.empty, []) inputs
          in
            (IL.CFG.mkBlock stms, env)
          end

  (* gather the top-level definitions in a block.  This is a hack that is used to make all
   * of the globally defined variables visible to the rest of the program (including intermediate
   * results) so that later transforms (e.g., field normalization) will work.  Eventually the
   * variable analysis phase ought to clean things up.
   *)
    fun definedVars (IL.CFG{entry, ...}) = let
          fun gather (nd, vars) = (case IL.Node.kind nd
                 of IL.NULL => vars
                  | IL.ENTRY{succ, ...} => gather(!succ, vars)
                  | IL.COND{trueBranch, ...} => let
                      val (phis, succ) = findJoin (!trueBranch)
                      val vars = List.foldl (fn ((x, _), vars) => x::vars) vars (!phis)
                      in
                        gather (succ, vars)
                      end
                  | IL.COM{succ, ...} => gather (!succ, vars)
                  | IL.ASSIGN{stm=(x, _), succ, ...} => gather(!succ, x::vars)
                  | IL.MASSIGN{stm=(xs, _, _), succ, ...} => gather(!succ, xs@vars)
                  | _ => raise Fail("gather: unexpected " ^ IL.Node.toString nd)
                (* end case *))
          and findJoin nd = (case IL.Node.kind nd
                 of IL.JOIN{phis, succ, ...} => (phis, !succ)
                  | IL.COND{trueBranch, ...} => findJoin (#2 (findJoin (!trueBranch)))
                  | IL.COM{succ, ...} => findJoin (!succ)
                  | IL.ASSIGN{succ, ...} => findJoin (!succ)
                  | IL.MASSIGN{succ, ...} => findJoin (!succ)
                  | _ => raise Fail("findJoin: unexpected " ^ IL.Node.toString nd)
                (* end case *))
          in
            List.rev (gather (entry, []))
          end

    fun translate (S.Program{props, inputs, globals, globalInit, init, strands, ...}) = let
          val (globalInit, env) = let
                val (inputBlk, inputEnv) = cvtInputs inputs
                val (globBlk, env) = cvtBlock (([], []), inputEnv, [], globalInit)
                val cfg = IL.CFG.prependNode (IL.Node.mkENTRY(), inputBlk)
                val cfg = IL.CFG.concat(cfg, globBlk)
                val exit = IL.Node.mkRETURN(VMap.listItems inputEnv @ definedVars globBlk)
                val cfg = IL.CFG.concat (cfg, IL.CFG{entry = exit, exit = exit})                
                in
                  (cfg, env)
                end
        (* construct a reduced environment that just defines the globals (including inputs). *)
          val env = let
                val lookup = lookup env
                fun cvtVar (x, env) = VMap.insert(env, x, lookup x)
                val env = List.foldl (fn ((x, _), env) => cvtVar(x, env)) VMap.empty inputs
                val env = List.foldl cvtVar env globals
                in
                  env
                end
          val init = cvtInitially (env, init)
          fun cvtStrand (S.Strand{name, params, state, stateInit, methods}) = let
              (* extend the global environment with the strand's parameters *)
                val (env, params) = let
                      fun cvtParam (x, (env, xs)) = let
                            val x' = newVar x
                            in
                              (VMap.insert(env, x, x'), x'::xs)
                            end
                      val (env, params) = List.foldl cvtParam (env, []) params
                      in
                        (env, List.rev params)
                      end
              (* create the state variables *)
                val svars = let
                      fun newSVar x = IL.StateVar.new (
                            SimpleVar.kindOf x = S.StrandOutputVar,
                            SimpleVar.nameOf x, cvtTy(SimpleVar.typeOf x))
                      in
                        List.map newSVar state
                      end
              (* convert the state initialization code *)
                val (stateInit, env) = let
                      fun mkExit env = saveStrandState (env, (state, svars), IL.Node.mkSINIT())
                      in
                        cvtTopLevelBlock (env, stateInit, mkExit)
                      end
                fun cvtMeth (S.Method(name, blk)) = cvtMethod (env, name, state, svars, blk)
                in
                  IL.Strand{
                      name = name,
                      params = params,
                      state = svars,
                      stateInit = stateInit,
                      methods = List.map cvtMeth methods
                    }
                end
          val prog = IL.Program{
(* FIXME: we should just use the properties from the Simple program *)
                  props = checkProps strands,
                  globalInit = globalInit,
                  initially = init,
                  strands = List.map cvtStrand strands
                }
          in
            Census.init prog;
            prog
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0