8 |
|
|
9 |
structure Probe : sig |
structure Probe : sig |
10 |
|
|
11 |
val expand : MidIL.var * FieldDef.field_def * HighIL.var -> HighIL.assign list |
val expand : MidIL.var * FieldDef.field_def * MidIL.var -> MidIL.assign list |
12 |
|
|
13 |
end = struct |
end = struct |
14 |
|
|
16 |
structure SrcOp = HighOps |
structure SrcOp = HighOps |
17 |
structure DstIL = MidIL |
structure DstIL = MidIL |
18 |
structure DstOp = MidOps |
structure DstOp = MidOps |
19 |
|
structure DstV = DstIL.Var |
20 |
structure VMap = SrcIL.Var.Map |
structure VMap = SrcIL.Var.Map |
21 |
|
structure IT = Shape |
22 |
|
|
23 |
(* a tree representation of nested iterations over the image space, where the |
(* generate a new variable indexed by dimension *) |
24 |
* height of the tree corresponds to the number of dimensions and at each node |
fun newVar_dim (prefix, d) = |
25 |
* we have as many children as there are iterations. |
DstV.new (prefix ^ Partials.axisToString(Partials.axis d)) |
26 |
*) |
|
27 |
structure IT = |
fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) |
28 |
struct |
fun cons (x, args) = (x, DstIL.CONS args) |
29 |
datatype ('nd, 'lf) iter_tree |
fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i))) |
30 |
= LF of 'lf |
fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i))) |
|
| ND of ('nd * ('nd, 'lf) iter_tree list) |
|
31 |
|
|
32 |
fun create (depth, width, ndAttr, f, lfAttr, init) = let |
(* generate code for a evaluating a single element of a probe operation *) |
33 |
fun mk (d, i, arg) = if (d < depth) |
fun probeElem { |
34 |
then ND(ndAttr arg, List.tabulate(width, fn j => mk(d+1, j, f(j, arg)))) |
dim, (* dimension of space *) |
35 |
else LF(lfAttr arg) |
h, s, (* kernel h with support s *) |
36 |
|
n, f, (* Dst vars for integer and fractional components of position *) |
37 |
|
voxIter (* iterator over voxels *) |
38 |
|
} (result, pdOp) = let |
39 |
|
(* generate the variables that hold the convolution coefficients *) |
40 |
|
val convCoeffs = let |
41 |
|
val Partials.D l = pdOp |
42 |
|
fun mkVar 0 = newVar_dim("h", dim) |
43 |
|
| mkVar 1 = newVar_dim("dh", dim) |
44 |
|
| mkVar i = newVar_dim(concat["d", Int.toString i, "h"], dim) |
45 |
|
in |
46 |
|
List.map mkVar l |
47 |
|
end |
48 |
|
(* for each dimension, we evaluate the kernel at the coordinates for that axis *) |
49 |
|
val coeffCode = let |
50 |
|
fun gen (x, k, (d, code)) = let |
51 |
|
val d = d-1 |
52 |
|
val fd = newVar_dim ("f", d) |
53 |
|
val a = DstV.new "a" |
54 |
|
val tmps = List.tabulate(2*s, fn i => (DstV.new("t"^Int.toString i), s - (i+1))) |
55 |
|
fun mkArg ((t, n), code) = let |
56 |
|
val t' = DstV.new "r" |
57 |
in |
in |
58 |
mk (0, 0, init) |
realLit (t', n) :: |
59 |
|
assign (t, DstOp.Add DstOp.realTy, [fd, t']) :: |
60 |
|
code |
61 |
end |
end |
62 |
|
(* code in reverse order *) |
63 |
fun map (nd, lf) t = let |
val code = |
64 |
fun mapf (LF x) = LF(lf x) |
assign(x, DstOp.EvalKernel(2*s, h, k), [a]) :: |
65 |
| mapf (ND(i, kids)) = ND(nd i, List.map mapf kids) |
assign(fd, DstOp.Select(dim, d), [f]) :: |
66 |
|
cons(a, List.map #1 tmps) :: |
67 |
|
(List.foldl mkArg [] tmps) |
68 |
in |
in |
69 |
mapf t |
(d, List.rev code) |
70 |
end |
end |
71 |
|
val Partials.D l = pdOp |
|
fun foldr f init t = let |
|
|
fun fold (LF x, acc) = f(x, acc) |
|
|
| fold (ND(_, kids), acc) = List.foldr fold acc kids |
|
72 |
in |
in |
73 |
fold t |
#2 (ListPair.foldr gen (dim, []) (convCoeffs, l)) |
74 |
end |
end |
75 |
|
(* generate the reduction code *) |
76 |
|
fun genReduce (result, [hh], IT.ND(_, kids), code) = let |
77 |
|
(* the kids will all be leaves *) |
78 |
|
val vv = DstV.new "vv" |
79 |
|
fun getVox (IT.LF{vox, offsets}) = vox |
80 |
|
in |
81 |
|
cons (vv, List.map getVox kids) :: |
82 |
|
assign (result, DstOp.Dot(2*s), [hh, vv]) :: code |
83 |
end |
end |
84 |
(* generate a new variable indexed by dimension *) |
| genReduce (result, hh::r, IT.ND(_, kids), code) = let |
85 |
local |
val tv = DstV.new "tv" |
86 |
val dimNames = Vector.fromList[ "x", "y", "z" ]; |
val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i)) |
87 |
|
fun lp ([], [], code) = code |
88 |
|
| lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code)) |
89 |
|
val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code |
90 |
in |
in |
91 |
fun newVar_dim (prefix, d) = |
lp (tmps, kids, code) |
92 |
DstIL.Var.new (prefix ^ Vector.sub(dimNames, d)) |
end |
93 |
|
val reduceCode = genReduce (result, convCoeffs, voxIter, []) |
94 |
fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) |
in |
95 |
fun cons (x, args) = (x, DstIL.CONS args) |
coeffCode @ reduceCode |
96 |
fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i))) |
end |
|
fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i))) |
|
97 |
|
|
98 |
(* generate code for probing the field (D^k (v * h)) at pos *) |
(* generate code for probing the field (D^k (v * h)) at pos *) |
99 |
fun probe (result, (k, v, h), pos) = let |
fun probe (result, (k, v, h), pos) = let |
100 |
val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v |
val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v |
101 |
val dimTy = DstOp.VecTy dim |
val dimTy = DstOp.VecTy dim |
102 |
val s = Kernel.support h |
val s = Kernel.support h |
|
val sTy = DstOp.VecTy(2*s) |
|
103 |
(* generate the transform code *) |
(* generate the transform code *) |
104 |
val x = DstIL.Var.new "x" (* image-space position *) |
val x = DstV.new "x" (* image-space position *) |
105 |
val f = DstIL.Var.new "f" |
val f = DstV.new "f" |
106 |
val nd = DstIL.Var.new "nd" |
val nd = DstV.new "nd" |
107 |
val n = DstIL.Var.new "n" |
val n = DstV.new "n" |
108 |
val transformCode = [ |
val transformCode = [ |
109 |
assign(x, DstOp.Transform v, [pos]), |
assign(x, DstOp.Transform v, [pos]), |
110 |
assign(nd, DstOp.Floor dim, [x]), |
assign(nd, DstOp.Floor dim, [x]), |
111 |
assign(f, DstOp.Sub dimTy, [x, nd]), |
assign(f, DstOp.Sub dimTy, [x, nd]), |
112 |
assign(n, DstOp.TruncToInt dim, [nd]) |
assign(n, DstOp.TruncToInt dim, [nd]) |
113 |
] |
] |
114 |
|
(* generate the shape of the differentiation tensor with variables representing |
115 |
|
* the elements |
116 |
|
*) |
117 |
|
val diffIter = let |
118 |
|
val partial = Partials.partial dim |
119 |
|
fun f (i, axes) = Partials.axis i :: axes |
120 |
|
fun g axes = |
121 |
|
(DstV.new(String.concat("y" :: List.map Partials.axisToString axes)), partial axes) |
122 |
|
in |
123 |
|
IT.create (k-1, dim, fn _ => (), f, g, []) |
124 |
|
end |
125 |
(* generate code to load the voxel data *) |
(* generate code to load the voxel data *) |
126 |
val voxIter = let |
val voxIter = let |
127 |
fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id) |
fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id) |
128 |
fun g (offsets, id) = { |
fun g (offsets, id) = { |
129 |
offsets = offsets, |
offsets = offsets, |
130 |
vox = DstIL.Var.new(String.concat("v" :: List.map Int.toString id)) |
vox = DstV.new(String.concat("v" :: List.map Int.toString id)) |
131 |
} |
} |
132 |
in |
in |
133 |
IT.create (dim-1, 2*s, fn _ => (), f, g, ([], [])) |
IT.create (dim-1, 2*s, fn _ => (), f, g, ([], [])) |
137 |
fun computeIndices (_, []) = ([], []) |
fun computeIndices (_, []) = ([], []) |
138 |
| computeIndices (i, offset::offsets) = let |
| computeIndices (i, offset::offsets) = let |
139 |
val index = newVar_dim("i", i) |
val index = newVar_dim("i", i) |
140 |
val t1 = DstIL.Var.new "t1" |
val t1 = DstV.new "t1" |
141 |
val t2 = DstIL.Var.new "t2" |
val t2 = DstV.new "t2" |
142 |
val (indices, code) = computeIndices (i+1, offsets) |
val (indices, code) = computeIndices (i+1, offsets) |
143 |
val code = |
val code = |
144 |
intLit(t1, offset) :: |
intLit(t1, offset) :: |
145 |
assign(t2, DstOp.Select i, [n]) :: |
assign(t2, DstOp.Select(2*s, i), [n]) :: |
146 |
assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) :: |
assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) :: |
147 |
code |
code |
148 |
val indices = index::indices |
val indices = index::indices |
150 |
(indices, code) |
(indices, code) |
151 |
end |
end |
152 |
val (indices, indicesCode) = computeIndices (0, ~(s-1) :: offsets) |
val (indices, indicesCode) = computeIndices (0, ~(s-1) :: offsets) |
153 |
val a = DstIL.Var.new "a" |
val a = DstV.new "a" |
154 |
in |
in |
155 |
indicesCode :: [ |
indicesCode @ [ |
156 |
assign(a, DstOp.VoxelAddress v, indices), |
assign(a, DstOp.VoxelAddress v, indices), |
157 |
assign(vox, DstOp.LoadVoxels(ty, 2*s)) |
assign(vox, DstOp.LoadVoxels(ty, 2*s), [a]) |
158 |
] @ code |
] @ code |
159 |
end |
end |
160 |
in |
in |
161 |
IT.foldr genCode [] voxIter |
IT.foldr genCode [] voxIter |
162 |
end |
end |
163 |
val voxVars = IT.foldr (fn ({vox, ...}, vs) => vox::vs) [] voxIter |
(* generate code to evaluate and construct the result tensor *) |
164 |
(* generate the code for computing the convolution coefficients *) |
val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter} |
165 |
val convCoeffs = Vector.tabulate (dim, |
fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let |
166 |
fn d => Vector.tabulate (k+1, |
(* the kids will all be leaves *) |
167 |
fn 0 => newVar_dim("h", d) |
fun genProbeCode (IT.LF arg, code) = probeElem arg @ code |
168 |
| 1 => newVar_dim("dh", d) |
fun getProbeVar (IT.LF(t, _)) = t |
|
| i => newVar_dim(concat["d", Int.toString i, "h"], d))) |
|
|
fun coefficient (d, k) = Vector.sub(Vector.sub(convCoeffs, d), k) |
|
|
fun genCoeffCode d = if (d < dim) |
|
|
then let |
|
|
val code = genCoeffCode (d+1) |
|
|
val fd = newVar_dim "f" |
|
|
val a = DstIL.Var.new "a" |
|
|
val tmps = List.tabulate(2*s, |
|
|
fn i => (DstIL.Var.new("t"^Int.toString i), s - (i+1))) |
|
|
fun mkArg ((t, n), code) = let |
|
|
val t' = DstIL.Var.new "r" |
|
169 |
in |
in |
170 |
realLit (t', n) :: |
cons (result, List.map getProbeVar kids) :: List.foldr genProbeCode code kids |
|
assign (t, DstOp.Add DstOp.realTy, [fd, t']) :: |
|
|
code |
|
171 |
end |
end |
172 |
val coeffCode = |
| genProbe (result, IT.ND(_, kids), code) = let |
173 |
cons(a, List.map #1 tmps) :: |
val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i)) |
174 |
List.tabulate (k+1, fn i => |
fun lp ([], [], code) = code |
175 |
assign(coefficient(d, i), DstOp.EvalKernel(2*s, h, i), [a])) |
| lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code)) |
176 |
|
val code = cons(result, tmps) :: code |
177 |
in |
in |
178 |
assign(fd, DstOp.Select d, f) :: |
lp (tmps, kids, code) |
|
(List.foldr mkArg (coeffCode @ code) tmps) |
|
179 |
end |
end |
180 |
else [] |
val probeCode = genProbe (result, diffIter, []) |
|
val coeffCode = genCoeffCode |
|
|
(* generate the reduction code *) |
|
|
fun genReduce (d, IT.ND(kids), code) = |
|
|
if (d < dim) |
|
|
then List.foldr (fn (nd, code) => genReduce(d+1, nd, code)) code kids |
|
|
else let (* the kids will all be leaves *) |
|
|
val vv = DstIL.Var.new "vv" |
|
|
fun getVox (IT.LF{vox, offsets}) = vox |
|
|
val hh = coefficient (d, 0) (* FIXME: what is the right value for k? *) |
|
181 |
in |
in |
182 |
cons (vv, List.map getVox kids) :: |
transformCode @ loadCode @ probeCode |
|
assign (t, DstOp.Dot, [hh, vv]) :: code |
|
|
end |
|
|
val reduceCode = genReduce (1, voxIter, []) |
|
|
in |
|
|
transformCode @ loadCode @ coeffCode @ reduceCode |
|
183 |
end |
end |
184 |
|
|
185 |
end |
fun expand (result, fld, pos) = let |
|
|
|
|
fun expand (result, FieldDef.CONV(0, img, h), pos) = let |
|
186 |
fun expand' (result, FieldDef.CONV(k, v, h)) = let |
fun expand' (result, FieldDef.CONV(k, v, h)) = let |
187 |
val x = DstIL.Var.new "x" |
val x = DstV.new "x" |
188 |
val xformStm = (x, DstIL.OP(DstOp.Transform img, [pos'])) |
val xformStm = (x, DstIL.OP(DstOp.Transform v, [pos])) |
189 |
in |
in |
190 |
probe (result, (k, v, h), x) @ [xformStm] |
probe (result, (k, v, h), x) @ [xformStm] |
191 |
end |
end |
192 |
|
(* should push negation down to probe operation |
193 |
| expand' (result, FieldDef.NEG fld) = let |
| expand' (result, FieldDef.NEG fld) = let |
194 |
val r = DstIL.Var.new "value" |
val r = DstV.new "value" |
195 |
val stms = expand' (r, fld) |
val stms = expand' (r, fld) |
196 |
|
val ty = ?? |
197 |
in |
in |
198 |
expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])] |
expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])] |
199 |
end |
end |
200 |
|
*) |
201 |
| expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM" |
| expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM" |
202 |
in |
in |
203 |
List.rev (expand (result, fld)) |
List.rev (expand' (result, fld)) |
204 |
end |
end |
205 |
|
|
206 |
end |
end |