SCM Repository
Annotation of /trunk/src/compiler/high-to-mid/probe.sml
Parent Directory
|
Revision Log
Revision 349 - (view) (download)
1 : | jhr | 328 | (* probe.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | * Expansion of probe operations in the HighIL to MidIL translation. | ||
7 : | *) | ||
8 : | |||
9 : | structure Probe : sig | ||
10 : | |||
11 : | jhr | 349 | val expand : MidIL.var * FieldDef.field_def * MidIL.var -> MidIL.assign list |
12 : | jhr | 328 | |
13 : | end = struct | ||
14 : | |||
15 : | structure SrcIL = HighIL | ||
16 : | jhr | 334 | structure SrcOp = HighOps |
17 : | jhr | 328 | structure DstIL = MidIL |
18 : | jhr | 334 | structure DstOp = MidOps |
19 : | jhr | 349 | structure DstV = DstIL.Var |
20 : | jhr | 328 | structure VMap = SrcIL.Var.Map |
21 : | jhr | 349 | structure IT = Shape |
22 : | jhr | 328 | |
23 : | (* generate a new variable indexed by dimension *) | ||
24 : | fun newVar_dim (prefix, d) = | ||
25 : | jhr | 349 | DstV.new (prefix ^ Partials.axisToString(Partials.axis d)) |
26 : | jhr | 328 | |
27 : | fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) | ||
28 : | fun cons (x, args) = (x, DstIL.CONS args) | ||
29 : | jhr | 334 | fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i))) |
30 : | jhr | 328 | fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i))) |
31 : | |||
32 : | jhr | 349 | (* generate code for a evaluating a single element of a probe operation *) |
33 : | fun probeElem { | ||
34 : | dim, (* dimension of space *) | ||
35 : | h, s, (* kernel h with support s *) | ||
36 : | n, f, (* Dst vars for integer and fractional components of position *) | ||
37 : | voxIter (* iterator over voxels *) | ||
38 : | } (result, pdOp) = let | ||
39 : | (* generate the variables that hold the convolution coefficients *) | ||
40 : | val convCoeffs = let | ||
41 : | val Partials.D l = pdOp | ||
42 : | fun mkVar 0 = newVar_dim("h", dim) | ||
43 : | | mkVar 1 = newVar_dim("dh", dim) | ||
44 : | | mkVar i = newVar_dim(concat["d", Int.toString i, "h"], dim) | ||
45 : | in | ||
46 : | List.map mkVar l | ||
47 : | end | ||
48 : | (* for each dimension, we evaluate the kernel at the coordinates for that axis *) | ||
49 : | val coeffCode = let | ||
50 : | fun gen (x, k, (d, code)) = let | ||
51 : | val d = d-1 | ||
52 : | val fd = newVar_dim ("f", d) | ||
53 : | val a = DstV.new "a" | ||
54 : | val tmps = List.tabulate(2*s, fn i => (DstV.new("t"^Int.toString i), s - (i+1))) | ||
55 : | fun mkArg ((t, n), code) = let | ||
56 : | val t' = DstV.new "r" | ||
57 : | in | ||
58 : | realLit (t', n) :: | ||
59 : | assign (t, DstOp.Add DstOp.realTy, [fd, t']) :: | ||
60 : | code | ||
61 : | end | ||
62 : | (* code in reverse order *) | ||
63 : | val code = | ||
64 : | assign(x, DstOp.EvalKernel(2*s, h, k), [a]) :: | ||
65 : | assign(fd, DstOp.Select(dim, d), [f]) :: | ||
66 : | cons(a, List.map #1 tmps) :: | ||
67 : | (List.foldl mkArg [] tmps) | ||
68 : | in | ||
69 : | (d, List.rev code) | ||
70 : | end | ||
71 : | val Partials.D l = pdOp | ||
72 : | in | ||
73 : | #2 (ListPair.foldr gen (dim, []) (convCoeffs, l)) | ||
74 : | end | ||
75 : | (* generate the reduction code *) | ||
76 : | fun genReduce (result, [hh], IT.ND(_, kids), code) = let | ||
77 : | (* the kids will all be leaves *) | ||
78 : | val vv = DstV.new "vv" | ||
79 : | fun getVox (IT.LF{vox, offsets}) = vox | ||
80 : | in | ||
81 : | cons (vv, List.map getVox kids) :: | ||
82 : | assign (result, DstOp.Dot(2*s), [hh, vv]) :: code | ||
83 : | end | ||
84 : | | genReduce (result, hh::r, IT.ND(_, kids), code) = let | ||
85 : | val tv = DstV.new "tv" | ||
86 : | val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i)) | ||
87 : | fun lp ([], [], code) = code | ||
88 : | | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code)) | ||
89 : | val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code | ||
90 : | in | ||
91 : | lp (tmps, kids, code) | ||
92 : | end | ||
93 : | val reduceCode = genReduce (result, convCoeffs, voxIter, []) | ||
94 : | in | ||
95 : | coeffCode @ reduceCode | ||
96 : | end | ||
97 : | |||
98 : | jhr | 328 | (* generate code for probing the field (D^k (v * h)) at pos *) |
99 : | fun probe (result, (k, v, h), pos) = let | ||
100 : | val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v | ||
101 : | val dimTy = DstOp.VecTy dim | ||
102 : | val s = Kernel.support h | ||
103 : | (* generate the transform code *) | ||
104 : | jhr | 349 | val x = DstV.new "x" (* image-space position *) |
105 : | val f = DstV.new "f" | ||
106 : | val nd = DstV.new "nd" | ||
107 : | val n = DstV.new "n" | ||
108 : | jhr | 328 | val transformCode = [ |
109 : | jhr | 334 | assign(x, DstOp.Transform v, [pos]), |
110 : | assign(nd, DstOp.Floor dim, [x]), | ||
111 : | assign(f, DstOp.Sub dimTy, [x, nd]), | ||
112 : | jhr | 328 | assign(n, DstOp.TruncToInt dim, [nd]) |
113 : | ] | ||
114 : | jhr | 349 | (* generate the shape of the differentiation tensor with variables representing |
115 : | * the elements | ||
116 : | *) | ||
117 : | val diffIter = let | ||
118 : | val partial = Partials.partial dim | ||
119 : | fun f (i, axes) = Partials.axis i :: axes | ||
120 : | fun g axes = | ||
121 : | (DstV.new(String.concat("y" :: List.map Partials.axisToString axes)), partial axes) | ||
122 : | in | ||
123 : | IT.create (k-1, dim, fn _ => (), f, g, []) | ||
124 : | end | ||
125 : | jhr | 328 | (* generate code to load the voxel data *) |
126 : | val voxIter = let | ||
127 : | fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id) | ||
128 : | fun g (offsets, id) = { | ||
129 : | offsets = offsets, | ||
130 : | jhr | 349 | vox = DstV.new(String.concat("v" :: List.map Int.toString id)) |
131 : | jhr | 328 | } |
132 : | in | ||
133 : | jhr | 334 | IT.create (dim-1, 2*s, fn _ => (), f, g, ([], [])) |
134 : | jhr | 328 | end |
135 : | val loadCode = let | ||
136 : | fun genCode ({offsets, vox}, code) = let | ||
137 : | fun computeIndices (_, []) = ([], []) | ||
138 : | | computeIndices (i, offset::offsets) = let | ||
139 : | val index = newVar_dim("i", i) | ||
140 : | jhr | 349 | val t1 = DstV.new "t1" |
141 : | val t2 = DstV.new "t2" | ||
142 : | jhr | 328 | val (indices, code) = computeIndices (i+1, offsets) |
143 : | val code = | ||
144 : | intLit(t1, offset) :: | ||
145 : | jhr | 349 | assign(t2, DstOp.Select(2*s, i), [n]) :: |
146 : | jhr | 328 | assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) :: |
147 : | code | ||
148 : | val indices = index::indices | ||
149 : | in | ||
150 : | (indices, code) | ||
151 : | end | ||
152 : | val (indices, indicesCode) = computeIndices (0, ~(s-1) :: offsets) | ||
153 : | jhr | 349 | val a = DstV.new "a" |
154 : | jhr | 328 | in |
155 : | jhr | 349 | indicesCode @ [ |
156 : | jhr | 334 | assign(a, DstOp.VoxelAddress v, indices), |
157 : | jhr | 349 | assign(vox, DstOp.LoadVoxels(ty, 2*s), [a]) |
158 : | jhr | 328 | ] @ code |
159 : | end | ||
160 : | in | ||
161 : | IT.foldr genCode [] voxIter | ||
162 : | end | ||
163 : | jhr | 349 | (* generate code to evaluate and construct the result tensor *) |
164 : | val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter} | ||
165 : | fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let | ||
166 : | (* the kids will all be leaves *) | ||
167 : | fun genProbeCode (IT.LF arg, code) = probeElem arg @ code | ||
168 : | fun getProbeVar (IT.LF(t, _)) = t | ||
169 : | in | ||
170 : | cons (result, List.map getProbeVar kids) :: List.foldr genProbeCode code kids | ||
171 : | end | ||
172 : | | genProbe (result, IT.ND(_, kids), code) = let | ||
173 : | val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i)) | ||
174 : | fun lp ([], [], code) = code | ||
175 : | | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code)) | ||
176 : | val code = cons(result, tmps) :: code | ||
177 : | in | ||
178 : | lp (tmps, kids, code) | ||
179 : | end | ||
180 : | val probeCode = genProbe (result, diffIter, []) | ||
181 : | jhr | 328 | in |
182 : | jhr | 349 | transformCode @ loadCode @ probeCode |
183 : | jhr | 328 | end |
184 : | |||
185 : | jhr | 349 | fun expand (result, fld, pos) = let |
186 : | jhr | 328 | fun expand' (result, FieldDef.CONV(k, v, h)) = let |
187 : | jhr | 349 | val x = DstV.new "x" |
188 : | val xformStm = (x, DstIL.OP(DstOp.Transform v, [pos])) | ||
189 : | jhr | 328 | in |
190 : | probe (result, (k, v, h), x) @ [xformStm] | ||
191 : | end | ||
192 : | jhr | 349 | (* should push negation down to probe operation |
193 : | jhr | 328 | | expand' (result, FieldDef.NEG fld) = let |
194 : | jhr | 349 | val r = DstV.new "value" |
195 : | jhr | 328 | val stms = expand' (r, fld) |
196 : | jhr | 349 | val ty = ?? |
197 : | jhr | 328 | in |
198 : | expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])] | ||
199 : | end | ||
200 : | jhr | 349 | *) |
201 : | jhr | 328 | | expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM" |
202 : | in | ||
203 : | jhr | 349 | List.rev (expand' (result, fld)) |
204 : | jhr | 328 | end |
205 : | |||
206 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |