Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3788 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 : jhr 3787 (* This file expands probed fields
10 :     * Take a look at ProbeEin tex file for examples
11 :     * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
12 :     * Param_ids are used to note the placement of the argument in the midIR.var list
13 :     * Index_ids keep track of the shape of an Image or differentiation.
14 :     * Mu bind Index_id
15 :     * Generally, we will refer to the following
16 :     * dim:dimension of field V
17 :     * s: support of kernel H
18 :     * alpha: The alpha in <V_alpha * H^(deltas)>
19 :     * dx: The dx in <V_alpha * nabla_dx H>
20 :     * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
21 :     * Vid:param_id for V
22 :     * hid:param_id for H
23 :     * nid: integer position param_id
24 :     * fid :fractional position param_id
25 :     * img-imginfo about V
26 :     *)
27 :    
28 : jhr 3550 structure ProbeEin : sig
29 :    
30 : cchiw 3735 val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit
31 : jhr 3582
32 : jhr 3550 end = struct
33 : cchiw 3581
34 :     structure IR = MidIR
35 : jhr 3582 structure Op = MidOps
36 : cchiw 3581 structure V = IR.Var
37 :     structure Ty = MidTypes
38 : jhr 3550 structure E = Ein
39 : jhr 3577 structure T = CoordSpaceTransform
40 : jhr 3550
41 : jhr 3577 fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42 : cchiw 3735 fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
43 : jhr 3586 fun getRHSDst x = (case IR.Var.getDef x
44 :     of IR.OP(rator, args) => (rator, args)
45 : jhr 3787 | rhs => raise Fail(concat[
46 :     "expected rhs operator for ", IR.Var.toString x,
47 :     " but found ", IR.RHS.toString rhs
48 :     ])
49 : jhr 3550 (* end case *))
50 :    
51 : cchiw 3742 fun checkImg imgArg = (case IR.Var.getDef imgArg
52 : jhr 3787 of IR.OP(Op.LoadImage _, _) => imgArg
53 :     | rhs => raise Fail (String.concat[
54 :     "expected image for ", IR.Var.toString imgArg,
55 :     " but found ", IR.RHS.toString rhs
56 :     ])
57 :     (* end case *))
58 :    
59 : cchiw 3742 fun getImagInfo e = (case IR.Var.getDef e
60 : jhr 3787 of IR.OP(Op.LoadImage(Ty.ImageTy info, _), []) => (e, info, NONE)
61 :     | IR.OP(Op.BorderCtlDefault info, [imgArg]) =>
62 :     (imgArg, info, raise Fail "Default boarder control")
63 :     | IR.OP(Op.BorderCtlClamp info, [imgArg]) => (imgArg, info, SOME IndexCtl.Clamp)
64 :     | IR.OP(Op.BorderCtlMirror info, [imgArg]) => (imgArg, info, SOME IndexCtl.Mirror)
65 :     | IR.OP(Op.BorderCtlWrap info, [imgArg]) => (imgArg, info, SOME IndexCtl.Wrap)
66 :     | rhs => raise Fail (String.concat[
67 :     "expected image for ", IR.Var.toString e, " but found ", IR.RHS.toString rhs
68 :     ])
69 :     (* end case *))
70 : jhr 3550
71 : jhr 3586 fun getKernelDst hArg = (case IR.Var.getDef hArg
72 : jhr 3787 of IR.OP(Op.Kernel(h, _), _) => h
73 :     | rhs => raise Fail (String.concat[
74 :     "expected kernel for ", IR.Var.toString hArg,
75 :     " but found ", IR.RHS.toString rhs
76 :     ])
77 :     (* end case *))
78 : cchiw 3581
79 : cchiw 3742 (* handleArgs- returns image arguments, info, and border
80 : jhr 3550 * uses the Param_ids for the image, kernel, and tensor
81 : jhr 3577 * and gets the mid-IR vars for each.
82 : cchiw 3742 * Transforms the position to index space
83 :     * P is the mid-il var for the (transformation matrix)transpose
84 : jhr 3550 *)
85 : jhr 3732 fun handleArgs (avail, Vid, hid, tid, args) = let
86 : jhr 3787 val vI = List.nth (args, Vid)
87 :     val (vI, info, border) = getImagInfo vI
88 :     val vH = List.nth(args, hid)
89 :     val (vN, vF, vP) = T.worldToIndex{
90 :     avail = avail, info = info, img = vI, pos = List.nth(args, tid)
91 :     }
92 :     in
93 :     (vI, vH, vN, vF, vP, info, border, ImageInfo.dim info)
94 :     end
95 : jhr 3550
96 : jhr 3787 (* lifted Kernel expressions
97 :     args are axis, ein index_ids that represent differentiation, image dimension, kernel, fractional position, support
98 : cchiw 3735 *)
99 :     fun liftKrn (avail, dir, dx, dim, h, vF, s) = let
100 : jhr 3787 val range = 2*s
101 : cchiw 3735 (* build position vector for EvalKernel *)
102 : jhr 3787 val vX = if (dim = 1)
103 :     then vF (* position is a real type*)
104 :     else AvailRHS.addAssign (
105 :     avail, concat["vxindexed_dir", Int.toString dir, "_"],
106 :     Ty.realTy, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [vF]))
107 :     val vPos = AvailRHS.addAssign (
108 :     avail, concat["kernpos_dir", Int.toString dir, "_"],
109 :     Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))
110 :     val nKernEvals = List.length dx + 1
111 :     fun mkEval k = AvailRHS.addAssign (
112 :     avail, concat["mkeval_dir", Int.toString dir, "_del", Int.toString k],
113 : cchiw 3735 Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))
114 : jhr 3787 val vKs = List.tabulate(nKernEvals, fn k => mkEval k)
115 :     in
116 :     case vKs
117 :     of [v] => v (* scalar result *)
118 :     | _ => let
119 :     val consTy = Ty.TensorTy[nKernEvals, range]
120 :     in
121 :     AvailRHS.addAssign (
122 :     avail, concat["kernelCons_dir_", Int.toString dir, "_"],
123 :     consTy, IR.CONS(vKs, consTy))
124 :     end
125 :     (* end case *)
126 :     end
127 : jhr 3732
128 : jhr 3787 (* FIXME: what does this do??? *)
129 : cchiw 3784 fun mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border) = let
130 : cchiw 3785 (* creates lb int *)
131 : jhr 3787 val vLb = AvailRHS.addAssign (avail, "lit", Ty.intTy, IR.LIT(Literal.Int (1-(IntInf.fromInt s))))
132 : cchiw 3785 (*created n_0 +lb, n_1+lb*)
133 : jhr 3787 fun f i = let
134 :     val vA = AvailRHS.addAssign (
135 :     avail, "idx", Ty.intTy, IR.LIT(Literal.Int (IntInf.fromInt i)))
136 :     val vB = AvailRHS.addAssign (
137 :     avail, "subscript",
138 :     Ty.intTy, IR.OP(Op.Subscript(Ty.SeqTy(Ty.intTy, SOME dim)), [vN, vA]))
139 :     in
140 : cchiw 3777 AvailRHS.addAssign (avail, "add", Ty.intTy, IR.OP(Op.IAdd, [vB, vLb]))
141 : jhr 3787 end
142 : cchiw 3777 (* image positions *)
143 : jhr 3787 val s'= 2*s
144 :     val supportshape = List.tabulate(dim, fn _ => s')
145 :     val ldty = Ty.TensorTy (shape@supportshape)
146 :     val vNs = List.tabulate( dim, fn n => f n)
147 :     val vSq = AvailRHS.addAssign (
148 :     avail, "seq",
149 :     (* FIXME: where does this "9" come from??? *)
150 :     Ty.TensorTy[9], IR.SEQ(vNs, MidTypes.SeqTy(MidTypes.IntTy, SOME dim)))
151 :     val op1 = (case border
152 : jhr 3788 of NONE => Op.LoadVoxels (info, 2*s)
153 :     | SOME b => Op.LoadVoxelsWithCtl (info, 2*s, b)
154 : jhr 3787 (* end case *))
155 :     in
156 : cchiw 3785 AvailRHS.addAssign (avail, "ldvox", ldty, IR.OP(op1, [vI, vSq]))
157 : jhr 3787 end
158 :    
159 :     (* fieldReconstruction expands the body for the probed field *)
160 :     fun fieldReconstruction (avail, sx, alpha, shape, dx, Vid, Vidnew, kid, hid, tid, args) = let
161 :     val (vI, vH, vN, vF, vP, info, border, dim) = handleArgs (avail, Vid, hid, tid, args)
162 :     val h = getKernelDst vH
163 :     val s = Kernel.support h
164 : cchiw 3784 (* creating summation Index *)
165 : jhr 3787 val vs = List.tabulate (dim, fn i => (i +sx))
166 :     val esum = List.map (fn i => (E.V i, 1-s, s)) vs
167 :     (* represent image in ein expression with tensor *)
168 :     val imgexp= E.Img(Vidnew, alpha, List.map (fn i=> E.Value i) vs, s, E.None)
169 : cchiw 3784 (* create load voxel operator for image *)
170 : jhr 3787 val vLd = mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border)
171 : cchiw 3735 (* create kernel body *)
172 : jhr 3787 fun createKrn (0, krnexp, vAs) = (krnexp, vAs)
173 :     | createKrn (dir, krnexp, vAs) = let
174 :     val dir' = dir-1
175 :     (* ein expression *)
176 :     val deltas = List.map (fn e =>(E.C(dir'), e)) dx
177 :     val kexp0 = E.Krn(kid+dir, deltas, dir)
178 :     (* evalkernel operators *)
179 :     val vA = liftKrn (avail, dir, dx, dim, h, vF, s)
180 :     in
181 :     createKrn (dir', kexp0::krnexp, vA::vAs)
182 :     end
183 :     (* final ein expression body to represent field reconstruction *)
184 :     val (krnexp, vKs) = createKrn (dim, [], [])
185 :     val exp = E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))
186 :     in
187 :     (vLd::vKs, vP, exp)
188 :     end
189 : cchiw 3735
190 : jhr 3787 (* getsumshift:sum_indexid list* int list-> int
191 :     * get fresh/unused index_id, returns int
192 : jhr 3550 *)
193 : cchiw 3569 fun getsumshift ([], n) = n
194 : cchiw 3580 | getsumshift (sx, n) = let
195 : jhr 3787 val (E.V v, _, _) = List.hd(List.rev sx)
196 :     in
197 :     v+1
198 :     end
199 : jhr 3550
200 : jhr 3787 (* formBody:ein_exp->ein_exp *)
201 : cchiw 3580 fun formBody (E.Sum([],e)) = formBody e
202 :     | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
203 :     | formBody (E.Opn(E.Prod, [e])) = e
204 :     | formBody e = e
205 : jhr 3550
206 : jhr 3787 (* silly change in order of the product to match vis branch WorldtoSpace functions *)
207 : jhr 3583 fun multiPs (Ps, sx, body) = let
208 : jhr 3787 val exp = (case Ps
209 :     of [P0, P1, P2] => [P0, P1, P2, body]
210 :     | [P0, P1, P2, P3] => [P0, P1, P2, P3, body]
211 :     | _ => body::Ps
212 :     (* end case *))
213 :     in
214 :     formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
215 :     end
216 :    
217 :     (* FIXME: what does this do??? *)
218 : jhr 3583 fun arrangeBody (body, Ps, newsx, exp) = (case body
219 : jhr 3787 of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
220 : jhr 3583 | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
221 : jhr 3787 (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
222 : cchiw 3581 | E.Probe _ => (true, multiPs(Ps, newsx, exp))
223 :     | _ => raise Fail "impossible"
224 : jhr 3787 (* end case *))
225 :    
226 : jhr 3577 (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
227 : jhr 3550 -> ein_exp* *code
228 :     * Transforms position to world space
229 :     * transforms result back to index_space
230 : jhr 3787 * rewrites body
231 : jhr 3550 * replace probe with expanded version
232 :     *)
233 : jhr 3787 fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
234 :     (* tensor ids for position, transform matrix P, and kernel terms*)
235 :     val pid = length params
236 :     val Vidnew = pid+1
237 :     val kid = Vidnew
238 :     val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
239 :     val E.IMG(dim, shape) = List.nth(params, Vid)
240 :     val freshIndex = getsumshift (sx, length index)
241 :     val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)
242 :     val sxn = freshIndex + length dx' (*next available index id *)
243 :     val (args', vP, probe') = fieldReconstruction (
244 :     avail, sxn, alpha, shape, dx', Vid, Vidnew, kid, hid, tid, args)
245 :     (* add new params transformation matrix (Pid), image param, and kernel ids *)
246 :     val pP = E.TEN(true, [dim, dim])
247 :     val pV = List.nth(params, Vid)
248 :     val pK = List.tabulate(dim,fn _=> E.KRN)
249 :     val params' = params @ (pP::pV::pK)
250 :     val (_, body') = arrangeBody (body, Ps, sx', probe')
251 :     val einapp = (y, IR.EINAPP(mkEin(params', index, body'), args @ (vP::args')))
252 :     in
253 :     AvailRHS.addAssignToList (avail, einapp)
254 :     end
255 : cchiw 3741
256 : jhr 3787 (* transform T*P*P..Ps *)
257 : cchiw 3581 fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
258 : jhr 3787 val Pid = 0
259 :     val tid = 1
260 :     val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
261 :     (* need to rewrite dx *)
262 :     val sxx = sx@newsx
263 :     val (_, sizes, E.Conv(_, alpha', _, dx)) = (case sxx
264 : jhr 3583 (* QUESTION: what is the significance of "9" and "7" in this code? *)
265 : jhr 3787 of [] => ([], index, E.Conv(9, alpha, 7, dx'))
266 :     | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)
267 :     (* end case *))
268 :     fun filterAlpha [] = dx'
269 :     | filterAlpha (E.C _::es) = filterAlpha es
270 :     | filterAlpha (e1::es) = e1::(filterAlpha es)
271 :     val exp = E.Tensor(tid, filterAlpha alpha')
272 :     val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
273 :     val params = [E.TEN(true, [dim,dim]), E.TEN(true, sizes)]
274 :     val ein0 = mkEin(params, index, body')
275 :     in
276 :     (splitvar, ein0, sizes, dx, alpha')
277 :     end
278 : cchiw 3741
279 : jhr 3583 (* floats the reconstructed field term *)
280 : jhr 3732 fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
281 : jhr 3787 val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
282 :     val freshIndex = getsumshift(sx, length(index))
283 :     val E.IMG(dim, shape) = List.nth(params, Vid)
284 :     (* transform T*P*P..Ps *)
285 :     val (splitvar, ein0, sizes, dx, alpha') =
286 :     createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
287 :     val vT = V.new ("TPP", Ty.TensorTy(sizes))
288 :     (* reconstruct the lifted probe *)
289 :     (* making params args: image, position, and kernel ids *)
290 :     val kid = 0 (* params used *)
291 :     val params' = List.nth(params,Vid)::(List.tabulate(dim,fn _=> E.KRN))
292 :     (* create body for ein expression *)
293 :     val sxn = length sizes (*next available index id *)
294 :     val (args', vP, probe') =
295 :     fieldReconstruction (avail, sxn, alpha', shape, dx, Vid, Vid, kid, hid, tid, args)
296 :     val einApp1 = IR.EINAPP(mkEin(params', sizes, probe'), args')
297 :     (* transform T*P*P..Ps *)
298 :     val rtn0 = if splitvar
299 :     then FloatEin.transform(y, EinSums.transform ein0, [vP, vT])
300 :     else [(y, IR.EINAPP(ein0, [vP, vT]))]
301 :     in
302 :     List.app (fn e => AvailRHS.addAssignToList(avail, e)) (((vT, einApp1)::(rtn0)))
303 :     end
304 : cchiw 3741
305 : jhr 3583 (* expandEinOp: code-> code list
306 :     * A this point we only have simple ein ops
307 :     * Looks to see if the expression has a probe. If so, replaces it.
308 :     *)
309 : jhr 3732 fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
310 : jhr 3787 of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
311 :     replaceProbe (avail, e, body, [])
312 : jhr 3732 | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
313 : jhr 3787 liftProbe (avail, e, body, []) (*scans dx for contant*)
314 : jhr 3732 | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
315 : jhr 3787 replaceProbe (avail, e, p, sx) (*no dx*)
316 : jhr 3732 | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
317 : jhr 3787 liftProbe (avail, e, p, sx) (*scalar field*)
318 : jhr 3732 | (E.Sum(sx, E.Probe p)) =>
319 : jhr 3787 replaceProbe (avail, e, E.Probe p, sx)
320 : jhr 3732 | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
321 : jhr 3787 replaceProbe (avail, e, E.Probe p, sx)
322 : jhr 3734 | _ => AvailRHS.addAssignToList (avail, e)
323 : jhr 3787 (* end case *))
324 : jhr 3550
325 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0