SCM Repository
View of /branches/pure-cfg/src/compiler/c-target/c-target.sml
Parent Directory
|
Revision Log
Revision 623 -
(download)
(annotate)
Tue Mar 15 17:04:53 2011 UTC (9 years, 10 months ago) by jhr
File size: 18234 byte(s)
Tue Mar 15 17:04:53 2011 UTC (9 years, 10 months ago) by jhr
File size: 18234 byte(s)
Working on generating initially code. Also changed var and exp types in CTarget to be datatypes for better typechecking and documentation.
(* c-target.sml * * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. * * Generate C code with SSE 4.2 intrinsics. *) structure CTarget : TARGET = struct structure CL = CLang structure RN = RuntimeNames datatype ty = datatype TargetTy.ty datatype var = V of (ty * string) datatype exp = E of CLang.exp * ty type stm = CL.stm datatype strand = Strand of { name : string, tyName : string, state : var list ref, code : CL.decl list ref } datatype program = Prog of { globals : CL.decl list ref, topDecls : CL.decl list ref, strands : strand list ref } (* for SSE, we have 128-bit vectors *) fun vectorWidth () = !RN.gVectorWid (* target types *) val boolTy = T_Bool val intTy = T_Int val realTy = T_Real fun vecTy 1 = T_Real | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n) then raise Size else T_Vec n fun ivecTy 1 = T_Int | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n) then raise Size else T_IVec n fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy) fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy val stringTy = T_String val statusTy = CL.T_Named RN.statusTy (* convert target types to CLang types *) fun cvtTy T_Bool = CLang.T_Named "bool" | cvtTy T_String = CL.charPtr | cvtTy T_Int = !RN.gIntTy | cvtTy T_Real = !RN.gRealTy | cvtTy (T_Vec n) = CL.T_Named(RN.vecTy n) | cvtTy (T_IVec n) = CL.T_Named(RN.ivecTy n) | cvtTy (T_Image(n, _)) = CL.T_Ptr(CL.T_Named(RN.imageTy n)) | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty) (* report invalid arguments *) fun invalid (name, []) = raise Fail("invaild "^name) | invalid (name, args) = let fun arg2s (E(e, ty)) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"] val args = String.concatWith ", " (List.map arg2s args) in raise Fail(concat["invalid arguments to ", name, ": ", args]) end (* helper functions for checking the types of arguments *) fun scalarTy T_Int = true | scalarTy T_Real = true | scalarTy _ = false fun numTy T_Int = true | numTy T_Real = true | numTy (T_Vec _) = true | numTy (T_IVec _) = true | numTy _ = false fun newProgram () = ( RN.initTargetSpec(); Prog{ globals = ref [ CL.D_Verbatim[ if !Controls.doublePrecision then "#define DIDEROT_DOUBLE_PRECISION" else "#define DIDEROT_SINGLE_PRECISION", "#include \"Diderot/diderot.h\"" ]], topDecls = ref [], strands = ref [] }) (* register the global initialization part of a program *) fun globalInit (Prog{topDecls, ...}, init) = let val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init) in topDecls := initFn :: !topDecls end (* create and register the initially function for a program *) fun initially { prog = Prog{topDecls, ...}, isArray : bool, iterPrefix : stm list, iters : (var * exp * exp) list, createPrefix : stm list, strand=Strand{name, ...}, args : exp list } = let val nDims = List.length iters val worldTy = CL.T_Ptr(CL.T_Named RN.worldTy) fun mapi f xs = let fun mapf (_, []) = [] | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs) in mapf (0, xs) end val baseInit = mapi (fn (i, (_, E(e, _), _)) => (i, CL.I_Exp e)) iters val sizeInit = mapi (fn (i, (V(ty, _), E(lo, _), E(hi, _))) => (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, cvtTy ty)))) ) iters val allocCode = [ CL.S_Comment["allocate initial block of strands"], CL.S_Decl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)), CL.S_Decl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)), CL.S_Decl(worldTy, "wrld", SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [ CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)), CL.E_Bool isArray, CL.E_Int(IntInf.fromInt nDims, CL.int32), CL.E_Var "base", CL.E_Var "size" ])))) ] (* create the loop nest for the initially iterations *) val indexVar = "ix" fun mkLoopNest [] = CL.mkBlock(createPrefix @ [ CL.S_Decl(CL.T_Ptr(CL.T_Named(RN.strandTy name)), "sp", SOME(CL.I_Exp(CL.E_Apply(RN.inState, [CL.E_Var "wrld", CL.E_Var indexVar])))), CL.S_Call(RN.strandInit name, CL.E_Var "sp" :: List.map (fn (E(e, _)) => e) args) ]) | mkLoopNest ((V(ty, param), E(lo,_), E(hi, _))::iters) = let val body = mkLoopNest iters in CL.S_For( [(cvtTy ty, param, lo)], CL.mkBinOp(CL.E_Var param, CL.#<=, hi), [CL.mkPostOp(CL.E_Var param, CL.^++)], body) end val iterCode = [ CL.S_Comment["initially"], CL.S_Decl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))), mkLoopNest iters ] val body = CL.mkBlock(iterPrefix @ allocCode @ iterCode @ [CL.S_Return(SOME(CL.E_Var "wrld"))]) val initFn = CL.D_Func([], worldTy, RN.initially, [], body) in topDecls := initFn :: !topDecls end structure Var = struct fun global (Prog{globals, ...}, ty, name) = ( globals := CL.D_Var([], cvtTy ty, name, NONE) :: !globals; V(ty, name)) fun param (ty, name) = V(ty, name) fun state (Strand{state, ...}, ty, name) = ( state := V(ty, name) :: !state; V(ty, name)) fun var (ty, name) = V(ty, name) local val count = ref 0 fun freshName prefix = let val n = !count in count := n+1; concat[prefix, "_", Int.toString n] end in fun tmp ty = V(ty, freshName "tmp") fun fresh prefix = freshName prefix end (* local *) end (* expression construction *) structure Expr = struct (* return true if the given expression from is allowed as a subexpression *) fun allowedInline _ = true (* FIXME *) (* variable references *) fun global (V(ty, x)) = E(CL.mkVar x, ty) fun getState (V(ty, x)) = E(CL.mkIndirect(CL.mkVar "selfIn", x), ty) fun param (V(ty, x)) = E(CL.mkVar x, ty) fun var (V(ty, x)) = E(CL.mkVar x, ty) (* literals *) fun intLit n = E(CL.mkInt(n, !RN.gIntTy), intTy) fun floatLit f = E(CL.mkFlt(f, !RN.gRealTy), realTy) fun stringLit s = E(CL.mkStr s, stringTy) fun boolLit b = E(CL.mkBool b, boolTy) (* select from a vector. We have to cast to the corresponding union type and then * select from the array field. *) local fun sel (tyCode, field, ty) (i, e, n) = if (i < 0) orelse (n <= i) then raise Subscript else let val unionTy = CL.T_Named(concat["union", Int.toString n, !tyCode, "_t"]) val e1 = CL.mkCast(unionTy, e) val e2 = CL.mkSelect(e1, field) in E(CL.mkSubscript(e2, CL.mkInt(IntInf.fromInt i, CL.int32)), ty) end val selF = sel (RN.gRealSuffix, "r", T_Real) val selI = sel (RN.gIntSuffix, "i", T_Int) in fun select (i, E(e, T_Vec n)) = selF (i, e, n) | select (i, E(e, T_IVec n)) = selI (i, e, n) | select (_, x) = invalid("select", [x]) end (* local *) (* vector (and scalar) arithmetic *) local fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1 fun binop rator (E(e1, ty1), E(e2, ty2)) = if checkTys (ty1, ty2) then E(CL.mkBinOp(e1, rator, e2), ty1) else invalid ( concat["binary operator \"", CL.binopToString rator, "\""], [E(e1, ty1), E(e2, ty2)]) in fun add (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#+, e2), ty) | add args = binop CL.#+ args fun sub (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#-, e2), ty) | sub args = binop CL.#- args (* NOTE: multiplication and division are also used for scaling *) fun mul (E(e1, T_Real), E(e2, T_Vec n)) = E(CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n) | mul args = binop CL.#* args fun divide (E(e1, T_Vec n), E(e2, T_Real)) = let val E(one, _) = floatLit FloatLit.one in E(CL.E_Apply(RN.scale n, [CL.mkBinOp(one, CL.#/, e2), e1]), T_Vec n) end | divide args = binop CL.#/ args end (* local *) fun neg (E(e, T_Bool)) = raise Fail "invalid argument to neg" | neg (E(e, ty)) = E(CL.mkUnOp(CL.%-, e), ty) fun abs (E(e, T_Int)) = E(CL.mkApply("abs", [e]), T_Int) (* FIXME: not the right type for 64-bit ints *) | abs (E(e, T_Real)) = E(CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real) | abs (E(e, T_Vec n)) = raise Fail "FIXME: Expr.abs" | abs (E(e, T_IVec n)) = raise Fail "FIXME: Expr.abs" | abs _ = raise Fail "invalid argument to abs" fun dot (E(e1, T_Vec n1), E(e2, T_Vec n2)) = E(CL.E_Apply(RN.dot n1, [e1, e2]), T_Real) | dot _ = raise Fail "invalid argument to dot" fun cross (E(e1, T_Vec 3), E(e2, T_Vec 3)) = E(CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3) | cross _ = raise Fail "invalid argument to cross" fun length (E(e, T_Vec n)) = E(CL.E_Apply(RN.length n, [e]), T_Real) | length _ = raise Fail "invalid argument to length" fun normalize (E(e, T_Vec n)) = E(CL.E_Apply(RN.normalize n, [e]), T_Vec n) | normalize _ = raise Fail "invalid argument to length" (* comparisons *) local fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1 fun cmpop rator (E(e1, ty1), E(e2, ty2)) = if checkTys (ty1, ty2) then E(CL.mkBinOp(e1, rator, e2), T_Bool) else invalid ( concat["compare operator \"", CL.binopToString rator, "\""], [E(e1, ty1), E(e2, ty2)]) in val lt = cmpop CL.#< val lte = cmpop CL.#<= val equ = cmpop CL.#== val neq = cmpop CL.#!= val gte = cmpop CL.#>= val gt = cmpop CL.#> end (* local *) (* logical connectives *) fun not (E(e, T_Bool)) = E(CL.mkUnOp(CL.%!, e), T_Bool) | not _ = raise Fail "invalid argument to not" fun && (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#&&, e2), T_Bool) | && _ = raise Fail "invalid arguments to &&" fun || (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#||, e2), T_Bool) | || _ = raise Fail "invalid arguments to ||" local fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1 fun binFn f (E(e1, ty1), E(e2, ty2)) = if checkTys (ty1, ty2) then E(CL.mkApply(f ty1, [e1, e2]), ty1) else raise Fail "invalid arguments to binary function" in (* misc functions *) val min = binFn RN.min val max = binFn RN.max end (* local *) (* math functions *) fun pow (E(e1, T_Real), E(e2, T_Real)) = if !Controls.doublePrecision then E(CL.mkApply("pow", [e1, e2]), T_Real) else E(CL.mkApply("powf", [e1, e2]), T_Real) | pow _ = raise Fail "invalid arguments to pow" local fun r2r (ff, fd) (E(e, T_Real)) = if !Controls.doublePrecision then E(CL.mkApply(fd, [e]), T_Real) else E(CL.mkApply(ff, [e]), T_Real) | r2r (_, fd) e = invalid (fd, [e]) in val sin = r2r ("sinf", "sin") val cos = r2r ("cosf", "cos") val sqrt = r2r ("sqrtf", "sqrt") end (* local *) (* rounding *) fun trunc (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty) fun round (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("round", ty), [e]), ty) fun floor (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty) fun ceil (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty) (* conversions *) fun toInt (E(e, T_Real)) = E(CL.mkCast(!RN.gIntTy, e), T_Int) | toInt (E(e, T_Vec n)) = E(CL.mkApply(RN.vecftoi n, [e]), ivecTy n) | toInt e = invalid ("toInt", [e]) fun toReal (E(e, T_Int)) = E(CL.mkCast(!RN.gRealTy, e), T_Real) | toReal e = invalid ("toReal", [e]) (* runtime system hooks *) fun imageAddr (E(e, T_Image(_, rTy))) = let val cTy = CL.T_Ptr(CL.T_Num rTy) in E(CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy) end | imageAddr a = invalid("imageAddr", [a]) fun getImgData (E(e, T_Ptr rTy)) = let val realTy as CL.T_Num rTy' = !RN.gRealTy val e = CL.E_UnOp(CL.%*, e) in if (rTy' = rTy) then E(e, T_Real) else E(CL.E_Cast(realTy, e), T_Real) end | getImgData a = invalid("getImgData", [a]) fun posToImgSpace (E(img, T_Image(d, _)), E(pos, T_Vec n)) = let val e = CL.mkApply(RN.toImageSpace d, [img, pos]) in E(e, T_Vec n) end | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b]) fun inside (E(pos, T_Vec n), E(img, T_Image(d, _)), s) = let val e = CL.mkApply(RN.inside d, [pos, img, CL.mkInt(IntInf.fromInt s, CL.int32)]) in E(e, T_Bool) end | inside (a, b, _) = invalid("inside", [a, b]) end (* Expr *) (* statement construction *) structure Stmt = struct val comment = CL.S_Comment fun assignState (V(_, x), E(e, _)) = CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e) fun assign (V(_, x), E(e, _)) = CL.mkAssign(CL.mkVar x, e) fun decl (V(ty, x), SOME(E(e, _))) = CL.mkDecl(cvtTy ty, x, SOME(CL.I_Exp e)) | decl (V(ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE) val block = CL.mkBlock fun ifthen (E(e, T_Bool), s1) = CL.mkIfThen(e, s1) fun ifthenelse (E(e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2) fun for (V(ty, x), E(lo, _), E(hi, _), body) = CL.mkFor( [(cvtTy ty, x, lo)], CL.mkBinOp(CL.mkVar x, CL.#<=, hi), [CL.mkPostOp(CL.mkVar x, CL.^++)], body) (* special Diderot forms *) fun cons (V(T_Vec n, x), args : exp list) = CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.map (fn E(e, _) => e) args)) | cons _ = raise Fail "bogus cons" fun getImgData (V(T_Vec n, x), E(e, T_Ptr rTy)) = let val addr = Var.fresh "vp" val needsCast = (CL.T_Num rTy <> !RN.gRealTy) fun mkLoad i = let val e = CL.mkSubscript(CL.mkVar addr, CL.mkInt(IntInf.fromInt i, CL.int32)) in if needsCast then CL.mkCast(!RN.gRealTy, e) else e end in [ CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), addr, SOME(CL.I_Exp e)), CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad))) ] end | getImgData _ = raise Fail "bogus getImgData" local fun checkSts mkDecl = let val sts = Var.fresh "sts" in mkDecl sts @ [CL.mkIfThen( CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts), CL.mkCall("exit", [CL.mkInt(1, CL.int32)]))] end in fun loadImage (V(_, lhs), dim, E(name, _)) = checkSts (fn sts => let val imgTy = CL.T_Named(RN.imageTy dim) val loadFn = RN.loadImage dim in [ CL.S_Decl( statusTy, sts, SOME(CL.I_Exp(CL.E_Apply(loadFn, [name, CL.mkUnOp(CL.%&, CL.E_Var lhs)])))) ] end) fun input (V(ty, lhs), name, optDflt) = checkSts (fn sts => let val inputFn = RN.input ty val lhs = CL.E_Var lhs val (initCode, hasDflt) = (case optDflt of SOME(E(e, _)) => ([CL.S_Assign(lhs, e)], true) | NONE => ([], false) (* end case *)) val code = [ CL.S_Decl( statusTy, sts, SOME(CL.I_Exp(CL.E_Apply(inputFn, [ CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt ])))) ] in initCode @ code end) end (* local *) fun exit () = CL.mkReturn NONE fun active () = CL.mkReturn(SOME(CL.mkVar RN.kActive)) fun stabilize () = CL.mkReturn(SOME(CL.mkVar RN.kStabilize)) fun die () = CL.mkReturn(SOME(CL.mkVar RN.kDie)) end structure Strand = struct fun define (Prog{strands, ...}, strandId) = let val strand = Strand{ name = strandId, tyName = RN.strandTy strandId, state = ref [], code = ref [] } in strands := strand :: !strands; strand end (* register the strand-state initialization code. The variables are the strand * parameters. *) fun init (Strand{name, tyName, code, ...}, params, init) = let val fName = RN.strandInit name val params = CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") :: List.map (fn (V(ty, x)) => CL.PARAM([], cvtTy ty, x)) params val initFn = CL.D_Func([], CL.voidTy, fName, params, init) in code := initFn :: !code end (* register a strand method *) fun method (Strand{name, tyName, code, ...}, methName, body) = let val fName = concat[name, "_", methName] val params = [ CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"), CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ] val methFn = CL.D_Func([], CL.int32, fName, params, body) in code := methFn :: !code end end (* Strand *) fun genStrand (Strand{name, tyName, state, code}) = let val selfTyDef = CL.D_StructDef( List.rev (List.map (fn V(ty, x) => (cvtTy ty, x)) (!state)), tyName) in selfTyDef :: List.rev (!code) end (* generate the table of strand descriptors *) fun genStrandTable (ppStrm, strands) = let val nStrands = length strands fun genInit (Strand{name, tyName, code, ...}) = let fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f)) in CL.I_Struct[ ("name", CL.I_Exp(CL.E_Str name)), ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))), ("init", fnPtr("strand_init_t", RN.strandInit name)), ("update", fnPtr("update_method_t", name ^ "_update")) ] end fun genInits (_, []) = [] | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss) fun ppDecl dcl = PrintAsC.output(ppStrm, dcl) in ppDecl (CL.D_Var([], CL.int32, RN.numStrands, SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32))))); ppDecl (CL.D_Var([], CL.T_Array(CL.T_Named RN.strandDescTy, SOME nStrands), RN.strands, SOME(CL.I_Array(genInits (0, strands))))) end fun generate (baseName, Prog{globals, topDecls, strands}) = let val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"} val outS = TextIO.openOut fileName val ppStrm = PrintAsC.new outS fun ppDecl dcl = PrintAsC.output(ppStrm, dcl) in List.app ppDecl (List.rev (!globals)); List.app ppDecl (List.rev (!topDecls)); List.app (fn strand => List.app ppDecl (genStrand strand)) (!strands); genStrandTable (ppStrm, !strands); PrintAsC.close ppStrm; TextIO.closeOut outS end end structure CBackEnd = CodeGenFn(CTarget)
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |