Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml
ViewVC logotype

Annotation of /branches/pure-cfg/src/compiler/c-target/c-target.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 552 - (view) (download)

1 : jhr 519 (* c-target.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Generate C code with SSE 4.2 intrinsics.
7 :     *)
8 :    
9 :     structure CTarget : TARGET =
10 :     struct
11 :    
12 : jhr 522 structure CL = CLang
13 : jhr 551 structure RN = RuntimeNames
14 : jhr 522
15 : jhr 551 datatype ty = datatype TargetTy.ty
16 : jhr 519
17 : jhr 544 datatype strand = Strand of {
18 :     name : string,
19 :     tyName : string,
20 :     state : (ty * string) list ref,
21 :     code : CL.decl list ref
22 :     }
23 : jhr 525
24 :     type var = (ty * string) (* FIXME *)
25 :    
26 :     type exp = CLang.exp * ty
27 :    
28 :     type stm = CL.stm
29 :    
30 :     type method = unit (* FIXME *)
31 :    
32 : jhr 527 datatype program = Prog of {
33 :     globals : CL.decl list ref,
34 : jhr 533 topDecls : CL.decl list ref,
35 : jhr 527 strands : strand list ref
36 :     }
37 :    
38 : jhr 519 (* for SSE, we have 128-bit vectors *)
39 : jhr 551 fun vectorWidth () = !RN.gVectorWid
40 : jhr 519
41 :     (* target types *)
42 : jhr 525 val boolTy = T_Bool
43 :     val intTy = T_Int
44 :     val realTy = T_Real
45 :     fun vecTy 1 = T_Real
46 : jhr 551 | vecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
47 : jhr 525 then raise Size
48 :     else T_Vec n
49 :     fun ivecTy 1 = T_Int
50 : jhr 551 | ivecTy n = if (n < 1) orelse (!RN.gVectorWid < n)
51 : jhr 525 then raise Size
52 :     else T_IVec n
53 : jhr 548 fun imageTy (ImageInfo.ImgInfo{ty=([], rTy), dim, ...}) = T_Image(dim, rTy)
54 :     fun imageDataTy (ImageInfo.ImgInfo{ty=([], rTy), ...}) = T_Ptr rTy
55 : jhr 534 val stringTy = T_String
56 : jhr 519
57 : jhr 552 val statusTy = CL.T_Named RN.statusTy
58 : jhr 534
59 : jhr 528 (* convert target types to CLang types *)
60 :     fun cvtTy T_Bool = CLang.T_Named "bool"
61 : jhr 534 | cvtTy T_String = CL.charPtr
62 : jhr 551 | cvtTy T_Int = !RN.gIntTy
63 :     | cvtTy T_Real = !RN.gRealTy
64 :     | cvtTy (T_Vec n) = CLang.T_Named(RN.vecTy n)
65 :     | cvtTy (T_IVec n) = CLang.T_Named(RN.ivecTy n)
66 :     | cvtTy (T_Image(n, _)) = CLang.T_Named(RN.imageTy n)
67 : jhr 548 | cvtTy (T_Ptr ty) = CL.T_Ptr(CL.T_Num ty)
68 : jhr 528
69 : jhr 548 (* report invalid arguments *)
70 :     fun invalid (name, []) = raise Fail("invaild "^name)
71 :     | invalid (name, args) = let
72 : jhr 551 fun arg2s (e, ty) = concat["(", CL.expToString e, " : ", TargetTy.toString ty, ")"]
73 : jhr 548 val args = String.concatWith ", " (List.map arg2s args)
74 :     in
75 :     raise Fail(concat["invalid arguments to ", name, ": ", args])
76 :     end
77 :    
78 : jhr 525 (* helper functions for checking the types of arguments *)
79 :     fun scalarTy T_Int = true
80 :     | scalarTy T_Real = true
81 :     | scalarTy _ = false
82 : jhr 548 fun numTy T_Int = true
83 :     | numTy T_Real = true
84 :     | numTy (T_Vec _) = true
85 :     | numTy (T_IVec _) = true
86 :     | numTy _ = false
87 : jhr 519
88 : jhr 528 fun newProgram () = (
89 : jhr 551 RN.initTargetSpec();
90 : jhr 528 Prog{
91 :     globals = ref [],
92 : jhr 533 topDecls = ref [],
93 : jhr 528 strands = ref []
94 :     })
95 :    
96 : jhr 533 fun globalInit (Prog{topDecls, ...}, init) = let
97 : jhr 551 val initFn = CL.D_Func([], CL.voidTy, RN.initGlobals, [], init)
98 : jhr 533 in
99 :     topDecls := initFn :: !topDecls
100 :     end
101 :    
102 : jhr 525 structure Var =
103 :     struct
104 : jhr 528 fun global (Prog{globals, ...}, ty, name) = (
105 :     globals := CL.D_Var([], cvtTy ty, name) :: !globals;
106 :     (ty, name))
107 : jhr 544 fun param (ty, name) = (ty, name)
108 :     fun state (Strand{state, ...}, ty, name) = (
109 :     state := (ty, name) :: !state;
110 :     (ty, name))
111 :     fun var (ty, name) = (ty, name)
112 : jhr 525 fun tmp ty = raise Fail "FIXME: Var.tmp"
113 : jhr 519 end
114 :    
115 :     (* expression construction *)
116 : jhr 525 structure Expr =
117 :     struct
118 : jhr 549 (* return true if the given expression from is allowed as a subexpression *)
119 :     fun allowedInline _ = true (* FIXME *)
120 :    
121 : jhr 519 (* variable references *)
122 : jhr 525 fun global (ty, x) = (CL.mkVar x, ty)
123 : jhr 547 fun getState (ty, x) = (CL.mkIndirect(CL.mkVar "selfIn", x), ty)
124 : jhr 525 fun param (ty, x) = (CL.mkVar x, ty)
125 :     fun var (ty, x) = (CL.mkVar x, ty)
126 :    
127 : jhr 519 (* literals *)
128 : jhr 551 fun intLit n = (CL.mkInt(n, !RN.gIntTy), intTy)
129 :     fun floatLit f = (CL.mkFlt(f, !RN.gRealTy), realTy)
130 : jhr 533 fun stringLit s = (CL.mkStr s, stringTy)
131 : jhr 525 fun boolLit b = (CL.mkBool b, boolTy)
132 :    
133 : jhr 519 (* select from a vector *)
134 : jhr 525 fun select (i, (e, T_Vec n)) =
135 :     if (i < 0) orelse (n <= i)
136 :     then raise Subscript
137 :     else (CL.mkSubscript(e, CL.mkInt(IntInf.fromInt i, CL.int32)), T_Real)
138 :     | select (i, (e, T_IVec n)) =
139 :     if (i < 0) orelse (n <= i)
140 :     then raise Subscript
141 :     else (CL.mkSubscript(e, CL.mkInt(IntInf.fromInt i, CL.int32)), T_Int)
142 : jhr 548 | select (_, x) = invalid("select", [x])
143 : jhr 525
144 : jhr 519 (* vector (and scalar) arithmetic *)
145 : jhr 525 local
146 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso numTy ty1
147 :     fun binop rator ((e1, ty1), (e2, ty2)) =
148 :     if checkTys (ty1, ty2)
149 :     then (CL.mkBinOp(e1, rator, e2), ty1)
150 : jhr 548 else invalid (
151 :     concat["binary operator \"", CL.binopToString rator, "\""],
152 :     [(e1, ty1), (e2, ty2)])
153 : jhr 525 in
154 : jhr 548 fun add ((e1, ty as T_Ptr _), (e2, T_Int)) = (CL.mkBinOp(e1, CL.#+, e2), ty)
155 :     | add args = binop CL.#+ args
156 :     fun sub ((e1, ty as T_Ptr _), (e2, T_Int)) = (CL.mkBinOp(e1, CL.#-, e2), ty)
157 :     | sub args = binop CL.#- args
158 : jhr 544 (* NOTE: multiplication and division are also used for scaling *)
159 :     fun mul ((e1, T_Real), (e2, T_Vec n)) =
160 : jhr 551 (CL.E_Apply(RN.scale n, [e1, e2]), T_Vec n)
161 : jhr 544 | mul args = binop CL.#* args
162 :     fun divide ((e1, T_Vec n), (e2, T_Real)) =
163 : jhr 551 (CL.E_Apply(RN.scale n,
164 :     [CL.mkBinOp(#1(floatLit FloatLit.one), CL.#/, e2), e1]), T_Vec n)
165 : jhr 544 | divide args = binop CL.#/ args
166 : jhr 525 end (* local *)
167 :     fun neg (e, T_Bool) = raise Fail "invalid argument to neg"
168 :     | neg (e, ty) = (CL.mkUnOp(CL.%-, e), ty)
169 :    
170 :     fun abs (e, T_Int) = (CL.mkApply("abs", [e]), T_Int) (* FIXME: not the right type for 64-bit ints *)
171 : jhr 551 | abs (e, T_Real) = (CL.mkApply("fabs" ^ !RN.gRealSuffix, [e]), T_Real)
172 : jhr 525 | abs (e, T_Vec n) = raise Fail "FIXME: Expr.abs"
173 :     | abs (e, T_IVec n) = raise Fail "FIXME: Expr.abs"
174 :     | abs _ = raise Fail "invalid argument to abs"
175 :    
176 : jhr 551 fun dot ((e1, T_Vec n1), (e2, T_Vec n2)) = (CL.E_Apply(RN.dot n1, [e1, e2]), T_Real)
177 : jhr 525 | dot _ = raise Fail "invalid argument to dot"
178 :    
179 : jhr 551 fun cross ((e1, T_Vec 3), (e2, T_Vec 3)) = (CL.E_Apply(RN.cross(), [e1, e2]), T_Vec 3)
180 : jhr 525 | cross _ = raise Fail "invalid argument to cross"
181 :    
182 : jhr 551 fun length (e, T_Vec n) = (CL.E_Apply(RN.length n, [e]), T_Real)
183 : jhr 525 | length _ = raise Fail "invalid argument to length"
184 :    
185 : jhr 551 fun normalize (e, T_Vec n) = (CL.E_Apply(RN.normalize n, [e]), T_Vec n)
186 : jhr 525 | normalize _ = raise Fail "invalid argument to length"
187 :    
188 : jhr 519 (* comparisons *)
189 : jhr 525 local
190 :     fun checkTys (ty1, ty2) =
191 :     (ty1 = ty2) andalso scalarTy ty1
192 :     fun cmpop rator ((e1, ty1), (e2, ty2)) =
193 :     if checkTys (ty1, ty2)
194 :     then (CL.mkBinOp(e1, rator, e2), T_Bool)
195 : jhr 548 else invalid (
196 :     concat["compare operator \"", CL.binopToString rator, "\""],
197 :     [(e1, ty1), (e2, ty2)])
198 : jhr 525 in
199 :     val lt = cmpop CL.#<
200 :     val lte = cmpop CL.#<=
201 :     val equ = cmpop CL.#==
202 :     val neq = cmpop CL.#!=
203 :     val gte = cmpop CL.#>=
204 :     val gt = cmpop CL.#>
205 :     end (* local *)
206 :    
207 : jhr 519 (* logical connectives *)
208 : jhr 525 fun not (e, T_Bool) = (CL.mkUnOp(CL.%!, e), T_Bool)
209 :     | not _ = raise Fail "invalid argument to not"
210 :     fun && ((e1, T_Bool), (e2, T_Bool)) = (CL.mkBinOp(e1, CL.#&&, e2), T_Bool)
211 :     | && _ = raise Fail "invalid arguments to &&"
212 :     fun || ((e1, T_Bool), (e2, T_Bool)) = (CL.mkBinOp(e1, CL.#||, e2), T_Bool)
213 :     | || _ = raise Fail "invalid arguments to ||"
214 :    
215 :     local
216 :     fun checkTys (ty1, ty2) = (ty1 = ty2) andalso scalarTy ty1
217 :     fun binFn f ((e1, ty1), (e2, ty2)) =
218 :     if checkTys (ty1, ty2)
219 :     then (CL.mkApply(f, [e1, e2]), ty1)
220 :     else raise Fail "invalid arguments to binary function"
221 :     in
222 : jhr 519 (* misc functions *)
223 : jhr 525 val min = binFn "Diderot_min"
224 :     val max = binFn "Diderot_max"
225 :     end (* local *)
226 :    
227 : jhr 519 (* math functions *)
228 : jhr 525 fun pow ((e1, T_Real), (e2, T_Real)) =
229 :     if !Controls.doublePrecision
230 :     then (CL.mkApply("pow", [e1, e2]), T_Real)
231 :     else (CL.mkApply("powf", [e1, e2]), T_Real)
232 :     | pow _ = raise Fail "invalid arguments to pow"
233 :    
234 :     local
235 :     fun r2r (ff, fd) (e, T_Real) = if !Controls.doublePrecision
236 :     then (CL.mkApply(fd, [e]), T_Real)
237 :     else (CL.mkApply(ff, [e]), T_Real)
238 : jhr 551 | r2r (_, fd) e = invalid (fd, [e])
239 : jhr 525 in
240 :     val sin = r2r ("sinf", "sin")
241 :     val cos = r2r ("cosf", "cos")
242 :     val sqrt = r2r ("sqrtf", "sqrt")
243 :     end (* local *)
244 :    
245 : jhr 551 (* rounding *)
246 :     fun trunc (e, ty) = (CL.mkApply(RN.addTySuffix("trunc", ty), [e]), ty)
247 :     fun round (e, ty) = (CL.mkApply(RN.addTySuffix("round", ty), [e]), ty)
248 :     fun floor (e, ty) = (CL.mkApply(RN.addTySuffix("floor", ty), [e]), ty)
249 :     fun ceil (e, ty) = (CL.mkApply(RN.addTySuffix("ceil", ty), [e]), ty)
250 :    
251 : jhr 519 (* conversions *)
252 : jhr 551 fun toReal (e, T_Int) = (CL.mkCast(!RN.gRealTy, e), T_Real)
253 : jhr 548 | toReal e = invalid ("toReal", [e])
254 : jhr 525
255 : jhr 551 fun truncToInt (e as (_, T_Real)) = (CL.mkCast(!RN.gIntTy, #1(trunc e)), T_Int)
256 :     | truncToInt (e, T_Vec n) = (CL.mkApply(RN.truncToInt n, [e]), T_IVec n)
257 : jhr 548 | truncToInt e = invalid ("truncToInt", [e])
258 : jhr 551 fun roundToInt (e as (_, T_Real)) = (CL.mkCast(!RN.gIntTy, #1(round e)), T_Int)
259 : jhr 548 | roundToInt e = invalid ("roundToInt", [e])
260 : jhr 551 fun ceilToInt (e as (_, T_Real)) = (CL.mkCast(!RN.gIntTy, #1(floor e)), T_Int)
261 : jhr 548 | ceilToInt e = invalid ("ceilToInt", [e])
262 : jhr 551 fun floorToInt (e as (_, T_Real)) = (CL.mkCast(!RN.gIntTy, #1(ceil e)), T_Int)
263 : jhr 548 | floorToInt e = invalid ("floorToInt", [e])
264 : jhr 525
265 : jhr 519 (* runtime system hooks *)
266 : jhr 548 fun imageAddr (e, T_Image(_, rTy)) = let
267 : jhr 551 val cTy = CL.T_Ptr(!RN.gRealTy)
268 : jhr 528 in
269 : jhr 548 (CL.mkCast(cTy, CL.mkIndirect(e, "data")), T_Ptr rTy)
270 : jhr 528 end
271 : jhr 548 | imageAddr a = invalid("imageAddr", [a])
272 :     fun getImgData (e, T_Ptr rTy) = let
273 : jhr 551 val realTy as CL.T_Num rTy' = !RN.gRealTy
274 : jhr 548 val e = CL.E_UnOp(CL.%*, e)
275 :     in
276 :     if (rTy' = rTy)
277 :     then (e, T_Real)
278 :     else (CL.E_Cast(realTy, e), T_Real)
279 :     end
280 :     | getImgData a = invalid("getImgData", [a])
281 :     fun posToImgSpace ((img, T_Image(d, _)), (pos, T_Vec n)) = let
282 : jhr 551 val e = CL.mkApply(RN.toImageSpace d, [img, pos])
283 : jhr 548 in
284 :     (e, T_Vec n)
285 :     end
286 :     | posToImgSpace (a, b) = invalid("posToImgSpace", [a, b])
287 :     fun inside ((pos, T_Vec n), (img, T_Image(d, _)), s) = let
288 : jhr 551 val e = CL.mkApply(RN.inside d,
289 : jhr 547 [pos, img, CL.mkInt(IntInf.fromInt n, CL.int32)])
290 :     in
291 :     (e, T_Bool)
292 :     end
293 : jhr 548 | inside (a, b, _) = invalid("inside", [a, b])
294 : jhr 519
295 : jhr 547 end (* Expr *)
296 :    
297 : jhr 519 (* statement construction *)
298 : jhr 525 structure Stmt =
299 :     struct
300 :     val comment = CL.S_Comment
301 : jhr 547 fun assignState ((_, x), (e, _)) =
302 :     CL.mkAssign(CL.mkIndirect(CL.mkVar "selfOut", x), e)
303 : jhr 525 fun assign ((_, x), (e, _)) = CL.mkAssign(CL.mkVar x, e)
304 : jhr 528 fun decl ((ty, x), SOME(e, _)) = CL.mkDecl(cvtTy ty, x, SOME e)
305 :     | decl ((ty, x), NONE) = CL.mkDecl(cvtTy ty, x, NONE)
306 : jhr 525 val block = CL.mkBlock
307 : jhr 532 fun ifthen ((e, T_Bool), s1) = CL.mkIfThen(e, s1)
308 : jhr 525 fun ifthenelse ((e, T_Bool), s1, s2) = CL.mkIfThenElse(e, s1, s2)
309 : jhr 534 (* special Diderot forms *)
310 :     fun cons (lhs, args) = comment ["**** cons ****"] (* FIXME *)
311 : jhr 548 fun getImgData (lhs, n, e) = comment ["**** getImgData ****"] (* FIXME *)
312 : jhr 534 fun loadImage (lhs : var, dim, name : exp) = let
313 :     val sts = "sts"
314 : jhr 551 val imgTy = CL.T_Named(RN.imageTy dim)
315 :     val loadFn = RN.loadImage dim
316 : jhr 534 in [
317 :     CL.S_Decl(
318 :     statusTy, sts,
319 :     SOME(CL.E_Apply(loadFn, [#1 name, CL.mkUnOp(CL.%&, CL.E_Var(#2 lhs))])))
320 :     ] end
321 :     fun input (lhs : var, name, optDflt) = let
322 :     val sts = "sts"
323 : jhr 551 val inputFn = RN.input(#1 lhs)
324 : jhr 534 val lhs = CL.E_Var(#2 lhs)
325 :     val (initCode, hasDflt) = (case optDflt
326 :     of SOME(e, _) => ([CL.S_Assign(lhs, e)], true)
327 :     | NONE => ([], false)
328 :     (* end case *))
329 :     val code = [
330 :     CL.S_Decl(
331 :     statusTy, sts,
332 :     SOME(CL.E_Apply(inputFn, [
333 :     CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
334 :     ])))
335 :     ]
336 :     in
337 :     initCode @ code
338 :     end
339 : jhr 528 fun die () = comment ["**** die ****"] (* FIXME *)
340 :     fun stabilize () = comment ["**** stabilize ****"] (* FIXME *)
341 : jhr 519 end
342 :    
343 : jhr 544 structure Strand =
344 :     struct
345 :     fun define (Prog{strands, ...}, strandId) = let
346 :     val strand = Strand{
347 :     name = strandId,
348 : jhr 552 tyName = RN.strandTy strandId,
349 : jhr 544 state = ref [],
350 :     code = ref []
351 :     }
352 :     in
353 :     strands := strand :: !strands;
354 :     strand
355 :     end
356 :    
357 :     (* register the strand-state initialization code. The variables are the strand
358 :     * parameters.
359 :     *)
360 :     fun init (Strand{name, tyName, code, ...}, params, init) = let
361 : jhr 551 val fName = RN.strandInit name
362 : jhr 544 val params =
363 : jhr 547 CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
364 : jhr 544 List.map (fn (ty, x) => CL.PARAM([], cvtTy ty, x)) params
365 :     val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
366 :     in
367 :     code := initFn :: !code
368 :     end
369 : jhr 547
370 :     (* register a strand method *)
371 :     fun method (Strand{name, tyName, code, ...}, methName, body) = let
372 :     val fName = concat[name, "_", methName]
373 :     val params = [
374 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
375 :     CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")
376 :     ]
377 :     val methFn = CL.D_Func([], CL.int32, fName, params, body)
378 :     in
379 :     code := methFn :: !code
380 :     end
381 : jhr 544 end (* Strand *)
382 :    
383 :     fun genStrand (Strand{name, tyName, state, code}) = let
384 :     val selfTyDef = CL.D_StructDef(
385 :     List.rev (List.map (fn (ty, x) => (cvtTy ty, x)) (!state)),
386 :     tyName)
387 :     in
388 :     selfTyDef :: List.rev (!code)
389 :     end
390 :    
391 : jhr 533 fun generate (baseName, Prog{globals, topDecls, strands}) = let
392 : jhr 527 val fileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
393 :     val outS = TextIO.openOut fileName
394 :     val ppStrm = PrintAsC.new outS
395 : jhr 533 fun ppDecl dcl = PrintAsC.output(ppStrm, dcl)
396 : jhr 527 in
397 : jhr 533 List.app ppDecl (List.rev (!globals));
398 :     List.app ppDecl (List.rev (!topDecls));
399 : jhr 527 (* what about the strands, etc? *)
400 : jhr 544 List.app (fn strand => List.app ppDecl (genStrand strand)) (!strands);
401 : jhr 527 PrintAsC.close ppStrm;
402 :     TextIO.closeOut outS
403 :     end
404 :    
405 : jhr 519 end
406 :    
407 :     structure CBackEnd = CodeGenFn(CTarget)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0