Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5314 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 : jhr 3787 (* This file expands probed fields
10 :     * Take a look at ProbeEin tex file for examples
11 :     * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
12 :     * Param_ids are used to note the placement of the argument in the midIR.var list
13 :     * Index_ids keep track of the shape of an Image or differentiation.
14 :     * Mu bind Index_id
15 :     * Generally, we will refer to the following
16 :     * dim:dimension of field V
17 :     * s: support of kernel H
18 :     * alpha: The alpha in <V_alpha * H^(deltas)>
19 :     * dx: The dx in <V_alpha * nabla_dx H>
20 :     * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
21 :     * Vid:param_id for V
22 :     * hid:param_id for H
23 :     * nid: integer position param_id
24 :     * fid :fractional position param_id
25 :     * img-imginfo about V
26 :     *)
27 :    
28 : jhr 3550 structure ProbeEin : sig
29 :    
30 : cchiw 3735 val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit
31 : jhr 3582
32 : jhr 3550 end = struct
33 : cchiw 3581
34 :     structure IR = MidIR
35 : jhr 3582 structure Op = MidOps
36 : cchiw 3581 structure V = IR.Var
37 :     structure Ty = MidTypes
38 : jhr 3550 structure E = Ein
39 : jhr 3577 structure T = CoordSpaceTransform
40 : jhr 3550
41 : jhr 3577 fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42 : jhr 4059
43 : cchiw 3735 fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
44 : jhr 4059
45 : jhr 3586 fun getRHSDst x = (case IR.Var.getDef x
46 :     of IR.OP(rator, args) => (rator, args)
47 : jhr 3787 | rhs => raise Fail(concat[
48 :     "expected rhs operator for ", IR.Var.toString x,
49 :     " but found ", IR.RHS.toString rhs
50 :     ])
51 : jhr 3550 (* end case *))
52 :    
53 : jhr 4059 fun axis dir = (case dir of 0 => "X" | 1 => "Y" | 2 => "Z" | _ => "dir" ^ Int.toString dir)
54 :    
55 : cchiw 3742 fun checkImg imgArg = (case IR.Var.getDef imgArg
56 : jhr 3787 of IR.OP(Op.LoadImage _, _) => imgArg
57 :     | rhs => raise Fail (String.concat[
58 :     "expected image for ", IR.Var.toString imgArg,
59 :     " but found ", IR.RHS.toString rhs
60 :     ])
61 :     (* end case *))
62 :    
63 : jhr 4014 (* get the image referenced on a RHS and its border control (if any) *)
64 :     fun getImagInfo x = (case V.getDef x
65 : jhr 4317 of IR.GLOBAL gv => let
66 :     val Ty.ImageTy info = IR.GlobalVar.ty gv
67 :     in
68 :     (x, info, NONE)
69 :     end
70 :     | IR.OP(Op.BorderCtlDefault info, [img, v]) =>
71 :     (img, info, raise Fail "Default boarder control")
72 :     | IR.OP(Op.BorderCtlClamp info, [img]) => (img, info, SOME IndexCtl.Clamp)
73 :     | IR.OP(Op.BorderCtlMirror info, [img]) => (img, info, SOME IndexCtl.Mirror)
74 :     | IR.OP(Op.BorderCtlWrap info, [img]) => (img, info, SOME IndexCtl.Wrap)
75 :     | IR.OP(Op.LoadImage(Ty.ImageTy info, _), _) => (x, info, NONE)
76 :     | rhs => raise Fail (String.concat[
77 : jhr 4014 "expected image for ", V.toString x, " but found ", IR.RHS.toString rhs
78 : jhr 3787 ])
79 : jhr 4317 (* end case *))
80 : jhr 3550
81 : jhr 3586 fun getKernelDst hArg = (case IR.Var.getDef hArg
82 : jhr 3787 of IR.OP(Op.Kernel(h, _), _) => h
83 :     | rhs => raise Fail (String.concat[
84 :     "expected kernel for ", IR.Var.toString hArg,
85 :     " but found ", IR.RHS.toString rhs
86 :     ])
87 :     (* end case *))
88 : cchiw 3581
89 : jhr 3791 (* handleArgs- returns image arguments, info, and border
90 :     * uses the Param_ids for the image, kernel, and tensor
91 :     * and gets the mid-IR vars for each.
92 :     * Transforms the position to index space
93 :     * P is the mid-il var for the (transformation matrix)transpose
94 :     *)
95 : jhr 3732 fun handleArgs (avail, Vid, hid, tid, args) = let
96 : jhr 3787 val vI = List.nth (args, Vid)
97 :     val (vI, info, border) = getImagInfo vI
98 :     val vH = List.nth(args, hid)
99 :     val (vN, vF, vP) = T.worldToIndex{
100 :     avail = avail, info = info, img = vI, pos = List.nth(args, tid)
101 :     }
102 :     in
103 :     (vI, vH, vN, vF, vP, info, border, ImageInfo.dim info)
104 :     end
105 : jhr 3550
106 : jhr 3787 (* lifted Kernel expressions
107 : jhr 3791 * args are axis, ein index_ids that represent differentiation, image dimension,
108 :     * kernel, fractional position, support
109 :     *)
110 : cchiw 3735 fun liftKrn (avail, dir, dx, dim, h, vF, s) = let
111 : jhr 4317 val axis = axis dir
112 : jhr 3787 val range = 2*s
113 : cchiw 3735 (* build position vector for EvalKernel *)
114 : jhr 3787 val vX = if (dim = 1)
115 :     then vF (* position is a real type*)
116 :     else AvailRHS.addAssign (
117 : jhr 4057 avail, concat["v", axis, "_"],
118 : jhr 3797 Ty.realTy, IR.OP(Op.TensorIndex(Ty.TensorTy[dim], [dir]), [vF]))
119 : jhr 3787 val vPos = AvailRHS.addAssign (
120 : jhr 4057 avail, concat["kern", axis, "_"],
121 : jhr 3787 Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))
122 :     val nKernEvals = List.length dx + 1
123 :     fun mkEval k = AvailRHS.addAssign (
124 : jhr 4059 avail, concat["keval", axis, "_d", Int.toString k, "_"],
125 : cchiw 3735 Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))
126 : jhr 3787 val vKs = List.tabulate(nKernEvals, fn k => mkEval k)
127 :     in
128 :     case vKs
129 :     of [v] => v (* scalar result *)
130 :     | _ => let
131 :     val consTy = Ty.TensorTy[nKernEvals, range]
132 :     in
133 :     AvailRHS.addAssign (
134 : jhr 4057 avail, concat["kcons", axis, "_"],
135 : jhr 3787 consTy, IR.CONS(vKs, consTy))
136 :     end
137 :     (* end case *)
138 :     end
139 : cchiw 4546
140 : jhr 4075 (* `mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border)`
141 :     * creates load voxel operator to represent image addressing. The parameters are
142 : jhr 4317 * avail -- available assignments
143 : cchiw 4534 * vI -- image argument
144 : jhr 4317 * vN -- the integer indices into the image (IntTy for 1D, SeqTy for 2D+)
145 :     * info -- image info
146 : cchiw 4534 * alpha -- ein variable indices that represent the shape of a tensor field
147 :     * shape -- binding of alpha
148 : jhr 4317 * dim -- the dimension of the image
149 :     * s -- half the support of the reconstruction kernel
150 :     * border -- optional border control
151 : jhr 4075 *)
152 : cchiw 3784 fun mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border) = let
153 : cchiw 3785 (* creates lb int *)
154 : jhr 4317 val vLb = AvailRHS.addAssign (avail, "lit", Ty.intTy,
155 :     IR.LIT(Literal.Int(IntInf.fromInt(1 - s))))
156 :     (* the index argument to LoadVoxels; this is a single integer for 1D images *)
157 :     val idxs = if (dim = 1)
158 :     then AvailRHS.addAssign (avail, "idx", Ty.intTy, IR.OP(Op.IAdd, [vN, vLb]))
159 :     else let
160 :     val seqTy = Ty.SeqTy(Ty.intTy, SOME dim)
161 :     (* create sequence n_0 +lb .. n_1+lb *)
162 :     fun f i = let
163 :     val vA = AvailRHS.addAssign (
164 :     avail, "idx", Ty.intTy, IR.LIT(Literal.Int (IntInf.fromInt i)))
165 :     val vB = AvailRHS.addAssign (
166 :     avail, concat["n", axis i, "_"],
167 :     Ty.intTy, IR.OP(Op.Subscript seqTy, [vN, vA]))
168 :     in
169 :     AvailRHS.addAssign (avail, "idx", Ty.intTy, IR.OP(Op.IAdd, [vB, vLb]))
170 :     end
171 :     val vNs = List.tabulate (dim, f)
172 :     in
173 :     AvailRHS.addAssign (avail, "seq", seqTy, IR.SEQ(vNs, seqTy))
174 :     end
175 : cchiw 3777 (* image positions *)
176 : jhr 3787 val s'= 2*s
177 : jhr 4295 (* DEBUG
178 :     fun f es = String.concatWithMap "," Int.toString es
179 :     fun g es = String.concatWithMap ","
180 : jhr 4317 (fn (E.V e) => "v"^Int.toString e | E.C(c) => "c"^Int.toString c) es
181 : jhr 4295 *)
182 : jhr 4059 val supportshape = List.tabulate (dim, fn _ => s')
183 :     val ldty = Ty.TensorTy(shape @ supportshape)
184 : jhr 3787 val op1 = (case border
185 : jhr 4059 of NONE => Op.LoadVoxels (info, s')
186 :     | SOME b => Op.LoadVoxelsWithCtl (info, s', b)
187 : jhr 3787 (* end case *))
188 :     in
189 : jhr 4115 AvailRHS.addAssign (avail, "ldvox", ldty, IR.OP(op1, [vI, idxs]))
190 : jhr 3787 end
191 :    
192 : jhr 4059 (* fieldReconstruction expands the body for the probed field *)
193 : jhr 3787 fun fieldReconstruction (avail, sx, alpha, shape, dx, Vid, Vidnew, kid, hid, tid, args) = let
194 :     val (vI, vH, vN, vF, vP, info, border, dim) = handleArgs (avail, Vid, hid, tid, args)
195 :     val h = getKernelDst vH
196 :     val s = Kernel.support h
197 : cchiw 3784 (* creating summation Index *)
198 : jhr 3787 val vs = List.tabulate (dim, fn i => (i +sx))
199 : cchiw 3978 val esum = List.map (fn i => (i, 1-s, s)) vs
200 : jhr 3787 (* represent image in ein expression with tensor *)
201 : cchiw 5314 val imgexp = E.Img(Vidnew, alpha, List.map (fn i=> E.Value i) vs, s)
202 : cchiw 3784 (* create load voxel operator for image *)
203 : jhr 3787 val vLd = mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border)
204 : jhr 4295 (* DEBUG
205 :     fun f es = String.concatWithMap "," Int.toString es
206 :     fun g es = String.concatWithMap ","
207 : jhr 4317 (fn (E.V e) => "v"^Int.toString e | E.C(c) => "c"^Int.toString c) es
208 : jhr 4295 val Ty.TensorTy cat = V.ty vLd
209 :     val _ = print(String.concat[
210 : jhr 4317 "\n","after load voxel ", f cat, " = ", V.name(vLd),
211 :     " alpha = ", g alpha, " dim:", f[dim]," support: ", f[s]
212 :     ])
213 : jhr 4310 *)
214 : cchiw 3735 (* create kernel body *)
215 : jhr 3787 fun createKrn (0, krnexp, vAs) = (krnexp, vAs)
216 :     | createKrn (dir, krnexp, vAs) = let
217 :     val dir' = dir-1
218 :     (* ein expression *)
219 : jhr 4051 val deltas = List.map (fn e => (E.C dir', e)) dx
220 : jhr 3787 val kexp0 = E.Krn(kid+dir, deltas, dir)
221 :     (* evalkernel operators *)
222 : jhr 4051 val vA = liftKrn (avail, dir', dx, dim, h, vF, s)
223 : jhr 3787 in
224 :     createKrn (dir', kexp0::krnexp, vA::vAs)
225 :     end
226 :     (* final ein expression body to represent field reconstruction *)
227 :     val (krnexp, vKs) = createKrn (dim, [], [])
228 :     val exp = E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))
229 :     in
230 :     (vLd::vKs, vP, exp)
231 :     end
232 : cchiw 3735
233 : jhr 3787 (* getsumshift:sum_indexid list* int list-> int
234 :     * get fresh/unused index_id, returns int
235 : jhr 3550 *)
236 : cchiw 3569 fun getsumshift ([], n) = n
237 : cchiw 3580 | getsumshift (sx, n) = let
238 : cchiw 3978 val (v, _, _) = List.hd(List.rev sx)
239 : jhr 3787 in
240 :     v+1
241 :     end
242 : jhr 3550
243 : jhr 3787 (* formBody:ein_exp->ein_exp *)
244 : cchiw 3580 fun formBody (E.Sum([],e)) = formBody e
245 :     | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
246 :     | formBody (E.Opn(E.Prod, [e])) = e
247 :     | formBody e = e
248 : jhr 3550
249 : jhr 3787 (* silly change in order of the product to match vis branch WorldtoSpace functions *)
250 : jhr 3583 fun multiPs (Ps, sx, body) = let
251 : jhr 3787 val exp = (case Ps
252 :     of [P0, P1, P2] => [P0, P1, P2, body]
253 :     | [P0, P1, P2, P3] => [P0, P1, P2, P3, body]
254 :     | _ => body::Ps
255 :     (* end case *))
256 :     in
257 :     formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
258 :     end
259 :    
260 : jhr 3980 (* arrangeBody - function changes the ordering of multiplication
261 :     * to match vis12 branch and pass regression tests
262 :     *)
263 : jhr 3583 fun arrangeBody (body, Ps, newsx, exp) = (case body
264 : jhr 3787 of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
265 : jhr 3583 | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
266 : jhr 3787 (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
267 : cchiw 3581 | E.Probe _ => (true, multiPs(Ps, newsx, exp))
268 :     | _ => raise Fail "impossible"
269 : jhr 3787 (* end case *))
270 :    
271 : jhr 3577 (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
272 : jhr 3550 -> ein_exp* *code
273 :     * Transforms position to world space
274 :     * transforms result back to index_space
275 : jhr 3787 * rewrites body
276 : jhr 3550 * replace probe with expanded version
277 :     *)
278 : jhr 3787 fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
279 :     (* tensor ids for position, transform matrix P, and kernel terms*)
280 : cchiw 4291
281 : jhr 3787 val pid = length params
282 :     val Vidnew = pid+1
283 :     val kid = Vidnew
284 :     val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
285 :     val E.IMG(dim, shape) = List.nth(params, Vid)
286 :     val freshIndex = getsumshift (sx, length index)
287 : cchiw 4169 val (dx', tshape, sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)
288 : jhr 3787 val sxn = freshIndex + length dx' (*next available index id *)
289 :     val (args', vP, probe') = fieldReconstruction (
290 :     avail, sxn, alpha, shape, dx', Vid, Vidnew, kid, hid, tid, args)
291 :     (* add new params transformation matrix (Pid), image param, and kernel ids *)
292 :     val pP = E.TEN(true, [dim, dim])
293 :     val pV = List.nth(params, Vid)
294 :     val pK = List.tabulate(dim,fn _=> E.KRN)
295 :     val params' = params @ (pP::pV::pK)
296 :     val (_, body') = arrangeBody (body, Ps, sx', probe')
297 :     val einapp = (y, IR.EINAPP(mkEin(params', index, body'), args @ (vP::args')))
298 :     in
299 :     AvailRHS.addAssignToList (avail, einapp)
300 :     end
301 : cchiw 3741
302 : cchiw 4534 (* multply probe by transformation matrix and split product operation
303 :     input is a differentiated field of the form
304 :     eF = ∇_dx F_α
305 :     variable differentiation indices are transformed
306 :     (dx′,dx_tshape, Ps) = Transform(dx)
307 :     et is a tensor that represents probe
308 :     et = T_tshape
309 :     tshape(e_F) = {α+dx_tshape}
310 :     multply probe (et) by transformation matrix (Ps)
311 :     eout = et *Ps
312 :     ein0 = λ(T,P)⟨eout⟩ij(Tx,P)
313 :     The next part of the compler reconstructs probed field ec to lower IR
314 :     ec = Clean(∇dx′Fα)−→∇dx′′Fα′′
315 :     Tx = λ(F,x)⟨ec(x)⟩ij(F,x)
316 :     *)
317 : cchiw 3581 fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
318 : jhr 3787 val Pid = 0
319 :     val tid = 1
320 : cchiw 4534 val (dx', dx_tshape, newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
321 : jhr 3787 (* need to rewrite dx *)
322 :     val sxx = sx@newsx
323 : cchiw 4534 fun filterAlpha [] = []
324 : jhr 3787 | filterAlpha (E.C _::es) = filterAlpha es
325 :     | filterAlpha (e1::es) = e1::(filterAlpha es)
326 : cchiw 4534 val tshape = (filterAlpha (alpha))@ dx_tshape
327 :     val et = E.Tensor(tid, tshape)
328 :     val (splitvar, eout) = arrangeBody(body, Ps, newsx, et)
329 :     val (_, sizes, ec) = (case sxx
330 :     of [] => ([], index, E.Conv(0, alpha, 1, dx'))
331 :     | _ => CleanIndex.clean(E.Conv(0, alpha, 1, dx'), index, sxx)
332 :     (* end case *))
333 : jhr 3787 val params = [E.TEN(true, [dim,dim]), E.TEN(true, sizes)]
334 : cchiw 4534 val ein0 = mkEin(params, index, eout)
335 :     (* clean index *)
336 :     val E.Conv(_, alpha', _, dx) = ec
337 : jhr 3787 in
338 :     (splitvar, ein0, sizes, dx, alpha')
339 :     end
340 : cchiw 3741
341 : jhr 3583 (* floats the reconstructed field term *)
342 : jhr 3732 fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
343 : jhr 3787 val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
344 :     val freshIndex = getsumshift(sx, length(index))
345 :     val E.IMG(dim, shape) = List.nth(params, Vid)
346 :     (* transform T*P*P..Ps *)
347 :     val (splitvar, ein0, sizes, dx, alpha') =
348 :     createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
349 : jhr 4075 val vT = V.new ("TPP", Ty.tensorTy sizes)
350 : jhr 3787 (* reconstruct the lifted probe *)
351 :     (* making params args: image, position, and kernel ids *)
352 :     val kid = 0 (* params used *)
353 :     val params' = List.nth(params,Vid)::(List.tabulate(dim,fn _=> E.KRN))
354 :     (* create body for ein expression *)
355 :     val sxn = length sizes (*next available index id *)
356 : cchiw 4291 val Vidnew = 0
357 : jhr 3787 val (args', vP, probe') =
358 : cchiw 4291 fieldReconstruction (avail, sxn, alpha', shape, dx, Vid, Vidnew, kid, hid, tid, args)
359 : jhr 3787 val einApp1 = IR.EINAPP(mkEin(params', sizes, probe'), args')
360 : cchiw 4169 val einr= mkEin(params', sizes, probe')
361 : jhr 3787 (* transform T*P*P..Ps *)
362 :     val rtn0 = if splitvar
363 :     then FloatEin.transform(y, EinSums.transform ein0, [vP, vT])
364 :     else [(y, IR.EINAPP(ein0, [vP, vT]))]
365 :     in
366 :     List.app (fn e => AvailRHS.addAssignToList(avail, e)) (((vT, einApp1)::(rtn0)))
367 :     end
368 : cchiw 3741
369 : jhr 3583 (* expandEinOp: code-> code list
370 :     * A this point we only have simple ein ops
371 :     * Looks to see if the expression has a probe. If so, replaces it.
372 :     *)
373 : jhr 3732 fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
374 : jhr 3787 of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
375 :     replaceProbe (avail, e, body, [])
376 : jhr 3732 | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
377 : jhr 3787 liftProbe (avail, e, body, []) (*scans dx for contant*)
378 : jhr 3732 | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
379 : jhr 3787 replaceProbe (avail, e, p, sx) (*no dx*)
380 : jhr 3732 | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
381 : jhr 3787 liftProbe (avail, e, p, sx) (*scalar field*)
382 : jhr 3732 | (E.Sum(sx, E.Probe p)) =>
383 : jhr 3787 replaceProbe (avail, e, E.Probe p, sx)
384 : jhr 3732 | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
385 : jhr 3787 replaceProbe (avail, e, E.Probe p, sx)
386 : jhr 3734 | _ => AvailRHS.addAssignToList (avail, e)
387 : jhr 3787 (* end case *))
388 : jhr 3550
389 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0