Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3581 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin : sig
10 :    
11 :     end = struct
12 :    
13 : cchiw 3581
14 :     structure IR = MidIR
15 :     structure IROp = MidOps
16 :     structure V = IR.Var
17 :     structure Ty = MidTypes
18 : jhr 3550 structure E = Ein
19 : jhr 3577 structure T = CoordSpaceTransform
20 : jhr 3550
21 :     (* This file expands probed fields
22 :     * Take a look at ProbeEin tex file for examples
23 : jhr 3577 * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
24 :     * Param_ids are used to note the placement of the argument in the midIR.var list
25 : jhr 3550 * Index_ids keep track of the shape of an Image or differentiation.
26 :     * Mu bind Index_id
27 :     * Generally, we will refer to the following
28 :     * dim:dimension of field V
29 :     * s: support of kernel H
30 :     * alpha: The alpha in <V_alpha * H^(deltas)>
31 : cchiw 3581 * dx: The dx in <V_alpha * nabla_dx H>
32 :     * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
33 : jhr 3550 * Vid:param_id for V
34 :     * hid:param_id for H
35 :     * nid: integer position param_id
36 :     * fid :fractional position param_id
37 :     * img-imginfo about V
38 :     *)
39 :    
40 : jhr 3577 fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
41 : cchiw 3581 fun getRHSDst x = (case IR.Var.binding x
42 :     of IR.VB_RHS(IR.OP(rator, args)) => (rator, args)
43 :     | IR.VB_RHS(IR.VAR x') => getRHSDst x'
44 : jhr 3550 | vb => raise Fail(concat[
45 : cchiw 3581 "expected rhs operator for ", IR.Var.toString x,
46 :     " but found ", IR.vbToString vb
47 : jhr 3550 ])
48 :     (* end case *))
49 :    
50 : cchiw 3581 fun getImageDst (imgArg, args) = (case (getRHSDst imgArg)
51 :     of (IROp.LoadImage(_, _,info),_) => info
52 :     | (i,_) => raise Fail (String.concat[" Expected Image: ", IROp.toString i])
53 :     (* end case *))
54 : jhr 3550
55 : cchiw 3581 fun getKernelDst (hArg, args) = (case (getRHSDst hArg)
56 :     of (IROp.Kernel(h, _), _) => Kernel.support h
57 :     | (k,_) => raise Fail (String.concat["Expected kernel: ", IROp.toString k])
58 :     (* end case *))
59 :    
60 : jhr 3577 (*handleArgs():int*int*int*Mid IR.Var list
61 :     ->int*Mid.IRVars list* code*int* low-il-var
62 : jhr 3550 * uses the Param_ids for the image, kernel, and tensor
63 : jhr 3577 * and gets the mid-IR vars for each.
64 : jhr 3550 *Transforms the position to index space
65 :     *P is the mid-il var for the (transformation matrix)transpose
66 :     *)
67 : cchiw 3569 fun handleArgs (Vid, hid, tid, args) = let
68 : cchiw 3581 val imgArg = List.nth(args,Vid)
69 :     val info = getImageDst(imgArg, args)
70 :     val s = getKernelDst( List.nth(args, hid), args)
71 :     val (argsT, P, code) = T.worldToIndex{info = info, img = imgArg, pos = List.nth(args, tid)}
72 : jhr 3551 in
73 : cchiw 3581 (ImageInfo.dim info, args@argsT, code, s, P)
74 : jhr 3551 end
75 : jhr 3550
76 : cchiw 3581 (*fieldreconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id
77 : jhr 3550 * expands the body for the probed field
78 :     *)
79 : cchiw 3581 fun fieldreconstruction (dimO, s, sx, alpha, dx, Vid, hid, nid, fid) = let
80 : jhr 3550 (*1-d fields*)
81 : cchiw 3569 fun createKRND1 () = let
82 : cchiw 3580 val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]
83 : cchiw 3581 val deltas = List.map (fn e =>(E.C 0,e)) dx
84 :     val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sx)))
85 : jhr 3550 in
86 : cchiw 3581 E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])
87 : jhr 3550 end
88 :     (*createKRN Image field and kernels *)
89 : cchiw 3580 fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)
90 :     | createKRN (d, imgpos, rest) = let
91 :     val d' = d-1
92 :     val cx = E.C(d')
93 :     val Vsum = E.Value(sx+d')
94 :     val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])
95 : cchiw 3581 val deltas = List.map (fn e =>(cx, e)) dx
96 :     val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]), Vsum))
97 : jhr 3550 in
98 : cchiw 3580 createKRN(d',pos0::imgpos,rest0::rest)
99 : jhr 3550 end
100 : cchiw 3580
101 :     (*sumIndex creating summation Index for body*)
102 :     val esum = List.tabulate(dimO, (fn d =>(E.V d, 1-s, s)))
103 :     val exp = if (dimO=1) then createKRND1() else createKRN(dimO, [], [])
104 : jhr 3550 in
105 :     E.Sum(esum, exp)
106 :     end
107 :    
108 :     (*getsumshift:sum_indexid list* int list-> int
109 :     *get fresh/unused index_id, returns int
110 :     *)
111 : cchiw 3569 fun getsumshift ([], n) = n
112 : cchiw 3580 | getsumshift (sx, n) = let
113 : cchiw 3569 val (E.V v,_,_) = List.hd( List.rev sx)
114 : jhr 3550 in
115 : cchiw 3569 v+1
116 :     end
117 : jhr 3550
118 : cchiw 3580 (*formBody:ein_exp->ein_exp*)
119 :     fun formBody (E.Sum([],e)) = formBody e
120 :     | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
121 :     | formBody (E.Opn(E.Prod, [e])) = e
122 :     | formBody e = e
123 : jhr 3550
124 :     (* silly change in order of the product to match vis branch WorldtoSpace functions*)
125 : cchiw 3581 fun multiPs(Ps,sx,body) = let
126 :     val exp= (case Ps
127 :     of [P0,P1,P2] => [P0,P1,P2,body]
128 :     (*| [P0,P1] => [P0,body,P1] *)
129 :     | [P0,P1,P2,P3] => [P0,P1,P2,P3,body]
130 :     | _ => body::Ps
131 :     (* end case *))
132 :     in formBody(E.Sum(sx, E.Opn(E.Prod, exp))) end
133 : jhr 3550
134 : cchiw 3581 fun arrangeBody(body, Ps, newsx, exp)=(case body
135 :     of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
136 :     | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
137 :     | E.Probe _ => (true, multiPs(Ps, newsx, exp))
138 :     | _ => raise Fail "impossible"
139 :     (* end case *))
140 : cchiw 3579
141 : jhr 3577 (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
142 : jhr 3550 -> ein_exp* *code
143 :     * Transforms position to world space
144 :     * transforms result back to index_space
145 :     * rewrites body
146 :     * replace probe with expanded version
147 :     *)
148 : cchiw 3581 fun replaceProbe((y, IR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let
149 : cchiw 3569 val fid = length(params)
150 :     val nid = fid+1
151 :     val Pid = nid+1
152 : cchiw 3579 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
153 : jhr 3550 val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
154 : cchiw 3579 val freshIndex = getsumshift(sx,length(index))
155 : cchiw 3581 val (dx',sx',Ps) = T.imageToWorld(freshIndex,dim,dx,Pid)
156 : cchiw 3580 (*val params' = params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)
157 : cchiw 3581 val params' = params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]
158 :     val probe' = fieldreconstruction(dim, s, freshIndex+length(dx'), alpha, dx', Vid, hid, nid, fid)
159 :     val (_, body') = arrangeBody(body, Ps, sx', probe')
160 :     val einapp = (y,IR.EINAPP(mkEin(params',index,body'),argsA@[PArg]))
161 : jhr 3550 in
162 :     code@[einapp]
163 :     end
164 :    
165 : cchiw 3581 (*transform T*P*P..Ps*)
166 :     fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
167 : cchiw 3569 val Pid = 0
168 :     val tid = 1
169 : cchiw 3581 val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
170 : jhr 3550
171 :     (*need to rewrite dx*)
172 : cchiw 3581 val sxx = sx@newsx
173 :     val (_, sizes, E.Conv(_, alpha', _, dx')) = (case sxx
174 :     of [] => ([], index, E.Conv(9,alpha,7,dx'))
175 :     | _ => CleanIndex.clean(E.Conv(9,alpha,7,dx'), index, sxx)
176 : jhr 3550 (* end case *))
177 : cchiw 3581 fun filterAlpha [] = dx'
178 : cchiw 3569 | filterAlpha (E.C _::es) = filterAlpha es
179 : cchiw 3580 | filterAlpha (e1::es) = e1::(filterAlpha es)
180 : cchiw 3581 val exp = E.Tensor(tid, filterAlpha(alpha'))
181 :     val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
182 :     val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
183 :     val ein0 = mkEin(params, index, body')
184 : jhr 3550 in
185 : cchiw 3581 (splitvar, ein0, sizes, dx', alpha')
186 : jhr 3550 end
187 : cchiw 3581
188 :     (*floats the reconstructed field term*)
189 :     fun liftProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
190 : cchiw 3569 val fid = length(params)
191 :     val nid = fid+1
192 : cchiw 3579 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
193 : cchiw 3569 val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
194 :     val freshIndex = getsumshift(sx, length(index))
195 : jhr 3550
196 :     (*transform T*P*P..Ps*)
197 : cchiw 3579 val (splitvar, ein0, sizes, dx, alpha') = createEinApp(body, alpha, index, freshIndex, dim, dx, sx)
198 : cchiw 3581 val FArg = V.new ("T", Ty.TensorTy(sizes))
199 :     val einApp0 = IR.EINAPP(ein0, [PArg,FArg])
200 :     val rtn0 = if(splitvar)
201 :     then List.map (fn IR.ASSGN(e)=>e) (FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg]))
202 :     else [(y,IR.EINAPP(ein0, [PArg,FArg]))]
203 :    
204 :     (*reconstruct the lifted probe*)
205 : cchiw 3579 (*val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]*) (*Fixme: will get type error later*)
206 :     val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]
207 : cchiw 3580 val freshIndex' = length(sizes)
208 : cchiw 3581 val body' = fieldreconstruction(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)
209 :     val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')
210 : cchiw 3580 val code1 = (FArg, einApp1)::rtn0
211 : jhr 3550 in
212 : cchiw 3579 code@code1
213 : jhr 3550 end
214 :    
215 :    
216 :     (* expandEinOp: code-> code list
217 :     * A this point we only have simple ein ops
218 :     * Looks to see if the expression has a probe. If so, replaces it.
219 :     *)
220 : cchiw 3581 fun expandEinOp (e as (_, IR.EINAPP(Ein.EIN{body, ...},_))) =
221 :     (case body
222 : jhr 3550 of (E.Probe(E.Conv(_,_,_,[]),_))
223 : cchiw 3580 => replaceProbe(e, body, [])
224 : jhr 3550 | (E.Probe(E.Conv (_,alpha,_,dx),_))
225 : cchiw 3580 => liftProbe (e, body, []) (*scans dx for contant*)
226 : jhr 3550 | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
227 : cchiw 3580 => replaceProbe(e, p, sx) (*no dx*)
228 : jhr 3550 | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
229 : cchiw 3580 => liftProbe (e, p, sx) (*scalar field*)
230 : jhr 3550 | (E.Sum(sx,E.Probe p))
231 : cchiw 3580 => replaceProbe(e, E.Probe p, sx)
232 :     | (E.Sum(sx,E.Opn(E.Prod, [eps,E.Probe p])))
233 :     => replaceProbe(e ,E.Probe p, sx)
234 : jhr 3550 | _ => [e]
235 : cchiw 3581 (* end case *))
236 : jhr 3550
237 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0