Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 695 - (view) (download)

1 : jhr 519 (* c-target.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Generate C code with SSE 4.2 intrinsics.
7 :     *)
8 :    
9 :     structure CTarget : TARGET =
10 :     struct
11 :    
12 : jhr 522 structure CL = CLang
13 : jhr 551 structure RN = RuntimeNames
14 : jhr 522
15 : jhr 551 datatype ty = datatype TargetTy.ty
16 : jhr 519
17 : jhr 623 datatype var = V of (ty * string)
18 :    
19 :     datatype exp = E of CLang.exp * ty
20 :    
21 :     type stm = CL.stm
22 :    
23 : jhr 544 datatype strand = Strand of {
24 :     name : string,
25 :     tyName : string,
26 : jhr 623 state : var list ref,
27 : jhr 654 output : var option ref, (* the strand's output variable (only one for now) *)
28 : jhr 544 code : CL.decl list ref
29 :     }
30 : jhr 525
31 : jhr 527 datatype program = Prog of {
32 :     globals : CL.decl list ref,
33 : jhr 533 topDecls : CL.decl list ref,
34 : jhr 624 strands : strand AtomTable.hash_table,
35 :     initially : CL.decl ref
36 : jhr 527 }
37 :    
38 : jhr 519 (* for SSE, we have 128-bit vectors *)
39 : jhr 551 fun vectorWidth () = !RN.gVectorWid
40 : jhr 519
41 :     (* target types *)
42 : jhr 525 val boolTy = T_Bool
43 :     val intTy = T_Int
44 :     val realTy = T_Real
45 :     fun vecTy 1 = T_Real
46 : jhr 551 | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
47 : jhr 525 then raise Size
48 :     else T_Vec n
49 :     fun ivecTy 1 = T_Int
50 : jhr 551 | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
51 : jhr 525 then raise Size
52 :     else T_IVec n
53 : jhr 548 fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy)
54 :     fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy
55 : jhr 534 val stringTy = T_String
56 : jhr 519
57 : jhr 552 val statusTy = CL.T_Named RN.statusTy
58 : jhr 534
59 : jhr 528 (* convert target types to CLang types *)
60 :     fun cvtTy T_Bool = CLang.T_Named "bool"
61 : jhr 534 | cvtTy T_String = CL.charPtr
62 : jhr 551 | cvtTy T_Int = !RN.gIntTy
63 :     | cvtTy T_Real = !RN.gRealTy
64 : jhr 561 | cvtTy (T_Vec n) = CL.T_Named(RN.vecTy n)
65 :     | cvtTy (T_IVec n) = CL.T_Named(RN.ivecTy n)
66 : jhr 683 | cvtTy (T_Mat(n,m)) = CL.T_Named(RN.matTy(n,m))
67 : jhr 561 | cvtTy (T_Image(n, _)) = CL.T_Ptr(CL.T_Named(RN.imageTy n))
68 : jhr 548 | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty)
69 : jhr 528
70 : jhr 548 (* report invalid arguments *)
71 :     fun invalid (name, []) = raise Fail("invaild "^name)
72 :     | invalid (name, args) = let
73 : jhr 623 fun arg2s (E(e, ty)) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"]
74 : jhr 548 val args = String.concatWith ", " (List.map arg2s args)
75 :     in
76 :     raise Fail(concat["invalid arguments to ", name, ": ", args])
77 :     end
78 :    
79 : jhr 525 (* helper functions for checking the types of arguments *)
80 :     fun scalarTy T_Int = true
81 :     | scalarTy T_Real = true
82 :     | scalarTy _ = false
83 : jhr 548 fun numTy T_Int = true
84 :     | numTy T_Real = true
85 :     | numTy (T_Vec _) = true
86 :     | numTy (T_IVec _) = true
87 :     | numTy _ = false
88 : jhr 519
89 : jhr 528 fun newProgram () = (
90 : jhr 551 RN.initTargetSpec();
91 : jhr 528 Prog{
92 : jhr 554 globals = ref [
93 :     CL.D_Verbatim[
94 :     if !Controls.doublePrecision
95 :     then "#define DIDEROT_DOUBLE_PRECISION"
96 :     else "#define DIDEROT_SINGLE_PRECISION",
97 :     "#include \"Diderot/diderot.h\""
98 :     ]],
99 : jhr 533 topDecls = ref [],
100 : jhr 624 strands = AtomTable.mkTable (16, Fail "strand table"),
101 :     initially = ref(CL.D_Comment["missing initially"])
102 : jhr 528 })
103 :    
104 : jhr 618 (* register the global initialization part of a program *)
105 : jhr 533 fun globalInit (Prog{topDecls, ...}, init) = let
106 : jhr 551 val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
107 : jhr 533 in
108 :     topDecls := initFn :: !topDecls
109 :     end
110 :    
111 : jhr 623 (* create and register the initially function for a program *)
112 :     fun initially {
113 : jhr 624 prog = Prog{strands, initially, ...},
114 : jhr 623 isArray : bool,
115 : jhr 624 iterPrefix : stm,
116 : jhr 623 iters : (var * exp * exp) list,
117 : jhr 624 createPrefix : stm,
118 :     strand : Atom.atom,
119 : jhr 623 args : exp list
120 :     } = let
121 : jhr 624 val iterPrefix = (case iterPrefix
122 :     of CL.S_Block stms => stms
123 :     | stm => [stm]
124 :     (* end case *))
125 :     val createPrefix = (case createPrefix
126 :     of CL.S_Block stms => stms
127 :     | stm => [stm]
128 :     (* end case *))
129 :     val name = Atom.toString strand
130 : jhr 623 val nDims = List.length iters
131 :     val worldTy = CL.T_Ptr(CL.T_Named RN.worldTy)
132 :     fun mapi f xs = let
133 :     fun mapf (_, []) = []
134 :     | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs)
135 :     in
136 :     mapf (0, xs)
137 :     end
138 :     val baseInit = mapi (fn (i, (_, E(e, _), _)) => (i, CL.I_Exp e)) iters
139 :     val sizeInit = mapi
140 :     (fn (i, (V(ty, _), E(lo, _), E(hi, _))) =>
141 :     (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, cvtTy ty))))
142 :     ) iters
143 :     val allocCode = [
144 :     CL.S_Comment["allocate initial block of strands"],
145 :     CL.S_Decl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
146 :     CL.S_Decl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
147 :     CL.S_Decl(worldTy, "wrld",
148 :     SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [
149 :     CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)),
150 :     CL.E_Bool isArray,
151 :     CL.E_Int(IntInf.fromInt nDims, CL.int32),
152 :     CL.E_Var "base",
153 :     CL.E_Var "size"
154 :     ]))))
155 :     ]
156 :     (* create the loop nest for the initially iterations *)
157 :     val indexVar = "ix"
158 : jhr 634 val strandTy = CL.T_Ptr(CL.T_Named(RN.strandTy name))
159 : jhr 623 fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
160 : jhr 634 CL.S_Decl(strandTy, "sp",
161 :     SOME(CL.I_Exp(
162 :     CL.E_Cast(strandTy,
163 :     CL.E_Apply(RN.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
164 : jhr 624 CL.S_Call(RN.strandInit name, CL.E_Var "sp" :: List.map (fn (E(e, _)) => e) args),
165 :     CL.S_Assign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
166 : jhr 623 ])
167 :     | mkLoopNest ((V(ty, param), E(lo,_), E(hi, _))::iters) = let
168 :     val body = mkLoopNest iters
169 :     in
170 :     CL.S_For(
171 :     [(cvtTy ty, param, lo)],
172 :     CL.mkBinOp(CL.E_Var param, CL.#<=, hi),
173 :     [CL.mkPostOp(CL.E_Var param, CL.^++)],
174 :     body)
175 :     end
176 :     val iterCode = [
177 :     CL.S_Comment["initially"],
178 :     CL.S_Decl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
179 :     mkLoopNest iters
180 :     ]
181 :     val body = CL.mkBlock(iterPrefix @ allocCode @ iterCode @ [CL.S_Return(SOME(CL.E_Var "wrld"))])
182 :     val initFn = CL.D_Func([], worldTy, RN.initially, [], body)
183 : jhr 618 in
184 : jhr 624 initially := initFn
185 : jhr 618 end
186 :    
187 : jhr 525 structure Var =
188 :     struct
189 : jhr 528 fun global (Prog{globals, ...}, ty, name) = (
190 : jhr 573 globals := CL.D_Var([], cvtTy ty, name, NONE) :: !globals;
191 : jhr 623 V(ty, name))
192 :     fun param (ty, name) = V(ty, name)
193 : jhr 544 fun state (Strand{state, ...}, ty, name) = (
194 : jhr 623 state := V(ty, name) :: !state;
195 :     V(ty, name))
196 :     fun var (ty, name) = V(ty, name)
197 : jhr 554 local
198 :     val count = ref 0
199 :     fun freshName prefix = let
200 :     val n = !count
201 :     in
202 :     count := n+1;
203 :     concat[prefix, "_", Int.toString n]
204 :     end
205 :     in
206 : jhr 623 fun tmp ty = V(ty, freshName "tmp")
207 : jhr 554 fun fresh prefix = freshName prefix
208 :     end (* local *)
209 : jhr 519 end
210 :    
211 :     (* expression construction *)
212 : jhr 525 structure Expr =
213 :     struct
214 : jhr 549 (* return true if the given expression from is allowed as a subexpression *)
215 :     fun allowedInline _ = true (* FIXME *)
216 :    
217 : jhr 519 (* variable references *)
218 : jhr 623 fun global (V(ty, x)) = E(CL.mkVar x, ty)
219 :     fun getState (V(ty, x)) = E(CL.mkIndirect(CL.mkVar "selfIn", x), ty)
220 :     fun param (V(ty, x)) = E(CL.mkVar x, ty)
221 :     fun var (V(ty, x)) = E(CL.mkVar x, ty)
222 : jhr 525
223 : jhr 519 (* literals *)
224 : jhr 623 fun intLit n = E(CL.mkInt(n, !RN.gIntTy), intTy)
225 :     fun floatLit f = E(CL.mkFlt(f, !RN.gRealTy), realTy)
226 :     fun stringLit s = E(CL.mkStr s, stringTy)
227 :     fun boolLit b = E(CL.mkBool b, boolTy)
228 : jhr 525
229 : jhr 561 (* select from a vector. We have to cast to the corresponding union type and then
230 :     * select from the array field.
231 :     *)
232 :     local
233 :     fun sel (tyCode, field, ty) (i, e, n) =
234 :     if (i < 0) orelse (n <= i)
235 :     then raise Subscript
236 :     else let
237 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !tyCode, "_t"])
238 :     val e1 = CL.mkCast(unionTy, e)
239 :     val e2 = CL.mkSelect(e1, field)
240 :     in
241 : jhr 623 E(CL.mkSubscript(e2, CL.mkInt(IntInf.fromInt i, CL.int32)), ty)
242 : jhr 561 end
243 :     val selF = sel (RN.gRealSuffix, "r", T_Real)
244 :     val selI = sel (RN.gIntSuffix, "i", T_Int)
245 :     in
246 : jhr 654 fun ivecIndex (e, d, i) = let val E(e', _) = selI(i, e, d) in e' end
247 :     fun vecIndex (e, d, i) = let val E(e', _) = selF(i, e, d) in e' end
248 : jhr 623 fun select (i, E(e, T_Vec n)) = selF (i, e, n)
249 :     | select (i, E(e, T_IVec n)) = selI (i, e, n)
250 : jhr 548 | select (_, x) = invalid("select", [x])
251 : jhr 561 end (* local *)
252 : jhr 525
253 : jhr 519 (* vector (and scalar) arithmetic *)
254 : jhr 525 local
255 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1
256 : jhr 623 fun binop rator (E(e1, ty1), E(e2, ty2)) =
257 : jhr 525 if checkTys (ty1, ty2)
258 : jhr 623 then E(CL.mkBinOp(e1, rator, e2), ty1)
259 : jhr 548 else invalid (
260 :     concat["binary operator \"", CL.binopToString rator, "\""],
261 : jhr 623 [E(e1, ty1), E(e2, ty2)])
262 : jhr 525 in
263 : jhr 623 fun add (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#+, e2), ty)
264 : jhr 548 | add args = binop CL.#+ args
265 : jhr 623 fun sub (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#-, e2), ty)
266 : jhr 548 | sub args = binop CL.#- args
267 : jhr 544 (* NOTE: multiplication and division are also used for scaling *)
268 : jhr 623 fun mul (E(e1, T_Real), E(e2, T_Vec n)) =
269 :     E(CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n)
270 : jhr 544 | mul args = binop CL.#* args
271 : jhr 623 fun divide (E(e1, T_Vec n), E(e2, T_Real)) = let
272 :     val E(one, _) = floatLit FloatLit.one
273 :     in
274 :     E(CL.E_Apply(RN.scale n, [CL.mkBinOp(one, CL.#/, e2), e1]), T_Vec n)
275 :     end
276 : jhr 544 | divide args = binop CL.#/ args
277 : jhr 525 end (* local *)
278 : jhr 623 fun neg (E(e, T_Bool)) = raise Fail "invalid argument to neg"
279 :     | neg (E(e, ty)) = E(CL.mkUnOp(CL.%-, e), ty)
280 : jhr 525
281 : jhr 623 fun abs (E(e, T_Int)) = E(CL.mkApply("abs", [e]), T_Int) (* FIXME: not the right type for 64-bit ints *)
282 :     | abs (E(e, T_Real)) = E(CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real)
283 :     | abs (E(e, T_Vec n)) = raise Fail "FIXME: Expr.abs"
284 :     | abs (E(e, T_IVec n)) = raise Fail "FIXME: Expr.abs"
285 : jhr 525 | abs _ = raise Fail "invalid argument to abs"
286 :    
287 : jhr 623 fun dot (E(e1, T_Vec n1), E(e2, T_Vec n2)) = E(CL.E_Apply(RN.dot n1, [e1, e2]), T_Real)
288 : jhr 525 | dot _ = raise Fail "invalid argument to dot"
289 :    
290 : jhr 623 fun cross (E(e1, T_Vec 3), E(e2, T_Vec 3)) = E(CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3)
291 : jhr 525 | cross _ = raise Fail "invalid argument to cross"
292 :    
293 : jhr 623 fun length (E(e, T_Vec n)) = E(CL.E_Apply(RN.length n, [e]), T_Real)
294 : jhr 525 | length _ = raise Fail "invalid argument to length"
295 :    
296 : jhr 623 fun normalize (E(e, T_Vec n)) = E(CL.E_Apply(RN.normalize n, [e]), T_Vec n)
297 : jhr 525 | normalize _ = raise Fail "invalid argument to length"
298 :    
299 : jhr 683 (* matrix operations *)
300 :     fun trace (E(e, T_Mat(n,m))) = if (n = m) andalso (1 < n) andalso (m <= 4)
301 :     then E(CL.E_Apply(RN.trace n, [e]), T_Real)
302 :     else raise Fail "invalid matrix argument for trace"
303 :     | trace _ = raise Fail "invalid argument to trace"
304 :    
305 : jhr 519 (* comparisons *)
306 : jhr 525 local
307 :     fun checkTys (ty1, ty2) =
308 :     (ty1 = ty2) andalso scalarTy ty1
309 : jhr 623 fun cmpop rator (E(e1, ty1), E(e2, ty2)) =
310 : jhr 525 if checkTys (ty1, ty2)
311 : jhr 623 then E(CL.mkBinOp(e1, rator, e2), T_Bool)
312 : jhr 548 else invalid (
313 :     concat["compare operator \"", CL.binopToString rator, "\""],
314 : jhr 623 [E(e1, ty1), E(e2, ty2)])
315 : jhr 525 in
316 :     val lt = cmpop CL.#<
317 :     val lte = cmpop CL.#<=
318 :     val equ = cmpop CL.#==
319 :     val neq = cmpop CL.#!=
320 :     val gte = cmpop CL.#>=
321 :     val gt = cmpop CL.#>
322 :     end (* local *)
323 :    
324 : jhr 519 (* logical connectives *)
325 : jhr 623 fun not (E(e, T_Bool)) = E(CL.mkUnOp(CL.%!, e), T_Bool)
326 : jhr 525 | not _ = raise Fail "invalid argument to not"
327 : jhr 623 fun && (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#&&, e2), T_Bool)
328 : jhr 525 | && _ = raise Fail "invalid arguments to &&"
329 : jhr 623 fun || (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#||, e2), T_Bool)
330 : jhr 525 | || _ = raise Fail "invalid arguments to ||"
331 :    
332 :     local
333 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1
334 : jhr 623 fun binFn f (E(e1, ty1), E(e2, ty2)) =
335 : jhr 525 if checkTys (ty1, ty2)
336 : jhr 623 then E(CL.mkApply(f ty1, [e1, e2]), ty1)
337 : jhr 525 else raise Fail "invalid arguments to binary function"
338 :     in
339 : jhr 519 (* misc functions *)
340 : jhr 561 val min = binFn RN.min
341 :     val max = binFn RN.max
342 : jhr 525 end (* local *)
343 :    
344 : jhr 551 (* rounding *)
345 : jhr 623 fun trunc (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty)
346 :     fun round (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("round", ty), [e]), ty)
347 :     fun floor (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty)
348 :     fun ceil (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty)
349 : jhr 551
350 : jhr 519 (* conversions *)
351 : jhr 623 fun toInt (E(e, T_Real)) = E(CL.mkCast(!RN.gIntTy, e), T_Int)
352 :     | toInt (E(e, T_Vec n)) = E(CL.mkApply(RN.vecftoi n, [e]), ivecTy n)
353 : jhr 565 | toInt e = invalid ("toInt", [e])
354 : jhr 623 fun toReal (E(e, T_Int)) = E(CL.mkCast(!RN.gRealTy, e), T_Real)
355 : jhr 548 | toReal e = invalid ("toReal", [e])
356 : jhr 525
357 : jhr 519 (* runtime system hooks *)
358 : jhr 623 fun imageAddr (E(e, T_Image(_, rTy))) = let
359 : jhr 561 val cTy = CL.T_Ptr(CL.T_Num rTy)
360 : jhr 528 in
361 : jhr 623 E(CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy)
362 : jhr 528 end
363 : jhr 548 | imageAddr a = invalid("imageAddr", [a])
364 : jhr 623 fun getImgData (E(e, T_Ptr rTy)) = let
365 : jhr 551 val realTy as CL.T_Num rTy' = !RN.gRealTy
366 : jhr 548 val e = CL.E_UnOp(CL.%*, e)
367 :     in
368 :     if (rTy' = rTy)
369 : jhr 623 then E(e, T_Real)
370 :     else E(CL.E_Cast(realTy, e), T_Real)
371 : jhr 548 end
372 :     | getImgData a = invalid("getImgData", [a])
373 : jhr 623 fun posToImgSpace (E(img, T_Image(d, _)), E(pos, T_Vec n)) = let
374 : jhr 551 val e = CL.mkApply(RN.toImageSpace d, [img, pos])
375 : jhr 548 in
376 : jhr 623 E(e, T_Vec n)
377 : jhr 548 end
378 :     | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b])
379 : jhr 623 fun inside (E(pos, T_Vec n), E(img, T_Image(d, _)), s) = let
380 : jhr 551 val e = CL.mkApply(RN.inside d,
381 : jhr 576 [pos, img, CL.mkInt(IntInf.fromInt s, CL.int32)])
382 : jhr 547 in
383 : jhr 623 E(e, T_Bool)
384 : jhr 547 end
385 : jhr 548 | inside (a, b, _) = invalid("inside", [a, b])
386 : jhr 519
387 : jhr 695 (* other basis functions *)
388 :     local
389 :     val basis = [
390 :     (ILBasis.atan2, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
391 :     (ILBasis.cos, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
392 :     (ILBasis.pow, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
393 :     (ILBasis.sin, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
394 :     (ILBasis.sqrt, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
395 :     (ILBasis.tan, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real)
396 :     ]
397 :     fun mkLookup suffix = let
398 :     val tbl = ILBasis.Tbl.mkTable (16, Fail "basis table")
399 :     fun ins (f, chkTy, resTy) =
400 :     ILBasis.Tbl.insert tbl
401 :     (f, (ILBasis.toString f ^ suffix, chkTy, resTy))
402 :     in
403 :     List.app ins basis;
404 :     ILBasis.Tbl.lookup tbl
405 :     end
406 :     val fLookup = mkLookup "f"
407 :     val dLookup = mkLookup ""
408 :     in
409 :     fun apply (f, args) = let
410 :     val (f', chkArgs, resTy) = if !Controls.doublePrecision then dLookup f else fLookup f
411 :     in
412 :     case chkArgs args
413 :     of SOME args => E(CL.mkApply(f', args), resTy)
414 :     | NONE => raise Fail("invalid arguments for "^ILBasis.toString f)
415 :     end
416 :     end (* local *)
417 : jhr 547 end (* Expr *)
418 :    
419 : jhr 519 (* statement construction *)
420 : jhr 525 structure Stmt =
421 :     struct
422 :     val comment = CL.S_Comment
423 : jhr 623 fun assignState (V(_, x), E(e, _)) =
424 : jhr 547 CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e)
425 : jhr 623 fun assign (V(_, x), E(e, _)) = CL.mkAssign(CL.mkVar x, e)
426 :     fun decl (V(ty, x), SOME(E(e, _))) = CL.mkDecl(cvtTy ty, x, SOME(CL.I_Exp e))
427 :     | decl (V(ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE)
428 : jhr 525 val block = CL.mkBlock
429 : jhr 623 fun ifthen (E(e, T_Bool), s1) = CL.mkIfThen(e, s1)
430 :     fun ifthenelse (E(e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2)
431 :     fun for (V(ty, x), E(lo, _), E(hi, _), body) = CL.mkFor(
432 : jhr 617 [(cvtTy ty, x, lo)],
433 :     CL.mkBinOp(CL.mkVar x, CL.#<=, hi),
434 :     [CL.mkPostOp(CL.mkVar x, CL.^++)],
435 :     body)
436 : jhr 534 (* special Diderot forms *)
437 : jhr 623 fun cons (V(T_Vec n, x), args : exp list) =
438 :     CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.map (fn E(e, _) => e) args))
439 : jhr 553 | cons _ = raise Fail "bogus cons"
440 : jhr 623 fun getImgData (V(T_Vec n, x), E(e, T_Ptr rTy)) = let
441 : jhr 554 val addr = Var.fresh "vp"
442 : jhr 561 val needsCast = (CL.T_Num rTy <> !RN.gRealTy)
443 :     fun mkLoad i = let
444 :     val e = CL.mkSubscript(CL.mkVar addr, CL.mkInt(IntInf.fromInt i, CL.int32))
445 :     in
446 :     if needsCast then CL.mkCast(!RN.gRealTy, e) else e
447 :     end
448 : jhr 554 in [
449 : jhr 623 CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), addr, SOME(CL.I_Exp e)),
450 : jhr 561 CL.mkAssign(CL.mkVar x,
451 :     CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad)))
452 : jhr 554 ] end
453 :     | getImgData _ = raise Fail "bogus getImgData"
454 :     local
455 :     fun checkSts mkDecl = let
456 :     val sts = Var.fresh "sts"
457 :     in
458 :     mkDecl sts @
459 :     [CL.mkIfThen(
460 :     CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts),
461 :     CL.mkCall("exit", [CL.mkInt(1, CL.int32)]))]
462 :     end
463 :     in
464 : jhr 623 fun loadImage (V(_, lhs), dim, E(name, _)) = checkSts (fn sts => let
465 : jhr 551 val imgTy = CL.T_Named(RN.imageTy dim)
466 :     val loadFn = RN.loadImage dim
467 : jhr 534 in [
468 :     CL.S_Decl(
469 :     statusTy, sts,
470 : jhr 623 SOME(CL.I_Exp(CL.E_Apply(loadFn, [name, CL.mkUnOp(CL.%&, CL.E_Var lhs)]))))
471 : jhr 554 ] end)
472 : jhr 623 fun input (V(ty, lhs), name, optDflt) = checkSts (fn sts => let
473 :     val inputFn = RN.input ty
474 :     val lhs = CL.E_Var lhs
475 : jhr 534 val (initCode, hasDflt) = (case optDflt
476 : jhr 623 of SOME(E(e, _)) => ([CL.S_Assign(lhs, e)], true)
477 : jhr 534 | NONE => ([], false)
478 :     (* end case *))
479 :     val code = [
480 :     CL.S_Decl(
481 :     statusTy, sts,
482 : jhr 623 SOME(CL.I_Exp(CL.E_Apply(inputFn, [
483 : jhr 534 CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
484 : jhr 623 ]))))
485 : jhr 534 ]
486 :     in
487 :     initCode @ code
488 : jhr 554 end)
489 :     end (* local *)
490 : jhr 564 fun exit () = CL.mkReturn NONE
491 :     fun active () = CL.mkReturn(SOME(CL.mkVar RN.kActive))
492 :     fun stabilize () = CL.mkReturn(SOME(CL.mkVar RN.kStabilize))
493 : jhr 562 fun die () = CL.mkReturn(SOME(CL.mkVar RN.kDie))
494 : jhr 519 end
495 :    
496 : jhr 544 structure Strand =
497 :     struct
498 :     fun define (Prog{strands, ...}, strandId) = let
499 : jhr 624 val name = Atom.toString strandId
500 : jhr 544 val strand = Strand{
501 : jhr 624 name = name,
502 :     tyName = RN.strandTy name,
503 : jhr 544 state = ref [],
504 : jhr 654 output = ref NONE,
505 : jhr 544 code = ref []
506 :     }
507 :     in
508 : jhr 624 AtomTable.insert strands (strandId, strand);
509 : jhr 544 strand
510 :     end
511 :    
512 : jhr 624 (* return the strand with the given name *)
513 :     fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId
514 :    
515 : jhr 544 (* register the strand-state initialization code. The variables are the strand
516 :     * parameters.
517 :     *)
518 :     fun init (Strand{name, tyName, code, ...}, params, init) = let
519 : jhr 551 val fName = RN.strandInit name
520 : jhr 544 val params =
521 : jhr 547 CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
522 : jhr 623 List.map (fn (V(ty, x)) => CL.PARAM([], cvtTy ty, x)) params
523 : jhr 544 val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
524 :     in
525 :     code := initFn :: !code
526 :     end
527 : jhr 547
528 :     (* register a strand method *)
529 :     fun method (Strand{name, tyName, code, ...}, methName, body) = let
530 :     val fName = concat[name, "_", methName]
531 :     val params = [
532 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
533 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
534 :     ]
535 : jhr 654 val methFn = CL.D_Func(["static"], CL.int32, fName, params, body)
536 : jhr 547 in
537 :     code := methFn :: !code
538 :     end
539 : jhr 654
540 :     fun output (Strand{output, ...}, x) = (case !output
541 :     of NONE => output := SOME x
542 :     | _ => raise Fail "multiple outputs are not supported yet"
543 :     (* end case *))
544 : jhr 544 end (* Strand *)
545 :    
546 : jhr 654 fun genStrand (Strand{name, tyName, state, output, code}) = let
547 : jhr 624 (* the type declaration for the strand's state struct *)
548 : jhr 544 val selfTyDef = CL.D_StructDef(
549 : jhr 623 List.rev (List.map (fn V(ty, x) => (cvtTy ty, x)) (!state)),
550 : jhr 544 tyName)
551 : jhr 654 (* the print function *)
552 :     val prFnName = concat[name, "_print"]
553 :     val prFn = let
554 :     val params = [
555 :     CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
556 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
557 :     ]
558 :     val SOME(V(ty, x)) = !output
559 :     val outState = CL.mkIndirect(CL.mkVar "self", x)
560 :     val prArgs = (case ty
561 :     of TargetTy.T_Int => [CL.E_Str(!RN.gIntFormat ^ "\n"), outState]
562 :     | TargetTy.T_Real => [CL.E_Str "%f\n", outState]
563 :     | TargetTy.T_Vec d => let
564 :     val fmt = CL.E_Str(
565 :     String.concatWith " " (List.tabulate(d, fn _ => "%f"))
566 :     ^ "\n")
567 : jhr 656 val args = List.tabulate (d, fn i => Expr.vecIndex(outState, d, i))
568 : jhr 654 in
569 :     fmt :: args
570 :     end
571 :     | TargetTy.T_IVec d => let
572 :     val fmt = CL.E_Str(
573 :     String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))
574 :     ^ "\n")
575 : jhr 656 val args = List.tabulate (d, fn i => Expr.ivecIndex(outState, d, i))
576 : jhr 654 in
577 :     fmt :: args
578 :     end
579 :     | _ => raise Fail("genStrand: unsupported output type " ^ TargetTy.toString ty)
580 :     (* end case *))
581 :     in
582 :     CL.D_Func(["static"], CL.voidTy, prFnName, params,
583 :     CL.S_Call("fprintf", CL.mkVar "outS" :: prArgs))
584 :     end
585 : jhr 624 (* the strand's descriptor object *)
586 :     val descI = let
587 : jhr 573 fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
588 :     in
589 :     CL.I_Struct[
590 :     ("name", CL.I_Exp(CL.E_Str name)),
591 :     ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))),
592 : jhr 654 ("update", fnPtr("update_method_t", name ^ "_update")),
593 :     ("print", fnPtr("print_method_t", prFnName))
594 : jhr 573 ]
595 :     end
596 : jhr 624 val desc = CL.D_Var([], CL.T_Named RN.strandDescTy, RN.strandDesc name, SOME descI)
597 :     in
598 : jhr 654 selfTyDef :: List.rev (desc :: prFn :: !code)
599 : jhr 624 end
600 :    
601 :     (* generate the table of strand descriptors *)
602 :     fun genStrandTable (ppStrm, strands) = let
603 :     val nStrands = length strands
604 :     fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)))
605 : jhr 573 fun genInits (_, []) = []
606 :     | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
607 :     fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
608 :     in
609 :     ppDecl (CL.D_Var([], CL.int32, RN.numStrands,
610 :     SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
611 : jhr 624 ppDecl (CL.D_Var([],
612 :     CL.T_Array(CL.T_Ptr(CL.T_Named RN.strandDescTy), SOME nStrands),
613 :     RN.strands,
614 : jhr 573 SOME(CL.I_Array(genInits (0, strands)))))
615 :     end
616 :    
617 : jhr 624 fun generate (baseName, Prog{globals, topDecls, strands, initially}) = let
618 : jhr 527 val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
619 :     val outS = TextIO.openOut fileName
620 :     val ppStrm = PrintAsC.new outS
621 : jhr 533 fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
622 : jhr 624 val strands = AtomTable.listItems strands
623 : jhr 527 in
624 : jhr 533 List.app ppDecl (List.rev (!globals));
625 :     List.app ppDecl (List.rev (!topDecls));
626 : jhr 624 List.app (fn strand => List.app ppDecl (genStrand strand)) strands;
627 :     genStrandTable (ppStrm, strands);
628 :     ppDecl (!initially);
629 : jhr 527 PrintAsC.close ppStrm;
630 :     TextIO.closeOut outS
631 :     end
632 :    
633 : jhr 519 end
634 :    
635 :     structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0