Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3733 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin : sig
10 :    
11 : jhr 3732 val expand : AvailRHS.t -> MidIR.assign -> MidIR.assign list
12 : jhr 3582
13 : jhr 3550 end = struct
14 : cchiw 3581
15 :     structure IR = MidIR
16 : jhr 3582 structure Op = MidOps
17 : cchiw 3581 structure V = IR.Var
18 :     structure Ty = MidTypes
19 : jhr 3550 structure E = Ein
20 : jhr 3577 structure T = CoordSpaceTransform
21 : jhr 3550
22 : jhr 3582 (* This file expands probed fields
23 : jhr 3550 * Take a look at ProbeEin tex file for examples
24 : jhr 3577 * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
25 :     * Param_ids are used to note the placement of the argument in the midIR.var list
26 : jhr 3550 * Index_ids keep track of the shape of an Image or differentiation.
27 :     * Mu bind Index_id
28 :     * Generally, we will refer to the following
29 :     * dim:dimension of field V
30 :     * s: support of kernel H
31 :     * alpha: The alpha in <V_alpha * H^(deltas)>
32 : cchiw 3581 * dx: The dx in <V_alpha * nabla_dx H>
33 :     * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
34 : jhr 3550 * Vid:param_id for V
35 :     * hid:param_id for H
36 :     * nid: integer position param_id
37 :     * fid :fractional position param_id
38 :     * img-imginfo about V
39 :     *)
40 :    
41 : jhr 3577 fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42 : jhr 3583
43 : jhr 3586 fun getRHSDst x = (case IR.Var.getDef x
44 :     of IR.OP(rator, args) => (rator, args)
45 :     | rhs => raise Fail(concat[
46 : cchiw 3581 "expected rhs operator for ", IR.Var.toString x,
47 : jhr 3586 " but found ", IR.RHS.toString rhs
48 : jhr 3550 ])
49 :     (* end case *))
50 :    
51 : jhr 3586 fun getImageDst imgArg = (case IR.Var.getDef imgArg
52 :     of IR.OP(Op.LoadImage(Ty.ImageTy info, _), _) => info
53 : jhr 3582 (* FIXME: also border control! *)
54 : jhr 3586 | rhs => raise Fail (String.concat[
55 :     "expected image for ", IR.Var.toString imgArg,
56 :     " but found ", IR.RHS.toString rhs
57 :     ])
58 : jhr 3582 (* end case *))
59 : jhr 3550
60 : jhr 3586 fun getKernelDst hArg = (case IR.Var.getDef hArg
61 : jhr 3732 of IR.OP(Op.Kernel(h, _), _) => h
62 : jhr 3586 | rhs => raise Fail (String.concat[
63 :     "expected kernel for ", IR.Var.toString hArg,
64 :     " but found ", IR.RHS.toString rhs
65 :     ])
66 : jhr 3582 (* end case *))
67 : cchiw 3581
68 : jhr 3577 (*handleArgs():int*int*int*Mid IR.Var list
69 :     ->int*Mid.IRVars list* code*int* low-il-var
70 : jhr 3550 * uses the Param_ids for the image, kernel, and tensor
71 : jhr 3577 * and gets the mid-IR vars for each.
72 : jhr 3550 *Transforms the position to index space
73 :     *P is the mid-il var for the (transformation matrix)transpose
74 :     *)
75 : jhr 3732 fun handleArgs (avail, Vid, hid, tid, args) = let
76 : jhr 3583 val imgArg = List.nth (args, Vid)
77 : jhr 3586 val info = getImageDst imgArg
78 : jhr 3732 val s = Kernel.support (getKernelDst (List.nth(args, hid)))
79 :     val (n, f, P) = T.worldToIndex{
80 :     avail = avail, info = info, img = imgArg, pos = List.nth(args, tid)
81 :     }
82 : jhr 3551 in
83 : jhr 3732 (ImageInfo.dim info, n, f, s, P)
84 : jhr 3551 end
85 : jhr 3550
86 : jhr 3732 (* build position vector for EvalKernel; args are support, axis, image dimension, position
87 :     * vector
88 :     *)
89 :     fun buildPos (s, dir, dim, f) = let
90 :     val x = DstV.new ("x", DstTy.realTy)
91 :     val u = DstV.new ("kernel_pos", DstTy.TensorTy[2*s])
92 :     val stms = [
93 :     IR.ASSGN(x, IR.OP(Op.Index(DstTy.TensorTy[dim], dir), [f])),
94 :     IR.ASSGN(u, IR.OP(Op.BuildPos(s), [x]))
95 :     ]
96 :     in
97 :     (u, stms)
98 :     end
99 :    
100 :     (* apply differentiation *)
101 :     fun getKrnDel (s, h, k, u) = let
102 :     val v = DstV.new ("kernel_del"^Int.toString k, DstTy.TensorTy[2*s])
103 :     in
104 :     (v, IR.ASSGN(v, IR.OP(Op.EvalKernel(2*s, h, k), [u])))
105 :     end
106 :    
107 :     (*lifted Kernel expressions*)
108 :     (* TODO: match against argsA list? *)
109 :     fun liftKrn (dx, dir, dim, argsA, hid, fid, s) = let
110 :     val (posV, stms) = buildPos (s, dir, dim, List.nth(args, fid))
111 :     val h = List.nth(argsA, hid)
112 :     fun iter (0, vs, stms) = (vs, stms)
113 :     | iter (n, vs, stms) = let
114 :     val n = n-1
115 :     val (v, stm) = getKrnDel(s, h, n, posV)
116 :     in
117 :     iter (n, v::vs, stm::stms)
118 :     end
119 :     val nKernEvals = List.length dx + 1
120 :     val (vs, stms') = iter (nKernEvals, [], [])
121 :     in
122 :     case vs
123 :     of [v] => (v, stms'@stms) (* scalar result *)
124 :     | _ => let
125 :     val consTy = DstTy.TensorTy[nKernEvals, 2*s]
126 :     val resV = DstV.new ("kernel_cons", consTy)
127 :     val stm = IR.ASSGN(resV, IR.CONS(vs, consTy))
128 :     in
129 :     (resV, stm :: stms' @ stms)
130 :     end
131 :     (* end case *)
132 :     end
133 :    
134 : jhr 3583 (*fieldReconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id
135 : jhr 3550 * expands the body for the probed field
136 :     *)
137 : jhr 3732 fun fieldReconstruction (dimO, s, sx, alpha, dx, argsA, Vid, hid, n, f) = let
138 : jhr 3550 (*1-d fields*)
139 : jhr 3583 fun createKRND1 () = let
140 : jhr 3732 val imgpos = [E.Opn(E.Add, [E.Tensor(fid, []), E.Value sx])]
141 : jhr 3583 val deltas = List.map (fn e =>(E.C 0,e)) dx
142 : jhr 3732 val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value sx))
143 : jhr 3583 in
144 :     E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])
145 :     end
146 : jhr 3550 (*createKRN Image field and kernels *)
147 : jhr 3583 fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)
148 :     | createKRN (d, imgpos, rest) = let
149 :     val d' = d-1
150 :     val cx = E.C(d')
151 :     val Vsum = E.Value(sx+d')
152 :     val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])
153 :     val deltas = List.map (fn e =>(cx, e)) dx
154 : jhr 3732 val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]), Vsum))
155 :     val (
156 : jhr 3583 in
157 :     createKRN (d', pos0::imgpos, rest0::rest)
158 :     end
159 : cchiw 3580 (*sumIndex creating summation Index for body*)
160 : jhr 3583 val esum = List.tabulate (dimO, fn d => (E.V d, 1-s, s))
161 :     val exp = if (dimO = 1) then createKRND1() else createKRN(dimO, [], [])
162 :     in
163 :     E.Sum(esum, exp)
164 :     end
165 : jhr 3550
166 : jhr 3583 (*getsumshift:sum_indexid list* int list-> int
167 : jhr 3550 *get fresh/unused index_id, returns int
168 :     *)
169 : cchiw 3569 fun getsumshift ([], n) = n
170 : cchiw 3580 | getsumshift (sx, n) = let
171 : jhr 3732 val (E.V v, _, _) = List.hd(List.rev sx)
172 : jhr 3583 in
173 :     v+1
174 :     end
175 : jhr 3550
176 : jhr 3583 (*formBody:ein_exp->ein_exp*)
177 : cchiw 3580 fun formBody (E.Sum([],e)) = formBody e
178 :     | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
179 :     | formBody (E.Opn(E.Prod, [e])) = e
180 :     | formBody e = e
181 : jhr 3550
182 : jhr 3583 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
183 :     fun multiPs (Ps, sx, body) = let
184 :     val exp = (case Ps
185 :     of [P0, P1, P2] => [P0, P1, P2, body]
186 :     | [P0, P1, P2, P3] => [P0, P1, P2, P3, body]
187 :     | _ => body::Ps
188 :     (* end case *))
189 :     in
190 :     formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
191 :     end
192 : jhr 3550
193 : jhr 3583 fun arrangeBody (body, Ps, newsx, exp) = (case body
194 :     of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
195 :     | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
196 :     (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
197 : cchiw 3581 | E.Probe _ => (true, multiPs(Ps, newsx, exp))
198 :     | _ => raise Fail "impossible"
199 : jhr 3583 (* end case *))
200 : cchiw 3579
201 : jhr 3577 (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
202 : jhr 3550 -> ein_exp* *code
203 :     * Transforms position to world space
204 :     * transforms result back to index_space
205 :     * rewrites body
206 :     * replace probe with expanded version
207 :     *)
208 : jhr 3732 fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
209 : jhr 3583 val fid = length params
210 :     val nid = fid+1
211 :     val Pid = nid+1
212 :     val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
213 : jhr 3732 val (dim, n, f, s, PArg) = handleArgs (avail, Vid, hid, tid, args)
214 : jhr 3583 val freshIndex = getsumshift (sx, length index)
215 :     val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, Pid)
216 :     val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim]), E.TEN(true, [dim, dim])]
217 : jhr 3732 val probe' = fieldReconstruction (avail, dim, s, freshIndex+length dx', alpha, dx', Vid, hid, n, f)
218 : jhr 3583 val (_, body') = arrangeBody (body, Ps, sx', probe')
219 : jhr 3732 val einapp = (y, IR.EINAPP(mkEin(params', index, body'), argsA@[n, f, PArg]))
220 :     in
221 : jhr 3733 AvailRHS.addAssignToList (avail, einapp)
222 : jhr 3583 end
223 : jhr 3550
224 : cchiw 3581 (*transform T*P*P..Ps*)
225 :     fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
226 : jhr 3583 val Pid = 0
227 :     val tid = 1
228 :     val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
229 : jhr 3550 (*need to rewrite dx*)
230 : jhr 3583 val sxx = sx@newsx
231 :     val (_, sizes, E.Conv(_, alpha', _, dx')) = (case sxx
232 :     (* QUESTION: what is the significance of "9" and "7" in this code? *)
233 :     of [] => ([], index, E.Conv(9, alpha, 7, dx'))
234 :     | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)
235 :     (* end case *))
236 :     fun filterAlpha [] = dx'
237 :     | filterAlpha (E.C _::es) = filterAlpha es
238 :     | filterAlpha (e1::es) = e1::(filterAlpha es)
239 :     val exp = E.Tensor(tid, filterAlpha alpha')
240 :     val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
241 :     val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
242 :     val ein0 = mkEin(params, index, body')
243 :     in
244 :     (splitvar, ein0, sizes, dx', alpha')
245 :     end
246 : cchiw 3581
247 : jhr 3583 (* floats the reconstructed field term *)
248 : jhr 3732 fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
249 :     val fid = length params
250 : jhr 3583 val nid = fid+1
251 :     val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
252 : jhr 3732 val (dim, args', s, PArg) = handleArgs(avail, Vid, hid, tid, args)
253 : jhr 3583 val freshIndex = getsumshift(sx, length(index))
254 :     (* transform T*P*P..Ps *)
255 :     val (splitvar, ein0, sizes, dx, alpha') =
256 :     createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
257 :     val FArg = V.new ("T", Ty.TensorTy(sizes))
258 :     val einApp0 = IR.EINAPP(ein0, [PArg, FArg])
259 :     val rtn0 = if splitvar
260 :     then FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg])
261 :     else [(y, IR.EINAPP(ein0, [PArg, FArg]))]
262 :     (* reconstruct the lifted probe *)
263 :     val params' = params@[E.TEN(true, [dim]), E.TEN(true, [dim])]
264 :     val freshIndex' = length sizes
265 :     val body' = fieldReconstruction (dim, s, freshIndex', alpha', dx, Vid, hid, nid, fid)
266 :     val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')
267 :     in
268 :     code @ (FArg, einApp1) :: rtn0
269 :     end
270 : jhr 3550
271 : jhr 3583 (* expandEinOp: code-> code list
272 :     * A this point we only have simple ein ops
273 :     * Looks to see if the expression has a probe. If so, replaces it.
274 :     *)
275 : jhr 3732 fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
276 :     of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
277 :     replaceProbe (avail, e, body, [])
278 :     | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
279 :     liftProbe (avail, e, body, []) (*scans dx for contant*)
280 :     | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
281 :     replaceProbe (avail, e, p, sx) (*no dx*)
282 :     | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
283 :     liftProbe (avail, e, p, sx) (*scalar field*)
284 :     | (E.Sum(sx, E.Probe p)) =>
285 :     replaceProbe (avail, e, E.Probe p, sx)
286 :     | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
287 :     replaceProbe (avail, e, E.Probe p, sx)
288 :     | _ => addAssignToList (avail, e)
289 : jhr 3582 (* end case *))
290 : jhr 3550
291 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0