Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 798 - (view) (download)

1 : jhr 519 (* c-target.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Generate C code with SSE 4.2 intrinsics.
7 :     *)
8 :    
9 :     structure CTarget : TARGET =
10 :     struct
11 :    
12 : jhr 522 structure CL = CLang
13 : jhr 551 structure RN = RuntimeNames
14 : jhr 522
15 : jhr 551 datatype ty = datatype TargetTy.ty
16 : jhr 519
17 : jhr 623 datatype var = V of (ty * string)
18 :    
19 :     datatype exp = E of CLang.exp * ty
20 :    
21 :     type stm = CL.stm
22 :    
23 : jhr 544 datatype strand = Strand of {
24 :     name : string,
25 :     tyName : string,
26 : jhr 623 state : var list ref,
27 : jhr 654 output : var option ref, (* the strand's output variable (only one for now) *)
28 : jhr 544 code : CL.decl list ref
29 :     }
30 : jhr 525
31 : jhr 527 datatype program = Prog of {
32 :     globals : CL.decl list ref,
33 : jhr 533 topDecls : CL.decl list ref,
34 : jhr 624 strands : strand AtomTable.hash_table,
35 :     initially : CL.decl ref
36 : jhr 527 }
37 :    
38 : jhr 519 (* for SSE, we have 128-bit vectors *)
39 : jhr 551 fun vectorWidth () = !RN.gVectorWid
40 : jhr 519
41 :     (* target types *)
42 : jhr 525 val boolTy = T_Bool
43 :     val intTy = T_Int
44 :     val realTy = T_Real
45 :     fun vecTy 1 = T_Real
46 : jhr 551 | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
47 : jhr 525 then raise Size
48 :     else T_Vec n
49 :     fun ivecTy 1 = T_Int
50 : jhr 551 | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
51 : jhr 525 then raise Size
52 :     else T_IVec n
53 : jhr 736 fun tensorTy [] = realTy
54 :     | tensorTy [n] = vecTy n
55 :     | tensorTy [d1, d2] = T_Mat(d1, d2)
56 :     | tensorTy dd = raise Fail "FIXME: order > 2 tensor type"
57 : jhr 548 fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy)
58 :     fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy
59 : jhr 534 val stringTy = T_String
60 : jhr 519
61 : jhr 552 val statusTy = CL.T_Named RN.statusTy
62 : jhr 534
63 : jhr 528 (* convert target types to CLang types *)
64 :     fun cvtTy T_Bool = CLang.T_Named "bool"
65 : jhr 534 | cvtTy T_String = CL.charPtr
66 : jhr 551 | cvtTy T_Int = !RN.gIntTy
67 :     | cvtTy T_Real = !RN.gRealTy
68 : jhr 561 | cvtTy (T_Vec n) = CL.T_Named(RN.vecTy n)
69 :     | cvtTy (T_IVec n) = CL.T_Named(RN.ivecTy n)
70 : jhr 683 | cvtTy (T_Mat(n,m)) = CL.T_Named(RN.matTy(n,m))
71 : jhr 561 | cvtTy (T_Image(n, _)) = CL.T_Ptr(CL.T_Named(RN.imageTy n))
72 : jhr 548 | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty)
73 : jhr 528
74 : jhr 548 (* report invalid arguments *)
75 :     fun invalid (name, []) = raise Fail("invaild "^name)
76 :     | invalid (name, args) = let
77 : jhr 623 fun arg2s (E(e, ty)) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"]
78 : jhr 548 val args = String.concatWith ", " (List.map arg2s args)
79 :     in
80 :     raise Fail(concat["invalid arguments to ", name, ": ", args])
81 :     end
82 :    
83 : jhr 525 (* helper functions for checking the types of arguments *)
84 :     fun scalarTy T_Int = true
85 :     | scalarTy T_Real = true
86 :     | scalarTy _ = false
87 : jhr 548 fun numTy T_Int = true
88 :     | numTy T_Real = true
89 :     | numTy (T_Vec _) = true
90 :     | numTy (T_IVec _) = true
91 :     | numTy _ = false
92 : jhr 519
93 : jhr 798 (* integer literal expression *)
94 :     fun intExp (i : int) = CL.mkInt(IntInf.fromInt i, CL.int32)
95 :    
96 : jhr 528 fun newProgram () = (
97 : jhr 551 RN.initTargetSpec();
98 : jhr 528 Prog{
99 : jhr 554 globals = ref [
100 :     CL.D_Verbatim[
101 :     if !Controls.doublePrecision
102 :     then "#define DIDEROT_DOUBLE_PRECISION"
103 :     else "#define DIDEROT_SINGLE_PRECISION",
104 :     "#include \"Diderot/diderot.h\""
105 :     ]],
106 : jhr 533 topDecls = ref [],
107 : jhr 624 strands = AtomTable.mkTable (16, Fail "strand table"),
108 :     initially = ref(CL.D_Comment["missing initially"])
109 : jhr 528 })
110 :    
111 : jhr 618 (* register the global initialization part of a program *)
112 : jhr 533 fun globalInit (Prog{topDecls, ...}, init) = let
113 : jhr 551 val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
114 : jhr 533 in
115 :     topDecls := initFn :: !topDecls
116 :     end
117 :    
118 : jhr 623 (* create and register the initially function for a program *)
119 :     fun initially {
120 : jhr 624 prog = Prog{strands, initially, ...},
121 : jhr 623 isArray : bool,
122 : jhr 624 iterPrefix : stm,
123 : jhr 623 iters : (var * exp * exp) list,
124 : jhr 624 createPrefix : stm,
125 :     strand : Atom.atom,
126 : jhr 623 args : exp list
127 :     } = let
128 : jhr 624 val iterPrefix = (case iterPrefix
129 :     of CL.S_Block stms => stms
130 :     | stm => [stm]
131 :     (* end case *))
132 :     val createPrefix = (case createPrefix
133 :     of CL.S_Block stms => stms
134 :     | stm => [stm]
135 :     (* end case *))
136 :     val name = Atom.toString strand
137 : jhr 623 val nDims = List.length iters
138 :     val worldTy = CL.T_Ptr(CL.T_Named RN.worldTy)
139 :     fun mapi f xs = let
140 :     fun mapf (_, []) = []
141 :     | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs)
142 :     in
143 :     mapf (0, xs)
144 :     end
145 :     val baseInit = mapi (fn (i, (_, E(e, _), _)) => (i, CL.I_Exp e)) iters
146 :     val sizeInit = mapi
147 :     (fn (i, (V(ty, _), E(lo, _), E(hi, _))) =>
148 :     (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, cvtTy ty))))
149 :     ) iters
150 :     val allocCode = [
151 :     CL.S_Comment["allocate initial block of strands"],
152 :     CL.S_Decl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
153 :     CL.S_Decl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
154 :     CL.S_Decl(worldTy, "wrld",
155 :     SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [
156 :     CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)),
157 :     CL.E_Bool isArray,
158 :     CL.E_Int(IntInf.fromInt nDims, CL.int32),
159 :     CL.E_Var "base",
160 :     CL.E_Var "size"
161 :     ]))))
162 :     ]
163 :     (* create the loop nest for the initially iterations *)
164 :     val indexVar = "ix"
165 : jhr 634 val strandTy = CL.T_Ptr(CL.T_Named(RN.strandTy name))
166 : jhr 623 fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
167 : jhr 634 CL.S_Decl(strandTy, "sp",
168 :     SOME(CL.I_Exp(
169 :     CL.E_Cast(strandTy,
170 :     CL.E_Apply(RN.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
171 : jhr 624 CL.S_Call(RN.strandInit name, CL.E_Var "sp" :: List.map (fn (E(e, _)) => e) args),
172 :     CL.S_Assign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
173 : jhr 623 ])
174 :     | mkLoopNest ((V(ty, param), E(lo,_), E(hi, _))::iters) = let
175 :     val body = mkLoopNest iters
176 :     in
177 :     CL.S_For(
178 :     [(cvtTy ty, param, lo)],
179 :     CL.mkBinOp(CL.E_Var param, CL.#<=, hi),
180 :     [CL.mkPostOp(CL.E_Var param, CL.^++)],
181 :     body)
182 :     end
183 :     val iterCode = [
184 :     CL.S_Comment["initially"],
185 :     CL.S_Decl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
186 :     mkLoopNest iters
187 :     ]
188 :     val body = CL.mkBlock(iterPrefix @ allocCode @ iterCode @ [CL.S_Return(SOME(CL.E_Var "wrld"))])
189 :     val initFn = CL.D_Func([], worldTy, RN.initially, [], body)
190 : jhr 618 in
191 : jhr 624 initially := initFn
192 : jhr 618 end
193 :    
194 : jhr 525 structure Var =
195 :     struct
196 : jhr 528 fun global (Prog{globals, ...}, ty, name) = (
197 : jhr 573 globals := CL.D_Var([], cvtTy ty, name, NONE) :: !globals;
198 : jhr 623 V(ty, name))
199 :     fun param (ty, name) = V(ty, name)
200 : jhr 544 fun state (Strand{state, ...}, ty, name) = (
201 : jhr 623 state := V(ty, name) :: !state;
202 :     V(ty, name))
203 :     fun var (ty, name) = V(ty, name)
204 : jhr 554 local
205 :     val count = ref 0
206 :     fun freshName prefix = let
207 :     val n = !count
208 :     in
209 :     count := n+1;
210 :     concat[prefix, "_", Int.toString n]
211 :     end
212 :     in
213 : jhr 623 fun tmp ty = V(ty, freshName "tmp")
214 : jhr 554 fun fresh prefix = freshName prefix
215 :     end (* local *)
216 : jhr 519 end
217 :    
218 :     (* expression construction *)
219 : jhr 525 structure Expr =
220 :     struct
221 : jhr 549 (* return true if the given expression from is allowed as a subexpression *)
222 :     fun allowedInline _ = true (* FIXME *)
223 :    
224 : jhr 519 (* variable references *)
225 : jhr 623 fun global (V(ty, x)) = E(CL.mkVar x, ty)
226 :     fun getState (V(ty, x)) = E(CL.mkIndirect(CL.mkVar "selfIn", x), ty)
227 :     fun param (V(ty, x)) = E(CL.mkVar x, ty)
228 :     fun var (V(ty, x)) = E(CL.mkVar x, ty)
229 : jhr 525
230 : jhr 519 (* literals *)
231 : jhr 623 fun intLit n = E(CL.mkInt(n, !RN.gIntTy), intTy)
232 :     fun floatLit f = E(CL.mkFlt(f, !RN.gRealTy), realTy)
233 :     fun stringLit s = E(CL.mkStr s, stringTy)
234 :     fun boolLit b = E(CL.mkBool b, boolTy)
235 : jhr 525
236 : jhr 561 (* select from a vector. We have to cast to the corresponding union type and then
237 :     * select from the array field.
238 :     *)
239 :     local
240 :     fun sel (tyCode, field, ty) (i, e, n) =
241 :     if (i < 0) orelse (n <= i)
242 :     then raise Subscript
243 :     else let
244 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !tyCode, "_t"])
245 :     val e1 = CL.mkCast(unionTy, e)
246 :     val e2 = CL.mkSelect(e1, field)
247 :     in
248 : jhr 798 E(CL.mkSubscript(e2, intExp i), ty)
249 : jhr 561 end
250 :     val selF = sel (RN.gRealSuffix, "r", T_Real)
251 :     val selI = sel (RN.gIntSuffix, "i", T_Int)
252 :     in
253 : jhr 654 fun ivecIndex (e, d, i) = let val E(e', _) = selI(i, e, d) in e' end
254 :     fun vecIndex (e, d, i) = let val E(e', _) = selF(i, e, d) in e' end
255 : jhr 623 fun select (i, E(e, T_Vec n)) = selF (i, e, n)
256 :     | select (i, E(e, T_IVec n)) = selI (i, e, n)
257 : jhr 548 | select (_, x) = invalid("select", [x])
258 : jhr 561 end (* local *)
259 : jhr 525
260 : jhr 705 fun subscript (E(e1, ty), E(e2, T_Int)) = let
261 :     val (n, tyCode, elemTy) = (case ty
262 :     of T_Vec n => (n, !RN.gRealSuffix, T_Real)
263 :     | T_IVec n => (n, !RN.gIntSuffix, T_Int)
264 :     (* end case *))
265 :     val unionTy = CL.T_Named(concat["union", Int.toString n, tyCode, "_t"])
266 : jhr 732 val vecExp = CL.mkSelect(CL.mkCast(unionTy, e1), "r")
267 : jhr 705 in
268 : jhr 732 E(CL.mkSubscript(vecExp, e2), elemTy)
269 : jhr 705 end
270 :    
271 : jhr 519 (* vector (and scalar) arithmetic *)
272 : jhr 525 local
273 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1
274 : jhr 623 fun binop rator (E(e1, ty1), E(e2, ty2)) =
275 : jhr 525 if checkTys (ty1, ty2)
276 : jhr 623 then E(CL.mkBinOp(e1, rator, e2), ty1)
277 : jhr 548 else invalid (
278 :     concat["binary operator \"", CL.binopToString rator, "\""],
279 : jhr 623 [E(e1, ty1), E(e2, ty2)])
280 : jhr 525 in
281 : jhr 623 fun add (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#+, e2), ty)
282 : jhr 548 | add args = binop CL.#+ args
283 : jhr 623 fun sub (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#-, e2), ty)
284 : jhr 548 | sub args = binop CL.#- args
285 : jhr 544 (* NOTE: multiplication and division are also used for scaling *)
286 : jhr 623 fun mul (E(e1, T_Real), E(e2, T_Vec n)) =
287 :     E(CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n)
288 : jhr 544 | mul args = binop CL.#* args
289 : jhr 623 fun divide (E(e1, T_Vec n), E(e2, T_Real)) = let
290 :     val E(one, _) = floatLit FloatLit.one
291 :     in
292 :     E(CL.E_Apply(RN.scale n, [CL.mkBinOp(one, CL.#/, e2), e1]), T_Vec n)
293 :     end
294 : jhr 544 | divide args = binop CL.#/ args
295 : jhr 525 end (* local *)
296 : jhr 623 fun neg (E(e, T_Bool)) = raise Fail "invalid argument to neg"
297 :     | neg (E(e, ty)) = E(CL.mkUnOp(CL.%-, e), ty)
298 : jhr 525
299 : jhr 623 fun abs (E(e, T_Int)) = E(CL.mkApply("abs", [e]), T_Int) (* FIXME: not the right type for 64-bit ints *)
300 :     | abs (E(e, T_Real)) = E(CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real)
301 :     | abs (E(e, T_Vec n)) = raise Fail "FIXME: Expr.abs"
302 :     | abs (E(e, T_IVec n)) = raise Fail "FIXME: Expr.abs"
303 : jhr 525 | abs _ = raise Fail "invalid argument to abs"
304 :    
305 : jhr 623 fun dot (E(e1, T_Vec n1), E(e2, T_Vec n2)) = E(CL.E_Apply(RN.dot n1, [e1, e2]), T_Real)
306 : jhr 525 | dot _ = raise Fail "invalid argument to dot"
307 :    
308 : jhr 623 fun cross (E(e1, T_Vec 3), E(e2, T_Vec 3)) = E(CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3)
309 : jhr 525 | cross _ = raise Fail "invalid argument to cross"
310 :    
311 : jhr 623 fun length (E(e, T_Vec n)) = E(CL.E_Apply(RN.length n, [e]), T_Real)
312 : jhr 525 | length _ = raise Fail "invalid argument to length"
313 :    
314 : jhr 623 fun normalize (E(e, T_Vec n)) = E(CL.E_Apply(RN.normalize n, [e]), T_Vec n)
315 : jhr 525 | normalize _ = raise Fail "invalid argument to length"
316 :    
317 : jhr 683 (* matrix operations *)
318 :     fun trace (E(e, T_Mat(n,m))) = if (n = m) andalso (1 < n) andalso (m <= 4)
319 :     then E(CL.E_Apply(RN.trace n, [e]), T_Real)
320 :     else raise Fail "invalid matrix argument for trace"
321 :     | trace _ = raise Fail "invalid argument to trace"
322 :    
323 : jhr 519 (* comparisons *)
324 : jhr 525 local
325 :     fun checkTys (ty1, ty2) =
326 :     (ty1 = ty2) andalso scalarTy ty1
327 : jhr 623 fun cmpop rator (E(e1, ty1), E(e2, ty2)) =
328 : jhr 525 if checkTys (ty1, ty2)
329 : jhr 623 then E(CL.mkBinOp(e1, rator, e2), T_Bool)
330 : jhr 548 else invalid (
331 :     concat["compare operator \"", CL.binopToString rator, "\""],
332 : jhr 623 [E(e1, ty1), E(e2, ty2)])
333 : jhr 525 in
334 :     val lt = cmpop CL.#<
335 :     val lte = cmpop CL.#<=
336 :     val equ = cmpop CL.#==
337 :     val neq = cmpop CL.#!=
338 :     val gte = cmpop CL.#>=
339 :     val gt = cmpop CL.#>
340 :     end (* local *)
341 :    
342 : jhr 519 (* logical connectives *)
343 : jhr 623 fun not (E(e, T_Bool)) = E(CL.mkUnOp(CL.%!, e), T_Bool)
344 : jhr 525 | not _ = raise Fail "invalid argument to not"
345 : jhr 623 fun && (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#&&, e2), T_Bool)
346 : jhr 525 | && _ = raise Fail "invalid arguments to &&"
347 : jhr 623 fun || (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#||, e2), T_Bool)
348 : jhr 525 | || _ = raise Fail "invalid arguments to ||"
349 :    
350 : jhr 754 (* misc functions *)
351 : jhr 525 local
352 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1
353 : jhr 623 fun binFn f (E(e1, ty1), E(e2, ty2)) =
354 : jhr 525 if checkTys (ty1, ty2)
355 : jhr 623 then E(CL.mkApply(f ty1, [e1, e2]), ty1)
356 : jhr 525 else raise Fail "invalid arguments to binary function"
357 :     in
358 : jhr 561 val min = binFn RN.min
359 :     val max = binFn RN.max
360 : jhr 754 fun lerp (E(e1, ty1), E(e2, ty2), E(e3, T_Real)) =
361 :     if (ty1 = ty2)
362 :     then (case ty1
363 :     of T_Real => E(CL.mkApply(RN.lerp 0, [e1, e2, e3]), T_Real)
364 :     | T_Vec n => E(CL.mkApply(RN.lerp n, [e1, e2, e3]), ty1)
365 :     | ty => raise Fail(concat["lerp<", TargetTy.toString ty, "> not supported"])
366 :     (* end case *))
367 :     else raise Fail "invalid arguments to lerp"
368 :     | lerp _ = raise Fail "invalid arguments to lerp"
369 : jhr 525 end (* local *)
370 :    
371 : jhr 551 (* rounding *)
372 : jhr 623 fun trunc (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty)
373 :     fun round (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("round", ty), [e]), ty)
374 :     fun floor (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty)
375 :     fun ceil (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty)
376 : jhr 551
377 : jhr 519 (* conversions *)
378 : jhr 623 fun toInt (E(e, T_Real)) = E(CL.mkCast(!RN.gIntTy, e), T_Int)
379 :     | toInt (E(e, T_Vec n)) = E(CL.mkApply(RN.vecftoi n, [e]), ivecTy n)
380 : jhr 565 | toInt e = invalid ("toInt", [e])
381 : jhr 623 fun toReal (E(e, T_Int)) = E(CL.mkCast(!RN.gRealTy, e), T_Real)
382 : jhr 548 | toReal e = invalid ("toReal", [e])
383 : jhr 525
384 : jhr 519 (* runtime system hooks *)
385 : jhr 623 fun imageAddr (E(e, T_Image(_, rTy))) = let
386 : jhr 561 val cTy = CL.T_Ptr(CL.T_Num rTy)
387 : jhr 528 in
388 : jhr 623 E(CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy)
389 : jhr 528 end
390 : jhr 548 | imageAddr a = invalid("imageAddr", [a])
391 : jhr 623 fun getImgData (E(e, T_Ptr rTy)) = let
392 : jhr 551 val realTy as CL.T_Num rTy' = !RN.gRealTy
393 : jhr 548 val e = CL.E_UnOp(CL.%*, e)
394 :     in
395 :     if (rTy' = rTy)
396 : jhr 623 then E(e, T_Real)
397 :     else E(CL.E_Cast(realTy, e), T_Real)
398 : jhr 548 end
399 :     | getImgData a = invalid("getImgData", [a])
400 : jhr 623 fun posToImgSpace (E(img, T_Image(d, _)), E(pos, T_Vec n)) = let
401 : jhr 551 val e = CL.mkApply(RN.toImageSpace d, [img, pos])
402 : jhr 548 in
403 : jhr 623 E(e, T_Vec n)
404 : jhr 548 end
405 :     | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b])
406 : jhr 623 fun inside (E(pos, T_Vec n), E(img, T_Image(d, _)), s) = let
407 : jhr 798 val e = CL.mkApply(RN.inside d, [pos, img, intExp s])
408 : jhr 547 in
409 : jhr 623 E(e, T_Bool)
410 : jhr 547 end
411 : jhr 548 | inside (a, b, _) = invalid("inside", [a, b])
412 : jhr 519
413 : jhr 695 (* other basis functions *)
414 :     local
415 :     val basis = [
416 :     (ILBasis.atan2, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
417 :     (ILBasis.cos, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
418 :     (ILBasis.pow, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
419 :     (ILBasis.sin, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
420 :     (ILBasis.sqrt, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
421 :     (ILBasis.tan, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real)
422 :     ]
423 :     fun mkLookup suffix = let
424 :     val tbl = ILBasis.Tbl.mkTable (16, Fail "basis table")
425 :     fun ins (f, chkTy, resTy) =
426 :     ILBasis.Tbl.insert tbl
427 :     (f, (ILBasis.toString f ^ suffix, chkTy, resTy))
428 :     in
429 :     List.app ins basis;
430 :     ILBasis.Tbl.lookup tbl
431 :     end
432 :     val fLookup = mkLookup "f"
433 :     val dLookup = mkLookup ""
434 :     in
435 :     fun apply (f, args) = let
436 :     val (f', chkArgs, resTy) = if !Controls.doublePrecision then dLookup f else fLookup f
437 :     in
438 :     case chkArgs args
439 :     of SOME args => E(CL.mkApply(f', args), resTy)
440 :     | NONE => raise Fail("invalid arguments for "^ILBasis.toString f)
441 :     end
442 :     end (* local *)
443 : jhr 547 end (* Expr *)
444 :    
445 : jhr 519 (* statement construction *)
446 : jhr 525 structure Stmt =
447 :     struct
448 :     val comment = CL.S_Comment
449 : jhr 623 fun assignState (V(_, x), E(e, _)) =
450 : jhr 547 CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e)
451 : jhr 623 fun assign (V(_, x), E(e, _)) = CL.mkAssign(CL.mkVar x, e)
452 :     fun decl (V(ty, x), SOME(E(e, _))) = CL.mkDecl(cvtTy ty, x, SOME(CL.I_Exp e))
453 :     | decl (V(ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE)
454 : jhr 525 val block = CL.mkBlock
455 : jhr 623 fun ifthen (E(e, T_Bool), s1) = CL.mkIfThen(e, s1)
456 :     fun ifthenelse (E(e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2)
457 :     fun for (V(ty, x), E(lo, _), E(hi, _), body) = CL.mkFor(
458 : jhr 617 [(cvtTy ty, x, lo)],
459 :     CL.mkBinOp(CL.mkVar x, CL.#<=, hi),
460 :     [CL.mkPostOp(CL.mkVar x, CL.^++)],
461 :     body)
462 : jhr 534 (* special Diderot forms *)
463 : jhr 623 fun cons (V(T_Vec n, x), args : exp list) =
464 : jhr 798 [CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.map (fn E(e, _) => e) args))]
465 :     | cons (V(T_Mat _, x), args) = let
466 :     val x = CL.mkVar x
467 :     (* matrices are represented as arrays of union<d><ty>_t vectors *)
468 :     fun doRows (_, []) = []
469 :     | doRows (i, E(e, _)::es) =
470 :     CL.mkAssign(CL.mkSelect(CL.mkSubscript(x, intExp i), "v"),e)
471 :     :: doRows (i+1, es)
472 :     in
473 :     doRows (0, args)
474 :     end
475 : jhr 553 | cons _ = raise Fail "bogus cons"
476 : jhr 623 fun getImgData (V(T_Vec n, x), E(e, T_Ptr rTy)) = let
477 : jhr 554 val addr = Var.fresh "vp"
478 : jhr 561 val needsCast = (CL.T_Num rTy <> !RN.gRealTy)
479 :     fun mkLoad i = let
480 : jhr 798 val e = CL.mkSubscript(CL.mkVar addr, intExp i)
481 : jhr 561 in
482 :     if needsCast then CL.mkCast(!RN.gRealTy, e) else e
483 :     end
484 : jhr 554 in [
485 : jhr 623 CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), addr, SOME(CL.I_Exp e)),
486 : jhr 561 CL.mkAssign(CL.mkVar x,
487 :     CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad)))
488 : jhr 554 ] end
489 :     | getImgData _ = raise Fail "bogus getImgData"
490 :     local
491 :     fun checkSts mkDecl = let
492 :     val sts = Var.fresh "sts"
493 :     in
494 :     mkDecl sts @
495 :     [CL.mkIfThen(
496 :     CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts),
497 : jhr 798 CL.mkCall("exit", [intExp 1]))]
498 : jhr 554 end
499 :     in
500 : jhr 623 fun loadImage (V(_, lhs), dim, E(name, _)) = checkSts (fn sts => let
501 : jhr 551 val imgTy = CL.T_Named(RN.imageTy dim)
502 :     val loadFn = RN.loadImage dim
503 : jhr 534 in [
504 :     CL.S_Decl(
505 :     statusTy, sts,
506 : jhr 623 SOME(CL.I_Exp(CL.E_Apply(loadFn, [name, CL.mkUnOp(CL.%&, CL.E_Var lhs)]))))
507 : jhr 554 ] end)
508 : jhr 623 fun input (V(ty, lhs), name, optDflt) = checkSts (fn sts => let
509 :     val inputFn = RN.input ty
510 :     val lhs = CL.E_Var lhs
511 : jhr 534 val (initCode, hasDflt) = (case optDflt
512 : jhr 623 of SOME(E(e, _)) => ([CL.S_Assign(lhs, e)], true)
513 : jhr 534 | NONE => ([], false)
514 :     (* end case *))
515 :     val code = [
516 :     CL.S_Decl(
517 :     statusTy, sts,
518 : jhr 623 SOME(CL.I_Exp(CL.E_Apply(inputFn, [
519 : jhr 534 CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
520 : jhr 623 ]))))
521 : jhr 534 ]
522 :     in
523 :     initCode @ code
524 : jhr 554 end)
525 :     end (* local *)
526 : jhr 564 fun exit () = CL.mkReturn NONE
527 :     fun active () = CL.mkReturn(SOME(CL.mkVar RN.kActive))
528 :     fun stabilize () = CL.mkReturn(SOME(CL.mkVar RN.kStabilize))
529 : jhr 562 fun die () = CL.mkReturn(SOME(CL.mkVar RN.kDie))
530 : jhr 519 end
531 :    
532 : jhr 544 structure Strand =
533 :     struct
534 :     fun define (Prog{strands, ...}, strandId) = let
535 : jhr 624 val name = Atom.toString strandId
536 : jhr 544 val strand = Strand{
537 : jhr 624 name = name,
538 :     tyName = RN.strandTy name,
539 : jhr 544 state = ref [],
540 : jhr 654 output = ref NONE,
541 : jhr 544 code = ref []
542 :     }
543 :     in
544 : jhr 624 AtomTable.insert strands (strandId, strand);
545 : jhr 544 strand
546 :     end
547 :    
548 : jhr 624 (* return the strand with the given name *)
549 :     fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId
550 :    
551 : jhr 544 (* register the strand-state initialization code. The variables are the strand
552 :     * parameters.
553 :     *)
554 :     fun init (Strand{name, tyName, code, ...}, params, init) = let
555 : jhr 551 val fName = RN.strandInit name
556 : jhr 544 val params =
557 : jhr 547 CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
558 : jhr 623 List.map (fn (V(ty, x)) => CL.PARAM([], cvtTy ty, x)) params
559 : jhr 544 val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
560 :     in
561 :     code := initFn :: !code
562 :     end
563 : jhr 547
564 :     (* register a strand method *)
565 :     fun method (Strand{name, tyName, code, ...}, methName, body) = let
566 :     val fName = concat[name, "_", methName]
567 :     val params = [
568 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
569 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
570 :     ]
571 : jhr 654 val methFn = CL.D_Func(["static"], CL.int32, fName, params, body)
572 : jhr 547 in
573 :     code := methFn :: !code
574 :     end
575 : jhr 654
576 :     fun output (Strand{output, ...}, x) = (case !output
577 :     of NONE => output := SOME x
578 :     | _ => raise Fail "multiple outputs are not supported yet"
579 :     (* end case *))
580 : jhr 544 end (* Strand *)
581 :    
582 : jhr 654 fun genStrand (Strand{name, tyName, state, output, code}) = let
583 : jhr 624 (* the type declaration for the strand's state struct *)
584 : jhr 544 val selfTyDef = CL.D_StructDef(
585 : jhr 623 List.rev (List.map (fn V(ty, x) => (cvtTy ty, x)) (!state)),
586 : jhr 544 tyName)
587 : jhr 654 (* the print function *)
588 :     val prFnName = concat[name, "_print"]
589 :     val prFn = let
590 :     val params = [
591 :     CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
592 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
593 :     ]
594 :     val SOME(V(ty, x)) = !output
595 :     val outState = CL.mkIndirect(CL.mkVar "self", x)
596 :     val prArgs = (case ty
597 :     of TargetTy.T_Int => [CL.E_Str(!RN.gIntFormat ^ "\n"), outState]
598 :     | TargetTy.T_Real => [CL.E_Str "%f\n", outState]
599 :     | TargetTy.T_Vec d => let
600 :     val fmt = CL.E_Str(
601 :     String.concatWith " " (List.tabulate(d, fn _ => "%f"))
602 :     ^ "\n")
603 : jhr 656 val args = List.tabulate (d, fn i => Expr.vecIndex(outState, d, i))
604 : jhr 654 in
605 :     fmt :: args
606 :     end
607 :     | TargetTy.T_IVec d => let
608 :     val fmt = CL.E_Str(
609 :     String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))
610 :     ^ "\n")
611 : jhr 656 val args = List.tabulate (d, fn i => Expr.ivecIndex(outState, d, i))
612 : jhr 654 in
613 :     fmt :: args
614 :     end
615 :     | _ => raise Fail("genStrand: unsupported output type " ^ TargetTy.toString ty)
616 :     (* end case *))
617 :     in
618 :     CL.D_Func(["static"], CL.voidTy, prFnName, params,
619 :     CL.S_Call("fprintf", CL.mkVar "outS" :: prArgs))
620 :     end
621 : jhr 624 (* the strand's descriptor object *)
622 :     val descI = let
623 : jhr 573 fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
624 :     in
625 :     CL.I_Struct[
626 :     ("name", CL.I_Exp(CL.E_Str name)),
627 :     ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))),
628 : jhr 654 ("update", fnPtr("update_method_t", name ^ "_update")),
629 :     ("print", fnPtr("print_method_t", prFnName))
630 : jhr 573 ]
631 :     end
632 : jhr 624 val desc = CL.D_Var([], CL.T_Named RN.strandDescTy, RN.strandDesc name, SOME descI)
633 :     in
634 : jhr 654 selfTyDef :: List.rev (desc :: prFn :: !code)
635 : jhr 624 end
636 :    
637 :     (* generate the table of strand descriptors *)
638 :     fun genStrandTable (ppStrm, strands) = let
639 :     val nStrands = length strands
640 :     fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)))
641 : jhr 573 fun genInits (_, []) = []
642 :     | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
643 :     fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
644 :     in
645 :     ppDecl (CL.D_Var([], CL.int32, RN.numStrands,
646 :     SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
647 : jhr 624 ppDecl (CL.D_Var([],
648 :     CL.T_Array(CL.T_Ptr(CL.T_Named RN.strandDescTy), SOME nStrands),
649 :     RN.strands,
650 : jhr 573 SOME(CL.I_Array(genInits (0, strands)))))
651 :     end
652 :    
653 : jhr 731 fun genSrc (baseName, Prog{globals, topDecls, strands, initially}) = let
654 : jhr 527 val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
655 :     val outS = TextIO.openOut fileName
656 :     val ppStrm = PrintAsC.new outS
657 : jhr 533 fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
658 : jhr 624 val strands = AtomTable.listItems strands
659 : jhr 527 in
660 : jhr 533 List.app ppDecl (List.rev (!globals));
661 :     List.app ppDecl (List.rev (!topDecls));
662 : jhr 624 List.app (fn strand => List.app ppDecl (genStrand strand)) strands;
663 :     genStrandTable (ppStrm, strands);
664 :     ppDecl (!initially);
665 : jhr 527 PrintAsC.close ppStrm;
666 :     TextIO.closeOut outS
667 :     end
668 :    
669 : jhr 731 (* FIXME: control flags that should go somewhere else *)
670 :     val debug = ref false
671 :     val verbose = ref true
672 :    
673 :     fun system cmd = (
674 :     if !verbose
675 :     then print(cmd ^ "\n")
676 :     else ();
677 :     if OS.Process.isSuccess(OS.Process.system cmd)
678 :     then ()
679 :     else raise Fail "error compiling/linking")
680 :    
681 :     fun compile baseName = let
682 :     val cFile = OS.Path.joinBaseExt{base=baseName, ext=SOME"c"}
683 :     val cflags = if !debug
684 :     then Paths.cflags
685 :     else String.concatWith " " ["-NDEBUG", Paths.cflags]
686 :     val cmd = String.concatWith " " [
687 :     Paths.cc, "-c", cflags,
688 :     "-I" ^ Paths.diderotInclude, "-I" ^ Paths.teemInclude,
689 :     cFile
690 :     ]
691 :     in
692 :     system cmd
693 :     end
694 :    
695 :     fun link baseName = let
696 :     val objFile = OS.Path.joinBaseExt{base=baseName, ext=SOME"o"}
697 :     val exeFile = baseName
698 :     val cmd = String.concatWith " " [
699 :     Paths.cc, "-o", exeFile, objFile,
700 :     "-L" ^ Paths.teemLib, "-lteem",
701 :     OS.Path.concat(Paths.diderotLib, "diderot-lib.o")
702 :     ]
703 :     in
704 :     system cmd
705 :     end
706 :    
707 :     fun generate (baseName, prog) = (
708 :     genSrc (baseName, prog);
709 :     compile baseName;
710 :     link baseName)
711 :    
712 : jhr 519 end
713 :    
714 :     structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0