Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/pure-cfg/src/compiler/cl-target/tree-to-cl.sml
ViewVC logotype

Annotation of /branches/pure-cfg/src/compiler/cl-target/tree-to-cl.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 918 - (view) (download)

1 : jhr 918 (* tree-to-cl.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Translate TreeIL to the OpenCL version of CLang.
7 :     *)
8 :    
9 :     structure TreeToCL : sig
10 :    
11 :     datatype var = V of (CLang.ty * CLang.var)
12 :    
13 :     type env = var TreeIL.Var.Map.map
14 :    
15 :     val trType : TreeIL.Ty.ty -> CLang.ty
16 :    
17 :     val trBlock : env * (env * TreeIL.exp list * CLang.stm -> CLang.stm list) * TreeIL.block -> CLang.stm
18 :    
19 :     val trAssign : env * TreeIL.var * TreeIL.exp -> CLang.stm list
20 :    
21 :     val trExp : env * TreeIL.exp -> CLang.exp
22 :    
23 :     (* vector indexing support. Arguments are: vector, arity, index *)
24 :     val ivecIndex : CLang.exp * int * int -> CLang.exp
25 :     val vecIndex : CLang.exp * int * int -> CLang.exp
26 :    
27 :     end = struct
28 :    
29 :     structure CL = CLang
30 :     structure RN = RuntimeNames
31 :     structure IL = TreeIL
32 :     structure Op = IL.Op
33 :     structure Ty = IL.Ty
34 :     structure V = IL.Var
35 :    
36 :     datatype var = V of (CLang.ty * CLang.var)
37 :    
38 :     type env = var TreeIL.Var.Map.map
39 :    
40 :     fun lookup (env, x) = (case V.Map.find (env, x)
41 :     of SOME(V(_, x')) => x'
42 :     | NONE => raise Fail(concat["lookup(_, ", V.name x, ")"])
43 :     (* end case *))
44 :    
45 :     (* integer literal expression *)
46 :     fun intExp (i : int) = CL.mkInt(IntInf.fromInt i, CL.int32)
47 :    
48 :     (* translate TreeIL types to CLang types *)
49 :     fun trType ty = (case ty
50 :     of Ty.BoolTy => CLang.T_Named "bool"
51 :     | Ty.StringTy => CL.charPtr
52 :     | Ty.IVecTy 1 => !RN.gIntTy
53 :     | Ty.IVecTy n => CL.T_Named(RN.ivecTy n)
54 :     | Ty.TensorTy[] => !RN.gRealTy
55 :     | Ty.TensorTy[n] => CL.T_Named(RN.vecTy n)
56 :     | Ty.TensorTy[n, m] => CL.T_Named(RN.matTy(n,m))
57 :     | Ty.AddrTy(ImageInfo.ImgInfo{ty=([], rTy), ...}) => CL.T_Ptr(CL.T_Num rTy)
58 :     | Ty.ImageTy(ImageInfo.ImgInfo{dim, ...}) => CL.T_Ptr(CL.T_Named(RN.imageTy dim))
59 :     | _ => raise Fail(concat["TreeToC.trType(", Ty.toString ty, ")"])
60 :     (* end case *))
61 :    
62 :     (* generate new variables *)
63 :     local
64 :     val count = ref 0
65 :     fun freshName prefix = let
66 :     val n = !count
67 :     in
68 :     count := n+1;
69 :     concat[prefix, "_", Int.toString n]
70 :     end
71 :     in
72 :     fun tmpVar ty = freshName "tmp"
73 :     fun freshVar prefix = freshName prefix
74 :     end (* local *)
75 :    
76 :     (* translate IL basis functions *)
77 :     local
78 :     fun mkLookup suffix = let
79 :     val tbl = ILBasis.Tbl.mkTable (16, Fail "basis table")
80 :     fun ins f = ILBasis.Tbl.insert tbl (f, ILBasis.toString f ^ suffix)
81 :     in
82 :     List.app ins ILBasis.allFuns;
83 :     ILBasis.Tbl.lookup tbl
84 :     end
85 :     val fLookup = mkLookup "f"
86 :     val dLookup = mkLookup ""
87 :     in
88 :     fun trApply (f, args) = let
89 :     val f' = if !Controls.doublePrecision then dLookup f else fLookup f
90 :     in
91 :     CL.mkApply(f', args)
92 :     end
93 :     end (* local *)
94 :    
95 :     (* vector indexing support. Arguments are: vector, arity, index *)
96 :     fun ivecIndex (v, n, ix) = let
97 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !RN.gIntSuffix, "_t"])
98 :     val e1 = CL.mkCast(unionTy, v)
99 :     val e2 = CL.mkSelect(e1, "i")
100 :     in
101 :     CL.mkSubscript(e2, intExp ix)
102 :     end
103 :    
104 :     fun vecIndex (v, n, ix) = let
105 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !RN.gRealSuffix, "_t"])
106 :     val e1 = CL.mkCast(unionTy, v)
107 :     val e2 = CL.mkSelect(e1, "r")
108 :     in
109 :     CL.mkSubscript(e2, intExp ix)
110 :     end
111 :    
112 :     (* translate a variable use *)
113 :     fun trVar (env, x) = (case V.kind x
114 :     of IL.VK_Global => CL.mkVar(lookup(env, x))
115 :     | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x))
116 :     | IL.VK_Local => CL.mkVar(lookup(env, x))
117 :     (* end case *))
118 :    
119 :     (* Translate a TreeIL operator application to a CLang expression *)
120 :     fun trOp (rator, args) = (case (rator, args)
121 :     of (Op.Add ty, [a, b]) =>
122 :     CL.mkBinOp(a, CL.#+, b)
123 :     | (Op.Sub ty, [a, b]) =>
124 :     CL.mkBinOp(a, CL.#-, b)
125 :     | (Op.Mul ty, [a, b]) =>
126 :     CL.mkBinOp(a, CL.#*, b)
127 :     | (Op.Div ty, [a, b]) =>
128 :     CL.mkBinOp(a, CL.#/, b)
129 :     | (Op.Neg ty, [a]) =>
130 :     CL.mkUnOp(CL.%-, a)
131 :     | (Op.Abs(Ty.IVecTy 1), args) =>
132 :     CL.mkApply("abs", args)
133 :     | (Op.Abs(Ty.TensorTy[]), args) =>
134 :     CL.mkApply(RN.fabs(), args)
135 :     | (Op.Abs ty, [a]) =>
136 :     raise Fail(concat["Abs<", Ty.toString ty, ">"])
137 :     | (Op.LT ty, [a, b]) =>
138 :     CL.mkBinOp(a, CL.#<, b)
139 :     | (Op.LTE ty, [a, b]) =>
140 :     CL.mkBinOp(a, CL.#<=, b)
141 :     | (Op.EQ ty, [a, b]) =>
142 :     CL.mkBinOp(a, CL.#==, b)
143 :     | (Op.NEQ ty, [a, b]) =>
144 :     CL.mkBinOp(a, CL.#!=, b)
145 :     | (Op.GTE ty, [a, b]) =>
146 :     CL.mkBinOp(a, CL.#>=, b)
147 :     | (Op.GT ty, [a, b]) =>
148 :     CL.mkBinOp(a, CL.#>, b)
149 :     | (Op.Not, [a]) =>
150 :     CL.mkUnOp(CL.%!, a)
151 :     | (Op.Max, args) =>
152 :     CL.mkApply(RN.max(), args)
153 :     | (Op.Min, args) =>
154 :     CL.mkApply(RN.min(), args)
155 :     | (Op.Lerp ty, args) => (case ty
156 :     of Ty.TensorTy[] => CL.mkApply(RN.lerp 1, args)
157 :     | Ty.TensorTy[n] => CL.mkApply(RN.lerp n, args)
158 :     | _ => raise Fail(concat[
159 :     "lerp<", Ty.toString ty, "> not supported"
160 :     ])
161 :     (* end case *))
162 :     | (Op.Dot d, args) =>
163 :     CL.E_Apply(RN.dot d, args)
164 :     | (Op.MulVecMat(m, n), args) =>
165 :     if (1 < m) andalso (m < 4) andalso (m = n)
166 :     then CL.E_Apply(RN.mulVecMat(m,n), args)
167 :     else raise Fail "unsupported vector-matrix multiply"
168 :     | (Op.MulMatVec(m, n), args) =>
169 :     if (1 < m) andalso (m < 4) andalso (m = n)
170 :     then CL.E_Apply(RN.mulMatVec(m,n), args)
171 :     else raise Fail "unsupported matrix-vector multiply"
172 :     | (Op.MulMatMat(m, n, p), args) =>
173 :     if (1 < m) andalso (m < 4) andalso (m = n) andalso (n = p)
174 :     then CL.E_Apply(RN.mulMatMat(m,n,p), args)
175 :     else raise Fail "unsupported matrix-matrix multiply"
176 :     | (Op.Cross, args) =>
177 :     CL.E_Apply(RN.cross(), args)
178 :     | (Op.Select(Ty.IVecTy n, i), [a]) =>
179 :     ivecIndex (a, n, i)
180 :     | (Op.Select(Ty.TensorTy[n], i), [a]) =>
181 :     vecIndex (a, n, i)
182 :     | (Op.Norm(Ty.TensorTy[n]), args) =>
183 :     CL.E_Apply(RN.length n, args)
184 :     | (Op.Norm(Ty.TensorTy[m,n]), args) =>
185 :     CL.E_Apply(RN.norm(m,n), args)
186 :     | (Op.Normalize d, args) =>
187 :     CL.E_Apply(RN.normalize d, args)
188 :     | (Op.Trace n, args) =>
189 :     CL.E_Apply(RN.trace n, args)
190 :     | (Op.Scale(Ty.TensorTy[n]), args) =>
191 :     CL.E_Apply(RN.scale n, args)
192 :     | (Op.CL, _) =>
193 :     raise Fail "CL unimplemented"
194 :     | (Op.PrincipleEvec ty, _) =>
195 :     raise Fail "PrincipleEvec unimplemented"
196 :     | (Op.Subscript(Ty.IVecTy n), [v, ix]) => let
197 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !RN.gIntSuffix, "_t"])
198 :     val vecExp = CL.mkSelect(CL.mkCast(unionTy, v), "i")
199 :     in
200 :     CL.mkSubscript(vecExp, ix)
201 :     end
202 :     | (Op.Subscript(Ty.TensorTy[n]), [v, ix]) => let
203 :     val unionTy = CL.T_Named(concat["union", Int.toString n, !RN.gRealSuffix, "_t"])
204 :     val vecExp = CL.mkSelect(CL.mkCast(unionTy, v), "r")
205 :     in
206 :     CL.mkSubscript(vecExp, ix)
207 :     end
208 :     | (Op.Subscript(Ty.TensorTy[_,_]), [m, ix, jx]) =>
209 :     CL.mkSubscript(CL.mkSelect(CL.mkSubscript(m, ix), "r"), jx)
210 :     | (Op.Subscript ty, t::(ixs as _::_)) =>
211 :     raise Fail(concat["Subscript<", Ty.toString ty, "> unsupported"])
212 :     | (Op.Ceiling d, args) =>
213 :     CL.mkApply(RN.addTySuffix("ceil", d), args)
214 :     | (Op.Floor d, args) =>
215 :     CL.mkApply(RN.addTySuffix("floor", d), args)
216 :     | (Op.Round d, args) =>
217 :     CL.mkApply(RN.addTySuffix("round", d), args)
218 :     | (Op.Trunc d, args) =>
219 :     CL.mkApply(RN.addTySuffix("trunc", d), args)
220 :     | (Op.IntToReal, [a]) =>
221 :     CL.mkCast(!RN.gRealTy, a)
222 :     | (Op.RealToInt 1, [a]) =>
223 :     CL.mkCast(!RN.gIntTy, a)
224 :     | (Op.RealToInt d, args) =>
225 :     CL.mkApply(RN.vecftoi d, args)
226 :     (* FIXME: need type info *)
227 :     | (Op.ImageAddress(ImageInfo.ImgInfo{ty=(_,rTy), ...}), [a]) => let
228 :     val cTy = CL.T_Ptr(CL.T_Num rTy)
229 :     in
230 :     CL.mkCast(cTy, CL.mkIndirect(a, "data"))
231 :     end
232 :     | (Op.LoadVoxels(info, 1), [a]) => let
233 :     val realTy as CL.T_Num rTy = !RN.gRealTy
234 :     val a = CL.E_UnOp(CL.%*, a)
235 :     in
236 :     if (rTy = ImageInfo.sampleTy info)
237 :     then a
238 :     else CL.E_Cast(realTy, a)
239 :     end
240 :     | (Op.LoadVoxels _, [a]) =>
241 :     raise Fail("impossible " ^ Op.toString rator)
242 :     | (Op.PosToImgSpace(ImageInfo.ImgInfo{dim, ...}), [img, pos]) =>
243 :     CL.mkApply(RN.toImageSpace dim, [img, pos])
244 :     | (Op.GradToWorldSpace d, [v, x]) =>
245 :     raise Fail "GradToWorldSpace unimplemented"
246 :     | (Op.LoadImage info, [a]) =>
247 :     raise Fail("impossible " ^ Op.toString rator)
248 :     | (Op.Inside(ImageInfo.ImgInfo{dim, ...}, s), [pos, img]) =>
249 :     CL.mkApply(RN.inside dim, [pos, img, intExp s])
250 :     | (Op.Input(ty, name), []) =>
251 :     raise Fail("impossible " ^ Op.toString rator)
252 :     | (Op.InputWithDefault(ty, name), [a]) =>
253 :     raise Fail("impossible " ^ Op.toString rator)
254 :     | _ => raise Fail(concat[
255 :     "unknown or incorrect operator ", Op.toString rator
256 :     ])
257 :     (* end case *))
258 :    
259 :     fun trExp (env, e) = (case e
260 :     of IL.E_Var x => trVar (env, x)
261 :     | IL.E_Lit(Literal.Int n) => CL.mkInt(n, !RN.gIntTy)
262 :     | IL.E_Lit(Literal.Bool b) => CL.mkBool b
263 :     | IL.E_Lit(Literal.Float f) => CL.mkFlt(f, !RN.gRealTy)
264 :     | IL.E_Lit(Literal.String s) => CL.mkStr s
265 :     | IL.E_Op(rator, args) => trOp (rator, trExps(env, args))
266 :     | IL.E_Apply(f, args) => trApply(f, trExps(env, args))
267 :     | IL.E_Cons(Ty.TensorTy[n], args) => CL.mkApply(RN.mkVec n, trExps(env, args))
268 :     | IL.E_Cons(ty, _) => raise Fail(concat["E_Cons(", Ty.toString ty, ", _) in expression"])
269 :     (* end case *))
270 :    
271 :     and trExps (env, exps) = List.map (fn exp => trExp(env, exp)) exps
272 :    
273 :     fun trAssign (env, lhs, rhs) = let
274 :     val lhs = (case V.kind lhs
275 :     of IL.VK_Global => CL.mkVar(lookup(env, lhs))
276 :     | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, lhs))
277 :     | IL.VK_Local => CL.mkVar(lookup(env, lhs))
278 :     (* end case *))
279 :     in
280 :     (* certain rhs forms, such as those that return a matrix,
281 :     * require a function call instead of an assignment
282 :     *)
283 :     case rhs
284 :     of IL.E_Op(Op.Add(Ty.TensorTy[m,n]), args) =>
285 :     [CL.mkCall(RN.addMat(m,n), lhs :: trExps(env, args))]
286 :     | IL.E_Op(Op.Sub(Ty.TensorTy[m,n]), args) =>
287 :     [CL.mkCall(RN.subMat(m,n), lhs :: trExps(env, args))]
288 :     | IL.E_Op(Op.Neg(Ty.TensorTy[m,n]), args) =>
289 :     [CL.mkCall(RN.scaleMat(m,n), lhs :: intExp ~1 :: trExps(env, args))]
290 :     | IL.E_Op(Op.Scale(Ty.TensorTy[m,n]), args) =>
291 :     [CL.mkCall(RN.scaleMat(m,n), lhs :: trExps(env, args))]
292 :     | IL.E_Op(Op.MulMatMat(m,n,p), args) =>
293 :     [CL.mkCall(RN.mulMatMat(m,n,p), lhs :: trExps(env, args))]
294 :     | IL.E_Op(Op.Identity n, args) =>
295 :     [CL.mkCall(RN.identityMat n, [lhs])]
296 :     | IL.E_Op(Op.Zero(Ty.TensorTy[n,m]), args) =>
297 :     [CL.mkCall(RN.zeroMat(m,n), [lhs])]
298 :     | IL.E_Op(Op.LoadVoxels(info, n), [a]) =>
299 :     if (n > 1)
300 :     then let
301 :     val stride = ImageInfo.stride info
302 :     val rTy = ImageInfo.sampleTy info
303 :     val vp = freshVar "vp"
304 :     val needsCast = (CL.T_Num rTy <> !RN.gRealTy)
305 :     fun mkLoad i = let
306 :     val e = CL.mkSubscript(CL.mkVar vp, intExp(i*stride))
307 :     in
308 :     if needsCast then CL.mkCast(!RN.gRealTy, e) else e
309 :     end
310 :     in [
311 :     CL.mkDecl(CL.T_Ptr(CL.T_Num rTy), vp, SOME(CL.I_Exp(trExp(env, a)))),
312 :     CL.mkAssign(lhs,
313 :     CL.mkApply(RN.mkVec n, List.tabulate (n, mkLoad)))
314 :     ] end
315 :     else [CL.mkAssign(lhs, trExp(env, rhs))]
316 :     | IL.E_Cons(Ty.TensorTy[n,m], args) => let
317 :     (* matrices are represented as arrays of union<d><ty>_t vectors *)
318 :     fun doRows (_, []) = []
319 :     | doRows (i, e::es) =
320 :     CL.mkAssign(CL.mkSelect(CL.mkSubscript(lhs, intExp i), "v"), e)
321 :     :: doRows (i+1, es)
322 :     in
323 :     doRows (0, trExps(env, args))
324 :     end
325 :     | IL.E_Var x => (case IL.Var.ty x
326 :     of Ty.TensorTy[n,m] => [CL.mkCall(RN.copyMat(n,m), [lhs, trVar(env, x)])]
327 :     | _ => [CL.mkAssign(lhs, trVar(env, x))]
328 :     (* end case *))
329 :     | _ => [CL.mkAssign(lhs, trExp(env, rhs))]
330 :     (* end case *)
331 :     end
332 :    
333 :     fun trBlock (env : env, saveState, blk) = let
334 :     (* generate code to check the status of runtime-system calls *)
335 :     fun checkSts mkDecl = let
336 :     val sts = freshVar "sts"
337 :     in
338 :     mkDecl sts @
339 :     [CL.mkIfThen(
340 :     CL.mkBinOp(CL.mkVar "DIDEROT_OK", CL.#!=, CL.mkVar sts),
341 :     CL.mkCall("exit", [intExp 1]))]
342 :     end
343 :     fun trStmt (env, stm) = (case stm
344 :     of IL.S_Comment text => [CL.mkComment text]
345 :     | IL.S_Assign(x, exp) => trAssign (env, x, exp)
346 :     | IL.S_IfThen(cond, thenBlk) =>
347 :     [CL.mkIfThen(trExp(env, cond), trBlk(env, thenBlk))]
348 :     | IL.S_IfThenElse(cond, thenBlk, elseBlk) =>
349 :     [CL.mkIfThenElse(trExp(env, cond),
350 :     trBlk(env, thenBlk),
351 :     trBlk(env, elseBlk))]
352 :     | IL.S_LoadImage(lhs, dim, name) => checkSts (fn sts => let
353 :     val lhs = lookup(env, lhs)
354 :     val name = trExp(env, name)
355 :     val imgTy = CL.T_Named(RN.imageTy dim)
356 :     val loadFn = RN.loadImage dim
357 :     in [
358 :     CL.mkDecl(
359 :     CL.T_Named RN.statusTy, sts,
360 :     SOME(CL.I_Exp(CL.E_Apply(loadFn, [name, CL.mkUnOp(CL.%&, CL.E_Var lhs)]))))
361 :     ] end)
362 :     | IL.S_Input(lhs, name, optDflt) => checkSts (fn sts => let
363 :     val inputFn = RN.input(V.ty lhs)
364 :     val lhs = lookup(env, lhs)
365 :     val lhs = CL.E_Var lhs
366 :     val (initCode, hasDflt) = (case optDflt
367 :     of SOME e => ([CL.mkAssign(lhs, trExp(env, e))], true)
368 :     | NONE => ([], false)
369 :     (* end case *))
370 :     val code = [
371 :     CL.mkDecl(
372 :     CL.T_Named RN.statusTy, sts,
373 :     SOME(CL.I_Exp(CL.E_Apply(inputFn, [
374 :     CL.E_Str name, CL.mkUnOp(CL.%&, lhs), CL.mkBool hasDflt
375 :     ]))))
376 :     ]
377 :     in
378 :     initCode @ code
379 :     end)
380 :     (* FIXME: what about the args? *)
381 :     | IL.S_Exit args => [CL.mkReturn NONE]
382 :     | IL.S_Active args =>
383 :     saveState (env, args, CL.mkReturn(SOME(CL.mkVar RN.kActive)))
384 :     | IL.S_Stabilize args =>
385 :     saveState (env, args, CL.mkReturn(SOME(CL.mkVar RN.kStabilize)))
386 :     | IL.S_Die => [CL.mkReturn(SOME(CL.mkVar RN.kDie))]
387 :     (* end case *))
388 :     and trBlk (env, IL.Block{locals, body}) = let
389 :     val env = List.foldl
390 :     (fn (x, env) => V.Map.insert(env, x, V(trType(V.ty x), V.name x)))
391 :     env locals
392 :     val stms = List.foldr (fn (stm, stms) => trStmt(env, stm)@stms) [] body
393 :     fun mkDecl (x, stms) = (case V.Map.find (env, x)
394 :     of SOME(V(ty, x')) => CL.mkDecl(ty, x', NONE) :: stms
395 :     | NONE => raise Fail(concat["mkDecl(", V.name x, ", _)"])
396 :     (* end case *))
397 :     val stms = List.foldr mkDecl stms locals
398 :     in
399 :     CL.mkBlock stms
400 :     end
401 :     in
402 :     trBlk (env, blk)
403 :     end
404 :    
405 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0