Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

View of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1475 - (download) (annotate)
Wed Aug 31 18:46:42 2011 UTC (8 years, 3 months ago) by jhr
File size: 20991 byte(s)
  Tweaking runtime system names
(* c-target.sml
 *
 * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure CTarget : TARGET =
  struct

    structure IL = TreeIL
    structure V = IL.Var
    structure Ty = IL.Ty
    structure CL = CLang
    structure N = CNames

  (* variable translation *)
    structure TrVar =
      struct
        type env = CL.typed_var TreeIL.Var.Map.map
        fun lookup (env, x) = (case V.Map.find (env, x)
               of SOME(CL.V(_, x')) => x'
                | NONE => raise Fail(concat["lookup(_, ", V.name x, ")"])
              (* end case *))
      (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)
        fun lvalueVar (env, x) = (case V.kind x
               of IL.VK_Global => CL.mkVar(lookup(env, x))
                | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, x))
                | IL.VK_Local => CL.mkVar(lookup(env, x))
              (* end case *))
      (* translate a variable that occurs in an r-value context *)
        fun rvalueVar (env, x) = (case V.kind x
               of IL.VK_Global => CL.mkVar(lookup(env, x))
                | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x))
                | IL.VK_Local => CL.mkVar(lookup(env, x))
              (* end case *))
      end

    structure ToC = TreeToCFn (TrVar)

    type var = CL.typed_var
    type exp = CL.exp
    type stm = CL.stm

    datatype strand = Strand of {
        name : string,
        tyName : string,
        state : var list ref,
        output : (Ty.ty * CL.var) option ref,   (* the strand's output variable (only one for now) *)
        code : CL.decl list ref
      }

    datatype program = Prog of {
        name : string,                  (* stem of source file *)
        double : bool,                  (* true for double-precision support *)
        parallel : bool,                (* true for multithreaded (or multi-GPU) target *)
        debug : bool,                   (* true for debug support in executable *)
        globals : CL.decl list ref,
        topDecls : CL.decl list ref,
        strands : strand AtomTable.hash_table,
        initially : CL.decl ref
      }

    datatype env = ENV of {
        info : env_info,
        vMap : var V.Map.map,
        scope : scope
      }

    and env_info = INFO of {
        prog : program
      }

    and scope
      = NoScope
      | GlobalScope
      | InitiallyScope
      | StrandScope of TreeIL.var list  (* strand initialization *)
      | MethodScope of MethodName.name * TreeIL.var list  (* method body; vars are state variables *)

  (* the supprted widths of vectors of reals on the target.  For the GNU vector extensions,
   * the supported sizes are powers of two, but float2 is broken.
   * NOTE: we should also consider the AVX vector hardware, which has 256-bit registers.
   *)
    fun vectorWidths () = if !N.doublePrecision
          then [2, 4, 8]
          else [4, 8]

  (* tests for whether various expression forms can appear inline *)
    fun inlineCons n = (n < 2)          (* vectors are inline, but not matrices *)
    val inlineMatrixExp = false         (* can matrix-valued expressions appear inline? *)

  (* TreeIL to target translations *)
    structure Tr =
      struct
        fun fragment (ENV{info, vMap, scope}, blk) = let
              val (vMap, stms) = ToC.trFragment (vMap, blk)
              in
                (ENV{info=info, vMap=vMap, scope=scope}, stms)
              end
        fun saveState cxt stateVars (env, args, stm) = (
              ListPair.foldrEq
                (fn (x, e, stms) => ToC.trAssign(env, x, e)@stms)
                  [stm]
                    (stateVars, args)
              ) handle ListPair.UnequalLengths => (
                print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);
                raise Fail(concat["saveState ", cxt, ": length mismatch"]))
        fun block (ENV{vMap, scope, ...}, blk) = (case scope
               of StrandScope stateVars => ToC.trBlock (vMap, saveState "StrandScope" stateVars, blk)
                | MethodScope(name, stateVars) => ToC.trBlock (vMap, saveState "MethodScope" stateVars, blk)
                | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)
              (* end case *))
        fun exp (ENV{vMap, ...}, e) = ToC.trExp(vMap, e)
      end

  (* variables *)
    structure Var =
      struct
        fun name (CL.V(_, name)) = name
        fun global (Prog{globals, ...}, name, ty) = let
              val ty' = ToC.trType ty
              in
                globals := CL.D_Var([], ty', name, NONE) :: !globals;
                CL.V(ty', name)
              end
        fun param x = CL.V(ToC.trType(V.ty x), V.name x)
        fun state (Strand{state, ...}, x) = let
              val ty' = ToC.trType(V.ty x)
              val x' = CL.V(ty', V.name x)
              in
                state := x' :: !state;
                x'
              end
      end

  (* environments *)
    structure Env =
      struct
      (* create a new environment *)
        fun new prog = ENV{
                info=INFO{prog = prog},
                vMap = V.Map.empty,
                scope = NoScope
              }
      (* define the current translation context *)
        fun setScope scope (ENV{info, vMap, ...}) = ENV{info=info, vMap=vMap, scope=scope}
        val scopeGlobal = setScope GlobalScope
        val scopeInitially = setScope InitiallyScope
        fun scopeStrand (env, svars) = setScope (StrandScope svars) env
        fun scopeMethod (env, name, svars) = setScope (MethodScope(name, svars)) env
      (* bind a TreeIL varaiable to a target variable *)
        fun bind (ENV{info, vMap, scope}, x, x') = ENV{
                info = info,
                vMap = V.Map.insert(vMap, x, x'),
                scope = scope
              }
      end

  (* programs *)
    structure Program =
      struct
        fun new {name, double, parallel, debug} = (
              N.initTargetSpec double;
              Prog{
                  name = name,
                  double = double, parallel = parallel, debug = debug,
                  globals = ref [ (* NOTE: in reverse order! *)
                      CL.D_Var(["static"], CL.charPtr, "ProgramName",
                        SOME(CL.I_Exp(CL.mkStr name))),
                      CL.D_Verbatim[
                          if double
                            then "#define DIDEROT_DOUBLE_PRECISION"
                            else "#define DIDEROT_SINGLE_PRECISION",
                          if parallel
                            then "#define DIDEROT_TARGET_PARALLEL"
                            else "#define DIDEROT_TARGET_C",
                          "#include \"Diderot/diderot.h\""
                        ]
                    ],
                  topDecls = ref [],
                  strands = AtomTable.mkTable (16, Fail "strand table"),
                  initially = ref(CL.D_Comment["missing initially"])
                })
      (* register the code that is used to register command-line options for input variables *)
        fun inputs (Prog{topDecls, ...}, stm) = let
              val inputsFn = CL.D_Func(
                    [], CL.voidTy, N.registerOpts,
                    [CL.PARAM([], CL.T_Ptr(CL.T_Named N.optionsTy), "opts")],
                    stm)
              in
                topDecls := inputsFn :: !topDecls
              end
      (* register the global initialization part of a program *)
        fun init (Prog{topDecls, ...}, init) = let
              val initFn = CL.D_Func([], CL.voidTy, N.initGlobals, [], init)
              val shutdownFn = CL.D_Func(
                    [], CL.voidTy, N.shutdown,
                    [CL.PARAM([], CL.T_Ptr(CL.T_Named N.worldTy), "wrld")],
                    CL.S_Block[])
              in
                topDecls := shutdownFn :: initFn :: !topDecls
              end
      (* create and register the initially function for a program *)
        fun initially {
              prog = Prog{name=progName, strands, initially, ...},
              isArray : bool,
              iterPrefix : stm list,
              iters : (var * exp * exp) list,
              createPrefix : stm list,
              strand : Atom.atom,
              args : exp list
            } = let
              val name = Atom.toString strand
              val nDims = List.length iters
              val worldTy = CL.T_Ptr(CL.T_Named N.worldTy)
              fun mapi f xs = let
                    fun mapf (_, []) = []
                      | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs)
                    in
                      mapf (0, xs)
                    end
              val baseInit = mapi (fn (i, (_, e, _)) => (i, CL.I_Exp e)) iters
              val sizeInit = mapi
                    (fn (i, (CL.V(ty, _), lo, hi)) =>
                        (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, ty))))
                    ) iters
            (* code to allocate the world and initial strands *)
              val wrld = "wrld"
              val allocCode = [
                      CL.mkComment["allocate initial block of strands"],
                      CL.mkDecl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
                      CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
                      CL.mkDecl(worldTy, wrld,
                        SOME(CL.I_Exp(CL.E_Apply(N.allocInitially, [
                            CL.mkVar "ProgramName",
                            CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)),
                            CL.E_Bool isArray,
                            CL.E_Int(IntInf.fromInt nDims, CL.int32),
                            CL.E_Var "base",
                            CL.E_Var "size"
                          ]))))
                    ]
            (* create the loop nest for the initially iterations *)
              val indexVar = "ix"
              val strandTy = CL.T_Ptr(CL.T_Named(N.strandTy name))
              fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
                      CL.mkDecl(strandTy, "sp",
                        SOME(CL.I_Exp(
                          CL.E_Cast(strandTy,
                          CL.E_Apply(N.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
                      CL.mkCall(N.strandInit name, CL.E_Var "sp" :: args),
                      CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
                    ])
                | mkLoopNest ((CL.V(ty, param), lo, hi)::iters) = let
                    val body = mkLoopNest iters
                    in
                      CL.mkFor(
                        [(ty, param, lo)],
                        CL.mkBinOp(CL.E_Var param, CL.#<=, hi),
                        [CL.mkPostOp(CL.E_Var param, CL.^++)],
                        body)
                    end
              val iterCode = [
                      CL.mkComment["initially"],
                      CL.mkDecl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
                      mkLoopNest iters
                    ]
              val body = CL.mkBlock(
                    iterPrefix @
                    allocCode @
                    iterCode @
                    [CL.mkReturn(SOME(CL.E_Var "wrld"))])
              val initFn = CL.D_Func([], worldTy, N.initially, [], body)
              in
                initially := initFn
              end

      (***** OUTPUT *****)
        fun genStrand (Strand{name, tyName, state, output, code}) = let
            (* the type declaration for the strand's state struct *)
              val selfTyDef = CL.D_StructDef(
                      List.rev (List.map (fn CL.V(ty, x) => (ty, x)) (!state)),
                      tyName)
            (* the type and access expression for the strand's output variable *)
              val (outTy, outState) = (case !output
                     of SOME(ty, x) => (ty, CL.mkIndirect(CL.mkVar "self", x))
                      | NONE => raise Fail "no output variable"
                    (* end case *))
            (* the print function *)
              val prFnName = concat[name, "_print"]
              val prFn = let
                    val params = [
                          CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
                          CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
                        ]
                    val prArgs = (case outTy
                           of Ty.IVecTy 1 => [CL.E_Str(!N.gIntFormat ^ "\n"), outState]
                            | Ty.IVecTy d => let
                                val fmt = CL.E_Str(
                                      String.concatWith " " (List.tabulate(d, fn _ => !N.gIntFormat))
                                      ^ "\n")
                                val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))
                                in
                                  fmt :: args
                                end
                            | Ty.TensorTy[] => [CL.E_Str "%f\n", outState]
                            | Ty.TensorTy[d] => let
                                val fmt = CL.E_Str(
                                      String.concatWith " " (List.tabulate(d, fn _ => "%f"))
                                      ^ "\n")
                                val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))
                                in
                                  fmt :: args
                                end
                            | _ => raise Fail("genStrand: unsupported output type " ^ Ty.toString outTy)
                          (* end case *))
                    in
                      CL.D_Func(["static"], CL.voidTy, prFnName, params,
                        CL.mkCall("fprintf", CL.mkVar "outS" :: prArgs))
                    end
            (* the output function *)
              val outFnName = concat[name, "_output"]
              val outFn = let
                    val params = [
                          CL.PARAM([], CL.T_Ptr CL.voidTy, "outS"),
                          CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
                        ]
                  (* get address of output variable *)
                    val outState = CL.mkUnOp(CL.%&, outState)
                    in
                      CL.D_Func(["static"], CL.voidTy, outFnName, params,
                        CL.mkCall("memcpy", [CL.mkVar "outS", outState, CL.mkSizeof(ToC.trType outTy)] ))
                    end
            (* the strand's descriptor object *)
              val descI = let
                    fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
		    val nrrdTy = NrrdTypes.toNrrdType outTy
                    in
                      CL.I_Struct[
                          ("name", CL.I_Exp(CL.mkStr name)),
                          ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(N.strandTy name)))),
                          ("outputSzb", CL.I_Exp(CL.mkSizeof(ToC.trType outTy))),
                          ("nrrdType", CL.I_Exp(CL.mkInt nrrdTy)),
                          ("update", fnPtr("update_method_t", name ^ "_update")),
                          ("stabilize", fnPtr("stabilize_method_t", name ^ "_stabilize")),
                          ("print", fnPtr("print_method_t", prFnName)),
                          ("output", fnPtr("output_method_t", outFnName))
                        ]
                    end
              val desc = CL.D_Var([], CL.T_Named N.strandDescTy, N.strandDesc name, SOME descI)
              in
                selfTyDef :: List.rev (desc :: prFn :: outFn :: !code)
              end

      (* generate the table of strand descriptors *)
        fun genStrandTable (ppStrm, strands) = let
              val nStrands = length strands
              fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)))
              fun genInits (_, []) = []
                | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
              fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
              in
                ppDecl (CL.D_Var([], CL.int32, N.numStrands,
                  SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
                ppDecl (CL.D_Var([],
                  CL.T_Array(CL.T_Ptr(CL.T_Named N.strandDescTy), SOME nStrands),
                  N.strands,
                  SOME(CL.I_Array(genInits (0, strands)))))
              end

        fun genSrc (baseName, prog) = let
              val Prog{name, globals, topDecls, strands, initially, ...} = prog
              val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
              val outS = TextIO.openOut fileName
              val ppStrm = PrintAsC.new outS
              fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
              val strands = AtomTable.listItems strands
              in
                List.app ppDecl (List.rev (!globals));
                List.app ppDecl (List.rev (!topDecls));
                List.app (fn strand => List.app ppDecl (genStrand strand)) strands;
                genStrandTable (ppStrm, strands);
                ppDecl (!initially);
                PrintAsC.close ppStrm;
                TextIO.closeOut outS
              end

      (* output the code to a file.  The string is the basename of the file, the extension
       * is provided by the target.
       *)
        fun generate (basename, prog as Prog{name, double, parallel, debug, ...}) = let
              fun condCons (true, x, xs) = x::xs
                | condCons (false, _, xs) = xs
            (* generate the C compiler flags *)
              val cflags = ["-I" ^ Paths.diderotInclude, "-I" ^ Paths.teemInclude]
              val cflags = condCons (parallel, #pthread Paths.cflags, cflags)
              val cflags = if debug
                    then #debug Paths.cflags :: cflags
                    else #ndebug Paths.cflags :: cflags
              val cflags = #base Paths.cflags :: cflags
            (* generate the loader flags *)
              val extraLibs = condCons (parallel, #pthread Paths.extraLibs, [])
              val extraLibs = Paths.teemLinkFlags @ #base Paths.extraLibs :: extraLibs
              val rtLib = TargetUtil.runtimeName {
                      target = TargetUtil.TARGET_C,
                      parallel = parallel, double = double, debug = debug
                    }
              val ldOpts = rtLib :: extraLibs
              in
                genSrc (basename, prog);
                RunCC.compile (basename, cflags);
                RunCC.link (basename, ldOpts)
              end

      end

  (* strands *)
    structure Strand =
      struct
        fun define (Prog{strands, ...}, strandId) = let
              val name = Atom.toString strandId
              val strand = Strand{
                      name = name,
                      tyName = N.strandTy name,
                      state = ref [],
                      output = ref NONE,
                      code = ref []
                    }
              in
                AtomTable.insert strands (strandId, strand);
                strand
              end

      (* return the strand with the given name *)
        fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId

      (* register the strand-state initialization code.  The variables are the strand
       * parameters.
       *)
        fun init (Strand{name, tyName, code, ...}, params, init) = let
              val fName = N.strandInit name
              val params =
                    CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
                      List.map (fn (CL.V(ty, x)) => CL.PARAM([], ty, x)) params
              val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
              in
                code := initFn :: !code
              end

      (* register a strand method *)
        fun method (Strand{name, tyName, code, ...}, methName, body) = let
              val fName = concat[name, "_", MethodName.toString methName]
              val params = [
                      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
                      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
                    ]
              val resTy = (case methName
                     of MethodName.Update => CL.T_Named "StrandStatus_t"
                      | MethodName.Stabilize => CL.voidTy
                    (* end case *))
              val methFn = CL.D_Func(["static"], resTy, fName, params, body)
              in
                code := methFn :: !code
              end

        fun output (Strand{output, ...}, ty, CL.V(_, x)) = output := SOME(ty, x)

      end

  end

structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0