(* c-target.sml * * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. *) structure CTarget : TARGET = struct structure IL = TreeIL structure V = IL.Var structure Ty = IL.Ty structure CL = CLang structure N = CNames (* variable translation *) structure TrVar = struct type env = CL.typed_var TreeIL.Var.Map.map fun lookup (env, x) = (case V.Map.find (env, x) of SOME(CL.V(_, x')) => x' | NONE => raise Fail(concat["lookup(_, ", V.name x, ")"]) (* end case *)) (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *) fun lvalueVar (env, x) = (case V.kind x of IL.VK_Global => CL.mkVar(lookup(env, x)) | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, x)) | IL.VK_Local => CL.mkVar(lookup(env, x)) (* end case *)) (* translate a variable that occurs in an r-value context *) fun rvalueVar (env, x) = (case V.kind x of IL.VK_Global => CL.mkVar(lookup(env, x)) | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x)) | IL.VK_Local => CL.mkVar(lookup(env, x)) (* end case *)) end structure ToC = TreeToCFn (TrVar) type var = CL.typed_var type exp = CL.exp type stm = CL.stm datatype strand = Strand of { name : string, tyName : string, state : var list ref, output : (Ty.ty * CL.var) option ref, (* the strand's output variable (only one for now) *) code : CL.decl list ref } datatype program = Prog of { name : string, (* stem of source file *) double : bool, (* true for double-precision support *) parallel : bool, (* true for multithreaded (or multi-GPU) target *) debug : bool, (* true for debug support in executable *) globals : CL.decl list ref, topDecls : CL.decl list ref, strands : strand AtomTable.hash_table, initially : CL.decl ref } datatype env = ENV of { info : env_info, vMap : var V.Map.map, scope : scope } and env_info = INFO of { prog : program } and scope = NoScope | GlobalScope | InitiallyScope | StrandScope of TreeIL.var list (* strand initialization *) | MethodScope of TreeIL.var list (* method body; vars are state variables *) (* the supprted widths of vectors of reals on the target. For the GNU vector extensions, * the supported sizes are powers of two, but float2 is broken. * NOTE: we should also consider the AVX vector hardware, which has 256-bit registers. *) fun vectorWidths () = if !N.doublePrecision then [2, 4, 8] else [4, 8] (* tests for whether various expression forms can appear inline *) fun inlineCons n = (n < 2) (* vectors are inline, but not matrices *) val inlineMatrixExp = false (* can matrix-valued expressions appear inline? *) (* TreeIL to target translations *) structure Tr = struct fun fragment (ENV{info, vMap, scope}, blk) = let val (vMap, stms) = ToC.trFragment (vMap, blk) in (ENV{info=info, vMap=vMap, scope=scope}, stms) end fun saveState cxt stateVars (env, args, stm) = ( ListPair.foldrEq (fn (x, e, stms) => ToC.trAssign(env, x, e)@stms) [stm] (stateVars, args) ) handle ListPair.UnequalLengths => ( print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]); raise Fail(concat["saveState ", cxt, ": length mismatch"])) fun block (ENV{vMap, scope, ...}, blk) = (case scope of StrandScope stateVars => ToC.trBlock (vMap, saveState "StrandScope" stateVars, blk) | MethodScope stateVars => ToC.trBlock (vMap, saveState "MethodScope" stateVars, blk) | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk) (* end case *)) fun exp (ENV{vMap, ...}, e) = ToC.trExp(vMap, e) end (* variables *) structure Var = struct fun name (CL.V(_, name)) = name fun global (Prog{globals, ...}, name, ty) = let val ty' = ToC.trType ty in globals := CL.D_Var([], ty', name, NONE) :: !globals; CL.V(ty', name) end fun param x = CL.V(ToC.trType(V.ty x), V.name x) fun state (Strand{state, ...}, x) = let val ty' = ToC.trType(V.ty x) val x' = CL.V(ty', V.name x) in state := x' :: !state; x' end end (* environments *) structure Env = struct (* create a new environment *) fun new prog = ENV{ info=INFO{prog = prog}, vMap = V.Map.empty, scope = NoScope } (* define the current translation context *) fun setScope scope (ENV{info, vMap, ...}) = ENV{info=info, vMap=vMap, scope=scope} val scopeGlobal = setScope GlobalScope val scopeInitially = setScope InitiallyScope fun scopeStrand (env, svars) = setScope (StrandScope svars) env fun scopeMethod (env, svars) = setScope (MethodScope svars) env (* bind a TreeIL varaiable to a target variable *) fun bind (ENV{info, vMap, scope}, x, x') = ENV{ info = info, vMap = V.Map.insert(vMap, x, x'), scope = scope } end (* programs *) structure Program = struct fun new {name, double, parallel, debug} = ( N.initTargetSpec double; Prog{ name = name, double = double, parallel = parallel, debug = debug, globals = ref [ (* NOTE: in reverse order! *) CL.D_Var(["static"], CL.charPtr, "ProgramName", SOME(CL.I_Exp(CL.mkStr name))), CL.D_Verbatim[ if double then "#define DIDEROT_DOUBLE_PRECISION" else "#define DIDEROT_SINGLE_PRECISION", if parallel then "#define DIDEROT_TARGET_PARALLEL" else "#define DIDEROT_TARGET_C", "#include \"Diderot/diderot.h\"" ] ], topDecls = ref [], strands = AtomTable.mkTable (16, Fail "strand table"), initially = ref(CL.D_Comment["missing initially"]) }) (* register the code that is used to register command-line options for input variables *) fun inputs (Prog{topDecls, ...}, stm) = let val inputsFn = CL.D_Func( [], CL.voidTy, N.registerOpts, [CL.PARAM([], CL.T_Ptr(CL.T_Named N.optionsTy), "opts")], stm) in topDecls := inputsFn :: !topDecls end (* register the global initialization part of a program *) fun init (Prog{topDecls, ...}, init) = let val initFn = CL.D_Func([], CL.voidTy, N.initGlobals, [], init) val shutdownFn = CL.D_Func( [], CL.voidTy, N.shutdown, [CL.PARAM([], CL.T_Ptr(CL.T_Named N.worldTy), "wrld")], CL.S_Block[]) in topDecls := shutdownFn :: initFn :: !topDecls end (* create and register the initially function for a program *) fun initially { prog = Prog{name=progName, strands, initially, ...}, isArray : bool, iterPrefix : stm list, iters : (var * exp * exp) list, createPrefix : stm list, strand : Atom.atom, args : exp list } = let val name = Atom.toString strand val nDims = List.length iters val worldTy = CL.T_Ptr(CL.T_Named N.worldTy) fun mapi f xs = let fun mapf (_, []) = [] | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs) in mapf (0, xs) end val baseInit = mapi (fn (i, (_, e, _)) => (i, CL.I_Exp e)) iters val sizeInit = mapi (fn (i, (CL.V(ty, _), lo, hi)) => (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, ty)))) ) iters (* code to allocate the world and initial strands *) val wrld = "wrld" val allocCode = [ CL.mkComment["allocate initial block of strands"], CL.mkDecl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)), CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)), CL.mkDecl(worldTy, wrld, SOME(CL.I_Exp(CL.E_Apply(N.allocInitially, [ CL.mkVar "ProgramName", CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)), CL.E_Bool isArray, CL.E_Int(IntInf.fromInt nDims, CL.int32), CL.E_Var "base", CL.E_Var "size" ])))) ] (* create the loop nest for the initially iterations *) val indexVar = "ix" val strandTy = CL.T_Ptr(CL.T_Named(N.strandTy name)) fun mkLoopNest [] = CL.mkBlock(createPrefix @ [ CL.mkDecl(strandTy, "sp", SOME(CL.I_Exp( CL.E_Cast(strandTy, CL.E_Apply(N.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))), CL.mkCall(N.strandInit name, CL.E_Var "sp" :: args), CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32))) ]) | mkLoopNest ((CL.V(ty, param), lo, hi)::iters) = let val body = mkLoopNest iters in CL.mkFor( [(ty, param, lo)], CL.mkBinOp(CL.E_Var param, CL.#<=, hi), [CL.mkPostOp(CL.E_Var param, CL.^++)], body) end val iterCode = [ CL.mkComment["initially"], CL.mkDecl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))), mkLoopNest iters ] val body = CL.mkBlock( iterPrefix @ allocCode @ iterCode @ [CL.mkReturn(SOME(CL.E_Var "wrld"))]) val initFn = CL.D_Func([], worldTy, N.initially, [], body) in initially := initFn end (***** OUTPUT *****) fun genStrand (Strand{name, tyName, state, output, code}) = let (* the type declaration for the strand's state struct *) val selfTyDef = CL.D_StructDef( List.rev (List.map (fn CL.V(ty, x) => (ty, x)) (!state)), tyName) (* the print function *) val prFnName = concat[name, "_print"] val prFn = let val params = [ CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"), CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self") ] val SOME(ty, x) = !output val outState = CL.mkIndirect(CL.mkVar "self", x) val prArgs = (case ty of Ty.IVecTy 1 => [CL.E_Str(!N.gIntFormat ^ "\n"), outState] | Ty.IVecTy d => let val fmt = CL.E_Str( String.concatWith " " (List.tabulate(d, fn _ => !N.gIntFormat)) ^ "\n") val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i)) in fmt :: args end | Ty.TensorTy[] => [CL.E_Str "%f\n", outState] | Ty.TensorTy[d] => let val fmt = CL.E_Str( String.concatWith " " (List.tabulate(d, fn _ => "%f")) ^ "\n") val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i)) in fmt :: args end | _ => raise Fail("genStrand: unsupported output type " ^ Ty.toString ty) (* end case *)) in CL.D_Func(["static"], CL.voidTy, prFnName, params, CL.mkCall("fprintf", CL.mkVar "outS" :: prArgs)) end (* the strand's descriptor object *) val descI = let fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f)) val SOME(outTy, _) = !output in CL.I_Struct[ ("name", CL.I_Exp(CL.mkStr name)), ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(N.strandTy name)))), (* ("outputSzb", CL.I_Exp(CL.mkSizeof(ToC.trTy outTy))), *) ("update", fnPtr("update_method_t", name ^ "_update")), ("print", fnPtr("print_method_t", prFnName)) ] end val desc = CL.D_Var([], CL.T_Named N.strandDescTy, N.strandDesc name, SOME descI) in selfTyDef :: List.rev (desc :: prFn :: !code) end (* generate the table of strand descriptors *) fun genStrandTable (ppStrm, strands) = let val nStrands = length strands fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name))) fun genInits (_, []) = [] | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss) fun ppDecl dcl = PrintAsC.output(ppStrm, dcl) in ppDecl (CL.D_Var([], CL.int32, N.numStrands, SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32))))); ppDecl (CL.D_Var([], CL.T_Array(CL.T_Ptr(CL.T_Named N.strandDescTy), SOME nStrands), N.strands, SOME(CL.I_Array(genInits (0, strands))))) end fun genSrc (baseName, prog) = let val Prog{name, globals, topDecls, strands, initially, ...} = prog val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"} val outS = TextIO.openOut fileName val ppStrm = PrintAsC.new outS fun ppDecl dcl = PrintAsC.output(ppStrm, dcl) val strands = AtomTable.listItems strands in List.app ppDecl (List.rev (!globals)); List.app ppDecl (List.rev (!topDecls)); List.app (fn strand => List.app ppDecl (genStrand strand)) strands; genStrandTable (ppStrm, strands); ppDecl (!initially); PrintAsC.close ppStrm; TextIO.closeOut outS end (* output the code to a file. The string is the basename of the file, the extension * is provided by the target. *) fun generate (basename, prog as Prog{name, double, parallel, debug, ...}) = let fun condCons (true, x, xs) = x::xs | condCons (false, _, xs) = xs (* generate the C compiler flags *) val cflags = ["-I" ^ Paths.diderotInclude, "-I" ^ Paths.teemInclude] val cflags = condCons (parallel, #pthread Paths.cflags, cflags) val cflags = if debug then #debug Paths.cflags :: cflags else #ndebug Paths.cflags :: cflags val cflags = #base Paths.cflags :: cflags (* generate the loader flags *) val extraLibs = condCons (parallel, #pthread Paths.extraLibs, []) val extraLibs = Paths.teemLinkFlags @ #base Paths.extraLibs :: extraLibs val rtLib = TargetUtil.runtimeName { target = TargetUtil.TARGET_C, parallel = parallel, double = double, debug = debug } val ldOpts = rtLib :: extraLibs in genSrc (basename, prog); RunCC.compile (basename, cflags); RunCC.link (basename, ldOpts) end end (* strands *) structure Strand = struct fun define (Prog{strands, ...}, strandId) = let val name = Atom.toString strandId val strand = Strand{ name = name, tyName = N.strandTy name, state = ref [], output = ref NONE, code = ref [] } in AtomTable.insert strands (strandId, strand); strand end (* return the strand with the given name *) fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId (* register the strand-state initialization code. The variables are the strand * parameters. *) fun init (Strand{name, tyName, code, ...}, params, init) = let val fName = N.strandInit name val params = CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") :: List.map (fn (CL.V(ty, x)) => CL.PARAM([], ty, x)) params val initFn = CL.D_Func([], CL.voidTy, fName, params, init) in code := initFn :: !code end (* register a strand method *) fun method (Strand{name, tyName, code, ...}, methName, body) = let val fName = concat[name, "_", methName] val params = [ CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"), CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ] val methFn = CL.D_Func(["static"], CL.int32, fName, params, body) in code := methFn :: !code end fun output (Strand{output, ...}, ty, CL.V(_, x)) = output := SOME(ty, x) end end structure CBackEnd = CodeGenFn(CTarget)