Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/high-to-mid/probe.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/high-to-mid/probe.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3291 - (view) (download)

1 : jhr 328 (* probe.sml
2 :     *
3 : jhr 3291 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : jhr 328 * All rights reserved.
7 :     *
8 :     * Expansion of probe operations in the HighIL to MidIL translation.
9 :     *)
10 :    
11 :     structure Probe : sig
12 :    
13 : jhr 1116 val expand : {
14 : jhr 2998 result : MidIL.var, (* result variable for probe *)
15 :     img : MidIL.var, (* probe image argument *)
16 :     v : ImageInfo.info, (* summary info about image *)
17 :     h : Kernel.kernel, (* reconstruction kernel *)
18 :     k : int, (* number of levels of differentiation *)
19 :     pos : MidIL.var, (* probe position argument *)
20 :     ctl : MidIL.var BorderCtl.ctl (* border-control for image indexing *)
21 :     } -> MidIL.assign list
22 : jhr 328
23 :     end = struct
24 :    
25 :     structure SrcIL = HighIL
26 : jhr 334 structure SrcOp = HighOps
27 : jhr 328 structure DstIL = MidIL
28 : jhr 391 structure DstTy = MidILTypes
29 : jhr 334 structure DstOp = MidOps
30 : jhr 349 structure DstV = DstIL.Var
31 : jhr 328 structure VMap = SrcIL.Var.Map
32 : jhr 3072 structure BCtl = BorderCtl
33 : jhr 349 structure IT = Shape
34 : jhr 328
35 :     (* generate a new variable indexed by dimension *)
36 : jhr 394 fun newVar_dim (prefix, d, ty) =
37 : jhr 2998 DstV.new (prefix ^ Partials.axisToString(Partials.axis d), ty)
38 : jhr 328
39 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
40 : jhr 1116 fun cons (x, args) = (x, DstIL.CONS(DstV.ty x, args))
41 : jhr 2014 fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt(IntInf.fromInt i))))
42 : jhr 328 fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i)))
43 :    
44 : jhr 349 (* generate code for a evaluating a single element of a probe operation *)
45 :     fun probeElem {
46 : jhr 2998 dim, (* dimension of space *)
47 :     h, s, (* kernel h with support s *)
48 :     n, f, (* Dst vars for integer and fractional components of position *)
49 :     voxIter (* iterator over voxels *)
50 :     } (result, pdOp) = let
51 :     val vecsTy = DstTy.vecTy(2*s) (* vectors of coefficients cover support of kernel *)
52 :     val vecDimTy = DstTy.vecTy dim
53 :     (* generate the variables that hold the convolution coefficients. The
54 :     * resulting list is in slowest-to-fastest axes order.
55 :     *)
56 :     val convCoeffs = let
57 :     val Partials.D l = pdOp
58 :     fun mkVar (_, [], coeffs) = coeffs
59 :     | mkVar (i, d::dd, coeffs) = (case d
60 :     of 0 => mkVar(i+1, dd, newVar_dim("h", i, vecsTy) :: coeffs)
61 :     | 1 => mkVar(i+1, dd, newVar_dim("dh", i, vecsTy) :: coeffs)
62 :     | _ => mkVar(i+1, dd, newVar_dim(concat["d", Int.toString d, "h"], i, vecsTy) :: coeffs)
63 :     (* end case *))
64 :     in
65 :     mkVar (0, l, [])
66 :     end
67 :     (* for each dimension in space, we evaluate the kernel at the coordinates for that axis.
68 :     * the coefficients are
69 :     * h_{s-i} (f - i) for 1-s <= i <= s
70 :     *)
71 :     val coeffCode = let
72 :     fun gen (x, k, (d, code)) = let
73 :     (* note that for 1D images, the f vector is a scalar *)
74 :     val fd = if (dim > 1)
75 :     then newVar_dim ("f", d, DstTy.realTy)
76 :     else f
77 :     val a = DstV.new ("a", vecsTy)
78 :     (* note that we reverse the order of the list since the convolution
79 :     * space is flipped from the image space and we want the voxel vector
80 :     * to be in increasing address order.
81 :     *)
82 :     val tmps = List.rev(List.tabulate(2*s,
83 :     fn i => (DstV.new("t"^Int.toString i, DstTy.realTy), i - s)))
84 :     fun mkArg ((t, 0), code) = (t, DstIL.VAR fd) :: code
85 :     | mkArg ((t, n), code) = let
86 :     val (rator, n) = if (n < 0) then (DstOp.Sub, ~n) else (DstOp.Add, n)
87 :     val t' = DstV.new ("r", DstTy.realTy)
88 :     in
89 :     realLit (t', n) ::
90 :     assign (t, rator DstTy.realTy, [fd, t']) ::
91 :     code
92 :     end
93 :     val code =
94 :     cons(a, List.map #1 tmps) ::
95 :     assign(x, DstOp.EvalKernel(2*s, h, k), [a]) ::
96 :     code
97 :     val code = List.foldr mkArg code tmps
98 :     val code = if (dim > 1)
99 :     then assign(fd, DstOp.Index(DstTy.vecTy dim, d), [f]) :: code
100 :     else code
101 :     in
102 :     (d+1, code)
103 :     end
104 :     val Partials.D l = pdOp
105 :     in
106 :     (* we iterate from fastest to slowest axis *)
107 :     #2 (ListPair.foldr gen (0, []) (convCoeffs, List.rev l))
108 :     end
109 :     (* generate the reduction code in reverse order *)
110 :     fun genReduce (result, [hh], IT.LF{vox, offsets}, code) =
111 :     assign (result, DstOp.Dot(2*s), [vox, hh]) :: code
112 :     | genReduce (result, hh::r, IT.ND(_, kids), code) = let
113 :     val tv = DstV.new ("tv", vecsTy)
114 :     val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i, DstTy.realTy))
115 :     fun lp ([], [], code) = code
116 :     | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code))
117 :     val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code
118 :     in
119 :     lp (tmps, kids, code)
120 :     end
121 :     | genReduce _ = raise Fail "genReduce"
122 :     val reduceCode = genReduce (result, convCoeffs, voxIter, [])
123 :     in
124 :     coeffCode @ reduceCode
125 :     end
126 : jhr 349
127 : jhr 3072 fun loadVoxelVec (vox, v, offset, s, img, indices, BCtl.None) = let
128 : jhr 3080 val a = DstV.new ("a", DstTy.AddrTy v)
129 :     in [
130 :     assign(a, DstOp.VoxelAddress(v, offset), img::indices),
131 :     assign(vox, DstOp.LoadVoxels(v, 2*s), [a])
132 :     ] end
133 : jhr 3072 | loadVoxelVec (vox, v, offset, s, img, indices, BCtl.Default x) = raise Fail "Border control default"
134 :     | loadVoxelVec (vox, v, offset, s, img, indices, BCtl.Index ictl) = let
135 : jhr 3080 val numVoxels = 2*s
136 :     (* split the indices into a reverse-order prefix and the last index *)
137 :     val (revIndices, baseIdx) = let
138 :     fun split ([i], prefix) = (prefix, i)
139 :     | split (idx::idxs, prefix) = split (idxs, idx::prefix)
140 :     in
141 :     split (indices, [])
142 :     end
143 :     fun mk (i, voxs, stms) = if (i < numVoxels)
144 :     then let
145 :     val (indices, stms) = if (i = 0)
146 :     then (indices, stms)
147 :     else let
148 :     val idx = DstV.new ("idx", DstTy.IntTy)
149 :     val n = DstV.new ("n", DstTy.IntTy)
150 :     val stms = assign(idx, DstOp.Add DstTy.IntTy, [baseIdx, n]) ::
151 :     intLit(n, i) :: stms
152 :     in
153 :     (List.revAppend(revIndices, [idx]), stms)
154 :     end
155 :     val a = DstV.new ("a", DstTy.AddrTy v)
156 :     val addrStm = assign(a, DstOp.VoxelAddressWithCtl(v, offset, ictl), img::indices)
157 :     val vox = DstV.new ("vox", DstTy.realTy)
158 :     val loadStm = assign(vox, DstOp.LoadVoxels(v, 1), [a])
159 :     in
160 :     mk (i+1, vox::voxs, loadStm :: addrStm :: stms)
161 :     end
162 :     else List.rev(cons(vox, List.rev voxs) :: stms)
163 :     in
164 :     mk (0, [], [])
165 :     end
166 : jhr 3072
167 :     fun doVoxelSample (result, v, k, s, diffIter, {h, n, f, img}, offset, ctl) = let
168 : jhr 2998 val stride = ImageInfo.stride
169 :     val dim = ImageInfo.dim v
170 :     val vecsTy = DstTy.vecTy(2*s) (* vectors of coefficients cover support of kernel *)
171 :     (* generate code to load the voxel data; since we use a vector load operation to load the
172 :     * fastest dimension, the height of the tree is one less than the dimension of space.
173 :     *)
174 :     val voxIter = let
175 :     fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id)
176 :     fun g (offsets, id) = {
177 :     offsets = ~(s-1) :: offsets,
178 :     vox = DstV.new(String.concat("v" :: List.map Int.toString id), vecsTy)
179 :     }
180 :     in
181 :     IT.create (dim-1, 2*s, fn _ => (), f, g, ([], []))
182 :     end
183 :     val loadCode = let
184 :     fun genCode ({offsets, vox}, code) = let
185 :     fun computeIndices (_, []) = ([], [])
186 :     | computeIndices (i, offset::offsets) = let
187 :     val index = newVar_dim("i", i, DstTy.intTy)
188 :     val t1 = DstV.new ("t1", DstTy.intTy)
189 :     val t2 = DstV.new ("t2", DstTy.intTy)
190 :     val (indices, code) = computeIndices (i+1, offsets)
191 :     val code = if (dim > 1)
192 :     then
193 :     intLit(t1, offset) ::
194 :     assign(t2, DstOp.Index(DstTy.iVecTy dim, i), [n]) ::
195 :     assign(index, DstOp.Add(DstTy.intTy), [t1, t2]) ::
196 :     code
197 :     else
198 :     intLit(t1, offset) ::
199 :     assign(index, DstOp.Add(DstTy.intTy), [t1, n]) ::
200 :     code
201 :     val indices = index::indices
202 :     in
203 :     (indices, code)
204 :     end
205 :     val (indices, indicesCode) = computeIndices (0, offsets)
206 :     val a = DstV.new ("a", DstTy.AddrTy v)
207 : jhr 3080 val loadStms = loadVoxelVec (vox, v, offset, s, img, indices, ctl)
208 : jhr 2998 in
209 : jhr 3072 indicesCode @ loadStms @ code
210 : jhr 2998 end
211 :     in
212 :     IT.foldr genCode [] voxIter
213 :     end
214 :     (* generate code to evaluate and construct the result tensor *)
215 :     val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter}
216 :     fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let
217 :     (* the kids will all be leaves *)
218 :     fun genProbeCode (IT.LF arg, code) = probeElem arg @ code
219 :     fun getProbeVar (IT.LF(t, _)) = t
220 :     in
221 :     List.foldr genProbeCode (cons (result, List.map getProbeVar kids) :: code) kids
222 :     end
223 :     | genProbe (result, IT.ND(ty, kids), code) = let
224 : jhr 1116 (* FIXME: the type of the tmps depends on the types of the kids *)
225 : jhr 2998 val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i, ty))
226 :     val code = cons(result, tmps) :: code
227 :     fun lp ([], [], code) = code
228 :     | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code))
229 :     in
230 :     lp (tmps, kids, code)
231 :     end
232 :     | genProbe (result, IT.LF(t, pdOp), code) = (* for scalar fields *)
233 :     probeElem (result, pdOp) @ code
234 :     val probeCode = if (k > 0)
235 :     then let
236 :     (* for gradients, etc. we have to transform back to world space *)
237 :     val ty = DstV.ty result
238 :     val tensor = DstV.new("tensor", ty)
239 :     val xform = assign(result, DstOp.TensorToWorldSpace(v, ty), [img, tensor])
240 :     in
241 :     genProbe (tensor, diffIter, [xform])
242 :     end
243 :     else genProbe (result, diffIter, [])
244 :     in
245 : jhr 450 (* FIXME: for dim > 1 and k > 1, we need to transform the result back into world space *)
246 : jhr 2998 loadCode @ probeCode
247 :     end
248 : jhr 328
249 : jhr 1116 (* generate code for probing the field (D^k (v * h)) at pos *)
250 : jhr 2984 fun expand {result, img, v, h, k, pos, ctl} = let
251 : jhr 2998 val dim = ImageInfo.dim v
252 :     val s = Kernel.support h
253 :     val vecsTy = DstTy.vecTy(2*s) (* vectors of coefficients to cover support of kernel *)
254 :     val vecDimTy = DstTy.vecTy dim
255 :     (* generate the transform code *)
256 :     val x = DstV.new ("x", vecDimTy) (* image-space position *)
257 :     val f = DstV.new ("f", vecDimTy)
258 :     val nd = DstV.new ("nd", vecDimTy)
259 :     val n = DstV.new ("n", DstTy.iVecTy dim)
260 :     val toImgSpaceCode = [
261 :     assign(x, DstOp.PosToImgSpace v, [img, pos]),
262 :     assign(nd, DstOp.Floor dim, [x]),
263 :     assign(f, DstOp.Sub vecDimTy, [x, nd]),
264 :     assign(n, DstOp.RealToInt dim, [nd])
265 :     ]
266 :     (* generate the shape of the differentiation tensor with variables representing
267 :     * the elements
268 :     *)
269 :     fun diffIter () = let
270 :     val partial = Partials.partial dim
271 :     fun f (i, (_::dd, axes)) = (dd, Partials.axis i :: axes)
272 :     fun labelNd (_::dd, _) = DstTy.tensorTy dd
273 :     fun labelLf (_, axes) = let
274 :     val r = DstV.new(
275 :     String.concat("r" :: List.map Partials.axisToString axes),
276 :     DstTy.realTy)
277 :     in
278 :     (r, partial axes)
279 :     end
280 :     in
281 :     IT.create (k, dim, labelNd, f, labelLf, (List.tabulate(k, fn _ => dim), []))
282 :     end
283 :     val vars = {h=h, n=n, f=f, img=img}
284 :     in
285 :     case ImageInfo.voxelShape v
286 : jhr 3072 of [] => toImgSpaceCode @ doVoxelSample (result, v, k, s, diffIter(), vars, 0, ctl)
287 : jhr 1795 | shape => let
288 :     (* the result will be a d1 x d2 matrix of k-order tensors; each element of the matrix
289 :     * has the sampleTy = tensor[dim,...,dim].
290 :     *)
291 :     val sampeShape = List.tabulate(k, fn _ => dim)
292 :     fun gen (y, d::dd, offset, code) = let
293 :     val prefix = DstV.name y
294 :     val xs = List.tabulate(d,
295 :     fn i => DstV.new(
296 :     concat[prefix, "_", Int.toString i],
297 :     DstTy.TensorTy(dd @ sampeShape)))
298 :     val code = cons(y, xs) :: code
299 :     in
300 :     List.foldr (fn (x, (off, cd)) => gen(x, dd, off, cd)) (offset, code) xs
301 :     end
302 :     | gen (y, [], offset, code) =
303 : jhr 3072 (offset-1, doVoxelSample (y, v, k, s, diffIter(), vars, offset, ctl) @ code)
304 : jhr 1795 val (_, code) = gen (result, shape, ImageInfo.stride v - 1, [])
305 :     in
306 :     toImgSpaceCode @ code
307 :     end
308 : jhr 2998 (* end case *)
309 :     end
310 : jhr 328
311 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0