Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/cxx-util/tree-to-cxx.sml
ViewVC logotype

View of /branches/vis15/src/compiler/cxx-util/tree-to-cxx.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3917 - (download) (annotate)
Sat May 28 16:41:39 2016 UTC (3 years, 5 months ago) by jhr
File size: 15644 byte(s)
  Working on merge: code generation
(* tree-to-cxx.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *
 * Translate TreeIR to the C++ version of CLang.
 *)


structure TreeToCxx : sig

    val trType : CodeGenEnv.t * TreeTypes.t -> CLang.ty

    val trBlock : CodeGenEnv.t * TreeIR.block -> CLang.stm

    val trExp : CodeGenEnv.t * TreeIR.exp -> CLang.exp

  (* translate an expression to a variable form; return the variable (as an expression)
   * and the (optional) declaration.
   *)
    val trExpToVar : CodeGenEnv.t * CLang.ty * string * TreeIR.exp -> CLang.exp * CLang.stm list

    val trAssign : CodeGenEnv.t * CLang.exp * TreeIR.exp -> CLang.stm list

  (* generate code to register an error message (require that a world pointer "wrld" is in scope) *)
    val errorMsgAdd : CLang.exp -> CLang.stm

    val trParam : CodeGenEnv.t -> TreeIR.var -> CLang.param

  end = struct

    structure CL = CLang
    structure IR = TreeIR
    structure Op = TreeOps
    structure Ty = TreeTypes
    structure V = TreeVar
    structure Env = CodeGenEnv

    fun trType (env, ty) = (case ty
	   of Ty.BoolTy => CL.boolTy
	    | Ty.StringTy => CL.T_Named "std::string"
	    | Ty.IntTy => Env.intTy env
	    | (Ty.VecTy(1, 1)) => Env.realTy env
	    | (Ty.VecTy(d, _)) => CL.T_Named("vec" ^ Int.toString d)
	    | (Ty.TupleTy tys) => raise Fail "FIXME: TupleTy"
	    | (Ty.TensorTy dd) => CL.T_Array(Env.realTy env, SOME(List.foldl Int.* 1 dd))
	    | (Ty.SeqTy(t, NONE)) => CL.T_Template("diderot::dynseq", [trType(env, t)])
	    | (Ty.SeqTy(t, SOME n)) => CL.T_Array(trType(env, t), SOME n)
	    | (Ty.ImageTy info) =>
		CL.T_Template(
		  concat["diderot::image", Int.toString(ImageInfo.dim info), "d"],
		  [Env.realTy env])
	    | (Ty.StrandTy name) => CL.T_Named("strand_" ^ Atom.toString name)
	  (* end case *))

  (* translate a local variable that occurs in an l-value context *)
    fun lvalueVar (env, x) = CL.mkVar(Env.lookup(env, x))
  (* translate a variable that occurs in an r-value context *)
    fun rvalueVar (env, x) = CL.mkVar(Env.lookup(env, x))

  (* translate a global variable that occurs in an l-value context *)
    fun lvalueGlobalVar (env, x) = CL.mkIndirect(CL.mkVar(Env.global env), TreeGlobalVar.name x)
  (* translate a global variable that occurs in an r-value context *)
    val rvalueGlobalVar = lvalueGlobalVar

  (* translate a strand state variable that occurs in an l-value context *)
    fun lvalueStateVar (env, x) = CL.mkIndirect(CL.mkVar(Env.selfOut env), TreeStateVar.name x)
  (* translate a strand state variable that occurs in an r-value context *)
    fun rvalueStateVar (env, x) = CL.mkIndirect(CL.mkVar(Env.selfIn env), TreeStateVar.name x)

  (* generate new variables *)
    local
      val count = ref 0
      fun freshName prefix = let
            val n = !count
            in
              count := n+1;
              concat[prefix, "_", Int.toString n]
            end
    in
    fun tmpVar () = freshName "tmp"
    fun freshVar prefix = freshName prefix
    end (* local *)

  (* integer literal expression *)
    fun intExp (i : int) = CL.mkInt(IntInf.fromInt i)

    val zero = CL.mkInt 0

    fun addrOf e = CL.mkUnOp(CL.%&, e)

  (* make an application of a function from the "std" namespace *)
    fun mkStdApply (f, args) = CL.mkApply("std::" ^ f, args)

  (* make an application of a function from the "diderot" namespace *)
    fun mkDiderotApply (f, args) = CL.mkApply("diderot::" ^ f, args)
    fun mkDiderotCall (f, args) = CL.mkCall("diderot::" ^ f, args)

  (* Translate a TreeIR operator application to a CLang expression *)
    fun trOp (env, rator, args) = (case (rator, args)
	   of (Op.IAdd, [a, b]) => CL.mkBinOp(a, CL.#+, b)
	    | (Op.ISub, [a, b]) => CL.mkBinOp(a, CL.#-, b)
	    | (Op.IMul, [a, b]) => CL.mkBinOp(a, CL.#*, b)
	    | (Op.IDiv, [a, b]) => CL.mkBinOp(a, CL.#/, b)
	    | (Op.IMod, [a, b]) => CL.mkBinOp(a, CL.#%, b)
	    | (Op.INeg, [a]) => CL.mkUnOp(CL.%-, a)
	    | (Op.RAdd, [a, b]) => CL.mkBinOp(a, CL.#+, b)
	    | (Op.RSub, [a, b]) => CL.mkBinOp(a, CL.#-, b)
	    | (Op.RMul, [a, b]) => CL.mkBinOp(a, CL.#*, b)
	    | (Op.RDiv, [a, b]) => CL.mkBinOp(a, CL.#/, b)
	    | (Op.RNeg, [a]) => CL.mkUnOp(CL.%-, a)
	    | (Op.RClamp, [a, b, c]) => CL.mkApply("clamp", [a, b, c])
	    | (Op.RLerp, [a, b, c]) => CL.mkApply("lerp", [a, b, c])
	    | (Op.RCeiling, [a]) => mkStdApply("ceil", [a])
	    | (Op.RFloor, [a]) => mkStdApply("floor", [a])
	    | (Op.RRound, [a]) => mkStdApply("round", [a])
	    | (Op.RTrunc, [a]) => mkStdApply("trunc", [a])
	    | (Op.RealToInt, [a]) => mkStdApply("lround", [a])
            | (Op.LT ty, [a, b]) => CL.mkBinOp(a, CL.#<, b)
            | (Op.LTE ty, [a, b]) => CL.mkBinOp(a, CL.#<=, b)
            | (Op.EQ ty, [a, b]) => CL.mkBinOp(a, CL.#==, b)
            | (Op.NEQ ty, [a, b]) => CL.mkBinOp(a, CL.#!=, b)
            | (Op.GTE ty, [a, b]) => CL.mkBinOp(a, CL.#>=, b)
            | (Op.GT ty, [a, b]) => CL.mkBinOp(a, CL.#>, b)
            | (Op.Not, [a]) => CL.mkUnOp(CL.%!, a)
	    | (Op.Abs ty, args) => mkStdApply("abs", args)
	    | (Op.Max ty, args) => mkStdApply("min", args)
	    | (Op.Min ty, args) => mkStdApply("max", args)
	    | (Op.VAdd d, [a, b]) => CL.mkBinOp(a, CL.#+, b)
	    | (Op.VSub d, [a, b]) => CL.mkBinOp(a, CL.#-, b)
	    | (Op.VScale d, [a, b]) => CL.mkApply("vscale", [a, b])
	    | (Op.VMul d, [a, b]) => CL.mkBinOp(a, CL.#*, b)
	    | (Op.VNeg d, [a]) => CL.mkUnOp(CL.%-, a)
	    | (Op.VSum d, [a]) => CL.mkApply("vsum", [a])
	    | (Op.VIndex(w, p, i), [a]) => CL.mkSubscript(a, intExp i)
	    | (Op.VClamp d, [a, b, c]) => CL.mkApply("vclamp", [a, b, c])
	    | (Op.VMapClamp d, [a, b, c]) => CL.mkApply("vclamp", [a, b, c])
	    | (Op.VLerp d, [a, b, c]) => CL.mkApply("vlerp", [a, b, c])
	    | (Op.VCeiling d, [a]) => CL.mkApply("vceiling", [a])
	    | (Op.VFloor d, [a]) => CL.mkApply("vfloor", [a])
	    | (Op.VRound d, [a]) => CL.mkApply("vround", [a])
	    | (Op.VTrunc d, [a]) => CL.mkApply("vtrunc", [a])
	    | (Op.VToInt d, [a]) => CL.mkApply("vtoi", [a])
	    | (Op.TensorIndex(Ty.TensorTy(_::dd), idxs), [a]) => let
	      (* dimensions/indices are slowest to fastest *)
		fun index ([], [i], acc) = acc + i
		  | index (d::dd, i::ii, acc) = index (dd, ii, d * (acc + i))
		in
		  CL.mkSubscript(a, intExp(index (dd, idxs, 0)))
		end
	    | (Op.ProjectLast(Ty.TensorTy(_::dd), idxs), [a]) => let
	      (* dimensions/indices are slowest to fastest *)
		fun index ([], [], acc) = acc
		  | index (d::dd, i::ii, acc) = index (dd, ii, d * (acc + i))
		in
		  CL.mkAddrOf(CL.mkSubscript(a, intExp(index (dd, idxs, 0))))
		end
	    | (Op.EigenVals2x2, [a]) => raise Fail "FIXME: EigenVals2x2"
	    | (Op.EigenVals3x3, [a]) => raise Fail "FIXME: EigenVals3x3"
	    | (Op.Select(ty, i), [a]) => raise Fail "FIXME: Select"
	    | (Op.Subscript ty, [a, b]) => CL.mkSubscript(a, b)
	    | (Op.MkDynamic(ty, i), [a]) => raise Fail "FIXME: MkDynamic"
	    | (Op.Append ty, [a, b]) => raise Fail "FIXME: Append"
	    | (Op.Prepend ty, [a, b]) => raise Fail "FIXME: Prepend"
	    | (Op.Concat ty, [a, b]) => raise Fail "FIXME: Concat"
	    | (Op.Range, [a, b]) => raise Fail "FIXME: Range"
	    | (Op.Length ty, [a]) => raise Fail "FIXME: Length"
	    | (Op.SphereQuery(ty1, ty2), []) => raise Fail "FIXME: SphereQuery"
	    | (Op.Sqrt, [a]) => mkStdApply("sqrt", [a])
	    | (Op.Cos, [a]) => mkStdApply("cos", [a])
	    | (Op.ArcCos, [a]) => mkStdApply("acos", [a])
	    | (Op.Sin, [a]) => mkStdApply("sin", [a])
	    | (Op.ArcSin, [a]) => mkStdApply("asin", [a])
	    | (Op.Tan, [a]) => mkStdApply("tan", [a])
	    | (Op.ArcTan, [a]) => mkStdApply("atan", [a])
	    | (Op.Exp, [a]) => mkStdApply("exp", [a])
	    | (Op.IntToReal, [a]) => CL.mkStaticCast(Env.realTy env, a)
(*
	    | R_All of ty
	    | R_Exists of ty
	    | R_Max of ty
	    | R_Min of ty
	    | R_Sum of ty
	    | R_Product of ty
	    | R_Mean of ty
	    | R_Variance of ty
*)
	    | (Op.Transform info, [img]) => CL.mkDispatch(img, "world2image", [])
	    | (Op.Translate info, [img]) => CL.mkDispatch(img, "translate", [])
	    | (Op.BaseAddress info, [img]) => CL.mkDispatch(img, "base_addr", [])
	    | (Op.ControlIndex(info, ctl, d), [img, idx]) =>
		CL.mkDispatch(img, IndexCtl.toString ctl, [intExp d, idx])
	    | (Op.Inside(info, s), [pos, img]) => CL.mkDispatch(img, "inside", [pos, intExp s])
	    | (Op.ImageDim(info, i), [img]) => CL.mkDispatch(img, "size", [intExp i])
	    | (Op.MathFn f, args) => mkStdApply(MathFns.toString f, args)
	    | _ => raise Fail(concat[
		   "unknown or incorrect operator ", Op.toString rator
		 ])
	  (* end case *))

    fun trExp (env, e) = (case e
           of IR.E_Global x => rvalueGlobalVar (env, x)
            | IR.E_State(NONE, x) => rvalueStateVar (env, x)
            | IR.E_State(SOME e, x) => CL.mkIndirect(trExp(env, e), TreeStateVar.name x)
            | IR.E_Var x => rvalueVar (env, x)
            | IR.E_Lit(Literal.Int n) => CL.mkIntTy(n, Env.intTy env)
            | IR.E_Lit(Literal.Bool b) => CL.mkBool b
            | IR.E_Lit(Literal.Real f) => CL.mkFlt(f, Env.realTy env)
            | IR.E_Lit(Literal.String s) => CL.mkStr s
            | IR.E_Op(rator, args) => trOp (env, rator, trExps(env, args))
	    | IR.E_Vec(w, pw, args) => let
		val args = trExps (env, args)
		val args = if (w < pw) then args @ List.tabulate(pw-w, fn _ => zero) else args
		in
		  CL.mkVec(CL.T_Named("vec" ^ Int.toString pw), args)
		end
            | IR.E_Cons(args, Ty.TensorTy shape) => raise Fail "unexpected E_Cons"
            | IR.E_Seq(args, ty) => raise Fail "unexpected E_Seq"
	    | IR.E_Pack(layout, args) => raise Fail "unexpected E_Pack"
(* FIXME: check if e is aligned and use "vload_aligned" in that case *)
	    | IR.E_VLoad(layout, e, i) =>
		CL.mkTemplateApply("vload",
		  [trType(env, Ty.nthVec(layout, i))],
		  [CL.mkBinOp(trExp(env, e), CL.#+, intExp(Ty.offsetOf(layout, i)))])
	    | _ => raise Fail "trExp"
          (* end case *))

    and trExps (env, exps) = List.map (fn exp => trExp(env, exp)) exps

(* QUESTION: not sure that we need this function? *)
    fun trExpToVar (env, ty, name, exp) = (case trExp (env, exp)
	   of e as CL.E_Var _ => (e, [])
	    | e => let
		val x = freshVar name
		in
		  (CL.mkVar x, [CL.mkDeclInit(ty, x, e)])
		end
	  (* end case *))

    fun trAssign (env, lhs, rhs) = let
	  fun trArg (i, arg) = CL.mkAssign(CL.mkSubscript(lhs, intExp i), trExp (env, arg))
	  in
	    case rhs
	     of IR.E_Pack(_, args) => [CL.mkCall ("vpack", List.map (fn e => trExp(env, e)) args)]
	      | IR.E_Cons(args, _) => List.mapi trArg args
	      | IR.E_Seq(args, _) => List.mapi trArg args
	      | _ => [CL.mkAssign(lhs, trExp (env, rhs))]
	    (* end case *)
	  end

    fun trDecl (env, ty, lhs, rhs) = let
	  fun trArgs args = CL.mkDecl(
		ty, lhs, SOME(CL.I_Exps(List.map (fn arg => CL.I_Exp(trExp (env, arg))) args)))
	  in
	    case rhs
	     of IR.E_Cons(args, _) => trArgs args
	      | IR.E_Seq(args, _) => trArgs args
	      | _ => CL.mkDeclInit(ty, lhs, trExp (env, rhs))
	    (* end case *)
	  end

    fun trMultiAssign (env, lhs, IR.E_Op(rator, args)) = (case (lhs, rator, args)
           of ([vals, vecs], Op.EigenVecs2x2, [exp]) =>
		mkDiderotCall("eigenvecs", [trExp (env, exp), vals, vecs])
            | ([vals, vecs], Op.EigenVecs3x3, [exp]) =>
		mkDiderotCall("eigenvecs", [trExp (env, exp), vals, vecs])
            | _ => raise Fail "bogus multi-assignment"
          (* end case *))
      | trMultiAssign (env, lhs, rhs) = raise Fail "bogus multi-assignment"

    fun trStms (env, stms : TreeIR.stm list) = let
          fun trStm (stm, (env, stms : CL.stm list)) = (case stm
                 of IR.S_Comment text => (env, CL.mkComment text :: stms)
                  | IR.S_Assign(true, x, exp) => let
		      val ty = trType (env, V.ty x)
		      val x' = V.name x
		      val env = Env.insert (env, x, x')
		      in
			(env, trDecl (env, ty, x', exp) :: stms)
		      end
                  | IR.S_Assign(false, x, exp) => let
		      val stms' = trAssign (env, lvalueVar (env, x), exp)
		      in
			(env, stms' @ stms)
		      end
		  | IR.S_MAssign(xs, exp) =>
                      (env, trMultiAssign (env, List.map (fn x => lvalueVar (env, x)) xs, exp) :: stms)
                  | IR.S_GAssign(x, exp) =>
		      (env, trAssign (env, lvalueGlobalVar (env, x), exp) @ stms)
                  | IR.S_IfThen(cond, thenBlk) =>
                      (env, CL.mkIfThen(trExp(env, cond), trBlock(env, thenBlk)) :: stms)
                  | IR.S_IfThenElse(cond, thenBlk, elseBlk) => let
		      val stm = CL.mkIfThenElse(trExp(env, cond),
			    trBlock(env, thenBlk),
			    trBlock(env, elseBlk))
		      in
			(env, stm :: stms)
		      end
		  | IR.S_Foreach(x, IR.E_Op(Op.Range, [lo, hi]), blk) => let
		      val x' = V.name x
		      val env' = Env.insert (env, x, x')
		      val (hi', hiInit) = if CodeGenUtil.isSimple hi
			    then (trExp(env, hi), [])
			    else let
			      val hi' = freshVar "hi"
			      in
				(CL.mkVar hi', [CL.mkDeclInit(CL.int32, hi', trExp(env, hi))])
			      end
		      val loop = CL.mkFor(
			    [(CL.int32, x', trExp(env, lo))],
			    CL.mkBinOp(CL.mkVar x', CL.#<=, hi'),
			    [CL.mkUnOp(CL.%++, CL.mkVar x')],
			    trBlock (env', blk))
		      in
			(env, hiInit @ loop :: stms)
		      end
		  | IR.S_Foreach(x, e, blk) => raise Fail "Foreach"
                  | IR.S_New(strand, args) => raise Fail "New"
                  | IR.S_Save(x, exp) => (env, trAssign (env, lvalueStateVar(env, x), exp))
                  | IR.S_LoadNrrd(lhs, ty, nrrd) => let
		      val stm = (case ty
			     of APITypes.SeqTy(ty, NONE) =>
				  GenLoadNrrd.loadSeqFromFile (lvalueVar (env, lhs), ty, CL.mkStr nrrd)
			      | APITypes.ImageTy _ =>
				  GenLoadNrrd.loadImage (lvalueVar (env, lhs), CL.mkStr nrrd)
			      | _ => raise Fail(concat[
				    "bogus type ", APITypes.toString ty, " for LoadNrrd"
				  ])
			    (* end case *))
		      in
			(env, stm :: stms)
		      end
                  | IR.S_Input(_, _, _, NONE) => (env, stms)
                  | IR.S_Input(gv, name, _, SOME dflt) =>
                      (env, CL.mkAssign(lvalueGlobalVar (env, gv), trExp(env, dflt)) :: stms)
                  | IR.S_InputNrrd _ => (env, stms)
                  | IR.S_Exit => (env, stms)
		  | IR.S_Print(tys, args) => let
		      val args = List.map (fn e => trExp(env, e)) args
		      val stm = GenPrint.genPrintStm (
			    CL.mkIndirect(CL.mkVar "wrld", "_output"),
			    tys, args)
		      in
			(env, stm::stms)
		      end
                  | IR.S_Active => (env, CL.mkReturn(SOME(CL.mkVar "diderot::kActive")) :: stms)
                  | IR.S_Stabilize => (env, CL.mkReturn(SOME(CL.mkVar "diderot::kStabilize")) :: stms)
                  | IR.S_Die => (env, CL.mkReturn(SOME(CL.mkVar "diderot::kDie")) :: stms)
                (* end case *))
          in
            List.rev (#2 (List.foldl trStm (env, []) stms))
          end

    and trBlock (env, IR.Block{locals, body}) = let
	  fun trLocal (x, (env, dcls)) = let
		val x' = V.name x
		val dcl = CL.mkDecl(trType(env, V.ty x), x', NONE)
		in
		  (Env.insert(env, x, x'), dcl :: dcls)
		end
	  val (env, dcls) = List.foldl trLocal (env, []) (!locals)
          in
            CL.mkBlock (dcls @ trStms (env, body))
          end

    fun errorMsgAdd msg =
          CL.mkCall("biffMsgAdd", [CL.mkIndirect(CL.mkVar "wrld", "_errors"), msg])

    fun trParam env x = let
	  val x' = V.name x
	  in
	    Env.insert (env, x, x');
	    CL.PARAM([], trType(env, V.ty x), x')
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0