Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

View of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 617 - (download) (annotate)
Sun Mar 13 16:51:09 2011 UTC (8 years, 9 months ago) by jhr
File size: 15754 byte(s)
  Adding for loops to C code generator
(* c-target.sml
 *
 * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * Generate C code with SSE 4.2 intrinsics.
 *)

structure CTarget : TARGET =
  struct

    structure CL = CLang
    structure RN = RuntimeNames

    datatype ty = datatype TargetTy.ty

    datatype strand = Strand of {
	name : string,
	tyName : string,
	state : (ty * string) list ref,
	code : CL.decl list ref
      }

    type var = (ty * string) (* FIXME *)

    type exp = CLang.exp * ty

    type stm = CL.stm

    type method = unit (* FIXME *)

    datatype program = Prog of {
	globals : CL.decl list ref,
	topDecls : CL.decl list ref,
	strands : strand list ref
      }

  (* for SSE, we have 128-bit vectors *)
    fun vectorWidth () = !RN.gVectorWid

  (* target types *)
    val boolTy = T_Bool
    val intTy = T_Int
    val realTy = T_Real
    fun vecTy 1 = T_Real
      | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
	  then raise Size
	  else T_Vec n
    fun ivecTy 1 = T_Int
      | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
	  then raise Size
	  else T_IVec n
    fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy)
    fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy
    val stringTy = T_String

    val statusTy = CL.T_Named RN.statusTy

  (* convert target types to CLang types *)
    fun cvtTy T_Bool = CLang.T_Named "bool"
      | cvtTy T_String = CL.charPtr
      | cvtTy T_Int = !RN.gIntTy
      | cvtTy T_Real = !RN.gRealTy
      | cvtTy (T_Vec n) = CL.T_Named(RN.vecTy n)
      | cvtTy (T_IVec n) = CL.T_Named(RN.ivecTy n)
      | cvtTy (T_Image(n, _)) = CL.T_Ptr(CL.T_Named(RN.imageTy n))
      | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty)

  (* report invalid arguments *)
    fun invalid (name, []) = raise Fail("invaild "^name)
      | invalid (name, args) = let
	  fun arg2s (e, ty) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"]
	  val args = String.concatWith ", " (List.map arg2s args)
	  in
	    raise Fail(concat["invalid arguments to ", name, ": ", args])
	  end

  (* helper functions for checking the types of arguments *)
    fun scalarTy T_Int = true
      | scalarTy T_Real = true
      | scalarTy _ = false
    fun numTy T_Int = true
      | numTy T_Real = true
      | numTy (T_Vec _) = true
      | numTy (T_IVec _) = true
      | numTy _ = false

    fun newProgram () = (
	  RN.initTargetSpec();
	  Prog{
	      globals = ref [
		CL.D_Verbatim[
		    if !Controls.doublePrecision
		      then "#define DIDEROT_DOUBLE_PRECISION"
		      else "#define DIDEROT_SINGLE_PRECISION",
		    "#include \"Diderot/diderot.h\""
		  ]],
	      topDecls = ref [],
	      strands = ref []
	    })

    fun globalInit (Prog{topDecls, ...}, init) = let
	  val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
	  in
	    topDecls := initFn :: !topDecls
	  end

    structure Var =
      struct
	fun global (Prog{globals, ...}, ty, name) = (
	      globals := CL.D_Var([], cvtTy ty, name, NONE) :: !globals;
	      (ty, name))
	fun param (ty, name) = (ty, name)
	fun state (Strand{state, ...}, ty, name) = (
	      state := (ty, name) :: !state;
	      (ty, name))
	fun var (ty, name) = (ty, name)
	local
	  val count = ref 0
	  fun freshName prefix = let
		val n = !count
		in
		  count := n+1;
		  concat[prefix, "_", Int.toString n]
		end
	in
	fun tmp ty = (ty, freshName "tmp")
	fun fresh prefix = freshName prefix
	end (* local *)
      end

  (* expression construction *)
    structure Expr =
      struct
      (* return true if the given expression from is allowed as a subexpression *)
	fun allowedInline _ = true (* FIXME *)

      (* variable references *)
	fun global (ty, x) = (CL.mkVar x, ty)
	fun getState (ty, x) = (CL.mkIndirect(CL.mkVar "selfIn", x), ty)
	fun param (ty, x) = (CL.mkVar x, ty)
	fun var (ty, x) = (CL.mkVar x, ty)

      (* literals *)
	fun intLit n = (CL.mkInt(n, !RN.gIntTy), intTy)
	fun floatLit f = (CL.mkFlt(f, !RN.gRealTy), realTy)
	fun stringLit s = (CL.mkStr s, stringTy)
	fun boolLit b = (CL.mkBool b, boolTy)

      (* select from a vector.  We have to cast to the corresponding union type and then
       * select from the array field.
       *)
	local
	  fun sel (tyCode, field, ty) (i, e, n) =
		if (i < 0) orelse (n <= i)
		  then raise Subscript
		  else let
		    val unionTy = CL.T_Named(concat["union", Int.toString n, !tyCode, "_t"])
		    val e1 = CL.mkCast(unionTy, e)
		    val e2 = CL.mkSelect(e1, field)
		    in
		      (CL.mkSubscript(e2, CL.mkInt(IntInf.fromInt i, CL.int32)), ty)
		    end
	val selF = sel (RN.gRealSuffix, "r", T_Real)
	val selI = sel (RN.gIntSuffix, "i", T_Int)
	in
	fun select (i, (e, T_Vec n)) = selF (i, e, n)
	  | select (i, (e, T_IVec n)) = selI (i, e, n)
	  | select (_, x) = invalid("select", [x])
	end (* local *)

      (* vector (and scalar) arithmetic *)
	local
	  fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1
	  fun binop rator ((e1, ty1), (e2, ty2)) =
		if checkTys (ty1, ty2)
		  then (CL.mkBinOp(e1, rator, e2), ty1)
		  else invalid (
		    concat["binary operator \"", CL.binopToString rator, "\""],
		    [(e1, ty1), (e2, ty2)])
	in
	fun add ((e1, ty as T_Ptr _), (e2, T_Int)) = (CL.mkBinOp(e1, CL.#+, e2), ty)
	  | add args = binop CL.#+ args
	fun sub ((e1, ty as T_Ptr _), (e2, T_Int)) = (CL.mkBinOp(e1, CL.#-, e2), ty)
	  | sub args = binop CL.#- args
      (* NOTE: multiplication and division are also used for scaling *)
	fun mul ((e1, T_Real), (e2, T_Vec n)) =
	      (CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n)
	  | mul args = binop CL.#* args
	fun divide ((e1, T_Vec n), (e2, T_Real)) =
	      (CL.E_Apply(RN.scale n,
		[CL.mkBinOp(#1(floatLit FloatLit.one), CL.#/, e2), e1]), T_Vec n)
	  | divide args = binop CL.#/ args
	end (* local *)
	fun neg (e, T_Bool) = raise Fail "invalid argument to neg"
	  | neg (e, ty) = (CL.mkUnOp(CL.%-, e), ty)

	fun abs (e, T_Int) = (CL.mkApply("abs", [e]), T_Int)	(* FIXME: not the right type for 64-bit ints *)
	  | abs (e, T_Real) = (CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real)
	  | abs (e, T_Vec n) = raise Fail "FIXME: Expr.abs"
	  | abs (e, T_IVec n) = raise Fail "FIXME: Expr.abs"
	  | abs _ = raise Fail "invalid argument to abs"

	fun dot ((e1, T_Vec n1), (e2, T_Vec n2)) = (CL.E_Apply(RN.dot n1, [e1, e2]), T_Real)
	  | dot _ = raise Fail "invalid argument to dot"

	fun cross ((e1, T_Vec 3), (e2, T_Vec 3)) = (CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3)
	  | cross _ = raise Fail "invalid argument to cross"

	fun length (e, T_Vec n) = (CL.E_Apply(RN.length n, [e]), T_Real)
	  | length _ = raise Fail "invalid argument to length"

	fun normalize (e, T_Vec n) = (CL.E_Apply(RN.normalize n, [e]), T_Vec n)
	  | normalize _ = raise Fail "invalid argument to length"

      (* comparisons *)
	local
	  fun checkTys (ty1, ty2) =
		(ty1 = ty2) andalso scalarTy ty1
	  fun cmpop rator ((e1, ty1), (e2, ty2)) =
		if checkTys (ty1, ty2)
		  then (CL.mkBinOp(e1, rator, e2), T_Bool)
		  else invalid (
		    concat["compare operator \"", CL.binopToString rator, "\""],
		    [(e1, ty1), (e2, ty2)])
	in
	val lt = cmpop CL.#<
	val lte = cmpop CL.#<=
	val equ = cmpop CL.#==
	val neq = cmpop CL.#!=
	val gte = cmpop CL.#>=
	val gt = cmpop CL.#>
	end (* local *)

      (* logical connectives *)
	fun not (e, T_Bool) = (CL.mkUnOp(CL.%!, e), T_Bool)
	  | not _ = raise Fail "invalid argument to not"
	fun && ((e1, T_Bool), (e2, T_Bool)) = (CL.mkBinOp(e1, CL.#&&, e2), T_Bool)
	  | && _ = raise Fail "invalid arguments to &&"
	fun || ((e1, T_Bool), (e2, T_Bool)) = (CL.mkBinOp(e1, CL.#||, e2), T_Bool)
	  | || _ = raise Fail "invalid arguments to ||"

	local
	  fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1
	  fun binFn f ((e1, ty1), (e2, ty2)) =
		if checkTys (ty1, ty2)
		  then (CL.mkApply(f ty1, [e1, e2]), ty1)
		  else raise Fail "invalid arguments to binary function"
	in
      (* misc functions *)
	val min = binFn RN.min
	val max = binFn RN.max
	end (* local *)

      (* math functions *)
	fun pow ((e1, T_Real), (e2, T_Real)) =
	      if !Controls.doublePrecision
		then (CL.mkApply("pow", [e1, e2]), T_Real)
		else (CL.mkApply("powf", [e1, e2]), T_Real)
	  | pow _ = raise Fail "invalid arguments to pow"

	local
	  fun r2r (ff, fd) (e, T_Real) = if !Controls.doublePrecision
		then (CL.mkApply(fd, [e]), T_Real)
		else (CL.mkApply(ff, [e]), T_Real)
	    | r2r (_, fd) e = invalid (fd, [e])
	in
	val sin = r2r ("sinf", "sin")
	val cos = r2r ("cosf", "cos")
	val sqrt = r2r ("sqrtf", "sqrt")
	end (* local *)

      (* rounding *)
	fun trunc (e, ty) = (CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty)
	fun round (e, ty) = (CL.mkApply(RN.addTySuffix("round", ty), [e]), ty)
	fun floor (e, ty) = (CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty)
	fun ceil (e, ty) = (CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty)

      (* conversions *)
	fun toInt (e, T_Real) = (CL.mkCast(!RN.gIntTy, e), T_Int)
	  | toInt (e, T_Vec n) = (CL.mkApply(RN.vecftoi n, [e]), ivecTy n)
	  | toInt e = invalid ("toInt", [e])
	fun toReal (e, T_Int) = (CL.mkCast(!RN.gRealTy, e), T_Real)
	  | toReal e = invalid ("toReal", [e])

      (* runtime system hooks *)
	fun imageAddr (e, T_Image(_, rTy)) = let
	      val cTy = CL.T_Ptr(CL.T_Num rTy)
	      in
		(CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy)
	      end
	  | imageAddr a = invalid("imageAddr", [a])
	fun getImgData (e, T_Ptr rTy) = let
	      val realTy as CL.T_Num rTy' = !RN.gRealTy
	      val e = CL.E_UnOp(CL.%*, e)
	      in
		if (rTy' = rTy)
		  then (e, T_Real)
		  else (CL.E_Cast(realTy, e), T_Real)
	      end
	  | getImgData a = invalid("getImgData", [a])
	fun posToImgSpace ((img, T_Image(d, _)), (pos, T_Vec n)) = let
	      val e = CL.mkApply(RN.toImageSpace d, [img, pos])
	      in
		(e, T_Vec n)
	      end
	  | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b])
	fun inside ((pos, T_Vec n), (img, T_Image(d, _)), s) = let
	      val e = CL.mkApply(RN.inside d,
		    [pos, img, CL.mkInt(IntInf.fromInt s, CL.int32)])
	      in
		(e, T_Bool)
	      end
	  | inside (a, b, _) = invalid("inside", [a, b])

      end (* Expr *)

  (* statement construction *)
    structure Stmt =
      struct
	val comment = CL.S_Comment
	fun assignState ((_, x), (e, _)) =
	      CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e)
	fun assign ((_, x), (e, _)) = CL.mkAssign(CL.mkVar x, e)
	fun decl ((ty, x), SOME(e, _)) = CL.mkDecl(cvtTy ty, x, SOME e)
	  | decl ((ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE)
	val block = CL.mkBlock
	fun ifthen ((e, T_Bool), s1) = CL.mkIfThen(e, s1)
	fun ifthenelse ((e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2)
	fun for ((ty, x), (lo, _), (hi, _), body) = CL.mkFor(
		[(cvtTy ty, x, lo)],
		CL.mkBinOp(CL.mkVar x, CL.#<=, hi),
		[CL.mkPostOp(CL.mkVar x, CL.^++)],
		body)
      (* special Diderot forms *)
	fun cons ((T_Vec n, x), args : exp list) =
	      CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.map #1 args))
	  | cons _ = raise Fail "bogus cons"
	fun getImgData ((T_Vec n, x), (e, T_Ptr rTy)) = let
	      val addr = Var.fresh "vp"
	      val needsCast = (CL.T_Num rTy <> !RN.gRealTy)
	      fun mkLoad i = let
		    val e = CL.mkSubscript(CL.mkVar addr, CL.mkInt(IntInf.fromInt i, CL.int32))
		    in
		      if needsCast then CL.mkCast(!RN.gRealTy, e) else e
		    end
	      in [
		CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), addr, SOME e),
		CL.mkAssign(CL.mkVar x,
		  CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad)))
	      ] end
	  | getImgData _ = raise Fail "bogus getImgData"
	local
	  fun checkSts mkDecl = let
		val sts = Var.fresh "sts"
		in
		  mkDecl sts @
		  [CL.mkIfThen(
		    CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts),
		    CL.mkCall("exit", [CL.mkInt(1, CL.int32)]))]
		end
	in
	fun loadImage (lhs : var, dim, name : exp) = checkSts (fn sts => let
	      val imgTy = CL.T_Named(RN.imageTy dim)
	      val loadFn = RN.loadImage dim
	      in [
		CL.S_Decl(
		  statusTy, sts,
		  SOME(CL.E_Apply(loadFn, [#1 name, CL.mkUnOp(CL.%&, CL.E_Var(#2 lhs))])))
	      ] end)
	fun input (lhs : var, name, optDflt) = checkSts (fn sts => let
	      val inputFn = RN.input(#1 lhs)
	      val lhs = CL.E_Var(#2 lhs)
	      val (initCode, hasDflt) = (case optDflt
		     of SOME(e, _) => ([CL.S_Assign(lhs, e)], true)
		      | NONE => ([], false)
		    (* end case *))
	      val code = [
		    CL.S_Decl(
		      statusTy, sts,
		      SOME(CL.E_Apply(inputFn, [
			  CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
			])))
		    ]
	      in
		initCode @ code
	      end)
	end (* local *)
	fun exit () = CL.mkReturn NONE
	fun active () = CL.mkReturn(SOME(CL.mkVar RN.kActive))
	fun stabilize () = CL.mkReturn(SOME(CL.mkVar RN.kStabilize))
	fun die () = CL.mkReturn(SOME(CL.mkVar RN.kDie))
      end

    structure Strand =
      struct
	fun define (Prog{strands, ...}, strandId) = let
	      val strand = Strand{
		      name = strandId,
		      tyName = RN.strandTy strandId,
		      state = ref [],
		      code = ref []
		    }
	      in
		strands := strand :: !strands;
		strand
	      end

      (* register the strand-state initialization code.  The variables are the strand
       * parameters.
       *)
	fun init (Strand{name, tyName, code, ...}, params, init) = let
	      val fName = RN.strandInit name
	      val params =
		    CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
		      List.map (fn (ty, x) => CL.PARAM([], cvtTy ty, x)) params
	      val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
	      in
		code := initFn :: !code
	      end

      (* register a strand method *)
	fun method (Strand{name, tyName, code, ...}, methName, body) = let
	      val fName = concat[name, "_", methName]
	      val params = [
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
		    ]
	      val methFn = CL.D_Func([], CL.int32, fName, params, body)
	      in
		code := methFn :: !code
	      end
      end (* Strand *)

    fun genStrand (Strand{name, tyName, state, code}) = let
	  val selfTyDef = CL.D_StructDef(
		  List.rev (List.map (fn (ty, x) => (cvtTy ty, x)) (!state)),
		  tyName)
	  in
	    selfTyDef :: List.rev (!code)
	  end

  (* generate the table of strand descriptors *)
    fun genStrandTable (ppStrm, strands) = let
	  val nStrands = length strands
	  fun genInit (Strand{name, tyName, code, ...}) = let
		fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
		in
		  CL.I_Struct[
		      ("name", CL.I_Exp(CL.E_Str name)),
		      ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))),
		      ("init", fnPtr("strand_init_t", RN.strandInit name)),
		      ("update", fnPtr("update_method_t", name ^ "_update"))
		    ]
		end
	  fun genInits (_, []) = []
	    | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
	  fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
	  in
	    ppDecl (CL.D_Var([], CL.int32, RN.numStrands,
	      SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
	    ppDecl (CL.D_Var([], CL.T_Array(CL.T_Named RN.strandDescTy, SOME nStrands), RN.strands,
	      SOME(CL.I_Array(genInits (0, strands)))))
	  end

    fun generate (baseName, Prog{globals, topDecls, strands}) = let
	  val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
	  val outS = TextIO.openOut fileName
	  val ppStrm = PrintAsC.new outS
	  fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
	  in
	    List.app ppDecl (List.rev (!globals));
	    List.app ppDecl (List.rev (!topDecls));
(* what about the strands, etc? *)
	    List.app (fn strand => List.app ppDecl (genStrand strand)) (!strands);
	    genStrandTable (ppStrm, !strands);
	    PrintAsC.close ppStrm;
	    TextIO.closeOut outS
	  end

  end

structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0