Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

View of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 812 - (download) (annotate)
Tue Apr 12 18:01:07 2011 UTC (8 years, 8 months ago) by jhr
File size: 25631 byte(s)
  Added codegen support for vector-matrix, matrix-vector, and matrix-matrix
  multiplication
(* c-target.sml
 *
 * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * Generate C code with SSE 4.2 intrinsics.
 *)

structure CTarget : TARGET =
  struct

    structure CL = CLang
    structure RN = RuntimeNames

    datatype ty = datatype TargetTy.ty

    datatype var = V of (ty * string)

    datatype exp = E of CLang.exp * ty

    type stm = CL.stm

    datatype strand = Strand of {
	name : string,
	tyName : string,
	state : var list ref,
	output : var option ref,	(* the strand's output variable (only one for now) *)
	code : CL.decl list ref
      }

    datatype program = Prog of {
	globals : CL.decl list ref,
	topDecls : CL.decl list ref,
	strands : strand AtomTable.hash_table,
	initially : CL.decl ref
      }

  (* for SSE, we have 128-bit vectors *)
    fun vectorWidth () = !RN.gVectorWid

  (* target types *)
    val boolTy = T_Bool
    val intTy = T_Int
    val realTy = T_Real
    fun vecTy 1 = T_Real
      | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
	  then raise Size
	  else T_Vec n
    fun ivecTy 1 = T_Int
      | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
	  then raise Size
	  else T_IVec n
    fun tensorTy [] = realTy
      | tensorTy [n] = vecTy n
      | tensorTy [d1, d2] = T_Mat(d1, d2)
      | tensorTy dd = raise Fail "FIXME: order > 2 tensor type"
    fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy)
    fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy
    val stringTy = T_String

    val statusTy = CL.T_Named RN.statusTy

  (* convert target types to CLang types *)
    fun cvtTy T_Bool = CLang.T_Named "bool"
      | cvtTy T_String = CL.charPtr
      | cvtTy T_Int = !RN.gIntTy
      | cvtTy T_Real = !RN.gRealTy
      | cvtTy (T_Vec n) = CL.T_Named(RN.vecTy n)
      | cvtTy (T_IVec n) = CL.T_Named(RN.ivecTy n)
      | cvtTy (T_Mat(n,m)) = CL.T_Named(RN.matTy(n,m))
      | cvtTy (T_Image(n, _)) = CL.T_Ptr(CL.T_Named(RN.imageTy n))
      | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty)

  (* report invalid arguments *)
    fun invalid (name, []) = raise Fail("invaild "^name)
      | invalid (name, args) = let
	  fun arg2s (E(e, ty)) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"]
	  val args = String.concatWith ", " (List.map arg2s args)
	  in
	    raise Fail(concat["invalid arguments to ", name, ": ", args])
	  end

  (* helper functions for checking the types of arguments *)
    fun scalarTy T_Int = true
      | scalarTy T_Real = true
      | scalarTy _ = false
    fun numTy T_Int = true
      | numTy T_Real = true
      | numTy (T_Vec _) = true
      | numTy (T_IVec _) = true
      | numTy _ = false

  (* integer literal expression *)
    fun intExp (i : int) = CL.mkInt(IntInf.fromInt i, CL.int32)

    fun newProgram () = (
	  RN.initTargetSpec();
	  Prog{
	      globals = ref [
		CL.D_Verbatim[
		    if !Controls.doublePrecision
		      then "#define DIDEROT_DOUBLE_PRECISION"
		      else "#define DIDEROT_SINGLE_PRECISION",
		    "#include \"Diderot/diderot.h\""
		  ]],
	      topDecls = ref [],
	      strands = AtomTable.mkTable (16, Fail "strand table"),
	      initially = ref(CL.D_Comment["missing initially"])
	    })

  (* register the global initialization part of a program *)
    fun globalInit (Prog{topDecls, ...}, init) = let
	  val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
	  in
	    topDecls := initFn :: !topDecls
	  end

  (* create and register the initially function for a program *)
    fun initially {
	    prog = Prog{strands, initially, ...},
	    isArray : bool,
	    iterPrefix : stm,
	    iters : (var * exp * exp) list,
	    createPrefix : stm,
	    strand : Atom.atom,
	    args : exp list
	  } = let
	  val iterPrefix = (case iterPrefix
		 of CL.S_Block stms => stms
		  | stm => [stm]
		(* end case *))
	  val createPrefix = (case createPrefix
		 of CL.S_Block stms => stms
		  | stm => [stm]
		(* end case *))
	  val name = Atom.toString strand
	  val nDims = List.length iters
	  val worldTy = CL.T_Ptr(CL.T_Named RN.worldTy)
	  fun mapi f xs = let
		fun mapf (_, []) = []
		  | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs)
		in
		  mapf (0, xs)
		end
	  val baseInit = mapi (fn (i, (_, E(e, _), _)) => (i, CL.I_Exp e)) iters
	  val sizeInit = mapi
		(fn (i, (V(ty, _), E(lo, _), E(hi, _))) =>
		    (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, cvtTy ty))))
		) iters
	  val allocCode = [
		  CL.S_Comment["allocate initial block of strands"],
		  CL.S_Decl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
		  CL.S_Decl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
		  CL.S_Decl(worldTy, "wrld",
		    SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [
			CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)),
			CL.E_Bool isArray,
			CL.E_Int(IntInf.fromInt nDims, CL.int32),
			CL.E_Var "base",
			CL.E_Var "size"
		      ]))))
		]
	(* create the loop nest for the initially iterations *)
	  val indexVar = "ix"
	  val strandTy = CL.T_Ptr(CL.T_Named(RN.strandTy name))
	  fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
		  CL.S_Decl(strandTy, "sp",
		    SOME(CL.I_Exp(
		      CL.E_Cast(strandTy,
		      CL.E_Apply(RN.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
		  CL.S_Call(RN.strandInit name, CL.E_Var "sp" :: List.map (fn (E(e, _)) => e) args),
		  CL.S_Assign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
		])
	    | mkLoopNest ((V(ty, param), E(lo,_), E(hi, _))::iters) = let
		val body = mkLoopNest iters
		in
		  CL.S_For(
		    [(cvtTy ty, param, lo)],
		    CL.mkBinOp(CL.E_Var param, CL.#<=, hi),
		    [CL.mkPostOp(CL.E_Var param, CL.^++)],
		    body)
		end
	  val iterCode = [
		  CL.S_Comment["initially"],
		  CL.S_Decl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
		  mkLoopNest iters
		]
	  val body = CL.mkBlock(iterPrefix @ allocCode @ iterCode @ [CL.S_Return(SOME(CL.E_Var "wrld"))])
	  val initFn = CL.D_Func([], worldTy, RN.initially, [], body)
	  in
	    initially := initFn
	  end

    structure Var =
      struct
	fun global (Prog{globals, ...}, ty, name) = (
	      globals := CL.D_Var([], cvtTy ty, name, NONE) :: !globals;
	      V(ty, name))
	fun param (ty, name) = V(ty, name)
	fun state (Strand{state, ...}, ty, name) = (
	      state := V(ty, name) :: !state;
	      V(ty, name))
	fun var (ty, name) = V(ty, name)
	local
	  val count = ref 0
	  fun freshName prefix = let
		val n = !count
		in
		  count := n+1;
		  concat[prefix, "_", Int.toString n]
		end
	in
	fun tmp ty = V(ty, freshName "tmp")
	fun fresh prefix = freshName prefix
	end (* local *)
      end

  (* expression construction *)
    structure Expr =
      struct
      (* return true if the given expression from is allowed as a subexpression *)
	fun allowedInline _ = true (* FIXME *)

      (* variable references *)
	fun global (V(ty, x)) = E(CL.mkVar x, ty)
	fun getState (V(ty, x)) = E(CL.mkIndirect(CL.mkVar "selfIn", x), ty)
	fun param (V(ty, x)) = E(CL.mkVar x, ty)
	fun var (V(ty, x)) = E(CL.mkVar x, ty)

      (* literals *)
	fun intLit n = E(CL.mkInt(n, !RN.gIntTy), intTy)
	fun floatLit f = E(CL.mkFlt(f, !RN.gRealTy), realTy)
	fun stringLit s = E(CL.mkStr s, stringTy)
	fun boolLit b = E(CL.mkBool b, boolTy)

      (* select from a vector.  We have to cast to the corresponding union type and then
       * select from the array field.
       *)
	local
	  fun sel (tyCode, field, ty) (i, e, n) =
		if (i < 0) orelse (n <= i)
		  then raise Subscript
		  else let
		    val unionTy = CL.T_Named(concat["union", Int.toString n, !tyCode, "_t"])
		    val e1 = CL.mkCast(unionTy, e)
		    val e2 = CL.mkSelect(e1, field)
		    in
		      E(CL.mkSubscript(e2, intExp i), ty)
		    end
	val selF = sel (RN.gRealSuffix, "r", T_Real)
	val selI = sel (RN.gIntSuffix, "i", T_Int)
	in
	fun ivecIndex (e, d, i) = let val E(e', _) = selI(i, e, d) in e' end
	fun vecIndex (e, d, i) = let val E(e', _) = selF(i, e, d) in e' end
	fun select (i, E(e, T_Vec n)) = selF (i, e, n)
	  | select (i, E(e, T_IVec n)) = selI (i, e, n)
	  | select (_, x) = invalid("select", [x])
	end (* local *)

	fun subscript1 (E(e1, ty), E(e2, T_Int)) = let
	      val (n, tyCode, elemTy, fld) = (case ty
		     of T_Vec n => (n, !RN.gRealSuffix, T_Real, "r")
		      | T_IVec n => (n, !RN.gIntSuffix, T_Int, "i")
		    (* end case *))
	      val unionTy = CL.T_Named(concat["union", Int.toString n, tyCode, "_t"])
	      val vecExp = CL.mkSelect(CL.mkCast(unionTy, e1), fld)
	      in
		E(CL.mkSubscript(vecExp, e2), elemTy)
	      end

	fun subscript2 (E(e1, T_Mat(n,m)), E(e2, T_Int), E(e3, T_Int)) =
	      E(CL.mkSubscript(CL.mkSelect(CL.mkSubscript(e1, e2), "r"), e3), T_Real)

      (* vector (and scalar) arithmetic *)
	local
	  fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1
	  fun binop rator (E(e1, ty1), E(e2, ty2)) =
		if checkTys (ty1, ty2)
		  then E(CL.mkBinOp(e1, rator, e2), ty1)
		  else invalid (
		    concat["binary operator \"", CL.binopToString rator, "\""],
		    [E(e1, ty1), E(e2, ty2)])
	in
	fun add (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#+, e2), ty)
	  | add args = binop CL.#+ args
	fun sub (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#-, e2), ty)
	  | sub args = binop CL.#- args
      (* NOTE: multiplication and division are also used for scaling *)
	fun mul (E(e1, T_Real), E(e2, T_Vec n)) =
	      E(CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n)
	  | mul args = binop CL.#* args
	fun divide (E(e1, T_Vec n), E(e2, T_Real)) = let
	      val E(one, _) = floatLit FloatLit.one
	      in
		E(CL.E_Apply(RN.scale n, [CL.mkBinOp(one, CL.#/, e2), e1]), T_Vec n)
	      end
	  | divide args = binop CL.#/ args
	end (* local *)
	fun neg (E(e, T_Bool)) = raise Fail "invalid argument to neg"
	  | neg (E(e, ty)) = E(CL.mkUnOp(CL.%-, e), ty)

	fun abs (E(e, T_Int)) = E(CL.mkApply("abs", [e]), T_Int)	(* FIXME: not the right type for 64-bit ints *)
	  | abs (E(e, T_Real)) = E(CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real)
	  | abs (E(e, T_Vec n)) = raise Fail "FIXME: Expr.abs"
	  | abs (E(e, T_IVec n)) = raise Fail "FIXME: Expr.abs"
	  | abs _ = raise Fail "invalid argument to abs"

	fun dot (E(e1, T_Vec n1), E(e2, T_Vec n2)) = E(CL.E_Apply(RN.dot n1, [e1, e2]), T_Real)
	  | dot _ = raise Fail "invalid argument to dot"

	fun mulVecMat (E(vec, T_Vec m), E(mat, T_Mat(m', n))) =
	      if (1 < m) andalso (m < 4) andalso (m = m') andalso (m = n)
		then E(CL.E_Apply(RN.mulVecMat(m,n), [vec, mat]), T_Vec n)
		else raise Fail "unsupported vector-matrix multiply"
	  | mulVecMat _ = raise Fail "invalid argument to mulVecMat"

	fun mulMatVec (E(mat, T_Mat(m, n)), E(vec, T_Vec n')) =
	      if (1 < m) andalso (m < 4) andalso (m = n) andalso (n = n')
		then E(CL.E_Apply(RN.mulMatVec(m,n), [mat, vec]), T_Vec n)
		else raise Fail "unsupported matrix-vector multiply"
	  | mulMatVec _ = raise Fail "invalid argument to mulMatVec"

	fun mulMatMat (E(mat1, T_Mat(m, n)), E(mat2, T_Mat(n', p))) =
	      if (1 < m) andalso (m < 4) andalso (m = n) andalso (n = p)
		then E(CL.E_Apply(RN.mulMatMat(m,n,p), [mat1, mat2]), T_Mat(m, p))
		else raise Fail "unsupported matrix-matrix multiply"
	  | mulMatMat _ = raise Fail "invalid argument to mulMatMat"

	fun cross (E(e1, T_Vec 3), E(e2, T_Vec 3)) = E(CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3)
	  | cross _ = raise Fail "invalid argument to cross"

	fun length (E(e, T_Vec n)) = E(CL.E_Apply(RN.length n, [e]), T_Real)
	  | length _ = raise Fail "invalid argument to length"

	fun normalize (E(e, T_Vec n)) = E(CL.E_Apply(RN.normalize n, [e]), T_Vec n)
	  | normalize _ = raise Fail "invalid argument to length"

      (* matrix operations *)
	fun trace (E(e, T_Mat(n,m))) = if (n = m) andalso (1 < n) andalso (m <= 4)
	      then E(CL.E_Apply(RN.trace n, [e]), T_Real)
	      else raise Fail "invalid matrix argument for trace"
	  | trace _ = raise Fail "invalid argument to trace"

      (* comparisons *)
	local
	  fun checkTys (ty1, ty2) =
		(ty1 = ty2) andalso scalarTy ty1
	  fun cmpop rator (E(e1, ty1), E(e2, ty2)) =
		if checkTys (ty1, ty2)
		  then E(CL.mkBinOp(e1, rator, e2), T_Bool)
		  else invalid (
		    concat["compare operator \"", CL.binopToString rator, "\""],
		    [E(e1, ty1), E(e2, ty2)])
	in
	val lt = cmpop CL.#<
	val lte = cmpop CL.#<=
	val equ = cmpop CL.#==
	val neq = cmpop CL.#!=
	val gte = cmpop CL.#>=
	val gt = cmpop CL.#>
	end (* local *)

      (* logical connectives *)
	fun not (E(e, T_Bool)) = E(CL.mkUnOp(CL.%!, e), T_Bool)
	  | not _ = raise Fail "invalid argument to not"
	fun && (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#&&, e2), T_Bool)
	  | && _ = raise Fail "invalid arguments to &&"
	fun || (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#||, e2), T_Bool)
	  | || _ = raise Fail "invalid arguments to ||"

      (* misc functions *)
	local
	  fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1
	  fun binFn f (E(e1, ty1), E(e2, ty2)) =
		if checkTys (ty1, ty2)
		  then E(CL.mkApply(f ty1, [e1, e2]), ty1)
		  else raise Fail "invalid arguments to binary function"
	in
	val min = binFn RN.min
	val max = binFn RN.max
	fun lerp (E(e1, ty1), E(e2, ty2), E(e3, T_Real)) =
	      if (ty1 = ty2)
		then (case ty1
		   of T_Real => E(CL.mkApply(RN.lerp 0, [e1, e2, e3]), T_Real)
		    | T_Vec n => E(CL.mkApply(RN.lerp n, [e1, e2, e3]), ty1)
		    | ty => raise Fail(concat["lerp<", TargetTy.toString ty, "> not supported"])
		  (* end case *))
		else raise Fail "invalid arguments to lerp"
	  | lerp _ = raise Fail "invalid arguments to lerp"
	end (* local *)

      (* rounding *)
	fun trunc (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty)
	fun round (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("round", ty), [e]), ty)
	fun floor (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty)
	fun ceil (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty)

      (* conversions *)
	fun toInt (E(e, T_Real)) = E(CL.mkCast(!RN.gIntTy, e), T_Int)
	  | toInt (E(e, T_Vec n)) = E(CL.mkApply(RN.vecftoi n, [e]), ivecTy n)
	  | toInt e = invalid ("toInt", [e])
	fun toReal (E(e, T_Int)) = E(CL.mkCast(!RN.gRealTy, e), T_Real)
	  | toReal e = invalid ("toReal", [e])

      (* runtime system hooks *)
	fun imageAddr (E(e, T_Image(_, rTy))) = let
	      val cTy = CL.T_Ptr(CL.T_Num rTy)
	      in
		E(CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy)
	      end
	  | imageAddr a = invalid("imageAddr", [a])
	fun getImgData (E(e, T_Ptr rTy)) = let
	      val realTy as CL.T_Num rTy' = !RN.gRealTy
	      val e = CL.E_UnOp(CL.%*, e)
	      in
		if (rTy' = rTy)
		  then E(e, T_Real)
		  else E(CL.E_Cast(realTy, e), T_Real)
	      end
	  | getImgData a = invalid("getImgData", [a])
	fun posToImgSpace (E(img, T_Image(d, _)), E(pos, T_Vec n)) = let
	      val e = CL.mkApply(RN.toImageSpace d, [img, pos])
	      in
		E(e, T_Vec n)
	      end
	  | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b])
	fun inside (E(pos, T_Vec n), E(img, T_Image(d, _)), s) = let
	      val e = CL.mkApply(RN.inside d, [pos, img, intExp s])
	      in
		E(e, T_Bool)
	      end
	  | inside (a, b, _) = invalid("inside", [a, b])

      (* other basis functions *)
	local
	  val basis = [
		  (ILBasis.atan2, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
		  (ILBasis.cos,   fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
		  (ILBasis.pow,   fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
		  (ILBasis.sin,   fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
		  (ILBasis.sqrt,  fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
		  (ILBasis.tan,   fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real)
		]
	  fun mkLookup suffix = let
		val tbl = ILBasis.Tbl.mkTable (16, Fail "basis table")
		fun ins (f, chkTy, resTy) =
		      ILBasis.Tbl.insert tbl
			(f, (ILBasis.toString f ^ suffix, chkTy, resTy))
		in
		  List.app ins basis;
		  ILBasis.Tbl.lookup tbl
		end
	  val fLookup = mkLookup "f"
	  val dLookup = mkLookup ""
	in
	fun apply (f, args) = let
	      val (f', chkArgs, resTy) = if !Controls.doublePrecision then dLookup f else fLookup f
	      in
		case chkArgs args
		 of SOME args => E(CL.mkApply(f', args), resTy)
		  | NONE => raise Fail("invalid arguments for "^ILBasis.toString f)
	      end
	end (* local *)
      end (* Expr *)

  (* statement construction *)
    structure Stmt =
      struct
	val comment = CL.S_Comment
	fun assignState (V(ty, x), E(e, _)) = (case (ty, e)
	       of (T_Mat(n,m), CL.E_Var _) =>
		    CL.mkCall(RN.copyMat(n,m), [CL.mkIndirect(CL.mkVar "selfOut", x), e])
		| _ => CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e)
	      (* end case *))
	fun assign (V(ty, x), E(e, _)) = (case (ty, e)
	       of (T_Mat(n,m), CL.E_Var y) => CL.mkCall(RN.copyMat(n,m), [CL.mkVar x, e])
		| _ => CL.mkAssign(CL.mkVar x, e)
	      (* end case *))
	fun decl (V(ty, x), SOME(E(e, _))) = CL.mkDecl(cvtTy ty, x, SOME(CL.I_Exp e))
	  | decl (V(ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE)
	val block = CL.mkBlock
	fun ifthen (E(e, T_Bool), s1) = CL.mkIfThen(e, s1)
	fun ifthenelse (E(e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2)
	fun for (V(ty, x), E(lo, _), E(hi, _), body) = CL.mkFor(
		[(cvtTy ty, x, lo)],
		CL.mkBinOp(CL.mkVar x, CL.#<=, hi),
		[CL.mkPostOp(CL.mkVar x, CL.^++)],
		body)
      (* special Diderot forms *)
	fun cons (V(T_Vec n, x), args : exp list) =
	      [CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.map (fn E(e, _) => e) args))]
	  | cons (V(T_Mat _, x), args) = let
	      val x = CL.mkVar x
	    (* matrices are represented as arrays of union<d><ty>_t vectors *)
	      fun doRows (_, []) = []
		| doRows (i, E(e, _)::es) =
		    CL.mkAssign(CL.mkSelect(CL.mkSubscript(x, intExp i), "v"),e)
		      :: doRows (i+1, es)
	      in
		doRows (0, args)
	      end    
	  | cons _ = raise Fail "bogus cons"
	fun getImgData (V(T_Vec n, x), E(e, T_Ptr rTy)) = let
	      val addr = Var.fresh "vp"
	      val needsCast = (CL.T_Num rTy <> !RN.gRealTy)
	      fun mkLoad i = let
		    val e = CL.mkSubscript(CL.mkVar addr, intExp i)
		    in
		      if needsCast then CL.mkCast(!RN.gRealTy, e) else e
		    end
	      in [
		CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), addr, SOME(CL.I_Exp e)),
		CL.mkAssign(CL.mkVar x,
		  CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad)))
	      ] end
	  | getImgData _ = raise Fail "bogus getImgData"
	local
	  fun checkSts mkDecl = let
		val sts = Var.fresh "sts"
		in
		  mkDecl sts @
		  [CL.mkIfThen(
		    CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts),
		    CL.mkCall("exit", [intExp 1]))]
		end
	in
	fun loadImage (V(_, lhs), dim, E(name, _)) = checkSts (fn sts => let
	      val imgTy = CL.T_Named(RN.imageTy dim)
	      val loadFn = RN.loadImage dim
	      in [
		CL.S_Decl(
		  statusTy, sts,
		  SOME(CL.I_Exp(CL.E_Apply(loadFn, [name, CL.mkUnOp(CL.%&, CL.E_Var lhs)]))))
	      ] end)
	fun input (V(ty, lhs), name, optDflt) = checkSts (fn sts => let
	      val inputFn = RN.input ty
	      val lhs = CL.E_Var lhs
	      val (initCode, hasDflt) = (case optDflt
		     of SOME(E(e, _)) => ([CL.S_Assign(lhs, e)], true)
		      | NONE => ([], false)
		    (* end case *))
	      val code = [
		    CL.S_Decl(
		      statusTy, sts,
		      SOME(CL.I_Exp(CL.E_Apply(inputFn, [
			  CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
			]))))
		    ]
	      in
		initCode @ code
	      end)
	end (* local *)
	fun exit () = CL.mkReturn NONE
	fun active () = CL.mkReturn(SOME(CL.mkVar RN.kActive))
	fun stabilize () = CL.mkReturn(SOME(CL.mkVar RN.kStabilize))
	fun die () = CL.mkReturn(SOME(CL.mkVar RN.kDie))
      end

    structure Strand =
      struct
	fun define (Prog{strands, ...}, strandId) = let
	      val name = Atom.toString strandId
	      val strand = Strand{
		      name = name,
		      tyName = RN.strandTy name,
		      state = ref [],
		      output = ref NONE,
		      code = ref []
		    }
	      in
		AtomTable.insert strands (strandId, strand);
		strand
	      end

      (* return the strand with the given name *)
	fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId

      (* register the strand-state initialization code.  The variables are the strand
       * parameters.
       *)
	fun init (Strand{name, tyName, code, ...}, params, init) = let
	      val fName = RN.strandInit name
	      val params =
		    CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
		      List.map (fn (V(ty, x)) => CL.PARAM([], cvtTy ty, x)) params
	      val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
	      in
		code := initFn :: !code
	      end

      (* register a strand method *)
	fun method (Strand{name, tyName, code, ...}, methName, body) = let
	      val fName = concat[name, "_", methName]
	      val params = [
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
		    ]
	      val methFn = CL.D_Func(["static"], CL.int32, fName, params, body)
	      in
		code := methFn :: !code
	      end

	fun output (Strand{output, ...}, x) = (case !output
	       of NONE => output := SOME x
		| _ => raise Fail "multiple outputs are not supported yet"
	      (* end case *))
      end (* Strand *)

    fun genStrand (Strand{name, tyName, state, output, code}) = let
	(* the type declaration for the strand's state struct *)
	  val selfTyDef = CL.D_StructDef(
		  List.rev (List.map (fn V(ty, x) => (cvtTy ty, x)) (!state)),
		  tyName)
	(* the print function *)
	  val prFnName = concat[name, "_print"]
	  val prFn = let
		val params = [
		      CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
		    ]
		val SOME(V(ty, x)) = !output
		val outState = CL.mkIndirect(CL.mkVar "self", x)
		val prArgs = (case ty
		       of TargetTy.T_Int => [CL.E_Str(!RN.gIntFormat ^ "\n"), outState]
			| TargetTy.T_Real => [CL.E_Str "%f\n", outState]
			| TargetTy.T_Vec d => let
			    val fmt = CL.E_Str(
				  String.concatWith " " (List.tabulate(d, fn _ => "%f"))
				  ^ "\n")
			    val args = List.tabulate (d, fn i => Expr.vecIndex(outState, d, i))
			    in
			      fmt :: args
			    end
			| TargetTy.T_IVec d => let
			    val fmt = CL.E_Str(
				  String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))
				  ^ "\n")
			    val args = List.tabulate (d, fn i => Expr.ivecIndex(outState, d, i))
			    in
			      fmt :: args
			    end
			| _ => raise Fail("genStrand: unsupported output type " ^ TargetTy.toString ty)
		      (* end case *))
		in
		  CL.D_Func(["static"], CL.voidTy, prFnName, params,
		    CL.S_Call("fprintf", CL.mkVar "outS" :: prArgs))
		end
	(* the strand's descriptor object *)
	  val descI = let
		fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
		in
		  CL.I_Struct[
		      ("name", CL.I_Exp(CL.E_Str name)),
		      ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))),
		      ("update", fnPtr("update_method_t", name ^ "_update")),
		      ("print", fnPtr("print_method_t", prFnName))
		    ]
		end
	  val desc = CL.D_Var([], CL.T_Named RN.strandDescTy, RN.strandDesc name, SOME descI)
	  in
	    selfTyDef :: List.rev (desc :: prFn :: !code)
	  end

  (* generate the table of strand descriptors *)
    fun genStrandTable (ppStrm, strands) = let
	  val nStrands = length strands
	  fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)))
	  fun genInits (_, []) = []
	    | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
	  fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
	  in
	    ppDecl (CL.D_Var([], CL.int32, RN.numStrands,
	      SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
	    ppDecl (CL.D_Var([],
	      CL.T_Array(CL.T_Ptr(CL.T_Named RN.strandDescTy), SOME nStrands),
	      RN.strands,
	      SOME(CL.I_Array(genInits (0, strands)))))
	  end

    fun genSrc (baseName, Prog{globals, topDecls, strands, initially}) = let
	  val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
	  val outS = TextIO.openOut fileName
	  val ppStrm = PrintAsC.new outS
	  fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
	  val strands = AtomTable.listItems strands
	  in
	    List.app ppDecl (List.rev (!globals));
	    List.app ppDecl (List.rev (!topDecls));
	    List.app (fn strand => List.app ppDecl (genStrand strand)) strands;
	    genStrandTable (ppStrm, strands);
	    ppDecl (!initially);
	    PrintAsC.close ppStrm;
	    TextIO.closeOut outS
	  end

  (* FIXME: control flags that should go somewhere else *)
    val debug = ref false
    val verbose = ref true

    fun system cmd = (
	  if !verbose
	    then print(cmd ^ "\n")
	    else ();
	  if OS.Process.isSuccess(OS.Process.system cmd)
	    then ()
	    else raise Fail "error compiling/linking")

    fun compile baseName = let
	  val cFile = OS.Path.joinBaseExt{base=baseName, ext=SOME"c"}
	  val cflags = if !debug
		then Paths.cflags
		else String.concatWith " " ["-NDEBUG", Paths.cflags]
	  val cmd = String.concatWith " " [
		  Paths.cc, "-c", cflags,
		  "-I" ^ Paths.diderotInclude, "-I" ^ Paths.teemInclude,
		  cFile
		]
	  in
	    system cmd
	  end

    fun link baseName = let
	  val objFile = OS.Path.joinBaseExt{base=baseName, ext=SOME"o"}
	  val exeFile = baseName
	  val cmd = String.concatWith " " [
		  Paths.cc, "-o", exeFile, objFile,
		  "-L" ^ Paths.teemLib, "-lteem",
		  OS.Path.concat(Paths.diderotLib, "diderot-lib.o")
		]
	  in
	    system cmd
	  end

    fun generate (baseName, prog) = (
	  genSrc (baseName, prog);
	  compile baseName;
	  link baseName)

  end

structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0