Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/high-to-mid/probe.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/high-to-mid/probe.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1166 - (view) (download)
Original Path: trunk/src/compiler/high-to-mid/probe.sml

1 : jhr 328 (* probe.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 328 * All rights reserved.
5 :     *
6 :     * Expansion of probe operations in the HighIL to MidIL translation.
7 :     *)
8 :    
9 :     structure Probe : sig
10 :    
11 : jhr 1116 val expand : {
12 :     result : MidIL.var, (* result variable for probe *)
13 :     img : MidIL.var, (* probe image argument *)
14 :     v : ImageInfo.info, (* summary info about image *)
15 :     h : Kernel.kernel, (* reconstruction kernel *)
16 :     k : int, (* number of levels of differentiation *)
17 :     pos : MidIL.var (* probe position argument *)
18 :     } -> MidIL.assign list
19 : jhr 328
20 :     end = struct
21 :    
22 :     structure SrcIL = HighIL
23 : jhr 334 structure SrcOp = HighOps
24 : jhr 328 structure DstIL = MidIL
25 : jhr 391 structure DstTy = MidILTypes
26 : jhr 334 structure DstOp = MidOps
27 : jhr 349 structure DstV = DstIL.Var
28 : jhr 328 structure VMap = SrcIL.Var.Map
29 : jhr 349 structure IT = Shape
30 : jhr 328
31 :     (* generate a new variable indexed by dimension *)
32 : jhr 394 fun newVar_dim (prefix, d, ty) =
33 :     DstV.new (prefix ^ Partials.axisToString(Partials.axis d), ty)
34 : jhr 328
35 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
36 : jhr 1116 fun cons (x, args) = (x, DstIL.CONS(DstV.ty x, args))
37 : jhr 334 fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i)))
38 : jhr 328 fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i)))
39 :    
40 : jhr 349 (* generate code for a evaluating a single element of a probe operation *)
41 :     fun probeElem {
42 :     dim, (* dimension of space *)
43 :     h, s, (* kernel h with support s *)
44 :     n, f, (* Dst vars for integer and fractional components of position *)
45 :     voxIter (* iterator over voxels *)
46 :     } (result, pdOp) = let
47 : jhr 1116 val vecsTy = DstTy.vecTy(2*s) (* vectors of coefficients cover support of kernel *)
48 :     val vecDimTy = DstTy.vecTy dim
49 :     (* generate the variables that hold the convolution coefficients. The
50 :     * resulting list is in slowest-to-fastest axes order.
51 :     *)
52 : jhr 349 val convCoeffs = let
53 :     val Partials.D l = pdOp
54 : jhr 1116 fun mkVar (_, [], coeffs) = coeffs
55 :     | mkVar (i, d::dd, coeffs) = (case d
56 :     of 0 => mkVar(i+1, dd, newVar_dim("h", i, vecsTy) :: coeffs)
57 :     | 1 => mkVar(i+1, dd, newVar_dim("dh", i, vecsTy) :: coeffs)
58 :     | _ => mkVar(i+1, dd, newVar_dim(concat["d", Int.toString d, "h"], i, vecsTy) :: coeffs)
59 : jhr 353 (* end case *))
60 : jhr 349 in
61 : jhr 1116 mkVar (0, l, [])
62 : jhr 349 end
63 : jhr 1166 val _ = Log.msg(concat["probeElem: ", Partials.partialToString pdOp, " in ", Int.toString(List.length convCoeffs), "D space\n"])
64 : jhr 1116 (* for each dimension in space, we evaluate the kernel at the coordinates for that axis.
65 :     * the coefficients are
66 :     * h_{s-i} (f - i) for 1-s <= i <= s
67 :     *)
68 : jhr 349 val coeffCode = let
69 :     fun gen (x, k, (d, code)) = let
70 : jhr 1116 (* note that for 1D images, the f vector is a scalar *)
71 :     val fd = if (dim > 1)
72 :     then newVar_dim ("f", d, DstTy.realTy)
73 :     else f
74 : jhr 394 val a = DstV.new ("a", vecsTy)
75 : jhr 1116 (* note that we reverse the order of the list since the convolution
76 :     * space is flipped from the image space and we want the voxel vector
77 :     * to be in increasing address order.
78 :     *)
79 :     val tmps = List.rev(List.tabulate(2*s,
80 :     fn i => (DstV.new("t"^Int.toString i, DstTy.realTy), i - s)))
81 :     fun mkArg ((t, 0), code) = (t, DstIL.VAR fd) :: code
82 :     | mkArg ((t, n), code) = let
83 :     val (rator, n) = if (n < 0) then (DstOp.Sub, ~n) else (DstOp.Add, n)
84 : jhr 394 val t' = DstV.new ("r", DstTy.realTy)
85 : jhr 349 in
86 :     realLit (t', n) ::
87 : jhr 1116 assign (t, rator DstTy.realTy, [fd, t']) ::
88 : jhr 349 code
89 :     end
90 :     val code =
91 : jhr 353 cons(a, List.map #1 tmps) ::
92 : jhr 349 assign(x, DstOp.EvalKernel(2*s, h, k), [a]) ::
93 : jhr 353 code
94 : jhr 1116 val code = List.foldr mkArg code tmps
95 :     val code = if (dim > 1)
96 :     then assign(fd, DstOp.Select(DstTy.vecTy dim, d), [f]) :: code
97 :     else code
98 : jhr 349 in
99 : jhr 1116 (d+1, code)
100 : jhr 349 end
101 :     val Partials.D l = pdOp
102 :     in
103 : jhr 1116 (* we iterate from fastest to slowest axis *)
104 :     #2 (ListPair.foldr gen (0, []) (convCoeffs, List.rev l))
105 : jhr 349 end
106 : jhr 1116 (* generate the reduction code in reverse order *)
107 : jhr 353 fun genReduce (result, [hh], IT.LF{vox, offsets}, code) =
108 :     assign (result, DstOp.Dot(2*s), [vox, hh]) :: code
109 : jhr 349 | genReduce (result, hh::r, IT.ND(_, kids), code) = let
110 : jhr 394 val tv = DstV.new ("tv", vecsTy)
111 :     val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i, DstTy.realTy))
112 : jhr 349 fun lp ([], [], code) = code
113 :     | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code))
114 :     val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code
115 :     in
116 :     lp (tmps, kids, code)
117 :     end
118 : jhr 353 | genReduce _ = raise Fail "genReduce"
119 : jhr 349 val reduceCode = genReduce (result, convCoeffs, voxIter, [])
120 :     in
121 :     coeffCode @ reduceCode
122 :     end
123 :    
124 : jhr 1116 fun doVoxelSample (result, v, k, s, diffIter, {h, n, f, img}, offset) = let
125 :     val stride = ImageInfo.stride
126 :     val dim = ImageInfo.dim v
127 :     val vecsTy = DstTy.vecTy(2*s) (* vectors of coefficients cover support of kernel *)
128 :     (* generate code to load the voxel data; since we use a vector load operation to load the
129 : jhr 353 * fastest dimension, the height of the tree is one less than the dimension of space.
130 :     *)
131 : jhr 328 val voxIter = let
132 :     fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id)
133 :     fun g (offsets, id) = {
134 : jhr 353 offsets = ~(s-1) :: offsets,
135 : jhr 394 vox = DstV.new(String.concat("v" :: List.map Int.toString id), vecsTy)
136 : jhr 328 }
137 :     in
138 : jhr 334 IT.create (dim-1, 2*s, fn _ => (), f, g, ([], []))
139 : jhr 328 end
140 : jhr 353 val _ = let
141 :     val indentWid = ref 2
142 :     fun inc () = (indentWid := !indentWid + 2)
143 :     fun dec () = (indentWid := !indentWid - 2)
144 : jhr 1166 fun indent () = Log.msg(CharVector.tabulate(!indentWid, fn _ => #" "))
145 :     fun nd () = (indent(); Log.msg "ND\n");
146 : jhr 353 fun lf {offsets, vox} = (
147 : jhr 1166 indent(); Log.msg "LF{offsets = ["; Log.msg(String.concatWith "," (List.map Int.toString offsets));
148 :     Log.msg "], vox = "; Log.msg(DstV.toString vox); Log.msg "}\n")
149 : jhr 353 fun pr (Shape.ND(attr, kids)) = (nd attr; inc(); List.app pr kids; dec())
150 :     | pr (Shape.LF attr) = lf attr
151 :     in
152 : jhr 1166 Log.msg "voxIter:\n";
153 : jhr 353 pr voxIter
154 :     end
155 : jhr 328 val loadCode = let
156 :     fun genCode ({offsets, vox}, code) = let
157 :     fun computeIndices (_, []) = ([], [])
158 :     | computeIndices (i, offset::offsets) = let
159 : jhr 394 val index = newVar_dim("i", i, DstTy.intTy)
160 :     val t1 = DstV.new ("t1", DstTy.intTy)
161 :     val t2 = DstV.new ("t2", DstTy.intTy)
162 : jhr 328 val (indices, code) = computeIndices (i+1, offsets)
163 : jhr 1116 val code = if (dim > 1)
164 :     then
165 :     intLit(t1, offset) ::
166 :     assign(t2, DstOp.Select(DstTy.IVecTy dim, i), [n]) ::
167 :     assign(index, DstOp.Add(DstTy.intTy), [t1, t2]) ::
168 :     code
169 :     else
170 :     intLit(t1, offset) ::
171 :     assign(index, DstOp.Add(DstTy.intTy), [t1, n]) ::
172 :     code
173 : jhr 328 val indices = index::indices
174 :     in
175 :     (indices, code)
176 :     end
177 : jhr 353 val (indices, indicesCode) = computeIndices (0, offsets)
178 : jhr 1116 val a = DstV.new ("a", DstTy.AddrTy v)
179 : jhr 328 in
180 : jhr 349 indicesCode @ [
181 : jhr 1116 assign(a, DstOp.VoxelAddress(v, offset), img::indices),
182 :     assign(vox, DstOp.LoadVoxels(v, 2*s), [a])
183 : jhr 328 ] @ code
184 :     end
185 :     in
186 :     IT.foldr genCode [] voxIter
187 :     end
188 : jhr 349 (* generate code to evaluate and construct the result tensor *)
189 :     val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter}
190 :     fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let
191 :     (* the kids will all be leaves *)
192 :     fun genProbeCode (IT.LF arg, code) = probeElem arg @ code
193 :     fun getProbeVar (IT.LF(t, _)) = t
194 :     in
195 : jhr 374 List.foldr genProbeCode (cons (result, List.map getProbeVar kids) :: code) kids
196 : jhr 349 end
197 : jhr 1116 | genProbe (result, IT.ND(ty, kids), code) = let
198 :     (* FIXME: the type of the tmps depends on the types of the kids *)
199 :     val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i, ty))
200 : jhr 374 val code = cons(result, tmps) :: code
201 : jhr 349 fun lp ([], [], code) = code
202 :     | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code))
203 :     in
204 :     lp (tmps, kids, code)
205 :     end
206 : jhr 352 | genProbe (result, IT.LF(t, pdOp), code) = (* for scalar fields *)
207 :     probeElem (result, pdOp) @ code
208 : jhr 1116 val probeCode = if (k > 0)
209 :     then let
210 :     (* for gradients, etc. we have to transform back to world space *)
211 :     val ty = DstV.ty result
212 :     val tensor = DstV.new("tensor", ty)
213 :     val xform = assign(result, DstOp.TensorToWorldSpace(v, ty), [img, tensor])
214 :     in
215 :     genProbe (tensor, diffIter, [xform])
216 :     end
217 :     else genProbe (result, diffIter, [])
218 : jhr 328 in
219 : jhr 450 (* FIXME: for dim > 1 and k > 1, we need to transform the result back into world space *)
220 : jhr 1116 loadCode @ probeCode
221 : jhr 328 end
222 :    
223 : jhr 1116 (* generate code for probing the field (D^k (v * h)) at pos *)
224 :     fun expand {result, img, v, h, k, pos} = let
225 :     val dim = ImageInfo.dim v
226 :     val s = Kernel.support h
227 :     val vecsTy = DstTy.vecTy(2*s) (* vectors of coefficients to cover support of kernel *)
228 :     val vecDimTy = DstTy.vecTy dim
229 :     (* generate the transform code *)
230 :     val x = DstV.new ("x", vecDimTy) (* image-space position *)
231 :     val f = DstV.new ("f", vecDimTy)
232 :     val nd = DstV.new ("nd", vecDimTy)
233 :     val n = DstV.new ("n", DstTy.IVecTy dim)
234 :     val toImgSpaceCode = [
235 :     assign(x, DstOp.PosToImgSpace v, [img, pos]),
236 :     assign(nd, DstOp.Floor dim, [x]),
237 :     assign(f, DstOp.Sub vecDimTy, [x, nd]),
238 :     assign(n, DstOp.RealToInt dim, [nd])
239 :     ]
240 :     (* generate the shape of the differentiation tensor with variables representing
241 :     * the elements
242 :     *)
243 :     val diffIter = let
244 :     val partial = Partials.partial dim
245 :     fun f (i, (_::dd, axes)) = (dd, Partials.axis i :: axes)
246 :     fun labelNd (_::dd, _) = DstTy.tensorTy dd
247 :     fun labelLf (_, axes) = let
248 :     val r = DstV.new(
249 :     String.concat("r" :: List.map Partials.axisToString axes),
250 :     DstTy.realTy)
251 :     in
252 :     (r, partial axes)
253 :     end
254 :     in
255 :     IT.create (k, dim, labelNd, f, labelLf, (List.tabulate(k, fn _ => dim), []))
256 : jhr 328 end
257 : jhr 1116 val _ = let
258 :     val indentWid = ref 2
259 :     fun inc () = (indentWid := !indentWid + 2)
260 :     fun dec () = (indentWid := !indentWid - 2)
261 : jhr 1166 fun indent () = Log.msg(CharVector.tabulate(!indentWid, fn _ => #" "))
262 :     fun nd ty = (indent(); Log.msg(concat["ND(", DstTy.toString ty, ")\n"]))
263 : jhr 1116 fun lf (x, partial) = (
264 : jhr 1166 indent(); Log.msg(concat["LF(", DstV.toString x, ", ", Partials.partialToString partial, ")\n"]))
265 : jhr 1116 fun pr (Shape.ND(attr, kids)) = (nd attr; inc(); List.app pr kids; dec())
266 :     | pr (Shape.LF attr) = lf attr
267 :     in
268 : jhr 1166 Log.msg "diffIter:\n";
269 : jhr 1116 pr diffIter
270 :     end
271 :     val vars = {h=h, n=n, f=f, img=img}
272 : jhr 328 in
273 : jhr 1116 case ImageInfo.voxelShape v
274 :     of [] => toImgSpaceCode @ doVoxelSample (result, v, k, s, diffIter, vars, 0)
275 :     | [d] => let
276 :     fun doSamples (offset, xs, code) = if (offset < 0)
277 :     then code @ [cons(result, xs)]
278 :     else let
279 :     val res = DstV.new ("probe" ^ Int.toString offset, DstTy.realTy)
280 :     val code = doVoxelSample (res, v, k, s, diffIter, vars, offset) @ code
281 :     in
282 :     doSamples (offset-1, res::xs, code)
283 :     end
284 :     in
285 :     toImgSpaceCode @ doSamples (d-1, [], [])
286 :     end
287 :     | _ => raise Fail "image data with order > 1 not supported yet"
288 :     (* end case *)
289 : jhr 328 end
290 :    
291 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0