Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

View of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 987 - (download) (annotate)
Tue Apr 26 21:43:17 2011 UTC (8 years, 2 months ago) by jhr
File size: 12026 byte(s)
  Split code to compile and link C code into its own module.
(* c-target.sml
 *
 * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure CTarget : TARGET =
  struct

    structure IL = TreeIL
    structure V = IL.Var
    structure Ty = IL.Ty
    structure CL = CLang
    structure RN = RuntimeNames
    structure ToC = TreeToC

    type var = ToC.var
    type exp = CL.exp
    type stm = CL.stm

    datatype strand = Strand of {
	name : string,
	tyName : string,
	state : var list ref,
	output : (Ty.ty * CL.var) option ref,	(* the strand's output variable (only one for now) *)
	code : CL.decl list ref
      }

    datatype program = Prog of {
	globals : CL.decl list ref,
	topDecls : CL.decl list ref,
	strands : strand AtomTable.hash_table,
	initially : CL.decl ref
      }

    datatype env = ENV of {
	info : env_info,
	vMap : var V.Map.map,
	scope : scope
      }

    and env_info = INFO of {
	prog : program
      }

    and scope
      = NoScope
      | GlobalScope
      | InitiallyScope
      | StrandScope
      | MethodScope of TreeIL.var list	(* method body; vars are state variables *)

  (* the supprted widths of vectors of reals on the target.  For the GNU vector extensions,
   * the supported sizes are powers of two, but float2 is broken.
   * NOTE: we should also consider the AVX vector hardware, which has 256-bit registers.
   *)
    fun vectorWidths () = if !Controls.doublePrecision
	  then [2, 4, 8]
	  else [4, 8]

  (* tests for whether various expression forms can appear inline *)
    fun inlineCons n = (n < 2)		(* vectors are inline, but not matrices *)
    val inlineMatrixExp = false		(* can matrix-valued expressions appear inline? *)

  (* TreeIL to target translations *)
    structure Tr =
      struct
	fun block (ENV{vMap, scope, ...}, blk) = (case scope
	       of MethodScope stateVars => let
		    fun saveState (env, args, stm) =
			  ListPair.foldrEq
			    (fn (x, e, stms) => ToC.trAssign(env, x, e)@stms)
			      [stm]
				(stateVars, args)
		    in
		      ToC.trBlock (vMap, saveState, blk)
		    end
		| _ => ToC.trBlock (vMap, fn _ => raise Fail "unexpected state save", blk)
	      (* end case *))
	fun exp (ENV{vMap, ...}, e) = ToC.trExp(vMap, e)
      end

  (* variables *)
    structure Var =
      struct
	fun global (Prog{globals, ...}, x) = let
	      val x' = V.name x
	      val ty' = ToC.trType(V.ty x)
	      in
		globals := CL.D_Var([], ty', x', NONE) :: !globals;
		ToC.V(ty', x')
	      end
	fun param x = ToC.V(ToC.trType(V.ty x), V.name x)
	fun state (Strand{state, ...}, x) = let
	      val ty' = ToC.trType(V.ty x)
	      val x' = ToC.V(ty', V.name x)
	      in
		state := x' :: !state;
		x'
	      end
      end

  (* environments *)
    structure Env =
      struct
      (* create a new environment *)
	fun new prog = ENV{
		info=INFO{prog = prog},
		vMap = V.Map.empty,
		scope = NoScope
	      }
      (* define the current translation context *)
	fun setScope scope (ENV{info, vMap, ...}) = ENV{info=info, vMap=vMap, scope=scope}
	val scopeGlobal = setScope GlobalScope
	val scopeInitially = setScope InitiallyScope
	fun scopeStrand (env, strand) = setScope StrandScope env
	fun scopeMethod (env, svars) = setScope (MethodScope svars) env
      (* bind a TreeIL varaiable to a target variable *)
	fun bind (ENV{info, vMap, scope}, x, x') = ENV{
		info = info,
		vMap = V.Map.insert(vMap, x, x'),
		scope = scope
	      }
      end

  (* programs *)
    structure Program =
      struct
	fun new () = (
	      RN.initTargetSpec();
	      Prog{
		  globals = ref [
		    CL.D_Verbatim[
			if !Controls.doublePrecision
			  then "#define DIDEROT_DOUBLE_PRECISION"
			  else "#define DIDEROT_SINGLE_PRECISION",
			"#include \"Diderot/diderot.h\""
		      ]],
		  topDecls = ref [],
		  strands = AtomTable.mkTable (16, Fail "strand table"),
		  initially = ref(CL.D_Comment["missing initially"])
		})
      (* register the global initialization part of a program *)
	fun init (Prog{topDecls, ...}, init) = let
	      val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
	      in
		topDecls := initFn :: !topDecls
	      end
      (* create and register the initially function for a program *)
	fun initially {
	      prog = Prog{strands, initially, ...},
	      isArray : bool,
	      iterPrefix : stm,
	      iters : (var * exp * exp) list,
	      createPrefix : stm,
	      strand : Atom.atom,
	      args : exp list
	    } = let
	      val iterPrefix = (case iterPrefix
		     of CL.S_Block stms => stms
		      | stm => [stm]
		    (* end case *))
	      val createPrefix = (case createPrefix
		     of CL.S_Block stms => stms
		      | stm => [stm]
		    (* end case *))
	      val name = Atom.toString strand
	      val nDims = List.length iters
	      val worldTy = CL.T_Ptr(CL.T_Named RN.worldTy)
	      fun mapi f xs = let
		    fun mapf (_, []) = []
		      | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs)
		    in
		      mapf (0, xs)
		    end
	      val baseInit = mapi (fn (i, (_, e, _)) => (i, CL.I_Exp e)) iters
	      val sizeInit = mapi
		    (fn (i, (ToC.V(ty, _), lo, hi)) =>
			(i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, ty))))
		    ) iters
	      val allocCode = [
		      CL.mkComment["allocate initial block of strands"],
		      CL.mkDecl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
		      CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
		      CL.mkDecl(worldTy, "wrld",
			SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [
			    CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)),
			    CL.E_Bool isArray,
			    CL.E_Int(IntInf.fromInt nDims, CL.int32),
			    CL.E_Var "base",
			    CL.E_Var "size"
			  ]))))
		    ]
	    (* create the loop nest for the initially iterations *)
	      val indexVar = "ix"
	      val strandTy = CL.T_Ptr(CL.T_Named(RN.strandTy name))
	      fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
		      CL.mkDecl(strandTy, "sp",
			SOME(CL.I_Exp(
			  CL.E_Cast(strandTy,
			  CL.E_Apply(RN.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
		      CL.mkCall(RN.strandInit name, CL.E_Var "sp" :: args),
		      CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
		    ])
		| mkLoopNest ((ToC.V(ty, param), lo, hi)::iters) = let
		    val body = mkLoopNest iters
		    in
		      CL.mkFor(
			[(ty, param, lo)],
			CL.mkBinOp(CL.E_Var param, CL.#<=, hi),
			[CL.mkPostOp(CL.E_Var param, CL.^++)],
			body)
		    end
	      val iterCode = [
		      CL.mkComment["initially"],
		      CL.mkDecl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
		      mkLoopNest iters
		    ]
	      val body = CL.mkBlock(iterPrefix @ allocCode @ iterCode @ [CL.mkReturn(SOME(CL.E_Var "wrld"))])
	      val initFn = CL.D_Func([], worldTy, RN.initially, [], body)
	      in
		initially := initFn
	      end

      (***** OUTPUT *****)
	fun genStrand (Strand{name, tyName, state, output, code}) = let
	    (* the type declaration for the strand's state struct *)
	      val selfTyDef = CL.D_StructDef(
		      List.rev (List.map (fn ToC.V(ty, x) => (ty, x)) (!state)),
		      tyName)
	    (* the print function *)
	      val prFnName = concat[name, "_print"]
	      val prFn = let
		    val params = [
			  CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
			  CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
			]
		    val SOME(ty, x) = !output
		    val outState = CL.mkIndirect(CL.mkVar "self", x)
		    val prArgs = (case ty
			   of Ty.IVecTy 1 => [CL.E_Str(!RN.gIntFormat ^ "\n"), outState]
			    | Ty.IVecTy d => let
				val fmt = CL.E_Str(
				      String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))
				      ^ "\n")
				val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))
				in
				  fmt :: args
				end
			    | Ty.TensorTy[] => [CL.E_Str "%f\n", outState]
			    | Ty.TensorTy[d] => let
				val fmt = CL.E_Str(
				      String.concatWith " " (List.tabulate(d, fn _ => "%f"))
				      ^ "\n")
				val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))
				in
				  fmt :: args
				end
			    | _ => raise Fail("genStrand: unsupported output type " ^ Ty.toString ty)
			  (* end case *))
		    in
		      CL.D_Func(["static"], CL.voidTy, prFnName, params,
			CL.mkCall("fprintf", CL.mkVar "outS" :: prArgs))
		    end
	    (* the strand's descriptor object *)
	      val descI = let
		    fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
		    in
		      CL.I_Struct[
			  ("name", CL.I_Exp(CL.E_Str name)),
			  ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))),
			  ("update", fnPtr("update_method_t", name ^ "_update")),
			  ("print", fnPtr("print_method_t", prFnName))
			]
		    end
	      val desc = CL.D_Var([], CL.T_Named RN.strandDescTy, RN.strandDesc name, SOME descI)
	      in
		selfTyDef :: List.rev (desc :: prFn :: !code)
	      end

      (* generate the table of strand descriptors *)
	fun genStrandTable (ppStrm, strands) = let
	      val nStrands = length strands
	      fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)))
	      fun genInits (_, []) = []
		| genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
	      fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
	      in
		ppDecl (CL.D_Var([], CL.int32, RN.numStrands,
		  SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
		ppDecl (CL.D_Var([],
		  CL.T_Array(CL.T_Ptr(CL.T_Named RN.strandDescTy), SOME nStrands),
		  RN.strands,
		  SOME(CL.I_Array(genInits (0, strands)))))
	      end

	fun genSrc (baseName, Prog{globals, topDecls, strands, initially}) = let
	      val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
	      val outS = TextIO.openOut fileName
	      val ppStrm = PrintAsC.new outS
	      fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
	      val strands = AtomTable.listItems strands
	      in
		List.app ppDecl (List.rev (!globals));
		List.app ppDecl (List.rev (!topDecls));
		List.app (fn strand => List.app ppDecl (genStrand strand)) strands;
		genStrandTable (ppStrm, strands);
		ppDecl (!initially);
		PrintAsC.close ppStrm;
		TextIO.closeOut outS
	      end

      (* output the code to a file.  The string is the basename of the file, the extension
       * is provided by the target.
       *)
	fun generate (baseName, prog) = (
	      genSrc (baseName, prog);
	      RunCC.compile baseName;
	      RunCC.link baseName)

      end

  (* strands *)
    structure Strand =
      struct
	fun define (Prog{strands, ...}, strandId) = let
	      val name = Atom.toString strandId
	      val strand = Strand{
		      name = name,
		      tyName = RN.strandTy name,
		      state = ref [],
		      output = ref NONE,
		      code = ref []
		    }
	      in
		AtomTable.insert strands (strandId, strand);
		strand
	      end

      (* return the strand with the given name *)
	fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId

      (* register the strand-state initialization code.  The variables are the strand
       * parameters.
       *)
	fun init (Strand{name, tyName, code, ...}, params, init) = let
	      val fName = RN.strandInit name
	      val params =
		    CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
		      List.map (fn (ToC.V(ty, x)) => CL.PARAM([], ty, x)) params
	      val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
	      in
		code := initFn :: !code
	      end

      (* register a strand method *)
	fun method (Strand{name, tyName, code, ...}, methName, body) = let
	      val fName = concat[name, "_", methName]
	      val params = [
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
		      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
		    ]
	      val methFn = CL.D_Func(["static"], CL.int32, fName, params, body)
	      in
		code := methFn :: !code
	      end

	fun output (Strand{output, ...}, ty, ToC.V(_, x)) = output := SOME(ty, x)

      end

  end

structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0