Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/target-cpu/gen-outputs.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/target-cpu/gen-outputs.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3909 - (view) (download)

1 : jhr 3909 (* gen-outputs.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * Generate strand output functions. The output formats always have a single axis for the
9 :     * data elements followed by one, or more, axes for the output structure. There are four
10 :     * cases that we handle:
11 :     *
12 :     * grid, fixed-size elements:
13 :     * nrrd has object axis followed by grid axes
14 :     *
15 :     * collection, fixed-size elements
16 :     * nrrd has object axis followed by a single axis
17 :     *
18 :     * grid, dynamic-size elements
19 :     * nLengths nrrd has size 2 for objects (offset, length) followed by grid axes
20 :     * nData nrrd has object axis followed by a single axis
21 :     *
22 :     * collection, dynamic-size elements
23 :     * nLengths nrrd has size 2 for objects (offset, length) followed by a single axis
24 :     * nData nrrd has object axis followed by a single axis
25 :     *
26 :     * The object axis kind depends on the output type, but it will either be one of the tensor types
27 :     * that Teem knows about or else nrrdKindList. In any case, the data elements are written as a
28 :     * flat vector following the in-memory layout. The other axes in the file will have nrrdKindSpace
29 :     * as their kind.
30 :     *
31 :     * TODO: some of this code will be common across all targets (e.g., writing outputs to files), so
32 :     * we will want to refactor it.
33 :     *
34 :     * TODO: for sequences of tensors (e.g., tensor[3][2]), we should use a separate axis for the
35 :     * sequence dimension with kind nrrdKindList.
36 :     *
37 :     * TODO: since the runtime tracks numbers of strands in various states, we should be
38 :     * able to use that information directly from the world without having to recompute it!
39 :     *)
40 :    
41 :     structure GenOutputs : sig
42 :    
43 :     (* gen (props, nAxes) outputs
44 :     * returns a list of function declarations for getting the output/snapshot nrrds from
45 :     * the program state. The arguments are:
46 :     * props - the target information
47 :     * nAxes - the number of axes in the grid of strands (NONE for a collection)
48 :     * outputs - the list of output state variables paired with their API types
49 :     *)
50 :     val gen : CodeGenEnv.t * int option -> (APITypes.t * string) list -> CLang.decl list
51 :    
52 :     end = struct
53 :    
54 :     structure IR = TreeIR
55 :     structure V = TreeVar
56 :     structure Ty = APITypes
57 :     structure CL = CLang
58 :     structure N = CNames
59 :     structure Nrrd = NrrdEnums
60 :     structure U = OutputUtil
61 :    
62 :     fun mapi f l = let
63 :     fun mapf (i, [], l) = List.rev l
64 :     | mapf (i, x::xs, l) = mapf (i+1, xs, f(i, x)::l)
65 :     in
66 :     mapf (0, l, [])
67 :     end
68 :    
69 :     val nrrdPtrTy = CL.T_Ptr(CL.T_Named "Nrrd")
70 :     val sizeTy = CL.T_Named "size_t"
71 :     fun wrldPtr tgt = N.worldPtrTy tgt
72 :     fun mkInt i = CL.mkInt(IntInf.fromInt i)
73 :    
74 :     (* variables in the generated code *)
75 :     val wrldV = CL.mkVar "wrld"
76 :     val sizesV = CL.mkVar "sizes"
77 :     val iV = CL.mkVar "i"
78 :     val nV = CL.mkVar "n"
79 :     val cpV = CL.mkVar "cp"
80 :     val ipV = CL.mkVar "ip"
81 :     val msgV = CL.mkVar "msg"
82 :     val offsetV = CL.mkVar "offset"
83 :     val nDataV = CL.mkVar "nData"
84 :     val nLengthsV = CL.mkVar "nLengths"
85 :     val numStableV = CL.mkVar "numStable"
86 :     val numElemsV = CL.mkVar "numElems"
87 :     val outSV = CL.mkVar "outS"
88 :     val DIDEROT_DIE = CL.mkVar "DIDEROT_DIE"
89 :     val DIDEROT_STABLE = CL.mkVar "DIDEROT_STABLE"
90 :     val NRRD = CL.mkVar "NRRD"
91 :    
92 :     (* dymanic sequence operations *)
93 :     fun seqLength arg = CL.mkApply("Diderot_DynSeqLength", [arg])
94 :     fun seqCopy (elemSz, dst, seq) = CL.mkApply("Diderot_DynSeqCopy", [elemSz, dst, seq])
95 :    
96 :     (* utility functions for initializing the sizes array *)
97 :     fun sizes i = CL.mkSubscript(sizesV, mkInt i)
98 :     fun setSizes (i, v) = CL.mkAssign(sizes i, v)
99 :    
100 :     (* code to access state variable
101 :     wrld->outState[i]->name
102 :     * or
103 :     wrld->state[i].name
104 :     *)
105 :     fun stateVar props name = if Properties.dualState props
106 :     then CL.mkIndirect(CL.mkSubscript(CL.mkIndirect(wrldV, "outState"), iV), name)
107 :     else CL.mkSelect(CL.mkSubscript(CL.mkIndirect(wrldV, "state"), iV), name)
108 :    
109 :     (* code fragment to loop over strands
110 :     for (unsigned int i = 0; i < wrld->numStrands; i++) ...
111 :     *)
112 :     fun forStrands stm = CL.mkFor(
113 :     [(CL.uint32, "i", mkInt 0)],
114 :     CL.mkBinOp(iV, CL.#<, CL.mkIndirect(wrldV, "numStrands")),
115 :     [CL.mkPostOp(iV, CL.^++)],
116 :     stm)
117 :    
118 :     (* code fragment to test for stable strands in a loop
119 :     if (wrld->status[i] == DIDEROT_STABLE)
120 :     ...
121 :     *)
122 :     fun ifStable stm = CL.mkIfThen(
123 :     CL.mkBinOp(CL.mkSubscript(CL.mkIndirect(wrldV, "status"), iV), CL.#==, DIDEROT_STABLE),
124 :     stm)
125 :    
126 :     (* code fragment to test for active strands in a loop; note that NEW strands are considered active.
127 :     if (wrld->status[i] != DIDEROT_DIE)
128 :     ...
129 :     *)
130 :     fun ifActive stm = CL.mkIfThen(
131 :     CL.mkBinOp(CL.mkSubscript(CL.mkIndirect(wrldV, "status"), iV), CL.#!=, DIDEROT_DIE),
132 :     stm)
133 :    
134 :     (* code fragment to initialize the axes kinds; the data axis (axis[0]) is given, but we skip it
135 :     * (by convention) if it is scalar. The other axes are the specified domAxisKind.
136 :     *)
137 :     fun initAxisKinds (nrrd, dataAxisKind, nAxes, domAxisKind) = let
138 :     (* nData->axis[0].kind *)
139 :     fun axisKind i = CL.mkSelect(CL.mkSubscript(CL.mkIndirect(nrrd, "axis"), mkInt i), "kind")
140 :     fun init (i, k) = CL.mkAssign (axisKind i, CL.mkVar(Nrrd.kindToEnum k))
141 :     val (firstSpace, dataAxis) = (case dataAxisKind
142 :     of Nrrd.KindScalar => (0, [])
143 :     | _ => (1, [init(0, dataAxisKind)])
144 :     (* end case *))
145 :     in
146 :     dataAxis @ List.tabulate(nAxes, fn i => init(i+firstSpace, domAxisKind))
147 :     end
148 :    
149 :     (* create the body of an output function for dynamic-size outputs. The structure of the
150 :     * function body is:
151 :     *
152 :     * declarations
153 :     * compute sizes array for nLengths
154 :     * allocate nrrd for nLengths
155 :     * compute sizes array for nData
156 :     * allocate nrrd for nData
157 :     * copy data from strands to nrrd
158 :     *)
159 :     fun genDynOutput (tgt, snapshot, nAxes, ty, name) = let
160 :     val (elemCTy, nrrdType, axisKind, nElems) = U.infoOf (tgt, ty)
161 :     val stateVar = stateVar tgt
162 :     val (nAxes, domAxisKind) = (case nAxes
163 :     of NONE => (1, Nrrd.KindList)
164 :     | SOME n => (n, Nrrd.KindSpace)
165 :     (* end case *))
166 :     (* declarations *)
167 :     val sizesDecl = CL.mkDecl(CL.T_Array(sizeTy, SOME(nAxes+1)), "sizes", NONE)
168 :     (* count number of elements (and stable strands) *)
169 :     val countElems = let
170 :     val nElemsInit = CL.mkDeclInit(CL.uint32, "numElems", CL.mkInt 0)
171 :     val cntElems = CL.S_Exp(CL.mkAssignOp(numElemsV, CL.+=, seqLength(stateVar name)))
172 :     in
173 :     if #isArray tgt
174 :     then [
175 :     CL.mkComment["count number of elements"],
176 :     nElemsInit, forStrands cntElems
177 :     ]
178 :     else let
179 :     val cntBlk = CL.mkBlock[cntElems, CL.S_Exp(CL.mkPostOp(numStableV, CL.^++))]
180 :     val lpBody = if snapshot
181 :     then ifActive cntBlk
182 :     else ifStable cntBlk
183 :     in [
184 :     CL.mkComment["count number of output elements and stable strands"],
185 :     CL.mkDeclInit(CL.uint32, "numStable", CL.mkInt 0),
186 :     nElemsInit,
187 :     forStrands lpBody
188 :     ] end
189 :     end
190 :     (* code to check for zero outputs, which happens for collections with no active strands *)
191 :     val checkForNoStrands = if #isArray tgt
192 :     then []
193 :     else [
194 :     CL.mkComment["check for no output"],
195 :     CL.mkIfThen(
196 :     CL.mkBinOp(mkInt 0, CL.#==, numStableV),
197 :     CL.mkBlock[
198 :     CL.mkCall("nrrdEmpty", [nLengthsV]),
199 :     CL.mkCall("nrrdEmpty", [nDataV]),
200 :     CL.mkReturn(SOME(CL.mkVar "false"))
201 :     ])
202 :     ]
203 :     (* generate code to allocate the nLengths nrrd *)
204 :     val lengthsNrrd = let
205 :     val dimSizes = setSizes(0, CL.mkInt 2) (* nLengths is 2-element vector *)
206 :     in
207 :     CL.mkComment["allocate nLengths nrrd"] ::
208 :     (if #isArray tgt
209 :     then dimSizes ::
210 :     List.tabulate (nAxes, fn i =>
211 :     setSizes(i+1, CL.mkSubscript(CL.mkIndirect(wrldV, "size"), mkInt(nAxes-i-1)))) @
212 :     [U.maybeAlloc (nLengthsV, Nrrd.tyToEnum Nrrd.TypeInt, nAxes+1)]
213 :     else [
214 :     dimSizes, setSizes(1, numStableV),
215 :     U.maybeAlloc (nLengthsV, Nrrd.tyToEnum Nrrd.TypeInt, 2)
216 :     ])
217 :     end
218 :     (* code to check for no data to output (i.e., all of the output sequences are empty) *)
219 :     val checkForEmpty = [
220 :     CL.mkComment["check for empty output"],
221 :     CL.mkIfThen(
222 :     CL.mkBinOp(mkInt 0, CL.#==, numElemsV),
223 :     CL.mkBlock[
224 :     CL.mkCall("nrrdEmpty", [nDataV]),
225 :     CL.mkReturn(SOME(CL.mkVar "false"))
226 :     ])
227 :     ]
228 :     (* generate code to allocate the data nrrd *)
229 :     val dataNrrd = if (axisKind = Nrrd.KindScalar)
230 :     then [ (* drop data axis for scalar data by convention *)
231 :     CL.mkComment["allocate nData nrrd"],
232 :     setSizes(0, numElemsV),
233 :     U.maybeAlloc (nDataV, Nrrd.tyToEnum nrrdType, 1)
234 :     ]
235 :     else [
236 :     CL.mkComment["allocate nData nrrd"],
237 :     setSizes(0, mkInt nElems),
238 :     setSizes(1, numElemsV),
239 :     U.maybeAlloc (nDataV, Nrrd.tyToEnum nrrdType, 2)
240 :     ]
241 :     (* generate the nLengths copy code *)
242 :     val copyLengths = let
243 :     val pInit = CL.mkDeclInit(CL.T_Ptr CL.uint32, "ip",
244 :     CL.mkCast(CL.T_Ptr(CL.uint32), CL.mkIndirect(nLengthsV, "data")))
245 :     val offsetDecl = CL.mkDeclInit(CL.uint32, "offset", CL.mkInt 0)
246 :     val copyBlk = CL.mkBlock[
247 :     CL.mkDeclInit(CL.uint32, "n", seqLength(stateVar name)),
248 :     CL.mkAssign(CL.mkUnOp(CL.%*, CL.mkPostOp(ipV, CL.^++)), offsetV),
249 :     CL.mkAssign(CL.mkUnOp(CL.%*, CL.mkPostOp(ipV, CL.^++)), nV),
250 :     CL.S_Exp(CL.mkAssignOp(offsetV, CL.+=, nV))
251 :     ]
252 :     val copyStm = if #isArray tgt
253 :     then copyBlk
254 :     else if #snapshot tgt
255 :     then ifActive copyBlk
256 :     else ifStable copyBlk
257 :     in
258 :     CL.mkComment["initialize nLengths nrrd"] ::
259 :     pInit ::
260 :     offsetDecl ::
261 :     forStrands copyStm ::
262 :     initAxisKinds (nLengthsV, Nrrd.Kind2Vector, nAxes, domAxisKind)
263 :     end
264 :     (* generate the nData copy code *)
265 :     val copyData = let
266 :     val pInit = CL.mkDeclInit(CL.charPtr, "cp",
267 :     CL.mkCast(CL.charPtr, CL.mkIndirect(nDataV, "data")))
268 :     val copyStm = CL.mkAssign(cpV, seqCopy(
269 :     CL.mkBinOp(mkInt nElems, CL.#*, CL.mkSizeof(elemCTy)), cpV, stateVar name))
270 :     val copyStm = if #isArray tgt
271 :     then copyStm
272 :     else if #snapshot tgt
273 :     then ifActive copyStm
274 :     else ifStable copyStm
275 :     in
276 :     CL.mkComment["initialize nLengths nrrd"] ::
277 :     pInit ::
278 :     forStrands copyStm ::
279 :     initAxisKinds (nDataV, axisKind, 1, Nrrd.KindList)
280 :     end
281 :     (* the function body *)
282 :     val stms =
283 :     sizesDecl ::
284 :     countElems @
285 :     checkForNoStrands @
286 :     lengthsNrrd @
287 :     checkForEmpty @
288 :     dataNrrd @
289 :     copyLengths @
290 :     copyData @
291 :     [CL.mkReturn(SOME(CL.mkVar "false"))]
292 :     in
293 :     ([CL.PARAM([], nrrdPtrTy, "nLengths"), CL.PARAM([], nrrdPtrTy, "nData")], CL.mkBlock stms)
294 :     end
295 :    
296 :     (* create the body of an output function for fixed-size outputs. The structure of the
297 :     * function body is:
298 :     *
299 :     * declare and compute sizes array
300 :     * allocate nrrd nData
301 :     * copy data from strands to nrrd
302 :     *)
303 :     fun genFixedOutput (tgt, snapshot, nAxes, ty, name) = let
304 :     val (elemCTy, nrrdType, axisKind, nElems) = U.infoOf (tgt, ty)
305 :     val stateVar = stateVar tgt
306 :     val (nAxes, domAxisKind) = (case nAxes
307 :     of NONE => (1, Nrrd.KindList)
308 :     | SOME n => (n, Nrrd.KindSpace)
309 :     (* end case *))
310 :     val nDataAxes = if (axisKind = Nrrd.KindScalar) then 0 else 1
311 :     (* generate the sizes initialization code *)
312 :     val initSizes = let
313 :     val dimSizes = let
314 :     val dcl = CL.mkDecl(CL.T_Array(sizeTy, SOME(nAxes+nDataAxes)), "sizes", NONE)
315 :     in
316 :     if (axisKind = Nrrd.KindScalar)
317 :     then [dcl]
318 :     else [dcl, setSizes(0, mkInt nElems)]
319 :     end
320 :     in
321 :     if #isArray tgt
322 :     then dimSizes @
323 :     List.tabulate (nAxes, fn i =>
324 :     setSizes(i+nDataAxes, CL.mkSubscript(CL.mkIndirect(wrldV, "size"), mkInt(nAxes-i-1))))
325 :     else let
326 :     val cntStm = CL.S_Exp(CL.mkPostOp(numStableV, CL.^++))
327 :     val lpBody = if snapshot
328 :     then ifActive cntStm
329 :     else ifStable cntStm
330 :     in
331 :     CL.mkDeclInit(sizeTy, "numStable", mkInt 0) ::
332 :     forStrands lpBody ::
333 :     dimSizes @ [setSizes(nDataAxes, numStableV)]
334 :     end
335 :     end
336 :     (* code to check for no data to output (i.e., no active strands) *)
337 :     val checkForEmpty = if (#isArray tgt)
338 :     then []
339 :     else [
340 :     CL.mkComment["check for empty output"],
341 :     CL.mkIfThen(
342 :     CL.mkBinOp(mkInt 0, CL.#==, numStableV),
343 :     CL.mkBlock[
344 :     CL.mkCall("nrrdEmpty", [nDataV]),
345 :     CL.mkReturn(SOME(CL.mkVar "false"))
346 :     ])
347 :     ]
348 :     (* generate the copy code *)
349 :     val copyCode = let
350 :     val pDecl = CL.mkDeclInit(CL.charPtr, "cp",
351 :     CL.mkCast(CL.charPtr, CL.mkIndirect(nDataV, "data")))
352 :     val copyBlk = CL.mkBlock[
353 :     CL.mkCall("memcpy", [
354 :     cpV,
355 :     CL.mkUnOp(CL.%&, stateVar name),
356 :     CL.mkBinOp(mkInt nElems, CL.#*, CL.mkSizeof elemCTy)
357 :     ]),
358 :     CL.mkExpStm(CL.mkAssignOp(cpV, CL.+=,
359 :     CL.mkBinOp(mkInt nElems, CL.#*, CL.mkSizeof elemCTy)))
360 :     ]
361 :     val copyStm = if #isArray tgt
362 :     then copyBlk
363 :     else if snapshot
364 :     then ifActive copyBlk
365 :     else ifStable copyBlk
366 :     in
367 :     pDecl :: forStrands copyStm :: initAxisKinds (nDataV, axisKind, nAxes, domAxisKind)
368 :     end
369 :     (* the function body *)
370 :     val stms =
371 :     CL.mkComment["Compute sizes of nrrd file"] ::
372 :     initSizes @
373 :     checkForEmpty @
374 :     CL.mkComment["Allocate nData nrrd"] ::
375 :     U.maybeAlloc (nDataV, Nrrd.tyToEnum nrrdType, nAxes+nDataAxes) ::
376 :     CL.mkComment["copy data to output nrrd"] ::
377 :     copyCode @
378 :     [CL.mkReturn(SOME(CL.mkVar "false"))]
379 :     in
380 :     ([CL.PARAM([], nrrdPtrTy, "nData")], CL.mkBlock stms)
381 :     end
382 :    
383 :     fun gen (tgt : Properties.props, nAxes) = let
384 :     fun getFn snapshot (ty, name) = let
385 :     val funcName = if snapshot
386 :     then N.snapshotGet(tgt, name)
387 :     else N.outputGet(tgt, name)
388 :     val fldName = "sv_" ^ name
389 :     fun mkFunc (params, body) =
390 :     CL.D_Func([], CL.boolTy, funcName, CL.PARAM([], wrldPtr tgt, "wrld")::params, body)
391 :     in
392 :     case ty
393 :     of Ty.DynSeqTy ty' => mkFunc (genDynOutput(tgt, snapshot, nAxes, ty', fldName))
394 :     | _ => mkFunc (genFixedOutput(tgt, snapshot, nAxes, ty, fldName))
395 :     (* end case *)
396 :     end
397 :     fun gen' outputs = let
398 :     val getFns = List.map (getFn false) outputs
399 :     in
400 :     if (#exec tgt)
401 :     then getFns @ U.genOutput(tgt, outputs)
402 :     else if (#snapshot tgt)
403 :     then List.map (getFn true) outputs @ getFns
404 :     else getFns
405 :     end
406 :     in
407 :     gen'
408 :     end
409 :    
410 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0