SCM Repository
Annotation of /trunk/src/compiler/high-to-mid/probe.sml
Parent Directory
|
Revision Log
Revision 374 - (view) (download)
1 : | jhr | 328 | (* probe.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | * Expansion of probe operations in the HighIL to MidIL translation. | ||
7 : | *) | ||
8 : | |||
9 : | structure Probe : sig | ||
10 : | |||
11 : | jhr | 349 | val expand : MidIL.var * FieldDef.field_def * MidIL.var -> MidIL.assign list |
12 : | jhr | 328 | |
13 : | end = struct | ||
14 : | |||
15 : | structure SrcIL = HighIL | ||
16 : | jhr | 334 | structure SrcOp = HighOps |
17 : | jhr | 328 | structure DstIL = MidIL |
18 : | jhr | 334 | structure DstOp = MidOps |
19 : | jhr | 349 | structure DstV = DstIL.Var |
20 : | jhr | 328 | structure VMap = SrcIL.Var.Map |
21 : | jhr | 349 | structure IT = Shape |
22 : | jhr | 328 | |
23 : | (* generate a new variable indexed by dimension *) | ||
24 : | fun newVar_dim (prefix, d) = | ||
25 : | jhr | 349 | DstV.new (prefix ^ Partials.axisToString(Partials.axis d)) |
26 : | jhr | 328 | |
27 : | fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) | ||
28 : | fun cons (x, args) = (x, DstIL.CONS args) | ||
29 : | jhr | 334 | fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i))) |
30 : | jhr | 328 | fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i))) |
31 : | |||
32 : | jhr | 349 | (* generate code for a evaluating a single element of a probe operation *) |
33 : | fun probeElem { | ||
34 : | dim, (* dimension of space *) | ||
35 : | h, s, (* kernel h with support s *) | ||
36 : | n, f, (* Dst vars for integer and fractional components of position *) | ||
37 : | voxIter (* iterator over voxels *) | ||
38 : | } (result, pdOp) = let | ||
39 : | (* generate the variables that hold the convolution coefficients *) | ||
40 : | val convCoeffs = let | ||
41 : | val Partials.D l = pdOp | ||
42 : | jhr | 353 | fun mkVar (_, []) = [] |
43 : | | mkVar (i, d::dd) = (case d | ||
44 : | of 0 => newVar_dim("h", i) :: mkVar(i+1, dd) | ||
45 : | | 1 => newVar_dim("dh", i) :: mkVar(i+1, dd) | ||
46 : | | _ => newVar_dim(concat["d", Int.toString d, "h"], i) :: mkVar(i+1, dd) | ||
47 : | (* end case *)) | ||
48 : | jhr | 349 | in |
49 : | jhr | 353 | mkVar (0, l) |
50 : | jhr | 349 | end |
51 : | jhr | 353 | val _ = print(concat["probeElem: ", Partials.partialToString pdOp, " in ", Int.toString(List.length convCoeffs), "D space\n"]) |
52 : | jhr | 349 | (* for each dimension, we evaluate the kernel at the coordinates for that axis *) |
53 : | val coeffCode = let | ||
54 : | fun gen (x, k, (d, code)) = let | ||
55 : | val d = d-1 | ||
56 : | val fd = newVar_dim ("f", d) | ||
57 : | val a = DstV.new "a" | ||
58 : | val tmps = List.tabulate(2*s, fn i => (DstV.new("t"^Int.toString i), s - (i+1))) | ||
59 : | fun mkArg ((t, n), code) = let | ||
60 : | val t' = DstV.new "r" | ||
61 : | in | ||
62 : | realLit (t', n) :: | ||
63 : | assign (t, DstOp.Add DstOp.realTy, [fd, t']) :: | ||
64 : | code | ||
65 : | end | ||
66 : | val code = | ||
67 : | jhr | 353 | cons(a, List.map #1 tmps) :: |
68 : | jhr | 349 | assign(x, DstOp.EvalKernel(2*s, h, k), [a]) :: |
69 : | jhr | 353 | code |
70 : | val code = | ||
71 : | jhr | 349 | assign(fd, DstOp.Select(dim, d), [f]) :: |
72 : | jhr | 353 | List.foldr mkArg code tmps |
73 : | jhr | 349 | in |
74 : | jhr | 353 | (d, code) |
75 : | jhr | 349 | end |
76 : | val Partials.D l = pdOp | ||
77 : | in | ||
78 : | #2 (ListPair.foldr gen (dim, []) (convCoeffs, l)) | ||
79 : | end | ||
80 : | (* generate the reduction code *) | ||
81 : | jhr | 353 | fun genReduce (result, [hh], IT.LF{vox, offsets}, code) = |
82 : | assign (result, DstOp.Dot(2*s), [vox, hh]) :: code | ||
83 : | jhr | 349 | | genReduce (result, hh::r, IT.ND(_, kids), code) = let |
84 : | val tv = DstV.new "tv" | ||
85 : | val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i)) | ||
86 : | fun lp ([], [], code) = code | ||
87 : | | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code)) | ||
88 : | val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code | ||
89 : | in | ||
90 : | lp (tmps, kids, code) | ||
91 : | end | ||
92 : | jhr | 353 | | genReduce _ = raise Fail "genReduce" |
93 : | jhr | 349 | val reduceCode = genReduce (result, convCoeffs, voxIter, []) |
94 : | in | ||
95 : | coeffCode @ reduceCode | ||
96 : | end | ||
97 : | |||
98 : | jhr | 328 | (* generate code for probing the field (D^k (v * h)) at pos *) |
99 : | fun probe (result, (k, v, h), pos) = let | ||
100 : | val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v | ||
101 : | val dimTy = DstOp.VecTy dim | ||
102 : | val s = Kernel.support h | ||
103 : | (* generate the transform code *) | ||
104 : | jhr | 349 | val x = DstV.new "x" (* image-space position *) |
105 : | val f = DstV.new "f" | ||
106 : | val nd = DstV.new "nd" | ||
107 : | val n = DstV.new "n" | ||
108 : | jhr | 328 | val transformCode = [ |
109 : | jhr | 334 | assign(x, DstOp.Transform v, [pos]), |
110 : | assign(nd, DstOp.Floor dim, [x]), | ||
111 : | assign(f, DstOp.Sub dimTy, [x, nd]), | ||
112 : | jhr | 328 | assign(n, DstOp.TruncToInt dim, [nd]) |
113 : | ] | ||
114 : | jhr | 349 | (* generate the shape of the differentiation tensor with variables representing |
115 : | * the elements | ||
116 : | *) | ||
117 : | val diffIter = let | ||
118 : | val partial = Partials.partial dim | ||
119 : | fun f (i, axes) = Partials.axis i :: axes | ||
120 : | fun g axes = | ||
121 : | jhr | 374 | (DstV.new(String.concat("r" :: List.map Partials.axisToString axes)), partial axes) |
122 : | jhr | 349 | in |
123 : | jhr | 374 | IT.create (k, dim, fn _ => (), f, g, []) |
124 : | jhr | 349 | end |
125 : | jhr | 374 | val _ = let |
126 : | val indentWid = ref 2 | ||
127 : | fun inc () = (indentWid := !indentWid + 2) | ||
128 : | fun dec () = (indentWid := !indentWid - 2) | ||
129 : | fun indent () = print(CharVector.tabulate(!indentWid, fn _ => #" ")) | ||
130 : | fun nd () = (indent(); print "ND\n"); | ||
131 : | fun lf (x, partial) = ( | ||
132 : | indent(); print(concat["LF(", DstV.toString x, ", ", Partials.partialToString partial, ")\n"])) | ||
133 : | fun pr (Shape.ND(attr, kids)) = (nd attr; inc(); List.app pr kids; dec()) | ||
134 : | | pr (Shape.LF attr) = lf attr | ||
135 : | in | ||
136 : | print "diffIter:\n"; | ||
137 : | pr diffIter | ||
138 : | end | ||
139 : | jhr | 353 | (* generate code to load the voxel data; since we a vector load operation to load the |
140 : | * fastest dimension, the height of the tree is one less than the dimension of space. | ||
141 : | *) | ||
142 : | jhr | 328 | val voxIter = let |
143 : | fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id) | ||
144 : | fun g (offsets, id) = { | ||
145 : | jhr | 353 | offsets = ~(s-1) :: offsets, |
146 : | jhr | 349 | vox = DstV.new(String.concat("v" :: List.map Int.toString id)) |
147 : | jhr | 328 | } |
148 : | in | ||
149 : | jhr | 334 | IT.create (dim-1, 2*s, fn _ => (), f, g, ([], [])) |
150 : | jhr | 328 | end |
151 : | jhr | 353 | val _ = let |
152 : | val indentWid = ref 2 | ||
153 : | fun inc () = (indentWid := !indentWid + 2) | ||
154 : | fun dec () = (indentWid := !indentWid - 2) | ||
155 : | fun indent () = print(CharVector.tabulate(!indentWid, fn _ => #" ")) | ||
156 : | fun nd () = (indent(); print "ND\n"); | ||
157 : | fun lf {offsets, vox} = ( | ||
158 : | indent(); print "LF{offsets = ["; print(String.concatWith "," (List.map Int.toString offsets)); | ||
159 : | print "], vox = "; print(DstV.toString vox); print "}\n") | ||
160 : | fun pr (Shape.ND(attr, kids)) = (nd attr; inc(); List.app pr kids; dec()) | ||
161 : | | pr (Shape.LF attr) = lf attr | ||
162 : | in | ||
163 : | jhr | 374 | print "voxIter:\n"; |
164 : | jhr | 353 | pr voxIter |
165 : | end | ||
166 : | jhr | 328 | val loadCode = let |
167 : | fun genCode ({offsets, vox}, code) = let | ||
168 : | fun computeIndices (_, []) = ([], []) | ||
169 : | | computeIndices (i, offset::offsets) = let | ||
170 : | val index = newVar_dim("i", i) | ||
171 : | jhr | 349 | val t1 = DstV.new "t1" |
172 : | val t2 = DstV.new "t2" | ||
173 : | jhr | 328 | val (indices, code) = computeIndices (i+1, offsets) |
174 : | val code = | ||
175 : | intLit(t1, offset) :: | ||
176 : | jhr | 349 | assign(t2, DstOp.Select(2*s, i), [n]) :: |
177 : | jhr | 328 | assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) :: |
178 : | code | ||
179 : | val indices = index::indices | ||
180 : | in | ||
181 : | (indices, code) | ||
182 : | end | ||
183 : | jhr | 353 | val (indices, indicesCode) = computeIndices (0, offsets) |
184 : | jhr | 349 | val a = DstV.new "a" |
185 : | jhr | 328 | in |
186 : | jhr | 349 | indicesCode @ [ |
187 : | jhr | 334 | assign(a, DstOp.VoxelAddress v, indices), |
188 : | jhr | 349 | assign(vox, DstOp.LoadVoxels(ty, 2*s), [a]) |
189 : | jhr | 328 | ] @ code |
190 : | end | ||
191 : | in | ||
192 : | IT.foldr genCode [] voxIter | ||
193 : | end | ||
194 : | jhr | 349 | (* generate code to evaluate and construct the result tensor *) |
195 : | val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter} | ||
196 : | fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let | ||
197 : | (* the kids will all be leaves *) | ||
198 : | fun genProbeCode (IT.LF arg, code) = probeElem arg @ code | ||
199 : | fun getProbeVar (IT.LF(t, _)) = t | ||
200 : | in | ||
201 : | jhr | 374 | List.foldr genProbeCode (cons (result, List.map getProbeVar kids) :: code) kids |
202 : | jhr | 349 | end |
203 : | | genProbe (result, IT.ND(_, kids), code) = let | ||
204 : | val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i)) | ||
205 : | jhr | 374 | val code = cons(result, tmps) :: code |
206 : | jhr | 349 | fun lp ([], [], code) = code |
207 : | | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code)) | ||
208 : | in | ||
209 : | lp (tmps, kids, code) | ||
210 : | end | ||
211 : | jhr | 352 | | genProbe (result, IT.LF(t, pdOp), code) = (* for scalar fields *) |
212 : | probeElem (result, pdOp) @ code | ||
213 : | jhr | 349 | val probeCode = genProbe (result, diffIter, []) |
214 : | jhr | 328 | in |
215 : | jhr | 349 | transformCode @ loadCode @ probeCode |
216 : | jhr | 328 | end |
217 : | |||
218 : | jhr | 349 | fun expand (result, fld, pos) = let |
219 : | jhr | 353 | fun expand' (result, FieldDef.CONV(k, v, h)) = probe (result, (k, v, h), pos) |
220 : | jhr | 349 | (* should push negation down to probe operation |
221 : | jhr | 328 | | expand' (result, FieldDef.NEG fld) = let |
222 : | jhr | 349 | val r = DstV.new "value" |
223 : | jhr | 328 | val stms = expand' (r, fld) |
224 : | jhr | 349 | val ty = ?? |
225 : | jhr | 328 | in |
226 : | expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])] | ||
227 : | end | ||
228 : | jhr | 349 | *) |
229 : | jhr | 328 | | expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM" |
230 : | in | ||
231 : | jhr | 353 | expand' (result, fld) |
232 : | jhr | 328 | end |
233 : | |||
234 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |