Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3777 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin : sig
10 :    
11 : cchiw 3735 val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit
12 : jhr 3582
13 : jhr 3550 end = struct
14 : cchiw 3581
15 :     structure IR = MidIR
16 : jhr 3582 structure Op = MidOps
17 : cchiw 3581 structure V = IR.Var
18 :     structure Ty = MidTypes
19 : jhr 3550 structure E = Ein
20 : jhr 3577 structure T = CoordSpaceTransform
21 : jhr 3550
22 : jhr 3582 (* This file expands probed fields
23 : jhr 3550 * Take a look at ProbeEin tex file for examples
24 : jhr 3577 * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
25 :     * Param_ids are used to note the placement of the argument in the midIR.var list
26 : jhr 3550 * Index_ids keep track of the shape of an Image or differentiation.
27 :     * Mu bind Index_id
28 :     * Generally, we will refer to the following
29 :     * dim:dimension of field V
30 :     * s: support of kernel H
31 :     * alpha: The alpha in <V_alpha * H^(deltas)>
32 : cchiw 3581 * dx: The dx in <V_alpha * nabla_dx H>
33 :     * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
34 : jhr 3550 * Vid:param_id for V
35 :     * hid:param_id for H
36 :     * nid: integer position param_id
37 :     * fid :fractional position param_id
38 :     * img-imginfo about V
39 :     *)
40 :    
41 : jhr 3577 fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42 : cchiw 3735 fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
43 : jhr 3586 fun getRHSDst x = (case IR.Var.getDef x
44 :     of IR.OP(rator, args) => (rator, args)
45 :     | rhs => raise Fail(concat[
46 : cchiw 3581 "expected rhs operator for ", IR.Var.toString x,
47 : jhr 3586 " but found ", IR.RHS.toString rhs
48 : jhr 3550 ])
49 :     (* end case *))
50 :    
51 : cchiw 3742 fun checkImg imgArg = (case IR.Var.getDef imgArg
52 :     of IR.OP(Op.LoadImage _, _) => imgArg
53 : jhr 3586 | rhs => raise Fail (String.concat[
54 :     "expected image for ", IR.Var.toString imgArg,
55 :     " but found ", IR.RHS.toString rhs
56 :     ])
57 : cchiw 3742 )
58 :    
59 :     fun getImagInfo e = (case IR.Var.getDef e
60 :     of IR.OP(Op.LoadImage(Ty.ImageTy info, _), []) => (e, info, E.None)
61 :     | IR.OP(Op.BorderCtlDefault info, [imgArg]) => (imgArg, info, E.Default)
62 :     | IR.OP(Op.BorderCtlClamp info, [imgArg]) => (imgArg, info, E.Clamp)
63 :     | IR.OP(Op.BorderCtlMirror info, [imgArg]) => (imgArg, info, E.Mirror)
64 :     | IR.OP(Op.BorderCtlWrap info, [imgArg]) => (imgArg, info, E.Wrap)
65 :     | rhs => raise Fail (String.concat[
66 :     "expected image for ", IR.Var.toString e,
67 :     " but found ", IR.RHS.toString rhs
68 :     ])
69 :     (* end case *))
70 : jhr 3550
71 : jhr 3586 fun getKernelDst hArg = (case IR.Var.getDef hArg
72 : jhr 3732 of IR.OP(Op.Kernel(h, _), _) => h
73 : jhr 3586 | rhs => raise Fail (String.concat[
74 :     "expected kernel for ", IR.Var.toString hArg,
75 :     " but found ", IR.RHS.toString rhs
76 :     ])
77 : jhr 3582 (* end case *))
78 : cchiw 3581
79 : cchiw 3742 (* handleArgs- returns image arguments, info, and border
80 : jhr 3550 * uses the Param_ids for the image, kernel, and tensor
81 : jhr 3577 * and gets the mid-IR vars for each.
82 : cchiw 3742 * Transforms the position to index space
83 :     * P is the mid-il var for the (transformation matrix)transpose
84 : jhr 3550 *)
85 : jhr 3732 fun handleArgs (avail, Vid, hid, tid, args) = let
86 : cchiw 3735 val vI = List.nth (args, Vid)
87 : cchiw 3742 val (vI, info, border) = getImagInfo vI
88 : cchiw 3735 val vH = List.nth(args, hid)
89 :     val (vN, vF, vP) = T.worldToIndex{
90 :     avail = avail, info = info, img = vI, pos = List.nth(args, tid)
91 : jhr 3732 }
92 : cchiw 3777 val dim = ImageInfo.dim info
93 : jhr 3551 in
94 : cchiw 3777 (vI, vH, vN, vF, vP, info, border, dim)
95 : jhr 3551 end
96 : jhr 3550
97 : cchiw 3735 (*lifted Kernel expressions
98 :     args are axis, ein index_ids that represent differentiation, image dimension, kernel, fractional position, support
99 :     *)
100 :     fun liftKrn (avail, dir, dx, dim, h, vF, s) = let
101 :     val range = 2*s
102 :    
103 :     (* build position vector for EvalKernel *)
104 :     val vX =
105 :     if (dim=1) then vF (* position is a real type*)
106 :     else AvailRHS.addAssign (avail, "vxindexed_dir"^Int.toString(dir)^"_", Ty.realTy, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [vF]))
107 :    
108 :     val vPos = AvailRHS.addAssign (avail, "kernelpos_dir"^Int.toString(dir)^"_", Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))
109 : jhr 3732
110 : cchiw 3735 val nKernEvals = List.length dx + 1
111 :     fun mkEval k = AvailRHS.addAssign (avail, "mkeval_dir"^Int.toString(dir)^"_del"^Int.toString k,
112 :     Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))
113 :     val vKs = List.tabulate(nKernEvals, (fn k => mkEval k))
114 : jhr 3732 in
115 : cchiw 3735 case vKs
116 :     of [v] => v (* scalar result *)
117 :     | _ => let
118 :     val consTy = Ty.TensorTy[nKernEvals, range]
119 :     in
120 :     AvailRHS.addAssign (avail, "kernelCons_dir_"^Int.toString(dir), consTy, IR.CONS(vKs, consTy))
121 :     end
122 :     (* end case *)
123 :     end
124 : jhr 3732
125 : cchiw 3777
126 :     fun mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s) = let
127 :     val vLb = AvailRHS.addAssign (avail, "lit", Ty.intTy, IR.LIT(Literal.Int (1-(IntInf.fromInt s))))
128 :     val s'= 2*s
129 :    
130 :     fun f i =
131 :     let
132 :     val vA = AvailRHS.addAssign (avail, "lit", Ty.intTy, IR.LIT(Literal.Int (IntInf.fromInt i)))
133 :     val vB = AvailRHS.addAssign (avail, "subscript", Ty.intTy, IR.OP(Op.Subscript(Ty.TensorTy[dim]), [vN, vA]))
134 :     in
135 :     AvailRHS.addAssign (avail, "add", Ty.intTy, IR.OP(Op.IAdd, [vB, vLb]))
136 :     end
137 :    
138 :     (* image positions *)
139 :     val supportshape = List.tabulate(dim, fn _ => s')
140 :     val ldty = Ty.TensorTy shape
141 :     val vNs = List.tabulate( dim, fn n => f n)
142 :     in
143 :     AvailRHS.addAssign (avail, "ldvox", ldty, IR.OP(Op.LoadVoxels (info, s), vI::vNs))
144 :     end
145 :    
146 :    
147 :    
148 :     (*fieldReconstruction expands the body for the probed field*)
149 :     fun fieldReconstruction (avail, sx, alpha, shape, dx, Vid, kid, hid, tid, args) = let
150 :     val (vI, vH, vN, vF, vP, info, border, dim) = handleArgs (avail, Vid, hid, tid, args)
151 : cchiw 3744 val h = getKernelDst vH
152 :     val s = Kernel.support h
153 :    
154 : cchiw 3735
155 : cchiw 3777 val imgexp= E.Tensor(Vid, List.rev (alpha))
156 :     val vLd = mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s)
157 :    
158 : cchiw 3735 (* create kernel body *)
159 :     fun createKrn (0, krnexp, vAs) = (krnexp, vAs)
160 :     | createKrn (dir, krnexp, vAs) = let
161 :     val dir' = dir-1
162 :     (* ein expression *)
163 :     val deltas = List.map (fn e =>(E.C(dir'), e)) dx
164 : cchiw 3742 val kexp0 = E.Krn(kid+dir, deltas, dir)
165 : cchiw 3735 (* evalkernel operators *)
166 :     val vA = liftKrn (avail, dir, dx, dim, h, vF, s)
167 :     in
168 :     createKrn (dir', kexp0::krnexp, vA::vAs)
169 :     end
170 :    
171 :     (* creating summation Index *)
172 : cchiw 3741 val esum = List.tabulate (dim, fn i => (E.V (i +sx), 1-s, s))
173 : cchiw 3735 (* final ein expression body to represent field reconstruction *)
174 :     val (krnexp, vAs) = createKrn (dim, [], [])
175 :     val exp = E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))
176 : jhr 3583 in
177 : cchiw 3777 (vAs, vLd, vN, vP, exp)
178 : jhr 3583 end
179 : jhr 3550
180 : jhr 3583 (*getsumshift:sum_indexid list* int list-> int
181 : jhr 3550 *get fresh/unused index_id, returns int
182 :     *)
183 : cchiw 3569 fun getsumshift ([], n) = n
184 : cchiw 3580 | getsumshift (sx, n) = let
185 : jhr 3732 val (E.V v, _, _) = List.hd(List.rev sx)
186 : jhr 3583 in
187 :     v+1
188 :     end
189 : jhr 3550
190 : jhr 3583 (*formBody:ein_exp->ein_exp*)
191 : cchiw 3580 fun formBody (E.Sum([],e)) = formBody e
192 :     | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
193 :     | formBody (E.Opn(E.Prod, [e])) = e
194 :     | formBody e = e
195 : jhr 3550
196 : jhr 3583 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
197 :     fun multiPs (Ps, sx, body) = let
198 :     val exp = (case Ps
199 :     of [P0, P1, P2] => [P0, P1, P2, body]
200 :     | [P0, P1, P2, P3] => [P0, P1, P2, P3, body]
201 :     | _ => body::Ps
202 :     (* end case *))
203 :     in
204 :     formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
205 :     end
206 : jhr 3550
207 : jhr 3583 fun arrangeBody (body, Ps, newsx, exp) = (case body
208 :     of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
209 :     | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
210 :     (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
211 : cchiw 3581 | E.Probe _ => (true, multiPs(Ps, newsx, exp))
212 :     | _ => raise Fail "impossible"
213 : jhr 3583 (* end case *))
214 : cchiw 3735
215 :    
216 : cchiw 3579
217 : jhr 3577 (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
218 : jhr 3550 -> ein_exp* *code
219 :     * Transforms position to world space
220 :     * transforms result back to index_space
221 :     * rewrites body
222 :     * replace probe with expanded version
223 :     *)
224 : jhr 3732 fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
225 : cchiw 3741
226 : cchiw 3735 (* tensor ids for position, transform matrix P, and kernel terms*)
227 :     val nid = length params
228 :     val pid = nid+1
229 : cchiw 3777 val kid = pid
230 : cchiw 3735
231 : jhr 3583 val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
232 : cchiw 3777 val E.IMG(dim, shape) = List.nth(params, Vid)
233 : jhr 3583 val freshIndex = getsumshift (sx, length index)
234 : cchiw 3735 val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)
235 :     val sxn = freshIndex+length dx' (*next available index id *)
236 : cchiw 3777
237 :     val (vKs, vLd, vN, vP, probe') = fieldReconstruction (avail, sxn, alpha, shape, dx', Vid, kid, hid, tid, args)
238 :    
239 : cchiw 3735 (* add new params position (nid), transformation matrix (Pid), and kernel ids *)
240 :     val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim, dim])] @(List.tabulate(dim,fn _=> E.TEN(true,[])))
241 : jhr 3583 val (_, body') = arrangeBody (body, Ps, sx', probe')
242 : cchiw 3777
243 :     val args = List.take(args, Vid)@(vLd::List.drop(args, Vid+1))
244 : cchiw 3735 val einapp = (y, IR.EINAPP(mkEin(params', index, body'), args @ [vN, vP]@vKs))
245 : cchiw 3777
246 : jhr 3732 in
247 : jhr 3733 AvailRHS.addAssignToList (avail, einapp)
248 : jhr 3583 end
249 : jhr 3550
250 : cchiw 3741
251 : cchiw 3581 (*transform T*P*P..Ps*)
252 :     fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
253 : cchiw 3735
254 : jhr 3583 val Pid = 0
255 :     val tid = 1
256 :     val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
257 : jhr 3550 (*need to rewrite dx*)
258 : jhr 3583 val sxx = sx@newsx
259 : cchiw 3741
260 :     val (_, sizes, E.Conv(_, alpha', _, dx)) = (case sxx
261 : jhr 3583 (* QUESTION: what is the significance of "9" and "7" in this code? *)
262 :     of [] => ([], index, E.Conv(9, alpha, 7, dx'))
263 :     | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)
264 :     (* end case *))
265 : cchiw 3741
266 : jhr 3583 fun filterAlpha [] = dx'
267 :     | filterAlpha (E.C _::es) = filterAlpha es
268 :     | filterAlpha (e1::es) = e1::(filterAlpha es)
269 : cchiw 3741
270 :     val exp = E.Tensor(tid, filterAlpha alpha')
271 :    
272 : jhr 3583 val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
273 : cchiw 3735 val params = [E.TEN(true, [dim,dim]), E.TEN(true, sizes)]
274 : jhr 3583 val ein0 = mkEin(params, index, body')
275 : cchiw 3741
276 : jhr 3583 in
277 : cchiw 3741 (splitvar, ein0, sizes, dx, alpha')
278 : jhr 3583 end
279 : cchiw 3741
280 :    
281 : jhr 3583 (* floats the reconstructed field term *)
282 : jhr 3732 fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
283 : jhr 3583 val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
284 : cchiw 3735 val freshIndex = getsumshift(sx, length(index))
285 : cchiw 3777 val E.IMG(dim, shape) = List.nth(params, Vid)
286 :    
287 : cchiw 3735 (* transform T*P*P..Ps *)
288 : jhr 3583 val (splitvar, ein0, sizes, dx, alpha') =
289 :     createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
290 : cchiw 3777 val vT = V.new ("TPP", Ty.TensorTy(sizes))
291 : cchiw 3741
292 : cchiw 3735 (* reconstruct the lifted probe *)
293 : cchiw 3741 (* making params args: image, position, and kernel ids *)
294 :     val nid = 1 (* transformed image position *)
295 : cchiw 3777 val kid = nid
296 : cchiw 3741 val params' = List.nth(params,Vid)::E.TEN(true, [dim])::(List.tabulate(dim,fn _=> E.TEN(true,[])))
297 :     (* create body for ein expression *)
298 :     val sxn = length sizes (*next available index id *)
299 : cchiw 3777 val (vKs, vLd, vN, vP, probe') = fieldReconstruction (avail, sxn, alpha', shape, dx, Vid, kid, hid, tid, args)
300 :     val args' = vLd :: vN ::vKs
301 : cchiw 3741 val einApp1 = IR.EINAPP(mkEin(params'@[E.TEN(true,[10])], sizes, probe'), args')
302 :    
303 : cchiw 3777 (* transform T*P*P..Ps *)
304 :     val rtn0 = if splitvar
305 :     then FloatEin.transform(y, EinSums.transform ein0, [vP, vT])
306 :     else [(y, IR.EINAPP(ein0, [vP, vT]))]
307 : cchiw 3741
308 : jhr 3583 in
309 : cchiw 3741 List.app (fn e => AvailRHS.addAssignToList(avail, e)) (((vT, einApp1)::(rtn0)))
310 : jhr 3583 end
311 : jhr 3550
312 : cchiw 3741
313 : jhr 3583 (* expandEinOp: code-> code list
314 :     * A this point we only have simple ein ops
315 :     * Looks to see if the expression has a probe. If so, replaces it.
316 :     *)
317 : jhr 3732 fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
318 :     of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
319 :     replaceProbe (avail, e, body, [])
320 :     | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
321 :     liftProbe (avail, e, body, []) (*scans dx for contant*)
322 : cchiw 3741
323 : jhr 3732 | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
324 :     replaceProbe (avail, e, p, sx) (*no dx*)
325 : cchiw 3741
326 : jhr 3732 | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
327 :     liftProbe (avail, e, p, sx) (*scalar field*)
328 : cchiw 3741
329 : jhr 3732 | (E.Sum(sx, E.Probe p)) =>
330 :     replaceProbe (avail, e, E.Probe p, sx)
331 :     | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
332 :     replaceProbe (avail, e, E.Probe p, sx)
333 : cchiw 3735
334 : jhr 3734 | _ => AvailRHS.addAssignToList (avail, e)
335 : jhr 3582 (* end case *))
336 : jhr 3550
337 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0