Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/mid-to-low/mid-to-low.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/mid-to-low/mid-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2822 - (view) (download)

1 : lamonts 345 (* mid-to-low.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : lamonts 345 * All rights reserved.
5 :     *
6 :     * Translation from MidIL to LowIL representations.
7 :     *)
8 :    
9 :     structure MidToLow : sig
10 :    
11 : jhr 459 val translate : MidIL.program -> LowIL.program
12 : lamonts 345
13 : jhr 387 end = struct
14 : lamonts 345
15 :     structure SrcIL = MidIL
16 :     structure SrcOp = MidOps
17 : jhr 2792 structure SrcGV = SrcIL.GlobalVar
18 : jhr 1640 structure SrcSV = SrcIL.StateVar
19 :     structure SrcTy = MidILTypes
20 : jhr 387 structure VTbl = SrcIL.Var.Tbl
21 : lamonts 345 structure DstIL = LowIL
22 : jhr 464 structure DstTy = LowILTypes
23 : lamonts 345 structure DstOp = LowOps
24 :    
25 : jhr 1640 (* instantiate the translation environment *)
26 :     structure Env = TranslateEnvFn (
27 :     struct
28 :     structure SrcIL = SrcIL
29 :     structure DstIL = DstIL
30 : jhr 2805 fun cvtTy ty = ty
31 : jhr 1640 end)
32 :    
33 : jhr 463 (* convert a rational to a FloatLit.float value. We do this by long division
34 :     * with a cutoff when we get to 12 digits.
35 :     *)
36 :     fun ratToFloat r = (case Rational.explode r
37 : jhr 2805 of {sign=0, ...} => FloatLit.zero false
38 :     | {sign, num, denom=1} => FloatLit.fromInt(IntInf.fromInt sign * num)
39 :     | {sign, num, denom} => let
40 :     (* normalize so that num <= denom *)
41 :     val (denom, exp) = let
42 :     fun lp (n, denom) = if (denom < num)
43 :     then lp(n+1, denom*10)
44 :     else (denom, n)
45 :     in
46 :     lp (1, denom)
47 :     end
48 :     (* normalize so that num <= denom < 10*num *)
49 :     val (num, exp) = let
50 :     fun lp (n, num) = if (10*num < denom)
51 :     then lp(n-1, 10*num)
52 :     else (num, n)
53 :     in
54 :     lp (exp, num)
55 :     end
56 :     (* divide num/denom, computing the resulting digits *)
57 :     fun divLp (n, a) = let
58 :     val (q, r) = IntInf.divMod(a, denom)
59 :     in
60 :     if (r = 0) then (q, [])
61 :     else if (n < 12) then let
62 :     val (d, dd) = divLp(n+1, 10*r)
63 :     in
64 :     if (d < 10)
65 :     then (q, (IntInf.toInt d)::dd)
66 :     else (q+1, 0::dd)
67 :     end
68 :     else if (IntInf.div(10*r, denom) < 5)
69 :     then (q, [])
70 :     else (q+1, []) (* round up *)
71 :     end
72 :     val digits = let
73 :     val (d, dd) = divLp (0, num)
74 :     in
75 :     (IntInf.toInt d)::dd
76 :     end
77 :     in
78 :     FloatLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp}
79 :     end
80 :     (* end case *))
81 : jhr 463
82 : jhr 1116 fun imul (r : DstIL.var, a, b) = (r, DstIL.OP(DstOp.Mul DstTy.intTy, [a, b]))
83 :     fun iadd (r : DstIL.var, a, b) = (r, DstIL.OP(DstOp.Add DstTy.intTy, [a, b]))
84 :     fun ilit (r : DstIL.var, n) = (r, DstIL.LIT(Literal.Int(IntInf.fromInt n)))
85 : jhr 1370 fun radd (r : DstIL.var, a, b) = (r, DstIL.OP(DstOp.Add DstTy.realTy, [a, b]))
86 : jhr 511
87 : jhr 465 (* expand the EvalKernel operations into vector operations. The parameters
88 :     * are
89 : jhr 2805 * result -- the lhs variable to store the result
90 :     * d -- the vector width of the operation, which should be equal
91 :     * to twice the support of the kernel
92 :     * h -- the kernel
93 :     * k -- the derivative of the kernel to evaluate
94 : jhr 465 *
95 :     * The generated code is computing
96 :     *
97 : jhr 2805 * result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... )
98 : jhr 465 *
99 :     * as a d-wide vector operation, where n is the degree of the kth derivative
100 :     * of h and the a_i are coefficient vectors that have an element for each
101 :     * piece of h. The computation is implemented as follows
102 :     *
103 : jhr 2805 * m_n = x * a_n
104 :     * s_{n-1} = a_{n-1} + m_n
105 :     * m_{n-1} = x * s_{n-1}
106 :     * s_{n-2} = a_{n-2} + m_{n-1}
107 :     * m_{n-2} = x * s_{n-2}
108 :     * ...
109 :     * s_1 = a_1 + m_2
110 :     * m_1 = x * s_1
111 :     * result = a_0 + m_1
112 : jhr 1116 *
113 :     * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml).
114 : jhr 459 *)
115 : jhr 463 fun expandEvalKernel (result, d, h, k, [x]) = let
116 : jhr 2805 val {isCont, segs} = Kernel.curve (h, k)
117 :     (* degree of polynomial *)
118 :     val deg = List.length(hd segs) - 1
119 :     (* convert to a vector of vectors to give fast access *)
120 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
121 :     (* get the kernel coefficient value for the d'th term of the i'th
122 :     * segment.
123 :     *)
124 :     fun coefficient d i =
125 :     Literal.Float(ratToFloat (Vector.sub (Vector.sub(segs, i), d)))
126 :     val ty = DstTy.vecTy d
127 :     val coeffs = List.tabulate (deg+1,
128 :     fn i => DstIL.Var.new("a"^Int.toString i, ty))
129 :     (* code to define the coefficient vectors *)
130 :     val coeffVecs = let
131 :     fun mk (x, (i, code)) = let
132 :     val lits = List.tabulate(d, coefficient i)
133 :     val vars = List.tabulate(d, fn _ => DstIL.Var.new("_f", DstTy.realTy))
134 :     val code =
135 :     ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits) @
136 :     (x, DstIL.CONS(DstIL.Var.ty x, vars)) :: code
137 :     in
138 :     (i-1, code)
139 :     end
140 :     in
141 :     #2 (List.foldr mk (deg, []) coeffs)
142 :     end
143 :     (* build the evaluation of the polynomials in reverse order *)
144 :     fun pTmp i = DstIL.Var.new("prod" ^ Int.toString i, ty)
145 :     fun sTmp i = DstIL.Var.new("sum" ^ Int.toString i, ty)
146 :     fun eval (i, [coeff]) = let
147 :     val m = pTmp i
148 :     in
149 :     (m, [(m, DstIL.OP(DstOp.Mul ty, [x, coeff]))])
150 :     end
151 :     | eval (i, coeff::r) = let
152 :     val (m, stms) = eval(i+1, r)
153 :     val s = sTmp i
154 :     val m' = pTmp i
155 :     val stms =
156 :     (m', DstIL.OP(DstOp.Mul ty, [x, s])) ::
157 :     (s, DstIL.OP(DstOp.Add ty, [coeff, m])) ::
158 :     stms
159 :     in
160 :     (m', stms)
161 :     end
162 :     val evalCode = (case coeffs
163 :     of [a0] => (* constant function *)
164 :     [(result, DstIL.VAR a0)]
165 :     | a0::r => let
166 :     val (m, stms) = eval (1, r)
167 :     in
168 :     List.rev ((result, DstIL.OP(DstOp.Add ty, [a0, m]))::stms)
169 :     end
170 :     (* end case *))
171 :     in
172 :     coeffVecs @ evalCode
173 :     end
174 : jhr 387
175 : jhr 1116 (* FIXME: we will get better down-stream CSE if we structure the address computation
176 :     * as
177 : jhr 2805 * (base + stride * (...)) + offset
178 : jhr 1116 * since the lhs argument will be the same for each sample.
179 :     *)
180 :     (* add code to handle the offset and stride when addressing non-scalar image data *)
181 :     fun adjustForStrideAndOffset (1, _, ix, code) = (ix, code)
182 :     | adjustForStrideAndOffset (stride, 0, ix, code) = let
183 : jhr 2805 val offp = DstIL.Var.new ("offp", DstTy.intTy)
184 :     val stride' = DstIL.Var.new ("stride", DstTy.intTy)
185 :     in
186 :     (offp, imul(offp, stride', ix) :: ilit(stride', stride) :: code)
187 :     end
188 : jhr 1116 | adjustForStrideAndOffset (stride, offset, ix, code) = let
189 : jhr 2805 val offp = DstIL.Var.new ("offp", DstTy.intTy)
190 :     val stride' = DstIL.Var.new ("stride", DstTy.intTy)
191 :     val offset' = DstIL.Var.new ("offset", DstTy.intTy)
192 :     val t = DstIL.Var.new ("t", DstTy.intTy)
193 :     val code =
194 :     iadd(offp, offset', t) ::
195 :     ilit (offset', offset) ::
196 :     imul(t, stride', ix) ::
197 :     ilit (stride', stride) ::
198 :     code
199 :     in
200 :     (offp, code)
201 :     end
202 : jhr 1116
203 : jhr 465 (* compute the load address for a given set of voxels indices. For the
204 :     * operation
205 :     *
206 : jhr 2805 * VoxelAddress<info,offset>(i_1, ..., i_d)
207 : jhr 465 *
208 :     * the address is given by
209 :     *
210 : jhr 2805 * base + offset + stride * (i_1 + N_1 * (i_2 + N_2 * (... + N_{d-1} * i_d) ...))
211 : jhr 465 *
212 :     * where
213 : jhr 2805 * base -- base address of the image data
214 :     * stride -- number of samples per voxel
215 :     * offset -- offset of sample being addressed
216 :     * N_i -- size of ith axis in elements
217 : jhr 1116 *
218 :     * Note that we are following the Nrrd convention that the axes are ordered
219 :     * in fastest to slowest order. We are also assuming the C semantics of address
220 :     * arithmetic, where the offset will be automatically scaled by the size of the
221 :     * elements.
222 : jhr 465 *)
223 : jhr 1116 fun expandVoxelAddress (result, info, offset, [img, ix]) = let
224 : jhr 2805 val dim = ImageInfo.dim info
225 :     val stride = ImageInfo.stride info
226 :     val shape = ImageInfo.voxelShape info
227 :     val (offp, code) = adjustForStrideAndOffset (stride, offset, ix, [])
228 :     val addrTy = DstTy.AddrTy info
229 :     val base = DstIL.Var.new ("imgBaseAddr", addrTy)
230 :     val code = (result, DstIL.OP(DstOp.Add addrTy, [base, offp])) ::
231 :     (base, DstIL.OP(DstOp.ImageAddress info, [img])) ::
232 :     code
233 :     in
234 :     List.rev code
235 :     end
236 : jhr 1116 | expandVoxelAddress (result, info, offset, img::ix1::indices) = let
237 : jhr 2805 val dim = ImageInfo.dim info
238 :     val sizes = ImageInfo.sizes info
239 :     val stride = ImageInfo.stride info
240 :     val shape = ImageInfo.voxelShape info
241 :     (* get N_1 ... N_{d-1} *)
242 : jhr 2011 (* FIXME: sizes is [] when the image does not have a proxy *)
243 : jhr 2805 val sizes = List.take (sizes, List.length sizes - 1)
244 :     (* generate the address computation code in reverse order *)
245 :     fun gen (d, [n], [ix]) = let
246 :     val n' = DstIL.Var.new ("n" ^ Int.toString d, DstTy.intTy)
247 :     val t = DstIL.Var.new ("t", DstTy.intTy)
248 :     val code = [
249 :     imul(t, n', ix),
250 :     ilit(n', n)
251 :     ]
252 :     in
253 :     (t, code)
254 :     end
255 :     | gen (d, n::ns, ix::ixs) = let
256 :     val n' = DstIL.Var.new ("n" ^ Int.toString d, DstTy.intTy)
257 :     val t1 = DstIL.Var.new ("t1", DstTy.intTy)
258 :     val t2 = DstIL.Var.new ("t2", DstTy.intTy)
259 :     val (t, code) = gen (d+1, ns, ixs)
260 :     val code =
261 :     imul(t2, n', t1) ::
262 :     ilit(n', n) ::
263 :     iadd(t1, ix, t) :: code
264 :     in
265 :     (t2, code)
266 :     end
267 : jhr 2011 (* FIXME: sizes is [] when the image does not have a proxy *)
268 : jhr 2805 val (tmp, code) = gen (0, sizes, indices)
269 :     val t = DstIL.Var.new ("index", DstTy.intTy)
270 :     val code = iadd(t, ix1, tmp) :: code
271 :     val (offp, code) = adjustForStrideAndOffset (stride, offset, t, code)
272 :     val addrTy = DstTy.AddrTy info
273 :     val base = DstIL.Var.new ("imgBaseAddr", addrTy)
274 :     val code = (result, DstIL.OP(DstOp.Add addrTy, [base, offp])) ::
275 :     (base, DstIL.OP(DstOp.ImageAddress info, [img])) ::
276 :     code
277 :     in
278 :     List.rev code
279 :     end
280 : lamonts 345
281 : jhr 1370 (* expand trace(M) *)
282 :     fun expandTrace (y, d, [m]) = let
283 : jhr 2805 val matTy = DstTy.TensorTy[d,d]
284 :     val rowTy = DstTy.TensorTy[d]
285 :     fun f (i, dst) = if (i < d-1)
286 :     then let
287 :     val i' = Int.toString i
288 :     val ix = DstIL.Var.new ("ix" ^ i', DstTy.intTy)
289 :     val x = DstIL.Var.new ("x" ^ i', DstTy.realTy)
290 :     val acc = DstIL.Var.new ("acc" ^ i', DstTy.realTy)
291 :     val stms = f (i+1, acc)
292 :     in
293 :     radd(dst, acc, x) ::
294 :     (x, DstIL.OP(DstOp.Subscript(matTy), [m, ix, ix])) ::
295 :     ilit(ix, i) ::
296 :     stms
297 :     end
298 :     else let
299 :     val ix = DstIL.Var.new ("ix" ^ Int.toString i, DstTy.intTy)
300 :     in [
301 :     (dst, DstIL.OP(DstOp.Subscript(matTy), [m, ix, ix])),
302 :     ilit(ix, i)
303 :     ] end
304 :     in
305 :     List.rev (f (0, y))
306 :     end
307 : jhr 1370
308 : jhr 431 fun expandOp (env, y, rator, args) = let
309 : jhr 2805 val args' = Env.renameList (env, args)
310 :     fun assign rator' = [(y, DstIL.OP(rator', args'))]
311 :     in
312 :     case rator
313 :     of SrcOp.Add ty => assign (DstOp.Add ty)
314 :     | SrcOp.Sub ty => assign (DstOp.Sub ty)
315 :     | SrcOp.Mul ty => assign (DstOp.Mul ty)
316 :     | SrcOp.Div ty => assign (DstOp.Div ty)
317 :     | SrcOp.Neg ty => assign (DstOp.Neg ty)
318 :     | SrcOp.Abs ty => assign (DstOp.Abs ty)
319 :     | SrcOp.LT ty => assign (DstOp.LT ty)
320 :     | SrcOp.LTE ty => assign (DstOp.LTE ty)
321 :     | SrcOp.EQ ty => assign (DstOp.EQ ty)
322 :     | SrcOp.NEQ ty => assign (DstOp.NEQ ty)
323 :     | SrcOp.GT ty => assign (DstOp.GT ty)
324 :     | SrcOp.GTE ty => assign (DstOp.GTE ty)
325 :     | SrcOp.Not => assign (DstOp.Not)
326 :     | SrcOp.Max => assign (DstOp.Max)
327 :     | SrcOp.Min => assign (DstOp.Min)
328 :     | SrcOp.Clamp ty => assign (DstOp.Clamp ty)
329 :     | SrcOp.Lerp ty => assign (DstOp.Lerp ty)
330 :     | SrcOp.Dot d => assign (DstOp.Dot d)
331 :     | SrcOp.MulVecMat(d1, d2) => assign (DstOp.MulVecMat(d1, d2))
332 :     | SrcOp.MulMatVec(d1, d2) => assign (DstOp.MulMatVec(d1, d2))
333 :     | SrcOp.MulMatMat(d1, d2, d3) => assign (DstOp.MulMatMat(d1, d2, d3))
334 :     | SrcOp.MulVecTen3(d1, d2, d3) => assign(DstOp.MulVecTen3(d1, d2, d3))
335 :     | SrcOp.MulTen3Vec(d1, d2, d3) => assign(DstOp.MulTen3Vec(d1, d2, d3))
336 :     | SrcOp.ColonMul(ty1, ty2) => assign(DstOp.ColonMul(ty1, ty2))
337 :     | SrcOp.Cross => assign (DstOp.Cross)
338 :     | SrcOp.Norm ty => assign (DstOp.Norm ty)
339 :     | SrcOp.Normalize d => assign (DstOp.Normalize d)
340 :     | SrcOp.Scale ty => assign (DstOp.Scale ty)
341 :     | SrcOp.Zero ty => assign (DstOp.Zero ty)
342 :     | SrcOp.PrincipleEvec ty => assign (DstOp.PrincipleEvec ty)
343 : jhr 1640 | SrcOp.EigenVals2x2 => assign (DstOp.EigenVals2x2)
344 :     | SrcOp.EigenVals3x3 => assign (DstOp.EigenVals3x3)
345 : jhr 2805 | SrcOp.Identity n => assign (DstOp.Identity n)
346 :     | SrcOp.Trace d => expandTrace (y, d, args')
347 :     | SrcOp.Transpose(d1, d2) => assign (DstOp.Transpose(d1, d2))
348 : jhr 2822 | SrcOp.Slice(ty, mask) => assign (DstOp.Slice(ty, mask))
349 : jhr 1640 | SrcOp.Select(ty as SrcTy.TupleTy tys, i) => assign (DstOp.Select(ty, i))
350 :     | SrcOp.Index(ty, i) => assign (DstOp.Index(ty, i))
351 :     | SrcOp.Subscript ty => assign (DstOp.Subscript ty)
352 : jhr 1689 | SrcOp.MkDynamic(ty, n) => assign (DstOp.MkDynamic(ty, n))
353 :     | SrcOp.Append ty => assign (DstOp.Append ty)
354 :     | SrcOp.Prepend ty => assign (DstOp.Prepend ty)
355 :     | SrcOp.Concat ty => assign (DstOp.Concat ty)
356 : jhr 2805 | SrcOp.Length ty => assign (DstOp.Length ty)
357 :     | SrcOp.Ceiling d => assign (DstOp.Ceiling d)
358 :     | SrcOp.Floor d => assign (DstOp.Floor d)
359 :     | SrcOp.Round d => assign (DstOp.Round d)
360 :     | SrcOp.Trunc d => assign (DstOp.Trunc d)
361 :     | SrcOp.IntToReal => assign (DstOp.IntToReal)
362 :     | SrcOp.RealToInt d => assign (DstOp.RealToInt d)
363 :     | SrcOp.VoxelAddress(info, offset) => expandVoxelAddress (y, info, offset, args')
364 :     | SrcOp.LoadVoxels(rty, d) => assign (DstOp.LoadVoxels(rty, d))
365 :     | SrcOp.PosToImgSpace info => assign (DstOp.PosToImgSpace info)
366 :     | SrcOp.TensorToWorldSpace(info, ty) => assign (DstOp.TensorToWorldSpace(info, ty))
367 :     | SrcOp.EvalKernel(d, h, k) => expandEvalKernel(y, d, h, k, args')
368 :     | SrcOp.Inside info => assign (DstOp.Inside info)
369 :     | SrcOp.LoadSeq(ty, nrrd) => assign (DstOp.LoadSeq(ty, nrrd))
370 :     | SrcOp.LoadImage(ty, nrrd) => assign (DstOp.LoadImage(ty, nrrd))
371 :     | SrcOp.Input inp => assign (DstOp.Input inp)
372 : jhr 2813 | SrcOp.InputWithDefault inp => assign (DstOp.InputWithDefault inp)
373 : jhr 2805 | rator => raise Fail("bogus operator " ^ SrcOp.toString rator)
374 :     (* end case *)
375 :     end
376 : jhr 431
377 : jhr 1116 (* expand a SrcIL assignment to a DstIL CFG *)
378 : jhr 387 fun expand (env, (y, rhs)) = let
379 : jhr 2805 val y' = Env.rename (env, y)
380 :     fun assign rhs = [DstIL.ASSGN(y', rhs)]
381 :     in
382 :     case rhs
383 : jhr 2796 of SrcIL.GLOBAL x => assign (DstIL.GLOBAL(Env.renameGV(env, x)))
384 : jhr 2805 | SrcIL.STATE x => assign (DstIL.STATE(Env.renameSV(env, x)))
385 : jhr 1640 | SrcIL.VAR x => assign (DstIL.VAR(Env.rename(env, x)))
386 : jhr 2805 | SrcIL.LIT lit => assign (DstIL.LIT lit)
387 :     | SrcIL.OP(rator, args) => List.map DstIL.ASSGN (expandOp (env, y', rator, args))
388 :     | SrcIL.APPLY(f, args) => assign (DstIL.APPLY(f, Env.renameList(env, args)))
389 :     | SrcIL.CONS(ty, args) => assign (DstIL.CONS(ty, Env.renameList(env, args)))
390 :     (* end case *)
391 :     end
392 : lamonts 345
393 : jhr 1640 (* expand a SrcIL multi-assignment to a DstIL CFG *)
394 :     fun mexpand (env, (ys, rator, xs)) = let
395 :     val ys' = Env.renameList(env, ys)
396 :     val rator' = (case rator
397 :     of SrcOp.EigenVecs2x2 => DstOp.EigenVecs2x2
398 :     | SrcOp.EigenVecs3x3 => DstOp.EigenVecs3x3
399 :     | SrcOp.Print tys => DstOp.Print tys
400 :     | _ => raise Fail("bogus operator " ^ SrcOp.toString rator)
401 :     (* end case *))
402 :     val xs' = Env.renameList(env, xs)
403 :     val nd = DstIL.Node.mkMASSIGN(ys', rator', xs')
404 :     in
405 :     DstIL.CFG{entry=nd, exit=nd}
406 :     end
407 :    
408 : jhr 387 structure Trans = TranslateFn (
409 :     struct
410 : jhr 1640 open Env
411 :     val expand = DstIL.CFG.mkBlock o expand
412 :     val mexpand = mexpand
413 : jhr 387 end)
414 :    
415 : jhr 1116 fun translate prog = let
416 : jhr 2805 val prog = Trans.translate prog
417 :     in
418 :     LowILCensus.init prog;
419 :     prog
420 :     end
421 : jhr 387
422 : jhr 435 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0