22 |
structure IT = Shape |
structure IT = Shape |
23 |
|
|
24 |
(* generate a new variable indexed by dimension *) |
(* generate a new variable indexed by dimension *) |
25 |
fun newVar_dim (prefix, d) = |
fun newVar_dim (prefix, d, ty) = |
26 |
DstV.new (prefix ^ Partials.axisToString(Partials.axis d)) |
DstV.new (prefix ^ Partials.axisToString(Partials.axis d), ty) |
27 |
|
|
28 |
fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) |
fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) |
29 |
fun cons (x, args) = (x, DstIL.CONS args) |
fun cons (x, args) = (x, DstIL.CONS args) |
37 |
n, f, (* Dst vars for integer and fractional components of position *) |
n, f, (* Dst vars for integer and fractional components of position *) |
38 |
voxIter (* iterator over voxels *) |
voxIter (* iterator over voxels *) |
39 |
} (result, pdOp) = let |
} (result, pdOp) = let |
40 |
|
val vecsTy = DstTy.VecTy(2*s) (* vectors of coefficients cover support of kernel *) |
41 |
|
val vecDimTy = DstTy.VecTy dim |
42 |
(* generate the variables that hold the convolution coefficients *) |
(* generate the variables that hold the convolution coefficients *) |
43 |
val convCoeffs = let |
val convCoeffs = let |
44 |
val Partials.D l = pdOp |
val Partials.D l = pdOp |
45 |
fun mkVar (_, []) = [] |
fun mkVar (_, []) = [] |
46 |
| mkVar (i, d::dd) = (case d |
| mkVar (i, d::dd) = (case d |
47 |
of 0 => newVar_dim("h", i) :: mkVar(i+1, dd) |
of 0 => newVar_dim("h", i, vecsTy) :: mkVar(i+1, dd) |
48 |
| 1 => newVar_dim("dh", i) :: mkVar(i+1, dd) |
| 1 => newVar_dim("dh", i, vecsTy) :: mkVar(i+1, dd) |
49 |
| _ => newVar_dim(concat["d", Int.toString d, "h"], i) :: mkVar(i+1, dd) |
| _ => newVar_dim(concat["d", Int.toString d, "h"], i, vecsTy) :: mkVar(i+1, dd) |
50 |
(* end case *)) |
(* end case *)) |
51 |
in |
in |
52 |
mkVar (0, l) |
mkVar (0, l) |
56 |
val coeffCode = let |
val coeffCode = let |
57 |
fun gen (x, k, (d, code)) = let |
fun gen (x, k, (d, code)) = let |
58 |
val d = d-1 |
val d = d-1 |
59 |
val fd = newVar_dim ("f", d) |
val fd = newVar_dim ("f", d, DstTy.realTy) |
60 |
val a = DstV.new "a" |
val a = DstV.new ("a", vecsTy) |
61 |
val tmps = List.tabulate(2*s, fn i => (DstV.new("t"^Int.toString i), s - (i+1))) |
val tmps = List.tabulate(2*s, |
62 |
|
fn i => (DstV.new("t"^Int.toString i, DstTy.realTy), s - (i+1))) |
63 |
fun mkArg ((t, n), code) = let |
fun mkArg ((t, n), code) = let |
64 |
val t' = DstV.new "r" |
val t' = DstV.new ("r", DstTy.realTy) |
65 |
in |
in |
66 |
realLit (t', n) :: |
realLit (t', n) :: |
67 |
assign (t, DstOp.Add DstTy.realTy, [fd, t']) :: |
assign (t, DstOp.Add DstTy.realTy, [fd, t']) :: |
85 |
fun genReduce (result, [hh], IT.LF{vox, offsets}, code) = |
fun genReduce (result, [hh], IT.LF{vox, offsets}, code) = |
86 |
assign (result, DstOp.Dot(2*s), [vox, hh]) :: code |
assign (result, DstOp.Dot(2*s), [vox, hh]) :: code |
87 |
| genReduce (result, hh::r, IT.ND(_, kids), code) = let |
| genReduce (result, hh::r, IT.ND(_, kids), code) = let |
88 |
val tv = DstV.new "tv" |
val tv = DstV.new ("tv", vecsTy) |
89 |
val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i)) |
val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i, DstTy.realTy)) |
90 |
fun lp ([], [], code) = code |
fun lp ([], [], code) = code |
91 |
| lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code)) |
| lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code)) |
92 |
val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code |
val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code |
102 |
(* generate code for probing the field (D^k (v * h)) at pos *) |
(* generate code for probing the field (D^k (v * h)) at pos *) |
103 |
fun probe (result, (k, v, h), pos) = let |
fun probe (result, (k, v, h), pos) = let |
104 |
val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v |
val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v |
|
val dimTy = DstTy.VecTy dim |
|
105 |
val s = Kernel.support h |
val s = Kernel.support h |
106 |
|
val vecsTy = DstTy.VecTy(2*s) (* vectors of coefficients cover support of kernel *) |
107 |
|
val vecDimTy = DstTy.VecTy dim |
108 |
(* generate the transform code *) |
(* generate the transform code *) |
109 |
val x = DstV.new "x" (* image-space position *) |
val x = DstV.new ("x", vecDimTy) (* image-space position *) |
110 |
val f = DstV.new "f" |
val f = DstV.new ("f", vecDimTy) |
111 |
val nd = DstV.new "nd" |
val nd = DstV.new ("nd", vecDimTy) |
112 |
val n = DstV.new "n" |
val n = DstV.new ("n", DstTy.IVecTy dim) |
113 |
val transformCode = [ |
val transformCode = [ |
114 |
assign(x, DstOp.Transform v, [pos]), |
assign(x, DstOp.Transform v, [pos]), |
115 |
assign(nd, DstOp.Floor dim, [x]), |
assign(nd, DstOp.Floor dim, [x]), |
116 |
assign(f, DstOp.Sub dimTy, [x, nd]), |
assign(f, DstOp.Sub vecDimTy, [x, nd]), |
117 |
assign(n, DstOp.TruncToInt dim, [nd]) |
assign(n, DstOp.TruncToInt dim, [nd]) |
118 |
] |
] |
119 |
(* generate the shape of the differentiation tensor with variables representing |
(* generate the shape of the differentiation tensor with variables representing |
122 |
val diffIter = let |
val diffIter = let |
123 |
val partial = Partials.partial dim |
val partial = Partials.partial dim |
124 |
fun f (i, axes) = Partials.axis i :: axes |
fun f (i, axes) = Partials.axis i :: axes |
125 |
fun g axes = |
fun g axes = let |
126 |
(DstV.new(String.concat("r" :: List.map Partials.axisToString axes)), partial axes) |
val r = DstV.new( |
127 |
|
String.concat("r" :: List.map Partials.axisToString axes), |
128 |
|
DstTy.realTy) |
129 |
|
in |
130 |
|
(r, partial axes) |
131 |
|
end |
132 |
in |
in |
133 |
IT.create (k, dim, fn _ => (), f, g, []) |
IT.create (k, dim, fn _ => (), f, g, []) |
134 |
end |
end |
153 |
fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id) |
fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id) |
154 |
fun g (offsets, id) = { |
fun g (offsets, id) = { |
155 |
offsets = ~(s-1) :: offsets, |
offsets = ~(s-1) :: offsets, |
156 |
vox = DstV.new(String.concat("v" :: List.map Int.toString id)) |
vox = DstV.new(String.concat("v" :: List.map Int.toString id), vecsTy) |
157 |
} |
} |
158 |
in |
in |
159 |
IT.create (dim-1, 2*s, fn _ => (), f, g, ([], [])) |
IT.create (dim-1, 2*s, fn _ => (), f, g, ([], [])) |
177 |
fun genCode ({offsets, vox}, code) = let |
fun genCode ({offsets, vox}, code) = let |
178 |
fun computeIndices (_, []) = ([], []) |
fun computeIndices (_, []) = ([], []) |
179 |
| computeIndices (i, offset::offsets) = let |
| computeIndices (i, offset::offsets) = let |
180 |
val index = newVar_dim("i", i) |
val index = newVar_dim("i", i, DstTy.intTy) |
181 |
val t1 = DstV.new "t1" |
val t1 = DstV.new ("t1", DstTy.intTy) |
182 |
val t2 = DstV.new "t2" |
val t2 = DstV.new ("t2", DstTy.intTy) |
183 |
val (indices, code) = computeIndices (i+1, offsets) |
val (indices, code) = computeIndices (i+1, offsets) |
184 |
val code = |
val code = |
185 |
intLit(t1, offset) :: |
intLit(t1, offset) :: |
191 |
(indices, code) |
(indices, code) |
192 |
end |
end |
193 |
val (indices, indicesCode) = computeIndices (0, offsets) |
val (indices, indicesCode) = computeIndices (0, offsets) |
194 |
val a = DstV.new "a" |
val a = DstV.new ("a", DstTy.AddrTy) |
195 |
in |
in |
196 |
indicesCode @ [ |
indicesCode @ [ |
197 |
assign(a, DstOp.VoxelAddress v, indices), |
assign(a, DstOp.VoxelAddress v, indices), |
211 |
List.foldr genProbeCode (cons (result, List.map getProbeVar kids) :: code) kids |
List.foldr genProbeCode (cons (result, List.map getProbeVar kids) :: code) kids |
212 |
end |
end |
213 |
| genProbe (result, IT.ND(_, kids), code) = let |
| genProbe (result, IT.ND(_, kids), code) = let |
214 |
val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i)) |
val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i, DstTy.realTy)) |
215 |
val code = cons(result, tmps) :: code |
val code = cons(result, tmps) :: code |
216 |
fun lp ([], [], code) = code |
fun lp ([], [], code) = code |
217 |
| lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code)) |
| lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code)) |