Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/cxx-util/gen-tys-and-ops.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/cxx-util/gen-tys-and-ops.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4351 - (view) (download)

1 : jhr 3918 (* gen-tys-and-ops.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure GenTysAndOps : sig
10 :    
11 : jhr 4351 val gen : CodeGenEnv.t * CollectInfo.t -> {
12 :     preWorld : CLang.decl list, (* decls to place before world outside of namespace *)
13 :     postWorld : CLang.decl list (* decls to place after world inside of namespace *)
14 :     }
15 : jhr 3918
16 :     end = struct
17 :    
18 :     structure IR = TreeIR
19 :     structure Ty = TreeTypes
20 :     structure CL = CLang
21 :     structure RN = CxxNames
22 :     structure Env = CodeGenEnv
23 :    
24 : jhr 3927 val zero = RealLit.zero false
25 :    
26 :     fun mkReturn exp = CL.mkReturn(SOME exp)
27 : jhr 3931 fun mkInt i = CL.mkInt(IntInf.fromInt i)
28 : jhr 4028 fun mkFunc (ty, name, params, body) = CL.D_Func(["inline"], ty, [], name, params, body)
29 : jhr 4033 (* make a constructor function's prototype and out-of-line definition *)
30 :     fun mkConstr (cls, params, inits) = (
31 : jhr 4317 CL.D_Constr([], [], cls, params, NONE),
32 :     CL.D_Constr(["inline"], [CL.SC_Type(CL.T_Named cls)], cls, params, SOME(inits, CL.mkBlock[]))
33 :     )
34 : jhr 4033 (* make a member function prototype and out-of-line definition *)
35 :     fun mkMemberFn (cls, ty, f, params, body) = (
36 : jhr 4317 CL.D_Proto([], ty, f, params),
37 :     CL.D_Func(["inline"], ty, [CL.SC_Type(CL.T_Named cls)], f, params, body)
38 :     )
39 : jhr 3927
40 : jhr 3994 fun genTyDecl env = let
41 : jhr 4317 val (realTy, realTyName, realTySz) = if #double(Env.target env)
42 :     then (CL.double, "double", 8)
43 :     else (CL.float, "float", 4)
44 :     (* generate the type and member declarations for a recorded type *)
45 :     fun genDecl (ty, (tyDcls, fnDefs)) = (case ty
46 :     of Ty.VecTy(w, pw) => let
47 :     val cTyName = RN.vecTyName w
48 :     val cTy = CL.T_Named cTyName
49 :     val typedefDcl = CL.D_Verbatim[concat[
50 :     "typedef ", realTyName, " ", cTyName,
51 :     " __attribute__ ((vector_size (",
52 :     Int.toString(realTySz * pw), ")));"
53 :     ]]
54 :     in
55 :     (typedefDcl :: tyDcls, fnDefs)
56 :     end
57 :     | Ty.TensorRefTy shape => let
58 :     val name = RN.tensorRefStruct shape
59 :     val baseCls = concat[
60 :     "diderot::tensor_ref<", realTyName, ",",
61 :     Int.toString(List.foldl Int.* 1 shape), ">"
62 :     ]
63 :     fun mkConstr' (paramTy, paramId, arg) = mkConstr (
64 :     name,
65 :     [CL.PARAM([], paramTy, paramId)],
66 :     [CL.mkApply(baseCls, [arg])])
67 :     (* constructor from float/double pointer *)
68 :     val (constrProto1, constrDef1) = mkConstr' (
69 :     CL.constPtrTy realTy, "src", CL.mkVar "src")
70 :     (* constructor from tensor struct *)
71 :     val (constrProto2, constrDef2) = mkConstr' (
72 :     CL.T_Named(concat["struct ", RN.tensorStruct shape, " const &"]),
73 :     "ten", CL.mkSelect(CL.mkVar "ten", "_data"))
74 :     (* copy constructor *)
75 :     val (constrProto3, constrDef3) = mkConstr' (
76 :     CL.T_Named(name ^ " const &"), "ten",
77 :     CL.mkSelect(CL.mkVar "ten", "_data"))
78 :     val thisData = CL.mkIndirect(CL.mkVar "this", "_data")
79 :     (* last vector as tensor_ref *)
80 :     val lastDcl = (case shape
81 :     of [] => raise Fail "unexpected TensorRef[]"
82 :     | [_] => []
83 :     | _::dd => let
84 :     val d = List.last dd
85 :     in [
86 :     CL.D_Func([], RN.tensorRefTy[d], [], "last",
87 :     [CL.PARAM([], CL.uint32, "i")],
88 :     CL.mkReturn(
89 :     SOME(CL.mkAddrOf(CL.mkSubscript(thisData, CL.mkVar "i")))))
90 :     ] end
91 :     (* end case *))
92 :     val members = constrProto1 :: constrProto2 :: constrProto3 :: lastDcl
93 :     val structDcl = CL.D_ClassDef{
94 :     name = name,
95 :     args = NONE,
96 :     from = SOME("public " ^ baseCls),
97 :     public = members,
98 :     protected = [],
99 :     private = []
100 :     }
101 :     in
102 :     (structDcl :: tyDcls, constrDef1 :: constrDef2 :: constrDef3 :: fnDefs)
103 :     end
104 :     | Ty.TensorTy shape => let
105 :     val len = List.foldl Int.* 1 shape
106 :     val name = RN.tensorStruct shape
107 :     val baseCls = concat[
108 :     "diderot::tensor<", realTyName, ",",
109 :     Int.toString(List.foldl Int.* 1 shape), ">"
110 :     ]
111 :     fun mkConstr (paramTy, paramId, arg) = CL.D_Constr (
112 :     [], [], name,
113 :     [CL.PARAM([], paramTy, paramId)],
114 :     SOME([CL.mkApply(baseCls, [arg])], CL.mkBlock[]))
115 :     (* default constructor *)
116 :     val constrDcl1 = CL.D_Constr (
117 :     [], [], name, [], SOME([CL.mkApply(baseCls, [])], CL.mkBlock[]))
118 :     (* constructor from initializer list *)
119 :     val constrDcl2 = mkConstr (
120 :     CL.T_Template("std::initializer_list", [realTy]), "const & il",
121 :     CL.mkVar "il")
122 :     (* constructor from float/double pointer *)
123 :     val constrDcl3 = mkConstr (
124 :     CL.constPtrTy realTy, "src", CL.mkVar "src")
125 :     (* copy constructor *)
126 :     val constrDcl4 = mkConstr (
127 :     CL.T_Named(name ^ " const &"), "ten",
128 :     CL.mkSelect(CL.mkVar "ten", "_data"))
129 :     (* destructor *)
130 :     val destrDcl = CL.D_Destr([], [], name, SOME(CL.mkBlock[]))
131 :     val thisData = CL.mkIndirect(CL.mkVar "this", "_data")
132 :     val returnThis = CL.mkReturn(SOME(CL.mkUnOp(CL.%*, CL.mkVar "this")))
133 :     (* assignment from Tensor *)
134 :     val (assignProto1, assignDef1) = mkMemberFn(name,
135 :     CL.T_Named(name ^ " &"), "operator=",
136 :     [CL.PARAM([], CL.T_Named name, "const & src")],
137 :     CL.mkBlock[
138 :     CL.mkCall("this->copy", [CL.mkSelect(CL.mkVar "src", "_data")]),
139 :     returnThis
140 :     ])
141 :     (* assignment from TensorRef *)
142 :     val (assignProto2, assignDef2) = mkMemberFn(name,
143 :     CL.T_Named(name ^ " &"), "operator=",
144 :     [CL.PARAM([], CL.T_Named(RN.tensorRefStruct shape), "const & src")],
145 :     CL.mkBlock[
146 :     CL.mkCall("this->copy", [CL.mkSelect(CL.mkVar "src", "_data")]),
147 :     returnThis
148 :     ])
149 :     (* assignment from initializer list *)
150 :     val (assignProto3, assignDef3) = mkMemberFn(name,
151 :     CL.T_Named(name ^ " &"), "operator=",
152 :     [CL.PARAM([], CL.T_Template("std::initializer_list", [realTy]), "const & il")],
153 :     CL.mkBlock[
154 :     CL.mkCall("this->copy", [CL.mkVar "il"]),
155 :     returnThis
156 :     ])
157 :     (* assignment from array *)
158 :     val (assignProto4, assignDef4) = mkMemberFn(name,
159 :     CL.T_Named(name ^ " &"), "operator=",
160 :     [CL.PARAM([], CL.constPtrTy realTy, "src")],
161 :     CL.mkBlock[
162 :     CL.mkCall("this->copy", [CL.mkVar "src"]),
163 :     returnThis
164 :     ])
165 :     (* last vector as tensor_ref *)
166 :     val lastDcl = (case shape
167 :     of [] => raise Fail "unexpected TensorTy[]"
168 :     | [_] => []
169 :     | _::dd => let
170 :     val d = List.last dd
171 :     in [
172 :     CL.D_Func([], RN.tensorRefTy[d], [], "last",
173 :     [CL.PARAM([], CL.uint32, "i")],
174 :     CL.mkReturn(
175 :     SOME(CL.mkAddrOf(CL.mkSubscript(thisData, CL.mkVar "i")))))
176 :     ] end
177 :     (* end case *))
178 :     val structDcl = CL.D_ClassDef{
179 :     name = name,
180 :     args = NONE,
181 :     from = SOME("public " ^ baseCls),
182 :     public =
183 :     constrDcl1 :: constrDcl2 :: constrDcl3 :: constrDcl4 ::
184 :     destrDcl ::
185 :     assignProto1 :: assignProto2 :: assignProto3 :: assignProto4 ::
186 :     lastDcl,
187 :     protected = [],
188 :     private = []
189 :     }
190 :     val fnDefs = assignDef1 :: assignDef2 :: assignDef3 :: assignDef4 :: fnDefs
191 :     in
192 :     (structDcl :: tyDcls, fnDefs)
193 :     end
194 :     | Ty.TupleTy tys => raise Fail "FIXME: TupleTy"
195 : jhr 3918 (* TODO
196 : jhr 4317 | Ty.SeqTy(ty, NONE) =>
197 :     | Ty.SeqTy(ty, SOME n) =>
198 : jhr 3918 *)
199 : jhr 4317 | ty => (tyDcls, fnDefs)
200 :     (* end case *))
201 :     in
202 :     genDecl
203 :     end
204 : jhr 3918
205 : jhr 4014 fun genSeqTrait env = let
206 : jhr 4317 val ns = #namespace(Env.target env)
207 :     val realTy = Env.realTy env
208 :     fun trType ty = TypeToCxx.trQType(env, TypeToCxx.NSDiderot, ty)
209 :     fun trait ({argTy, baseTy, elemTy, ndims, dims}, dcls) = let
210 :     (* the name of the teem function table for the given base type *)
211 :     val loadTbl = (case baseTy
212 :     of Ty.BoolTy => "nrrdILoad"
213 :     | Ty.IntTy => "nrrdILoad"
214 :     | Ty.VecTy(1, 1) => if #double(Env.target env)
215 :     then "nrrdDLoad"
216 :     else "nrrdFLoad"
217 :     | ty => raise Fail("genSeqTrait.loadFn: unexpected type " ^ Ty.toString ty)
218 :     (* end case *))
219 :     val loadTblTy = CL.constPtrTy(CL.T_Named "__details::load_fn_ptr<base_type>")
220 :     val dimArrTy = CL.T_Array(CL.uint32, SOME ndims)
221 :     val seqTy = CL.T_Template("dynseq_traits", [argTy])
222 :     val scope = CL.SC_Type seqTy
223 :     in
224 :     CL.D_Template([], CL.D_ClassDef{
225 :     name = "dynseq_traits",
226 :     args = SOME[argTy],
227 :     from = NONE,
228 :     public = [
229 :     CL.D_Typedef("value_type", elemTy),
230 :     CL.D_Typedef("base_type", trType baseTy),
231 :     CL.D_Var(
232 :     ["static"],
233 :     CL.constPtrTy(CL.T_Named "__details::load_fn_ptr<base_type>"),
234 :     [], "load_fn_tbl", NONE),
235 :     CL.D_Var(
236 :     ["static", "const"], CL.uint32, [], "ndims", SOME(CL.I_Exp(mkInt ndims))),
237 :     CL.D_Var(
238 :     ["static", "const"], dimArrTy, [], "dims", NONE)
239 :     ],
240 :     protected = [],
241 :     private = []
242 :     }) ::
243 :     CL.D_Var(
244 :     ["const"],
245 :     CL.T_Ptr(
246 :     CL.T_Template("__details::load_fn_ptr", [CL.T_Member(seqTy, "base_type")])),
247 :     [scope], "load_fn_tbl",
248 :     SOME(CL.I_Exp(CL.mkVar loadTbl))) ::
249 :     CL.D_Var(
250 :     ["const"], dimArrTy, [scope], "dims",
251 :     SOME(CL.I_Exps(List.map (CL.I_Exp o mkInt) dims))) ::
252 :     dcls
253 :     end
254 :     fun genTrait (ty, dcls) = (case ty
255 :     of Ty.SeqTy(argTy, NONE) => let
256 :     fun baseTy (Ty.SeqTy(ty, _)) = baseTy ty
257 :     | baseTy (Ty.TensorTy[]) = Ty.realTy
258 :     | baseTy ty = ty
259 :     val argTy = trType argTy
260 :     (* for sequences of scalar values, we set nDims to 0 so that it matches the
261 :     * format of a nrrd, where the dimension is not represented.
262 :     *)
263 :     fun scalarSeqTrait ty = trait ({
264 :     argTy = argTy, baseTy = ty, elemTy = argTy,
265 :     ndims = 0, dims = []
266 :     },
267 :     dcls)
268 :     in
269 :     case baseTy ty
270 :     of ty as Ty.TensorTy(shp as _::_) => trait ({
271 :     argTy = argTy, baseTy = Ty.realTy,
272 :     elemTy = argTy, ndims = List.length shp,
273 :     dims = shp
274 :     },
275 :     dcls)
276 :     | ty as Ty.BoolTy => scalarSeqTrait ty
277 :     | ty as Ty.IntTy => scalarSeqTrait ty
278 :     | ty as Ty.VecTy(1, 1) => scalarSeqTrait ty
279 :     | ty => raise Fail "FIXME: unsupported dynamic sequence type"
280 :     (* end case *)
281 :     end
282 :     | _ => dcls
283 :     (* end case *))
284 :     in
285 :     genTrait
286 :     end
287 : jhr 4014
288 :     datatype operation = datatype CollectInfo.operation
289 :    
290 : jhr 3918 val ostreamRef = CL.T_Named "std::ostream&"
291 :    
292 :     fun output (e, e') = CL.mkBinOp(e, CL.#<<, e')
293 :    
294 :     (* generate code for the expression "e << s", where "s" is string literal *)
295 :     fun outString (CL.E_BinOp(e, CL.#<<, CL.E_Str s1), s2) =
296 : jhr 4317 output (e, CL.mkStr(s1 ^ String.toCString s2))
297 : jhr 3918 | outString (e, s) = output (e, CL.mkStr(String.toCString s))
298 :    
299 :     (* generate a printing function for tensors with the given shape *)
300 :     fun genTensorPrinter shape = let
301 : jhr 4317 fun ten i = CL.mkSubscript(CL.mkSelect(CL.mkVar "ten", "_data"), mkInt i)
302 :     fun prefix (true, lhs) = lhs
303 :     | prefix (false, lhs) = outString(lhs, ",")
304 :     fun lp (isFirst, lhs, i, [d]) = let
305 :     fun lp' (_, lhs, i, 0) = (i, outString(lhs, "]"))
306 :     | lp' (isFirst, lhs, i, n) =
307 :     lp' (false, output (prefix (isFirst, lhs), ten i), i+1, n-1)
308 :     in
309 :     lp' (true, outString(lhs, "["), i, d)
310 :     end
311 :     | lp (isFirst, lhs, i, d::dd) = let
312 :     fun lp' (_, lhs, i, 0) = (i, outString(lhs, "]"))
313 :     | lp' (isFirst, lhs, i, n) = let
314 :     val (i, lhs) = lp (true, prefix (isFirst, lhs), i, dd)
315 :     in
316 :     lp' (false, lhs, i, n-1)
317 :     end
318 :     in
319 :     lp' (true, outString(lhs, "["), i, d)
320 :     end
321 :     val params = [
322 :     CL.PARAM([], ostreamRef, "outs"),
323 :     CL.PARAM([], RN.tensorRefTy shape, "const & ten")
324 :     ]
325 :     val (_, exp) = lp (true, CL.mkVar "outs", 0, shape)
326 :     in
327 :     CL.D_Func(["static"], ostreamRef, [], "operator<<", params, mkReturn exp)
328 :     end
329 : jhr 3918
330 : jhr 4128 (* builds AST for the expression "(x <= lo) ? lo : (hi <= x) ? hi : x;" *)
331 :     fun mkClampExp (lo, hi, x) =
332 : jhr 4317 CL.mkCond(CL.mkBinOp(x, CL.#<=, lo), lo,
333 :     CL.mkCond(CL.mkBinOp(hi, CL.#<=, x), hi,
334 :     x))
335 : jhr 3935
336 : jhr 3950 fun mkLerp (ty, name, realTy, mkT) = mkFunc(
337 : jhr 4317 ty, name,
338 :     [CL.PARAM([], ty, "a"), CL.PARAM([], ty, "b"), CL.PARAM([], realTy, "t")],
339 :     mkReturn (
340 :     CL.mkBinOp(
341 :     CL.mkVar "a",
342 :     CL.#+,
343 :     CL.mkBinOp(
344 :     mkT(CL.mkVar "t"),
345 :     CL.#*,
346 :     CL.mkBinOp(CL.mkVar "b", CL.#-, CL.mkVar "a")))))
347 : jhr 3927
348 : jhr 4277 fun expandFrag (env, frag) =
349 : jhr 4317 CL.verbatimDcl [frag] [("REALTY", if #double(Env.target env) then "double" else "float")]
350 : jhr 4277
351 : jhr 3921 fun doOp env (rator, dcls) = let
352 : jhr 4317 val realTy = Env.realTy env
353 :     fun mkVec (w, pw, f) = CL.mkVec(
354 :     RN.vecTy w,
355 :     List.tabulate(pw, fn i => if i < w then f i else CL.mkFlt(zero, realTy)))
356 :     fun mkVMap (ty, name, f, w, pw) = let
357 :     fun f' i = CL.mkApply(f, [CL.mkSubscript(CL.mkVar "v", mkInt i)])
358 :     in
359 :     mkFunc(ty, name, [CL.PARAM([], ty, "v")], mkReturn (mkVec (w, pw, f')))
360 :     end
361 :     val dcl = (case rator
362 :     of Print(Ty.TensorRefTy shape) => genTensorPrinter shape
363 :     | Print(Ty.TupleTy tys) => raise Fail "FIXME: printer for tuples"
364 :     | Print(Ty.SeqTy(ty, NONE)) => raise Fail "FIXME: printer for dynseq"
365 :     | Print(Ty.SeqTy(ty, SOME n)) => raise Fail "FIXME: printer for sequence"
366 :     | Print ty => CL.D_Verbatim[] (* no printer needed *)
367 :     | RClamp => let
368 :     val name = "clamp"
369 :     val params = [
370 :     CL.PARAM([], realTy, "lo"),
371 :     CL.PARAM([], realTy, "hi"),
372 :     CL.PARAM([], realTy, "x")
373 :     ]
374 :     in
375 :     mkFunc(realTy, name, params,
376 :     mkReturn(mkClampExp (CL.mkVar "lo", CL.mkVar "hi", CL.mkVar "x")))
377 :     end
378 :     | RLerp => mkLerp (realTy, "lerp", realTy, fn x => x)
379 :     | VScale(w, pw) => let
380 :     val cTy = RN.vecTy w
381 :     in
382 :     mkFunc(cTy, RN.vscale w,
383 :     [CL.PARAM([], realTy, "s"), CL.PARAM([], cTy, "v")],
384 :     mkReturn(
385 :     CL.mkBinOp(mkVec(w, pw, fn _ => CL.mkVar "s"), CL.#*, CL.mkVar "v")))
386 :     end
387 :     | VSum(w, pw) => let
388 :     val name = RN.vsum w
389 :     val params = [CL.PARAM([], RN.vecTy w, "v")]
390 :     fun mkSum 0 = CL.mkSubscript(CL.mkVar "v", mkInt 0)
391 :     | mkSum i = CL.mkBinOp(mkSum(i-1), CL.#+, CL.mkSubscript(CL.mkVar "v", mkInt i))
392 :     in
393 :     mkFunc(realTy, name, params, mkReturn(mkSum(w-1)))
394 :     end
395 :     | VDot(w, pw) => let
396 :     val name = RN.vdot w
397 :     val vTy = RN.vecTy w
398 :     val params = [CL.PARAM([], vTy, "u"), CL.PARAM([], vTy, "v")]
399 :     fun mkSum 0 = CL.mkSubscript(CL.mkVar "w", mkInt 0)
400 :     | mkSum i = CL.mkBinOp(mkSum(i-1), CL.#+, CL.mkSubscript(CL.mkVar "w", mkInt i))
401 :     in
402 :     mkFunc(realTy, name, params,
403 :     CL.mkBlock[
404 :     CL.mkDeclInit(vTy, "w", CL.mkBinOp(CL.mkVar "u", CL.#*, CL.mkVar "v")),
405 :     mkReturn(mkSum(w-1))
406 :     ])
407 :     end
408 :     | VClamp(w, pw) => let
409 :     val cTy = RN.vecTy w
410 :     val name = RN.vclamp w
411 :     val params = [
412 :     CL.PARAM([], realTy, "lo"),
413 :     CL.PARAM([], realTy, "hi"),
414 :     CL.PARAM([], cTy, "v")
415 :     ]
416 :     fun mkItem i = mkClampExp(
417 :     CL.mkVar "lo", CL.mkVar "hi", CL.mkSubscript(CL.mkVar "v", mkInt i))
418 :     in
419 :     mkFunc(cTy, name, params,
420 :     mkReturn (mkVec (w, pw, mkItem)))
421 :     end
422 :     | VMapClamp(w, pw) => let
423 :     val cTy = RN.vecTy w
424 :     val name = RN.vclamp w
425 :     val params = [
426 :     CL.PARAM([], cTy, "vlo"),
427 :     CL.PARAM([], cTy, "vhi"),
428 :     CL.PARAM([], cTy, "v")
429 :     ]
430 :     fun mkItem i = mkClampExp(
431 :     CL.mkSubscript(CL.mkVar "vlo", mkInt i),
432 :     CL.mkSubscript(CL.mkVar "vhi", mkInt i),
433 :     CL.mkSubscript(CL.mkVar "v", mkInt i))
434 :     in
435 :     mkFunc(cTy, name, params,
436 :     mkReturn (mkVec (w, pw, mkItem)))
437 :     end
438 :     | VLerp(w, pw) =>
439 :     mkLerp (RN.vecTy w, RN.vlerp w, realTy, fn x => mkVec(w, pw, fn i => x))
440 :     | VCeiling(w, pw) => mkVMap (RN.vecTy w, RN.vceiling w, "std::ceiling", w, pw)
441 :     | VFloor(w, pw) => mkVMap (RN.vecTy w, RN.vfloor w, "std::floor", w, pw)
442 :     | VRound(w, pw) => mkVMap (RN.vecTy w, RN.vround w, "std::round", w, pw)
443 :     | VTrunc(w, pw) => mkVMap (RN.vecTy w, RN.vtrunc w, "std::trunc", w, pw)
444 :     | VToInt(w, pw) => let
445 :     val intTy = Env.intTy env
446 :     in
447 :     mkFunc(CL.voidTy, RN.vtoi w,
448 :     [ CL.PARAM([], CL.T_Array(intTy, SOME w), "dst"),
449 :     CL.PARAM([], RN.vecTy w, "src")],
450 :     CL.mkBlock(List.tabulate (w,
451 :     fn i => CL.mkAssign(
452 :     CL.mkSubscript(CL.mkVar "dst", mkInt i),
453 :     CL.mkCons(intTy, [CL.mkSubscript(CL.mkVar "src", mkInt i)])))))
454 :     end
455 :     | VLoad(w, pw) => let
456 :     val name = RN.vload w
457 :     val cTy = RN.vecTy w
458 :     fun arg i = CL.mkSubscript(CL.mkVar "vp", mkInt i)
459 :     in
460 :     mkFunc(cTy, name,
461 :     [CL.PARAM(["const"], CL.T_Ptr realTy, "vp")],
462 :     mkReturn(mkVec (w, pw, arg)))
463 :     end
464 :     | VCons(w, pw) => let
465 :     val name = RN.vcons w
466 :     val cTy = RN.vecTy w
467 :     val params = List.tabulate(w, fn i => CL.PARAM([], realTy, "r"^Int.toString i))
468 :     fun arg i = CL.mkVar("r"^Int.toString i)
469 :     in
470 :     mkFunc(cTy, name, params, mkReturn(mkVec (w, pw, arg)))
471 :     end
472 :     | VPack layout => let
473 :     val name = RN.vpack (#wid layout)
474 :     val vParamTys = Ty.piecesOf layout
475 :     val vParams = List.mapi
476 :     (fn (i, Ty.VecTy(w, _)) => CL.PARAM([], RN.vecTy w, "v"^Int.toString i))
477 :     vParamTys
478 :     val dstTy = RN.tensorTy[#wid layout]
479 :     fun mkAssign (i, v, j) =
480 :     CL.mkAssign(
481 :     CL.mkSubscript(CL.mkSelect(CL.mkVar "dst", "_data"), mkInt i),
482 :     CL.mkSubscript(v, mkInt j))
483 :     fun mkAssignsForPiece (dstStart, pieceIdx, wid, stms) = let
484 :     val piece = CL.mkVar("v"^Int.toString pieceIdx)
485 :     fun mk (j, stms) = if (j < wid)
486 :     then mk (j+1, mkAssign (dstStart+j, piece, j) :: stms)
487 :     else stms
488 :     in
489 :     mk (0, stms)
490 :     end
491 :     fun mkAssigns (_, [], _, stms) = CL.mkBlock(List.rev stms)
492 :     | mkAssigns (i, Ty.VecTy(w, _)::tys, offset, stms) =
493 :     mkAssigns (i+1, tys, offset+w, mkAssignsForPiece(offset, i, w, stms))
494 :     in
495 :     mkFunc(CL.voidTy, name,
496 :     CL.PARAM([], dstTy, "&dst") :: vParams,
497 :     mkAssigns (0, vParamTys, 0, []))
498 :     end
499 :     | TensorCopy shp => CL.D_Verbatim[]
500 : jhr 3999 (*
501 : jhr 4317 | TensorCopy shp => let
502 :     val name = RN.tensorCopy shp
503 :     val dim = List.foldl Int.* 1 shp
504 :     val dstTy = CL.T_Array(realTy, SOME dim)
505 :     in
506 :     mkFunc(CL.voidTy, name,
507 :     [CL.PARAM([], dstTy, "dst"), CL.PARAM([], CL.constPtrTy realTy, "src")],
508 :     CL.mkCall("std::memcpy", [
509 :     CL.mkVar "dst", CL.mkVar "src", CL.mkSizeof dstTy
510 :     ]))
511 :     end
512 : jhr 3999 *)
513 : jhr 4317 | Transform d => let
514 :     val imgTy = CL.T_Template(RN.qImageTyName d, [realTy, CL.T_Named "TY"])
515 :     val e = CL.mkDispatch(CL.mkVar "img", "world2image", [])
516 :     val (resTy, e) = if (d = 1)
517 :     then (realTy, e)
518 :     else let val ty = RN.tensorRefTy[d, d]
519 :     in (ty, CL.mkCons(ty, [e])) end
520 :     in
521 :     CL.D_Template([CL.TypeParam "TY"],
522 :     mkFunc(resTy, "world2image",
523 :     [CL.PARAM([], imgTy, "const & img")],
524 :     CL.mkReturn(SOME e)))
525 :     end
526 :     | Translate d => let
527 :     val imgTy = CL.T_Template(RN.qImageTyName d, [realTy, CL.T_Named "TY"])
528 :     val e = CL.mkDispatch(CL.mkVar "img", "translate", [])
529 :     val (resTy, e) = if (d = 1)
530 :     then (realTy, e)
531 :     else let val ty = RN.tensorRefTy[d]
532 :     in (ty, CL.mkCons(ty, [e])) end
533 :     in
534 :     CL.D_Template([CL.TypeParam "TY"],
535 :     mkFunc(resTy, "translate",
536 :     [CL.PARAM([], imgTy, "const & img")],
537 :     CL.mkReturn(SOME e)))
538 :     end
539 :     | Inside(layout, s) => let
540 :     val dim = #wid layout
541 :     val vTys = List.map
542 :     (fn ty => TypeToCxx.trType (env, ty))
543 :     (TreeTypes.piecesOf layout)
544 :     val xs = List.mapi (fn (i, ty) => "x"^Int.toString i) vTys
545 :     val vParams =
546 :     ListPair.map (fn (ty, x) => CL.PARAM([], ty, x)) (vTys, xs)
547 :     val imgTy = CL.T_Template(RN.qImageTyName dim, [realTy, CL.T_Named "TY"])
548 :     (* make the tests `(x < img.size(i)-s)` and `((s-1) < x)` *)
549 :     fun mkTests (x, i) = [
550 :     CL.mkBinOp(x, CL.#<,
551 :     CL.mkBinOp(
552 :     CL.mkDispatch(CL.mkVar "img", "size", [mkInt i]),
553 :     CL.#-, mkInt s)),
554 :     CL.mkBinOp(mkInt(s-1), CL.#<, x)
555 :     ]
556 :     (* build the test expression from the pieces *)
557 :     fun mkExps (i, w, v::vr, pw::pwr, tests) =
558 :     if (i < dim)
559 :     then if (w < pw)
560 :     then let
561 :     val x = if (pw = 1)
562 :     then CL.mkVar v
563 :     else CL.mkSubscript(CL.mkVar v, mkInt w)
564 :     in
565 :     mkExps (i+1, w+1, v::vr, pw::pwr,
566 :     mkTests(x, w) @ tests)
567 :     end
568 :     else mkExps (i, pw, vr, pwr, tests)
569 :     else List.rev tests
570 :     | mkExps _ = raise Fail "inconsistent"
571 :     val (t1::tr) = mkExps (0, 0, xs, #pieces layout, [])
572 :     val exp = List.foldr
573 :     (fn (e1, e2) => CL.mkBinOp(e2, CL.#&&, e1))
574 :     t1 tr
575 :     in
576 :     CL.D_Template([CL.TypeParam "TY"],
577 :     mkFunc(CL.boolTy, RN.inside(dim, s),
578 :     vParams @ [CL.PARAM([], imgTy, "img")],
579 :     mkReturn exp))
580 :     end
581 :     | EigenVals2x2 => expandFrag (env, CxxFragments.eigenvals2x2)
582 :     | EigenVals3x3 => expandFrag (env, CxxFragments.eigenvals3x3)
583 :     | EigenVecs2x2 => expandFrag (env, CxxFragments.eigenvecs2x2)
584 :     | EigenVecs3x3 => expandFrag (env, CxxFragments.eigenvecs3x3)
585 : jhr 4351 | SphereQuery d => let
586 :     val seqTy = CL.T_Template("diderot::dynseq", [CL.uint32])
587 :     val (posTy, posExp) = if d > 1
588 :     then (
589 :     TypeToCxx.trType(env, Ty.TensorRefTy[d]),
590 :     CL.mkDispatch(CL.mkVar "pos", "base", [])
591 :     )
592 :     else (realTy, CL.mkVar "pos")
593 :     val wrldV = CL.mkVar RN.worldVar
594 :     in
595 :     mkFunc(
596 :     seqTy, "sphere_query",
597 :     [RN.worldParam, CL.PARAM([], posTy, "pos"), CL.PARAM([], realTy, "radius")],
598 :     CL.mkReturn(SOME(
599 :     CL.mkIndirectDispatch(
600 :     CL.mkIndirect(wrldV, "_tree"),
601 :     "sphere_query",
602 :     [CL.mkIndirect(wrldV, "_inState"), posExp, CL.mkVar "radius"]))))
603 :     end
604 : jhr 4317 (* end case *))
605 :     in
606 :     dcl :: dcls
607 :     end
608 : jhr 3921
609 : jhr 4351 val firstTy = CL.D_Comment["***** Begin synthesized types *****"]
610 :     val lastTy = CL.D_Comment["***** End synthesized types *****"]
611 :     val noDclsTy = CL.D_Comment["***** No synthesized types *****"]
612 : jhr 3994
613 : jhr 4351 val firstOp = CL.D_Comment["***** Begin synthesized operations *****"]
614 :     val lastOp = CL.D_Comment["***** End synthesized operations *****"]
615 :     val noDclsOp = CL.D_Comment["***** No synthesized operations *****"]
616 :    
617 : jhr 3918 fun gen (env, info) = let
618 : jhr 4317 val spec = Env.target env
619 :     val genTrait = genSeqTrait env
620 :     val genTyDecl = genTyDecl env
621 :     val opDcls = List.foldl (doOp env) [] (CollectInfo.listOps info)
622 :     val tys = CollectInfo.listTypes info
623 :     val (tyDcls, fnDefs) = List.foldr genTyDecl ([], []) tys
624 : jhr 4351 val dcls = tyDcls @ fnDefs
625 : jhr 4317 val traitDcls = List.foldl genTrait [] tys
626 : jhr 4351 val preDcls = if List.null dcls andalso List.null traitDcls
627 :     then [noDclsTy]
628 :     else let
629 :     val res = [lastTy]
630 :     val res = if List.null traitDcls
631 :     then res
632 :     else CL.D_Namespace("diderot", traitDcls) :: res
633 :     val res = if List.null dcls
634 :     then res
635 :     else CL.D_Namespace(#namespace(Env.target env), dcls) :: res
636 :     in
637 :     firstTy :: res
638 :     end
639 :     val postDcls = if List.null opDcls
640 :     then [noDclsOp]
641 :     else firstOp :: opDcls @ [lastOp]
642 : jhr 4317 in
643 : jhr 4351 {preWorld = preDcls, postWorld = postDcls}
644 : jhr 4317 end
645 : jhr 3918
646 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0