Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/compiler/high-to-mid/probe.sml
ViewVC logotype

Annotation of /trunk/src/compiler/high-to-mid/probe.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 352 - (view) (download)

1 : jhr 328 (* probe.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Expansion of probe operations in the HighIL to MidIL translation.
7 :     *)
8 :    
9 :     structure Probe : sig
10 :    
11 : jhr 349 val expand : MidIL.var * FieldDef.field_def * MidIL.var -> MidIL.assign list
12 : jhr 328
13 :     end = struct
14 :    
15 :     structure SrcIL = HighIL
16 : jhr 334 structure SrcOp = HighOps
17 : jhr 328 structure DstIL = MidIL
18 : jhr 334 structure DstOp = MidOps
19 : jhr 349 structure DstV = DstIL.Var
20 : jhr 328 structure VMap = SrcIL.Var.Map
21 : jhr 349 structure IT = Shape
22 : jhr 328
23 :     (* generate a new variable indexed by dimension *)
24 :     fun newVar_dim (prefix, d) =
25 : jhr 349 DstV.new (prefix ^ Partials.axisToString(Partials.axis d))
26 : jhr 328
27 :     fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
28 :     fun cons (x, args) = (x, DstIL.CONS args)
29 : jhr 334 fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i)))
30 : jhr 328 fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i)))
31 :    
32 : jhr 349 (* generate code for a evaluating a single element of a probe operation *)
33 :     fun probeElem {
34 :     dim, (* dimension of space *)
35 :     h, s, (* kernel h with support s *)
36 :     n, f, (* Dst vars for integer and fractional components of position *)
37 :     voxIter (* iterator over voxels *)
38 :     } (result, pdOp) = let
39 :     (* generate the variables that hold the convolution coefficients *)
40 :     val convCoeffs = let
41 :     val Partials.D l = pdOp
42 :     fun mkVar 0 = newVar_dim("h", dim)
43 :     | mkVar 1 = newVar_dim("dh", dim)
44 :     | mkVar i = newVar_dim(concat["d", Int.toString i, "h"], dim)
45 :     in
46 :     List.map mkVar l
47 :     end
48 :     (* for each dimension, we evaluate the kernel at the coordinates for that axis *)
49 :     val coeffCode = let
50 :     fun gen (x, k, (d, code)) = let
51 :     val d = d-1
52 :     val fd = newVar_dim ("f", d)
53 :     val a = DstV.new "a"
54 :     val tmps = List.tabulate(2*s, fn i => (DstV.new("t"^Int.toString i), s - (i+1)))
55 :     fun mkArg ((t, n), code) = let
56 :     val t' = DstV.new "r"
57 :     in
58 :     realLit (t', n) ::
59 :     assign (t, DstOp.Add DstOp.realTy, [fd, t']) ::
60 :     code
61 :     end
62 :     (* code in reverse order *)
63 :     val code =
64 :     assign(x, DstOp.EvalKernel(2*s, h, k), [a]) ::
65 :     assign(fd, DstOp.Select(dim, d), [f]) ::
66 :     cons(a, List.map #1 tmps) ::
67 :     (List.foldl mkArg [] tmps)
68 :     in
69 :     (d, List.rev code)
70 :     end
71 :     val Partials.D l = pdOp
72 :     in
73 :     #2 (ListPair.foldr gen (dim, []) (convCoeffs, l))
74 :     end
75 :     (* generate the reduction code *)
76 :     fun genReduce (result, [hh], IT.ND(_, kids), code) = let
77 :     (* the kids will all be leaves *)
78 :     val vv = DstV.new "vv"
79 :     fun getVox (IT.LF{vox, offsets}) = vox
80 :     in
81 :     cons (vv, List.map getVox kids) ::
82 :     assign (result, DstOp.Dot(2*s), [hh, vv]) :: code
83 :     end
84 :     | genReduce (result, hh::r, IT.ND(_, kids), code) = let
85 :     val tv = DstV.new "tv"
86 :     val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i))
87 :     fun lp ([], [], code) = code
88 :     | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code))
89 :     val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code
90 :     in
91 :     lp (tmps, kids, code)
92 :     end
93 :     val reduceCode = genReduce (result, convCoeffs, voxIter, [])
94 :     in
95 :     coeffCode @ reduceCode
96 :     end
97 :    
98 : jhr 328 (* generate code for probing the field (D^k (v * h)) at pos *)
99 :     fun probe (result, (k, v, h), pos) = let
100 :     val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v
101 :     val dimTy = DstOp.VecTy dim
102 :     val s = Kernel.support h
103 :     (* generate the transform code *)
104 : jhr 349 val x = DstV.new "x" (* image-space position *)
105 :     val f = DstV.new "f"
106 :     val nd = DstV.new "nd"
107 :     val n = DstV.new "n"
108 : jhr 328 val transformCode = [
109 : jhr 334 assign(x, DstOp.Transform v, [pos]),
110 :     assign(nd, DstOp.Floor dim, [x]),
111 :     assign(f, DstOp.Sub dimTy, [x, nd]),
112 : jhr 328 assign(n, DstOp.TruncToInt dim, [nd])
113 :     ]
114 : jhr 349 (* generate the shape of the differentiation tensor with variables representing
115 :     * the elements
116 :     *)
117 :     val diffIter = let
118 :     val partial = Partials.partial dim
119 :     fun f (i, axes) = Partials.axis i :: axes
120 :     fun g axes =
121 :     (DstV.new(String.concat("y" :: List.map Partials.axisToString axes)), partial axes)
122 :     in
123 :     IT.create (k-1, dim, fn _ => (), f, g, [])
124 :     end
125 : jhr 328 (* generate code to load the voxel data *)
126 :     val voxIter = let
127 :     fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id)
128 :     fun g (offsets, id) = {
129 :     offsets = offsets,
130 : jhr 349 vox = DstV.new(String.concat("v" :: List.map Int.toString id))
131 : jhr 328 }
132 :     in
133 : jhr 334 IT.create (dim-1, 2*s, fn _ => (), f, g, ([], []))
134 : jhr 328 end
135 :     val loadCode = let
136 :     fun genCode ({offsets, vox}, code) = let
137 :     fun computeIndices (_, []) = ([], [])
138 :     | computeIndices (i, offset::offsets) = let
139 :     val index = newVar_dim("i", i)
140 : jhr 349 val t1 = DstV.new "t1"
141 :     val t2 = DstV.new "t2"
142 : jhr 328 val (indices, code) = computeIndices (i+1, offsets)
143 :     val code =
144 :     intLit(t1, offset) ::
145 : jhr 349 assign(t2, DstOp.Select(2*s, i), [n]) ::
146 : jhr 328 assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) ::
147 :     code
148 :     val indices = index::indices
149 :     in
150 :     (indices, code)
151 :     end
152 :     val (indices, indicesCode) = computeIndices (0, ~(s-1) :: offsets)
153 : jhr 349 val a = DstV.new "a"
154 : jhr 328 in
155 : jhr 349 indicesCode @ [
156 : jhr 334 assign(a, DstOp.VoxelAddress v, indices),
157 : jhr 349 assign(vox, DstOp.LoadVoxels(ty, 2*s), [a])
158 : jhr 328 ] @ code
159 :     end
160 :     in
161 :     IT.foldr genCode [] voxIter
162 :     end
163 : jhr 349 (* generate code to evaluate and construct the result tensor *)
164 :     val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter}
165 :     fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let
166 :     (* the kids will all be leaves *)
167 :     fun genProbeCode (IT.LF arg, code) = probeElem arg @ code
168 :     fun getProbeVar (IT.LF(t, _)) = t
169 :     in
170 :     cons (result, List.map getProbeVar kids) :: List.foldr genProbeCode code kids
171 :     end
172 :     | genProbe (result, IT.ND(_, kids), code) = let
173 :     val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i))
174 :     fun lp ([], [], code) = code
175 :     | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code))
176 :     val code = cons(result, tmps) :: code
177 :     in
178 :     lp (tmps, kids, code)
179 :     end
180 : jhr 352 | genProbe (result, IT.LF(t, pdOp), code) = (* for scalar fields *)
181 :     probeElem (result, pdOp) @ code
182 : jhr 349 val probeCode = genProbe (result, diffIter, [])
183 : jhr 328 in
184 : jhr 349 transformCode @ loadCode @ probeCode
185 : jhr 328 end
186 :    
187 : jhr 349 fun expand (result, fld, pos) = let
188 : jhr 328 fun expand' (result, FieldDef.CONV(k, v, h)) = let
189 : jhr 349 val x = DstV.new "x"
190 :     val xformStm = (x, DstIL.OP(DstOp.Transform v, [pos]))
191 : jhr 328 in
192 :     probe (result, (k, v, h), x) @ [xformStm]
193 :     end
194 : jhr 349 (* should push negation down to probe operation
195 : jhr 328 | expand' (result, FieldDef.NEG fld) = let
196 : jhr 349 val r = DstV.new "value"
197 : jhr 328 val stms = expand' (r, fld)
198 : jhr 349 val ty = ??
199 : jhr 328 in
200 :     expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])]
201 :     end
202 : jhr 349 *)
203 : jhr 328 | expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM"
204 :     in
205 : jhr 349 List.rev (expand' (result, fld))
206 : jhr 328 end
207 :    
208 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0