Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/pure-cfg/src/compiler/high-to-mid/probe.sml
ViewVC logotype

Annotation of /branches/pure-cfg/src/compiler/high-to-mid/probe.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 517 - (view) (download)

1 : jhr 328 (* probe.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 328 * All rights reserved.
5 :     *
6 :     * Expansion of probe operations in the HighIL to MidIL translation.
7 :     *)
8 :    
9 :     structure Probe : sig
10 :    
11 : jhr 517 val expand : {
12 :     result : MidIL.var, (* result variable for probe *)
13 :     img : MidIL.var, (* probe image argument *)
14 :     v : ImageInfo.info, (* summary info about image *)
15 :     h : Kernel.kernel, (* reconstruction kernel *)
16 :     k : int, (* number of levels of differentiation *)
17 :     pos : MidIL.var (* probe position argument *)
18 :     } -> MidIL.assign list
19 : jhr 328
20 :     end = struct
21 :    
22 :     structure SrcIL = HighIL
23 : jhr 334 structure SrcOp = HighOps
24 : jhr 328 structure DstIL = MidIL
25 : jhr 391 structure DstTy = MidILTypes
26 : jhr 334 structure DstOp = MidOps
27 : jhr 349 structure DstV = DstIL.Var
28 : jhr 328 structure VMap = SrcIL.Var.Map
29 : jhr 349 structure IT = Shape
30 : jhr 328
31 :     (* generate a new variable indexed by dimension *)
32 : jhr 394 fun newVar_dim (prefix, d, ty) =
33 :     DstV.new (prefix ^ Partials.axisToString(Partials.axis d), ty)
34 : jhr 328
35 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
36 :     fun cons (x, args) = (x, DstIL.CONS args)
37 : jhr 334 fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i)))
38 : jhr 328 fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i)))
39 :    
40 : jhr 349 (* generate code for a evaluating a single element of a probe operation *)
41 :     fun probeElem {
42 :     dim, (* dimension of space *)
43 :     h, s, (* kernel h with support s *)
44 :     n, f, (* Dst vars for integer and fractional components of position *)
45 :     voxIter (* iterator over voxels *)
46 :     } (result, pdOp) = let
47 : jhr 394 val vecsTy = DstTy.VecTy(2*s) (* vectors of coefficients cover support of kernel *)
48 :     val vecDimTy = DstTy.VecTy dim
49 : jhr 349 (* generate the variables that hold the convolution coefficients *)
50 :     val convCoeffs = let
51 :     val Partials.D l = pdOp
52 : jhr 353 fun mkVar (_, []) = []
53 :     | mkVar (i, d::dd) = (case d
54 : jhr 394 of 0 => newVar_dim("h", i, vecsTy) :: mkVar(i+1, dd)
55 :     | 1 => newVar_dim("dh", i, vecsTy) :: mkVar(i+1, dd)
56 :     | _ => newVar_dim(concat["d", Int.toString d, "h"], i, vecsTy) :: mkVar(i+1, dd)
57 : jhr 353 (* end case *))
58 : jhr 349 in
59 : jhr 353 mkVar (0, l)
60 : jhr 349 end
61 : jhr 353 val _ = print(concat["probeElem: ", Partials.partialToString pdOp, " in ", Int.toString(List.length convCoeffs), "D space\n"])
62 : jhr 349 (* for each dimension, we evaluate the kernel at the coordinates for that axis *)
63 :     val coeffCode = let
64 :     fun gen (x, k, (d, code)) = let
65 :     val d = d-1
66 : jhr 394 val fd = newVar_dim ("f", d, DstTy.realTy)
67 :     val a = DstV.new ("a", vecsTy)
68 :     val tmps = List.tabulate(2*s,
69 :     fn i => (DstV.new("t"^Int.toString i, DstTy.realTy), s - (i+1)))
70 : jhr 349 fun mkArg ((t, n), code) = let
71 : jhr 394 val t' = DstV.new ("r", DstTy.realTy)
72 : jhr 349 in
73 :     realLit (t', n) ::
74 : jhr 391 assign (t, DstOp.Add DstTy.realTy, [fd, t']) ::
75 : jhr 349 code
76 :     end
77 :     val code =
78 : jhr 353 cons(a, List.map #1 tmps) ::
79 : jhr 349 assign(x, DstOp.EvalKernel(2*s, h, k), [a]) ::
80 : jhr 353 code
81 :     val code =
82 : jhr 420 assign(fd, DstOp.Select(DstTy.VecTy dim, d), [f]) ::
83 : jhr 353 List.foldr mkArg code tmps
84 : jhr 349 in
85 : jhr 353 (d, code)
86 : jhr 349 end
87 :     val Partials.D l = pdOp
88 :     in
89 :     #2 (ListPair.foldr gen (dim, []) (convCoeffs, l))
90 :     end
91 :     (* generate the reduction code *)
92 : jhr 353 fun genReduce (result, [hh], IT.LF{vox, offsets}, code) =
93 :     assign (result, DstOp.Dot(2*s), [vox, hh]) :: code
94 : jhr 349 | genReduce (result, hh::r, IT.ND(_, kids), code) = let
95 : jhr 394 val tv = DstV.new ("tv", vecsTy)
96 :     val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i, DstTy.realTy))
97 : jhr 349 fun lp ([], [], code) = code
98 :     | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code))
99 :     val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code
100 :     in
101 :     lp (tmps, kids, code)
102 :     end
103 : jhr 353 | genReduce _ = raise Fail "genReduce"
104 : jhr 349 val reduceCode = genReduce (result, convCoeffs, voxIter, [])
105 :     in
106 :     coeffCode @ reduceCode
107 :     end
108 :    
109 : jhr 328 (* generate code for probing the field (D^k (v * h)) at pos *)
110 : jhr 517 fun expand {result, img, v, h, k, pos} = let
111 : jhr 328 val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v
112 :     val s = Kernel.support h
113 : jhr 394 val vecsTy = DstTy.VecTy(2*s) (* vectors of coefficients cover support of kernel *)
114 :     val vecDimTy = DstTy.VecTy dim
115 : jhr 328 (* generate the transform code *)
116 : jhr 394 val x = DstV.new ("x", vecDimTy) (* image-space position *)
117 :     val f = DstV.new ("f", vecDimTy)
118 :     val nd = DstV.new ("nd", vecDimTy)
119 :     val n = DstV.new ("n", DstTy.IVecTy dim)
120 : jhr 328 val transformCode = [
121 : jhr 517 assign(x, DstOp.PosToImgSpace dim, [img, pos]),
122 : jhr 334 assign(nd, DstOp.Floor dim, [x]),
123 : jhr 394 assign(f, DstOp.Sub vecDimTy, [x, nd]),
124 : jhr 328 assign(n, DstOp.TruncToInt dim, [nd])
125 :     ]
126 : jhr 349 (* generate the shape of the differentiation tensor with variables representing
127 :     * the elements
128 :     *)
129 :     val diffIter = let
130 :     val partial = Partials.partial dim
131 :     fun f (i, axes) = Partials.axis i :: axes
132 : jhr 394 fun g axes = let
133 :     val r = DstV.new(
134 :     String.concat("r" :: List.map Partials.axisToString axes),
135 :     DstTy.realTy)
136 :     in
137 :     (r, partial axes)
138 :     end
139 : jhr 349 in
140 : jhr 374 IT.create (k, dim, fn _ => (), f, g, [])
141 : jhr 349 end
142 : jhr 374 val _ = let
143 :     val indentWid = ref 2
144 :     fun inc () = (indentWid := !indentWid + 2)
145 :     fun dec () = (indentWid := !indentWid - 2)
146 :     fun indent () = print(CharVector.tabulate(!indentWid, fn _ => #" "))
147 :     fun nd () = (indent(); print "ND\n");
148 :     fun lf (x, partial) = (
149 :     indent(); print(concat["LF(", DstV.toString x, ", ", Partials.partialToString partial, ")\n"]))
150 :     fun pr (Shape.ND(attr, kids)) = (nd attr; inc(); List.app pr kids; dec())
151 :     | pr (Shape.LF attr) = lf attr
152 :     in
153 :     print "diffIter:\n";
154 :     pr diffIter
155 :     end
156 : jhr 353 (* generate code to load the voxel data; since we a vector load operation to load the
157 :     * fastest dimension, the height of the tree is one less than the dimension of space.
158 :     *)
159 : jhr 328 val voxIter = let
160 :     fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id)
161 :     fun g (offsets, id) = {
162 : jhr 353 offsets = ~(s-1) :: offsets,
163 : jhr 394 vox = DstV.new(String.concat("v" :: List.map Int.toString id), vecsTy)
164 : jhr 328 }
165 :     in
166 : jhr 334 IT.create (dim-1, 2*s, fn _ => (), f, g, ([], []))
167 : jhr 328 end
168 : jhr 353 val _ = let
169 :     val indentWid = ref 2
170 :     fun inc () = (indentWid := !indentWid + 2)
171 :     fun dec () = (indentWid := !indentWid - 2)
172 :     fun indent () = print(CharVector.tabulate(!indentWid, fn _ => #" "))
173 :     fun nd () = (indent(); print "ND\n");
174 :     fun lf {offsets, vox} = (
175 :     indent(); print "LF{offsets = ["; print(String.concatWith "," (List.map Int.toString offsets));
176 :     print "], vox = "; print(DstV.toString vox); print "}\n")
177 :     fun pr (Shape.ND(attr, kids)) = (nd attr; inc(); List.app pr kids; dec())
178 :     | pr (Shape.LF attr) = lf attr
179 :     in
180 : jhr 374 print "voxIter:\n";
181 : jhr 353 pr voxIter
182 :     end
183 : jhr 328 val loadCode = let
184 :     fun genCode ({offsets, vox}, code) = let
185 :     fun computeIndices (_, []) = ([], [])
186 :     | computeIndices (i, offset::offsets) = let
187 : jhr 394 val index = newVar_dim("i", i, DstTy.intTy)
188 :     val t1 = DstV.new ("t1", DstTy.intTy)
189 :     val t2 = DstV.new ("t2", DstTy.intTy)
190 : jhr 328 val (indices, code) = computeIndices (i+1, offsets)
191 :     val code =
192 :     intLit(t1, offset) ::
193 : jhr 421 assign(t2, DstOp.Select(DstTy.IVecTy dim, i), [n]) ::
194 : jhr 391 assign(index, DstOp.Add(DstTy.intTy), [t1, t2]) ::
195 : jhr 328 code
196 :     val indices = index::indices
197 :     in
198 :     (indices, code)
199 :     end
200 : jhr 353 val (indices, indicesCode) = computeIndices (0, offsets)
201 : jhr 394 val a = DstV.new ("a", DstTy.AddrTy)
202 : jhr 328 in
203 : jhr 349 indicesCode @ [
204 : jhr 517 assign(a, DstOp.VoxelAddress dim, img::indices),
205 : jhr 349 assign(vox, DstOp.LoadVoxels(ty, 2*s), [a])
206 : jhr 328 ] @ code
207 :     end
208 :     in
209 :     IT.foldr genCode [] voxIter
210 :     end
211 : jhr 349 (* generate code to evaluate and construct the result tensor *)
212 :     val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter}
213 :     fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let
214 :     (* the kids will all be leaves *)
215 :     fun genProbeCode (IT.LF arg, code) = probeElem arg @ code
216 :     fun getProbeVar (IT.LF(t, _)) = t
217 :     in
218 : jhr 374 List.foldr genProbeCode (cons (result, List.map getProbeVar kids) :: code) kids
219 : jhr 349 end
220 :     | genProbe (result, IT.ND(_, kids), code) = let
221 : jhr 394 val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i, DstTy.realTy))
222 : jhr 374 val code = cons(result, tmps) :: code
223 : jhr 349 fun lp ([], [], code) = code
224 :     | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code))
225 :     in
226 :     lp (tmps, kids, code)
227 :     end
228 : jhr 352 | genProbe (result, IT.LF(t, pdOp), code) = (* for scalar fields *)
229 :     probeElem (result, pdOp) @ code
230 : jhr 349 val probeCode = genProbe (result, diffIter, [])
231 : jhr 328 in
232 : jhr 450 (* FIXME: for dim > 1 and k > 1, we need to transform the result back into world space *)
233 : jhr 349 transformCode @ loadCode @ probeCode
234 : jhr 328 end
235 :    
236 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0