Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 806 - (view) (download)

1 : jhr 519 (* c-target.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Generate C code with SSE 4.2 intrinsics.
7 :     *)
8 :    
9 :     structure CTarget : TARGET =
10 :     struct
11 :    
12 : jhr 522 structure CL = CLang
13 : jhr 551 structure RN = RuntimeNames
14 : jhr 522
15 : jhr 551 datatype ty = datatype TargetTy.ty
16 : jhr 519
17 : jhr 623 datatype var = V of (ty * string)
18 :    
19 :     datatype exp = E of CLang.exp * ty
20 :    
21 :     type stm = CL.stm
22 :    
23 : jhr 544 datatype strand = Strand of {
24 :     name : string,
25 :     tyName : string,
26 : jhr 623 state : var list ref,
27 : jhr 654 output : var option ref, (* the strand's output variable (only one for now) *)
28 : jhr 544 code : CL.decl list ref
29 :     }
30 : jhr 525
31 : jhr 527 datatype program = Prog of {
32 :     globals : CL.decl list ref,
33 : jhr 533 topDecls : CL.decl list ref,
34 : jhr 624 strands : strand AtomTable.hash_table,
35 :     initially : CL.decl ref
36 : jhr 527 }
37 :    
38 : jhr 519 (* for SSE, we have 128-bit vectors *)
39 : jhr 551 fun vectorWidth () = !RN.gVectorWid
40 : jhr 519
41 :     (* target types *)
42 : jhr 525 val boolTy = T_Bool
43 :     val intTy = T_Int
44 :     val realTy = T_Real
45 :     fun vecTy 1 = T_Real
46 : jhr 551 | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
47 : jhr 525 then raise Size
48 :     else T_Vec n
49 :     fun ivecTy 1 = T_Int
50 : jhr 551 | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
51 : jhr 525 then raise Size
52 :     else T_IVec n
53 : jhr 736 fun tensorTy [] = realTy
54 :     | tensorTy [n] = vecTy n
55 :     | tensorTy [d1, d2] = T_Mat(d1, d2)
56 :     | tensorTy dd = raise Fail "FIXME: order > 2 tensor type"
57 : jhr 548 fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy)
58 :     fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy
59 : jhr 534 val stringTy = T_String
60 : jhr 519
61 : jhr 552 val statusTy = CL.T_Named RN.statusTy
62 : jhr 534
63 : jhr 528 (* convert target types to CLang types *)
64 :     fun cvtTy T_Bool = CLang.T_Named "bool"
65 : jhr 534 | cvtTy T_String = CL.charPtr
66 : jhr 551 | cvtTy T_Int = !RN.gIntTy
67 :     | cvtTy T_Real = !RN.gRealTy
68 : jhr 561 | cvtTy (T_Vec n) = CL.T_Named(RN.vecTy n)
69 :     | cvtTy (T_IVec n) = CL.T_Named(RN.ivecTy n)
70 : jhr 683 | cvtTy (T_Mat(n,m)) = CL.T_Named(RN.matTy(n,m))
71 : jhr 561 | cvtTy (T_Image(n, _)) = CL.T_Ptr(CL.T_Named(RN.imageTy n))
72 : jhr 548 | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty)
73 : jhr 528
74 : jhr 548 (* report invalid arguments *)
75 :     fun invalid (name, []) = raise Fail("invaild "^name)
76 :     | invalid (name, args) = let
77 : jhr 623 fun arg2s (E(e, ty)) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"]
78 : jhr 548 val args = String.concatWith ", " (List.map arg2s args)
79 :     in
80 :     raise Fail(concat["invalid arguments to ", name, ": ", args])
81 :     end
82 :    
83 : jhr 525 (* helper functions for checking the types of arguments *)
84 :     fun scalarTy T_Int = true
85 :     | scalarTy T_Real = true
86 :     | scalarTy _ = false
87 : jhr 548 fun numTy T_Int = true
88 :     | numTy T_Real = true
89 :     | numTy (T_Vec _) = true
90 :     | numTy (T_IVec _) = true
91 :     | numTy _ = false
92 : jhr 519
93 : jhr 798 (* integer literal expression *)
94 :     fun intExp (i : int) = CL.mkInt(IntInf.fromInt i, CL.int32)
95 :    
96 : jhr 528 fun newProgram () = (
97 : jhr 551 RN.initTargetSpec();
98 : jhr 528 Prog{
99 : jhr 554 globals = ref [
100 :     CL.D_Verbatim[
101 :     if !Controls.doublePrecision
102 :     then "#define DIDEROT_DOUBLE_PRECISION"
103 :     else "#define DIDEROT_SINGLE_PRECISION",
104 :     "#include \"Diderot/diderot.h\""
105 :     ]],
106 : jhr 533 topDecls = ref [],
107 : jhr 624 strands = AtomTable.mkTable (16, Fail "strand table"),
108 :     initially = ref(CL.D_Comment["missing initially"])
109 : jhr 528 })
110 :    
111 : jhr 618 (* register the global initialization part of a program *)
112 : jhr 533 fun globalInit (Prog{topDecls, ...}, init) = let
113 : jhr 551 val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
114 : jhr 533 in
115 :     topDecls := initFn :: !topDecls
116 :     end
117 :    
118 : jhr 623 (* create and register the initially function for a program *)
119 :     fun initially {
120 : jhr 624 prog = Prog{strands, initially, ...},
121 : jhr 623 isArray : bool,
122 : jhr 624 iterPrefix : stm,
123 : jhr 623 iters : (var * exp * exp) list,
124 : jhr 624 createPrefix : stm,
125 :     strand : Atom.atom,
126 : jhr 623 args : exp list
127 :     } = let
128 : jhr 624 val iterPrefix = (case iterPrefix
129 :     of CL.S_Block stms => stms
130 :     | stm => [stm]
131 :     (* end case *))
132 :     val createPrefix = (case createPrefix
133 :     of CL.S_Block stms => stms
134 :     | stm => [stm]
135 :     (* end case *))
136 :     val name = Atom.toString strand
137 : jhr 623 val nDims = List.length iters
138 :     val worldTy = CL.T_Ptr(CL.T_Named RN.worldTy)
139 :     fun mapi f xs = let
140 :     fun mapf (_, []) = []
141 :     | mapf (i, x::xs) = f(i, x) :: mapf(i+1, xs)
142 :     in
143 :     mapf (0, xs)
144 :     end
145 :     val baseInit = mapi (fn (i, (_, E(e, _), _)) => (i, CL.I_Exp e)) iters
146 :     val sizeInit = mapi
147 :     (fn (i, (V(ty, _), E(lo, _), E(hi, _))) =>
148 :     (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, cvtTy ty))))
149 :     ) iters
150 :     val allocCode = [
151 :     CL.S_Comment["allocate initial block of strands"],
152 :     CL.S_Decl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
153 :     CL.S_Decl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
154 :     CL.S_Decl(worldTy, "wrld",
155 :     SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [
156 :     CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)),
157 :     CL.E_Bool isArray,
158 :     CL.E_Int(IntInf.fromInt nDims, CL.int32),
159 :     CL.E_Var "base",
160 :     CL.E_Var "size"
161 :     ]))))
162 :     ]
163 :     (* create the loop nest for the initially iterations *)
164 :     val indexVar = "ix"
165 : jhr 634 val strandTy = CL.T_Ptr(CL.T_Named(RN.strandTy name))
166 : jhr 623 fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
167 : jhr 634 CL.S_Decl(strandTy, "sp",
168 :     SOME(CL.I_Exp(
169 :     CL.E_Cast(strandTy,
170 :     CL.E_Apply(RN.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
171 : jhr 624 CL.S_Call(RN.strandInit name, CL.E_Var "sp" :: List.map (fn (E(e, _)) => e) args),
172 :     CL.S_Assign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
173 : jhr 623 ])
174 :     | mkLoopNest ((V(ty, param), E(lo,_), E(hi, _))::iters) = let
175 :     val body = mkLoopNest iters
176 :     in
177 :     CL.S_For(
178 :     [(cvtTy ty, param, lo)],
179 :     CL.mkBinOp(CL.E_Var param, CL.#<=, hi),
180 :     [CL.mkPostOp(CL.E_Var param, CL.^++)],
181 :     body)
182 :     end
183 :     val iterCode = [
184 :     CL.S_Comment["initially"],
185 :     CL.S_Decl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
186 :     mkLoopNest iters
187 :     ]
188 :     val body = CL.mkBlock(iterPrefix @ allocCode @ iterCode @ [CL.S_Return(SOME(CL.E_Var "wrld"))])
189 :     val initFn = CL.D_Func([], worldTy, RN.initially, [], body)
190 : jhr 618 in
191 : jhr 624 initially := initFn
192 : jhr 618 end
193 :    
194 : jhr 525 structure Var =
195 :     struct
196 : jhr 528 fun global (Prog{globals, ...}, ty, name) = (
197 : jhr 573 globals := CL.D_Var([], cvtTy ty, name, NONE) :: !globals;
198 : jhr 623 V(ty, name))
199 :     fun param (ty, name) = V(ty, name)
200 : jhr 544 fun state (Strand{state, ...}, ty, name) = (
201 : jhr 623 state := V(ty, name) :: !state;
202 :     V(ty, name))
203 :     fun var (ty, name) = V(ty, name)
204 : jhr 554 local
205 :     val count = ref 0
206 :     fun freshName prefix = let
207 :     val n = !count
208 :     in
209 :     count := n+1;
210 :     concat[prefix, "_", Int.toString n]
211 :     end
212 :     in
213 : jhr 623 fun tmp ty = V(ty, freshName "tmp")
214 : jhr 554 fun fresh prefix = freshName prefix
215 :     end (* local *)
216 : jhr 519 end
217 :    
218 :     (* expression construction *)
219 : jhr 525 structure Expr =
220 :     struct
221 : jhr 549 (* return true if the given expression from is allowed as a subexpression *)
222 :     fun allowedInline _ = true (* FIXME *)
223 :    
224 : jhr 519 (* variable references *)
225 : jhr 623 fun global (V(ty, x)) = E(CL.mkVar x, ty)
226 :     fun getState (V(ty, x)) = E(CL.mkIndirect(CL.mkVar "selfIn", x), ty)
227 :     fun param (V(ty, x)) = E(CL.mkVar x, ty)
228 :     fun var (V(ty, x)) = E(CL.mkVar x, ty)
229 : jhr 525
230 : jhr 519 (* literals *)
231 : jhr 623 fun intLit n = E(CL.mkInt(n, !RN.gIntTy), intTy)
232 :     fun floatLit f = E(CL.mkFlt(f, !RN.gRealTy), realTy)
233 :     fun stringLit s = E(CL.mkStr s, stringTy)
234 :     fun boolLit b = E(CL.mkBool b, boolTy)
235 : jhr 525
236 : jhr 561 (* select from a vector. We have to cast to the corresponding union type and then
237 :     * select from the array field.
238 :     *)
239 :     local
240 :     fun sel (tyCode, field, ty) (i, e, n) =
241 :     if (i < 0) orelse (n <= i)
242 :     then raise Subscript
243 :     else let
244 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !tyCode, "_t"])
245 :     val e1 = CL.mkCast(unionTy, e)
246 :     val e2 = CL.mkSelect(e1, field)
247 :     in
248 : jhr 798 E(CL.mkSubscript(e2, intExp i), ty)
249 : jhr 561 end
250 :     val selF = sel (RN.gRealSuffix, "r", T_Real)
251 :     val selI = sel (RN.gIntSuffix, "i", T_Int)
252 :     in
253 : jhr 654 fun ivecIndex (e, d, i) = let val E(e', _) = selI(i, e, d) in e' end
254 :     fun vecIndex (e, d, i) = let val E(e', _) = selF(i, e, d) in e' end
255 : jhr 623 fun select (i, E(e, T_Vec n)) = selF (i, e, n)
256 :     | select (i, E(e, T_IVec n)) = selI (i, e, n)
257 : jhr 548 | select (_, x) = invalid("select", [x])
258 : jhr 561 end (* local *)
259 : jhr 525
260 : jhr 802 fun subscript1 (E(e1, ty), E(e2, T_Int)) = let
261 :     val (n, tyCode, elemTy, fld) = (case ty
262 :     of T_Vec n => (n, !RN.gRealSuffix, T_Real, "r")
263 :     | T_IVec n => (n, !RN.gIntSuffix, T_Int, "i")
264 : jhr 705 (* end case *))
265 :     val unionTy = CL.T_Named(concat["union", Int.toString n, tyCode, "_t"])
266 : jhr 802 val vecExp = CL.mkSelect(CL.mkCast(unionTy, e1), fld)
267 : jhr 705 in
268 : jhr 732 E(CL.mkSubscript(vecExp, e2), elemTy)
269 : jhr 705 end
270 :    
271 : jhr 802 fun subscript2 (E(e1, T_Mat(n,m)), E(e2, T_Int), E(e3, T_Int)) =
272 :     E(CL.mkSubscript(CL.mkSelect(CL.mkSubscript(e1, e2), "r"), e3), T_Real)
273 :    
274 : jhr 519 (* vector (and scalar) arithmetic *)
275 : jhr 525 local
276 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1
277 : jhr 623 fun binop rator (E(e1, ty1), E(e2, ty2)) =
278 : jhr 525 if checkTys (ty1, ty2)
279 : jhr 623 then E(CL.mkBinOp(e1, rator, e2), ty1)
280 : jhr 548 else invalid (
281 :     concat["binary operator \"", CL.binopToString rator, "\""],
282 : jhr 623 [E(e1, ty1), E(e2, ty2)])
283 : jhr 525 in
284 : jhr 623 fun add (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#+, e2), ty)
285 : jhr 548 | add args = binop CL.#+ args
286 : jhr 623 fun sub (E(e1, ty as T_Ptr _), E(e2, T_Int)) = E(CL.mkBinOp(e1, CL.#-, e2), ty)
287 : jhr 548 | sub args = binop CL.#- args
288 : jhr 544 (* NOTE: multiplication and division are also used for scaling *)
289 : jhr 623 fun mul (E(e1, T_Real), E(e2, T_Vec n)) =
290 :     E(CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n)
291 : jhr 544 | mul args = binop CL.#* args
292 : jhr 623 fun divide (E(e1, T_Vec n), E(e2, T_Real)) = let
293 :     val E(one, _) = floatLit FloatLit.one
294 :     in
295 :     E(CL.E_Apply(RN.scale n, [CL.mkBinOp(one, CL.#/, e2), e1]), T_Vec n)
296 :     end
297 : jhr 544 | divide args = binop CL.#/ args
298 : jhr 525 end (* local *)
299 : jhr 623 fun neg (E(e, T_Bool)) = raise Fail "invalid argument to neg"
300 :     | neg (E(e, ty)) = E(CL.mkUnOp(CL.%-, e), ty)
301 : jhr 525
302 : jhr 623 fun abs (E(e, T_Int)) = E(CL.mkApply("abs", [e]), T_Int) (* FIXME: not the right type for 64-bit ints *)
303 :     | abs (E(e, T_Real)) = E(CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real)
304 :     | abs (E(e, T_Vec n)) = raise Fail "FIXME: Expr.abs"
305 :     | abs (E(e, T_IVec n)) = raise Fail "FIXME: Expr.abs"
306 : jhr 525 | abs _ = raise Fail "invalid argument to abs"
307 :    
308 : jhr 623 fun dot (E(e1, T_Vec n1), E(e2, T_Vec n2)) = E(CL.E_Apply(RN.dot n1, [e1, e2]), T_Real)
309 : jhr 525 | dot _ = raise Fail "invalid argument to dot"
310 :    
311 : jhr 623 fun cross (E(e1, T_Vec 3), E(e2, T_Vec 3)) = E(CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3)
312 : jhr 525 | cross _ = raise Fail "invalid argument to cross"
313 :    
314 : jhr 623 fun length (E(e, T_Vec n)) = E(CL.E_Apply(RN.length n, [e]), T_Real)
315 : jhr 525 | length _ = raise Fail "invalid argument to length"
316 :    
317 : jhr 623 fun normalize (E(e, T_Vec n)) = E(CL.E_Apply(RN.normalize n, [e]), T_Vec n)
318 : jhr 525 | normalize _ = raise Fail "invalid argument to length"
319 :    
320 : jhr 683 (* matrix operations *)
321 :     fun trace (E(e, T_Mat(n,m))) = if (n = m) andalso (1 < n) andalso (m <= 4)
322 :     then E(CL.E_Apply(RN.trace n, [e]), T_Real)
323 :     else raise Fail "invalid matrix argument for trace"
324 :     | trace _ = raise Fail "invalid argument to trace"
325 :    
326 : jhr 519 (* comparisons *)
327 : jhr 525 local
328 :     fun checkTys (ty1, ty2) =
329 :     (ty1 = ty2) andalso scalarTy ty1
330 : jhr 623 fun cmpop rator (E(e1, ty1), E(e2, ty2)) =
331 : jhr 525 if checkTys (ty1, ty2)
332 : jhr 623 then E(CL.mkBinOp(e1, rator, e2), T_Bool)
333 : jhr 548 else invalid (
334 :     concat["compare operator \"", CL.binopToString rator, "\""],
335 : jhr 623 [E(e1, ty1), E(e2, ty2)])
336 : jhr 525 in
337 :     val lt = cmpop CL.#<
338 :     val lte = cmpop CL.#<=
339 :     val equ = cmpop CL.#==
340 :     val neq = cmpop CL.#!=
341 :     val gte = cmpop CL.#>=
342 :     val gt = cmpop CL.#>
343 :     end (* local *)
344 :    
345 : jhr 519 (* logical connectives *)
346 : jhr 623 fun not (E(e, T_Bool)) = E(CL.mkUnOp(CL.%!, e), T_Bool)
347 : jhr 525 | not _ = raise Fail "invalid argument to not"
348 : jhr 623 fun && (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#&&, e2), T_Bool)
349 : jhr 525 | && _ = raise Fail "invalid arguments to &&"
350 : jhr 623 fun || (E(e1, T_Bool), E(e2, T_Bool)) = E(CL.mkBinOp(e1, CL.#||, e2), T_Bool)
351 : jhr 525 | || _ = raise Fail "invalid arguments to ||"
352 :    
353 : jhr 754 (* misc functions *)
354 : jhr 525 local
355 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1
356 : jhr 623 fun binFn f (E(e1, ty1), E(e2, ty2)) =
357 : jhr 525 if checkTys (ty1, ty2)
358 : jhr 623 then E(CL.mkApply(f ty1, [e1, e2]), ty1)
359 : jhr 525 else raise Fail "invalid arguments to binary function"
360 :     in
361 : jhr 561 val min = binFn RN.min
362 :     val max = binFn RN.max
363 : jhr 754 fun lerp (E(e1, ty1), E(e2, ty2), E(e3, T_Real)) =
364 :     if (ty1 = ty2)
365 :     then (case ty1
366 :     of T_Real => E(CL.mkApply(RN.lerp 0, [e1, e2, e3]), T_Real)
367 :     | T_Vec n => E(CL.mkApply(RN.lerp n, [e1, e2, e3]), ty1)
368 :     | ty => raise Fail(concat["lerp<", TargetTy.toString ty, "> not supported"])
369 :     (* end case *))
370 :     else raise Fail "invalid arguments to lerp"
371 :     | lerp _ = raise Fail "invalid arguments to lerp"
372 : jhr 525 end (* local *)
373 :    
374 : jhr 551 (* rounding *)
375 : jhr 623 fun trunc (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty)
376 :     fun round (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("round", ty), [e]), ty)
377 :     fun floor (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty)
378 :     fun ceil (E(e, ty)) = E(CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty)
379 : jhr 551
380 : jhr 519 (* conversions *)
381 : jhr 623 fun toInt (E(e, T_Real)) = E(CL.mkCast(!RN.gIntTy, e), T_Int)
382 :     | toInt (E(e, T_Vec n)) = E(CL.mkApply(RN.vecftoi n, [e]), ivecTy n)
383 : jhr 565 | toInt e = invalid ("toInt", [e])
384 : jhr 623 fun toReal (E(e, T_Int)) = E(CL.mkCast(!RN.gRealTy, e), T_Real)
385 : jhr 548 | toReal e = invalid ("toReal", [e])
386 : jhr 525
387 : jhr 519 (* runtime system hooks *)
388 : jhr 623 fun imageAddr (E(e, T_Image(_, rTy))) = let
389 : jhr 561 val cTy = CL.T_Ptr(CL.T_Num rTy)
390 : jhr 528 in
391 : jhr 623 E(CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy)
392 : jhr 528 end
393 : jhr 548 | imageAddr a = invalid("imageAddr", [a])
394 : jhr 623 fun getImgData (E(e, T_Ptr rTy)) = let
395 : jhr 551 val realTy as CL.T_Num rTy' = !RN.gRealTy
396 : jhr 548 val e = CL.E_UnOp(CL.%*, e)
397 :     in
398 :     if (rTy' = rTy)
399 : jhr 623 then E(e, T_Real)
400 :     else E(CL.E_Cast(realTy, e), T_Real)
401 : jhr 548 end
402 :     | getImgData a = invalid("getImgData", [a])
403 : jhr 623 fun posToImgSpace (E(img, T_Image(d, _)), E(pos, T_Vec n)) = let
404 : jhr 551 val e = CL.mkApply(RN.toImageSpace d, [img, pos])
405 : jhr 548 in
406 : jhr 623 E(e, T_Vec n)
407 : jhr 548 end
408 :     | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b])
409 : jhr 623 fun inside (E(pos, T_Vec n), E(img, T_Image(d, _)), s) = let
410 : jhr 798 val e = CL.mkApply(RN.inside d, [pos, img, intExp s])
411 : jhr 547 in
412 : jhr 623 E(e, T_Bool)
413 : jhr 547 end
414 : jhr 548 | inside (a, b, _) = invalid("inside", [a, b])
415 : jhr 519
416 : jhr 695 (* other basis functions *)
417 :     local
418 :     val basis = [
419 :     (ILBasis.atan2, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
420 :     (ILBasis.cos, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
421 :     (ILBasis.pow, fn [E(e1, T_Real), E(e2, T_Real)] => SOME[e1, e2] | _ => NONE, T_Real),
422 :     (ILBasis.sin, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
423 :     (ILBasis.sqrt, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real),
424 :     (ILBasis.tan, fn [E(e, T_Real)] => SOME[e] | _ => NONE, T_Real)
425 :     ]
426 :     fun mkLookup suffix = let
427 :     val tbl = ILBasis.Tbl.mkTable (16, Fail "basis table")
428 :     fun ins (f, chkTy, resTy) =
429 :     ILBasis.Tbl.insert tbl
430 :     (f, (ILBasis.toString f ^ suffix, chkTy, resTy))
431 :     in
432 :     List.app ins basis;
433 :     ILBasis.Tbl.lookup tbl
434 :     end
435 :     val fLookup = mkLookup "f"
436 :     val dLookup = mkLookup ""
437 :     in
438 :     fun apply (f, args) = let
439 :     val (f', chkArgs, resTy) = if !Controls.doublePrecision then dLookup f else fLookup f
440 :     in
441 :     case chkArgs args
442 :     of SOME args => E(CL.mkApply(f', args), resTy)
443 :     | NONE => raise Fail("invalid arguments for "^ILBasis.toString f)
444 :     end
445 :     end (* local *)
446 : jhr 547 end (* Expr *)
447 :    
448 : jhr 519 (* statement construction *)
449 : jhr 525 structure Stmt =
450 :     struct
451 :     val comment = CL.S_Comment
452 : jhr 806 fun assignState (V(ty, x), E(e, _)) = (case (ty, e)
453 :     of (T_Mat(n,m), CL.E_Var _) =>
454 :     CL.mkCall(RN.copyMat(n,m), [CL.mkIndirect(CL.mkVar "selfOut", x), e])
455 :     | _ => CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e)
456 :     (* end case *))
457 :     fun assign (V(ty, x), E(e, _)) = (case (ty, e)
458 :     of (T_Mat(n,m), CL.E_Var y) => CL.mkCall(RN.copyMat(n,m), [CL.mkVar x, e])
459 :     | _ => CL.mkAssign(CL.mkVar x, e)
460 :     (* end case *))
461 : jhr 623 fun decl (V(ty, x), SOME(E(e, _))) = CL.mkDecl(cvtTy ty, x, SOME(CL.I_Exp e))
462 :     | decl (V(ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE)
463 : jhr 525 val block = CL.mkBlock
464 : jhr 623 fun ifthen (E(e, T_Bool), s1) = CL.mkIfThen(e, s1)
465 :     fun ifthenelse (E(e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2)
466 :     fun for (V(ty, x), E(lo, _), E(hi, _), body) = CL.mkFor(
467 : jhr 617 [(cvtTy ty, x, lo)],
468 :     CL.mkBinOp(CL.mkVar x, CL.#<=, hi),
469 :     [CL.mkPostOp(CL.mkVar x, CL.^++)],
470 :     body)
471 : jhr 534 (* special Diderot forms *)
472 : jhr 623 fun cons (V(T_Vec n, x), args : exp list) =
473 : jhr 798 [CL.mkAssign(CL.mkVar x, CL.mkApply(RN.mkVec n, List.map (fn E(e, _) => e) args))]
474 :     | cons (V(T_Mat _, x), args) = let
475 :     val x = CL.mkVar x
476 :     (* matrices are represented as arrays of union<d><ty>_t vectors *)
477 :     fun doRows (_, []) = []
478 :     | doRows (i, E(e, _)::es) =
479 :     CL.mkAssign(CL.mkSelect(CL.mkSubscript(x, intExp i), "v"),e)
480 :     :: doRows (i+1, es)
481 :     in
482 :     doRows (0, args)
483 :     end
484 : jhr 553 | cons _ = raise Fail "bogus cons"
485 : jhr 623 fun getImgData (V(T_Vec n, x), E(e, T_Ptr rTy)) = let
486 : jhr 554 val addr = Var.fresh "vp"
487 : jhr 561 val needsCast = (CL.T_Num rTy <> !RN.gRealTy)
488 :     fun mkLoad i = let
489 : jhr 798 val e = CL.mkSubscript(CL.mkVar addr, intExp i)
490 : jhr 561 in
491 :     if needsCast then CL.mkCast(!RN.gRealTy, e) else e
492 :     end
493 : jhr 554 in [
494 : jhr 623 CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), addr, SOME(CL.I_Exp e)),
495 : jhr 561 CL.mkAssign(CL.mkVar x,
496 :     CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad)))
497 : jhr 554 ] end
498 :     | getImgData _ = raise Fail "bogus getImgData"
499 :     local
500 :     fun checkSts mkDecl = let
501 :     val sts = Var.fresh "sts"
502 :     in
503 :     mkDecl sts @
504 :     [CL.mkIfThen(
505 :     CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts),
506 : jhr 798 CL.mkCall("exit", [intExp 1]))]
507 : jhr 554 end
508 :     in
509 : jhr 623 fun loadImage (V(_, lhs), dim, E(name, _)) = checkSts (fn sts => let
510 : jhr 551 val imgTy = CL.T_Named(RN.imageTy dim)
511 :     val loadFn = RN.loadImage dim
512 : jhr 534 in [
513 :     CL.S_Decl(
514 :     statusTy, sts,
515 : jhr 623 SOME(CL.I_Exp(CL.E_Apply(loadFn, [name, CL.mkUnOp(CL.%&, CL.E_Var lhs)]))))
516 : jhr 554 ] end)
517 : jhr 623 fun input (V(ty, lhs), name, optDflt) = checkSts (fn sts => let
518 :     val inputFn = RN.input ty
519 :     val lhs = CL.E_Var lhs
520 : jhr 534 val (initCode, hasDflt) = (case optDflt
521 : jhr 623 of SOME(E(e, _)) => ([CL.S_Assign(lhs, e)], true)
522 : jhr 534 | NONE => ([], false)
523 :     (* end case *))
524 :     val code = [
525 :     CL.S_Decl(
526 :     statusTy, sts,
527 : jhr 623 SOME(CL.I_Exp(CL.E_Apply(inputFn, [
528 : jhr 534 CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
529 : jhr 623 ]))))
530 : jhr 534 ]
531 :     in
532 :     initCode @ code
533 : jhr 554 end)
534 :     end (* local *)
535 : jhr 564 fun exit () = CL.mkReturn NONE
536 :     fun active () = CL.mkReturn(SOME(CL.mkVar RN.kActive))
537 :     fun stabilize () = CL.mkReturn(SOME(CL.mkVar RN.kStabilize))
538 : jhr 562 fun die () = CL.mkReturn(SOME(CL.mkVar RN.kDie))
539 : jhr 519 end
540 :    
541 : jhr 544 structure Strand =
542 :     struct
543 :     fun define (Prog{strands, ...}, strandId) = let
544 : jhr 624 val name = Atom.toString strandId
545 : jhr 544 val strand = Strand{
546 : jhr 624 name = name,
547 :     tyName = RN.strandTy name,
548 : jhr 544 state = ref [],
549 : jhr 654 output = ref NONE,
550 : jhr 544 code = ref []
551 :     }
552 :     in
553 : jhr 624 AtomTable.insert strands (strandId, strand);
554 : jhr 544 strand
555 :     end
556 :    
557 : jhr 624 (* return the strand with the given name *)
558 :     fun lookup (Prog{strands, ...}, strandId) = AtomTable.lookup strands strandId
559 :    
560 : jhr 544 (* register the strand-state initialization code. The variables are the strand
561 :     * parameters.
562 :     *)
563 :     fun init (Strand{name, tyName, code, ...}, params, init) = let
564 : jhr 551 val fName = RN.strandInit name
565 : jhr 544 val params =
566 : jhr 547 CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
567 : jhr 623 List.map (fn (V(ty, x)) => CL.PARAM([], cvtTy ty, x)) params
568 : jhr 544 val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
569 :     in
570 :     code := initFn :: !code
571 :     end
572 : jhr 547
573 :     (* register a strand method *)
574 :     fun method (Strand{name, tyName, code, ...}, methName, body) = let
575 :     val fName = concat[name, "_", methName]
576 :     val params = [
577 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
578 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
579 :     ]
580 : jhr 654 val methFn = CL.D_Func(["static"], CL.int32, fName, params, body)
581 : jhr 547 in
582 :     code := methFn :: !code
583 :     end
584 : jhr 654
585 :     fun output (Strand{output, ...}, x) = (case !output
586 :     of NONE => output := SOME x
587 :     | _ => raise Fail "multiple outputs are not supported yet"
588 :     (* end case *))
589 : jhr 544 end (* Strand *)
590 :    
591 : jhr 654 fun genStrand (Strand{name, tyName, state, output, code}) = let
592 : jhr 624 (* the type declaration for the strand's state struct *)
593 : jhr 544 val selfTyDef = CL.D_StructDef(
594 : jhr 623 List.rev (List.map (fn V(ty, x) => (cvtTy ty, x)) (!state)),
595 : jhr 544 tyName)
596 : jhr 654 (* the print function *)
597 :     val prFnName = concat[name, "_print"]
598 :     val prFn = let
599 :     val params = [
600 :     CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
601 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
602 :     ]
603 :     val SOME(V(ty, x)) = !output
604 :     val outState = CL.mkIndirect(CL.mkVar "self", x)
605 :     val prArgs = (case ty
606 :     of TargetTy.T_Int => [CL.E_Str(!RN.gIntFormat ^ "\n"), outState]
607 :     | TargetTy.T_Real => [CL.E_Str "%f\n", outState]
608 :     | TargetTy.T_Vec d => let
609 :     val fmt = CL.E_Str(
610 :     String.concatWith " " (List.tabulate(d, fn _ => "%f"))
611 :     ^ "\n")
612 : jhr 656 val args = List.tabulate (d, fn i => Expr.vecIndex(outState, d, i))
613 : jhr 654 in
614 :     fmt :: args
615 :     end
616 :     | TargetTy.T_IVec d => let
617 :     val fmt = CL.E_Str(
618 :     String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))
619 :     ^ "\n")
620 : jhr 656 val args = List.tabulate (d, fn i => Expr.ivecIndex(outState, d, i))
621 : jhr 654 in
622 :     fmt :: args
623 :     end
624 :     | _ => raise Fail("genStrand: unsupported output type " ^ TargetTy.toString ty)
625 :     (* end case *))
626 :     in
627 :     CL.D_Func(["static"], CL.voidTy, prFnName, params,
628 :     CL.S_Call("fprintf", CL.mkVar "outS" :: prArgs))
629 :     end
630 : jhr 624 (* the strand's descriptor object *)
631 :     val descI = let
632 : jhr 573 fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
633 :     in
634 :     CL.I_Struct[
635 :     ("name", CL.I_Exp(CL.E_Str name)),
636 :     ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(RN.strandTy name)))),
637 : jhr 654 ("update", fnPtr("update_method_t", name ^ "_update")),
638 :     ("print", fnPtr("print_method_t", prFnName))
639 : jhr 573 ]
640 :     end
641 : jhr 624 val desc = CL.D_Var([], CL.T_Named RN.strandDescTy, RN.strandDesc name, SOME descI)
642 :     in
643 : jhr 654 selfTyDef :: List.rev (desc :: prFn :: !code)
644 : jhr 624 end
645 :    
646 :     (* generate the table of strand descriptors *)
647 :     fun genStrandTable (ppStrm, strands) = let
648 :     val nStrands = length strands
649 :     fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(RN.strandDesc name)))
650 : jhr 573 fun genInits (_, []) = []
651 :     | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
652 :     fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
653 :     in
654 :     ppDecl (CL.D_Var([], CL.int32, RN.numStrands,
655 :     SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
656 : jhr 624 ppDecl (CL.D_Var([],
657 :     CL.T_Array(CL.T_Ptr(CL.T_Named RN.strandDescTy), SOME nStrands),
658 :     RN.strands,
659 : jhr 573 SOME(CL.I_Array(genInits (0, strands)))))
660 :     end
661 :    
662 : jhr 731 fun genSrc (baseName, Prog{globals, topDecls, strands, initially}) = let
663 : jhr 527 val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
664 :     val outS = TextIO.openOut fileName
665 :     val ppStrm = PrintAsC.new outS
666 : jhr 533 fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
667 : jhr 624 val strands = AtomTable.listItems strands
668 : jhr 527 in
669 : jhr 533 List.app ppDecl (List.rev (!globals));
670 :     List.app ppDecl (List.rev (!topDecls));
671 : jhr 624 List.app (fn strand => List.app ppDecl (genStrand strand)) strands;
672 :     genStrandTable (ppStrm, strands);
673 :     ppDecl (!initially);
674 : jhr 527 PrintAsC.close ppStrm;
675 :     TextIO.closeOut outS
676 :     end
677 :    
678 : jhr 731 (* FIXME: control flags that should go somewhere else *)
679 :     val debug = ref false
680 :     val verbose = ref true
681 :    
682 :     fun system cmd = (
683 :     if !verbose
684 :     then print(cmd ^ "\n")
685 :     else ();
686 :     if OS.Process.isSuccess(OS.Process.system cmd)
687 :     then ()
688 :     else raise Fail "error compiling/linking")
689 :    
690 :     fun compile baseName = let
691 :     val cFile = OS.Path.joinBaseExt{base=baseName, ext=SOME"c"}
692 :     val cflags = if !debug
693 :     then Paths.cflags
694 :     else String.concatWith " " ["-NDEBUG", Paths.cflags]
695 :     val cmd = String.concatWith " " [
696 :     Paths.cc, "-c", cflags,
697 :     "-I" ^ Paths.diderotInclude, "-I" ^ Paths.teemInclude,
698 :     cFile
699 :     ]
700 :     in
701 :     system cmd
702 :     end
703 :    
704 :     fun link baseName = let
705 :     val objFile = OS.Path.joinBaseExt{base=baseName, ext=SOME"o"}
706 :     val exeFile = baseName
707 :     val cmd = String.concatWith " " [
708 :     Paths.cc, "-o", exeFile, objFile,
709 :     "-L" ^ Paths.teemLib, "-lteem",
710 :     OS.Path.concat(Paths.diderotLib, "diderot-lib.o")
711 :     ]
712 :     in
713 :     system cmd
714 :     end
715 :    
716 :     fun generate (baseName, prog) = (
717 :     genSrc (baseName, prog);
718 :     compile baseName;
719 :     link baseName)
720 :    
721 : jhr 519 end
722 :    
723 :     structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0