Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3569 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin : sig
10 :    
11 :     end = struct
12 :    
13 :     structure E = Ein
14 :     structure DstIL = MidIL
15 :     structure DstOp = MidOps
16 :     structure T = TransformEin
17 :     structure MidToS = MidToString
18 :     structure DstV = DstIL.Var
19 :     structure DstTy = MidILTypes
20 :    
21 :     (* This file expands probed fields
22 :     * Take a look at ProbeEin tex file for examples
23 :     * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
24 :     * Param_ids are used to note the placement of the argument in the midIL.var list
25 :     * Index_ids keep track of the shape of an Image or differentiation.
26 :     * Mu bind Index_id
27 :     * Generally, we will refer to the following
28 :     * dim:dimension of field V
29 :     * s: support of kernel H
30 :     * alpha: The alpha in <V_alpha * H^(deltas)>
31 :     * deltas: The deltas in <V_alpha * H^(deltas)>
32 :     * Vid:param_id for V
33 :     * hid:param_id for H
34 :     * nid: integer position param_id
35 :     * fid :fractional position param_id
36 :     * img-imginfo about V
37 :     *)
38 :    
39 :     (* FIXME: what are these for? should they be settable from the command-line? *)
40 :     val valnumflag = true
41 :     val tsplitvar = true
42 :     val fieldliftflag = true
43 :     val detflag = true
44 :    
45 : jhr 3551 fun transformToIndexSpace e = T.transformToIndexSpace e
46 :     fun transformToImgSpace e = T.transformToImgSpace e
47 : cchiw 3569
48 : jhr 3551 fun mkEin e = Ein.mkEin e
49 :     fun mkEinApp (rator, args) = DstIL.EINAPP(rator, args)
50 : jhr 3550 fun setConst e = E.setConst e
51 : jhr 3551 fun setNeg e = E.setNeg e
52 :     fun setExp e = E.setExp e
53 : jhr 3550 fun setDiv e= E.setDiv e
54 :     fun setSub e= E.setSub e
55 :     fun setProd e= E.setProd e
56 :     fun setAdd e= E.setAdd e
57 :    
58 :     fun getRHSDst x = (case DstIL.Var.binding x
59 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
60 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
61 :     | vb => raise Fail(concat[
62 :     "expected rhs operator for ", DstIL.Var.toString x,
63 :     " but found ", DstIL.vbToString vb
64 :     ])
65 :     (* end case *))
66 :    
67 :     (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
68 :     uses the Param_ids for the image, kernel,
69 :     and position tensor to get the Mid-IL arguments
70 :     returns the support of ther kernel, and image
71 :     *)
72 :     fun getArgsDst (hArg, imgArg, args) = (case (getRHSDst hArg, getRHSDst imgArg)
73 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ )) =>
74 :     (Kernel.support h, img, ImageInfo.dim img)
75 :     | ((k,_), (i,_)) => raise Fail (String.concat[
76 :     "Expected kernel: ", DstOp.toString k, ", Expected Image: ", DstOp.toString i
77 :     ])
78 :     (* end case *))
79 :    
80 :     (*handleArgs():int*int*int*Mid IL.Var list
81 :     ->int*Mid.ILVars list* code*int* low-il-var
82 :     * uses the Param_ids for the image, kernel, and tensor
83 :     * and gets the mid-IL vars for each.
84 :     *Transforms the position to index space
85 :     *P is the mid-il var for the (transformation matrix)transpose
86 :     *)
87 : cchiw 3569 fun handleArgs (Vid, hid, tid, args) = let
88 : jhr 3551 val imgArg = List.nth(args, Vid)
89 :     val hArg = List.nth(args, hid)
90 :     val newposArg = List.nth(args, tid)
91 :     val (s,img, dim) = getArgsDst(hArg, imgArg, args)
92 :     val (argsT, P, code) = transformToImgSpace(dim, img, newposArg, imgArg)
93 :     in
94 :     (dim, args@argsT, code, s, P)
95 :     end
96 : jhr 3550
97 :     (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
98 :     * expands the body for the probed field
99 :     *)
100 : cchiw 3569 fun createBody (dim, s, sx, alpha, deltas, Vid, hid, nid, fid)=let
101 : jhr 3550 (*1-d fields*)
102 : cchiw 3569 fun createKRND1 () = let
103 :     val sum = sx
104 :     val dels = List.map (fn e=>(E.C 0,e)) deltas
105 :     val pos = [setAdd[E.Tensor(fid,[]), E.Value(sum)]]
106 :     val rest = E.Krn(hid, dels, setSub(E.Tensor(nid,[]), E.Value(sum)))
107 : jhr 3550 in
108 : cchiw 3569 setProd [E.Img(Vid,alpha,pos),rest]
109 : jhr 3550 end
110 :     (*createKRN Image field and kernels *)
111 : cchiw 3569 fun createKRN (0, imgpos, rest) = setProd ([E.Img(Vid,alpha,imgpos)] @rest)
112 :     | createKRN (dim, imgpos, rest) = let
113 :     val dim' = dim-1
114 :     val sum = sx+dim'
115 :     val dels = List.map (fn e=>(E.C dim',e)) deltas
116 :     val pos = [setAdd[E.Tensor(fid,[E.C dim']), E.Value(sum)]]
117 :     val rest' = E.Krn(hid, dels, setSub(E.Tensor(nid,[E.C dim']), E.Value(sum)))
118 : jhr 3550 in
119 :     createKRN(dim',pos@imgpos,[rest']@rest)
120 :     end
121 : cchiw 3569 val exp = (case dim
122 : jhr 3550 of 1 => createKRND1()
123 : cchiw 3569 | _ => createKRN(dim, [], [])
124 : jhr 3550 (* end case *))
125 :     (*sumIndex creating summaiton Index for body*)
126 : cchiw 3569 val slb = 1-s
127 :     val esum = List.tabulate(dim, (fn dim=>(E.V (dim+sx), slb, s)))
128 : jhr 3550 in
129 :     E.Sum(esum, exp)
130 :     end
131 :    
132 :     (*getsumshift:sum_indexid list* int list-> int
133 :     *get fresh/unused index_id, returns int
134 :     *)
135 : cchiw 3569 fun getsumshift ([], n) = n
136 :     fun getsumshift (sx, n) = let
137 :     val (E.V v,_,_) = List.hd( List.rev sx)
138 : jhr 3550 in
139 : cchiw 3569 v+1
140 :     end
141 : jhr 3550
142 :     (*formBody:ein_exp->ein_exp
143 :     *)
144 : jhr 3551 fun formBody (E.Sum([],e))=formBody e
145 :     | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)
146 :     | formBody (E.Opn(E.Prod, [e]))=e
147 : jhr 3550 | formBody e=e
148 :    
149 :     (* silly change in order of the product to match vis branch WorldtoSpace functions*)
150 : jhr 3551 fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
151 : jhr 3550 (*
152 :     | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
153 :     *)
154 : jhr 3551 | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
155 :     | multiPs (Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
156 : jhr 3550
157 :    
158 : jhr 3551 fun multiMergePs ([P0, P1], [sx0, sx1], body) = E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
159 :     | multiMergePs e = multiPs e
160 : jhr 3550
161 :    
162 :     (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
163 :     -> ein_exp* *code
164 :     * Transforms position to world space
165 :     * transforms result back to index_space
166 :     * rewrites body
167 :     * replace probe with expanded version
168 :     *)
169 : cchiw 3569 fun replaceProbe (testN, (y, DstIL.EINAPP(e,args)), p, sx) = let
170 :     val params = Ein.params e
171 :     val fid = length(params)
172 :     val nid = fid+1
173 :     val Pid = nid+1
174 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p
175 :     val nshift = length(dx)
176 : jhr 3550 val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
177 : cchiw 3569 val freshIndex = getsumshift(sx,length(Ein.index e))
178 :     val (dx,newsx1,Ps) = transformToIndexSpace(freshIndex,dim,dx,Pid)
179 : jhr 3550 val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
180 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
181 :     val body' = multiPs(Ps,newsx1,body')
182 : cchiw 3569 val body'= (case (Ein.body e)
183 : jhr 3550 of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
184 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,body'])
185 :     | _ => body'
186 :     (* end case *))
187 : cchiw 3569 val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))
188 : jhr 3550 in
189 :     code@[einapp]
190 :     end
191 :    
192 : jhr 3551 fun createEinApp (originalb, alpha, index, freshIndex, dim, dx, sx) = let
193 : cchiw 3569 val Pid = 0
194 :     val tid = 1
195 : jhr 3550
196 :     (*Assumes body is already clean*)
197 : cchiw 3569 val (newdx, newsx, Ps)=transformToIndexSpace(freshIndex, dim, dx, Pid)
198 : jhr 3550
199 :     (*need to rewrite dx*)
200 : cchiw 3569 val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx
201 :     of [] => ([], index, E.Conv(9,alpha,7,newdx))
202 :     | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx), index, sx@newsx)
203 : jhr 3550 (* end case *))
204 : cchiw 3569 val params = [E.TEN(1,[dim,dim]), E.TEN(1,sizes)]
205 : jhr 3550 fun filterAlpha []=[]
206 : cchiw 3569 | filterAlpha (E.C _::es) = filterAlpha es
207 :     | filterAlpha (e1::es) = [e1]@(filterAlpha es)
208 :     val tshape = filterAlpha(alpha')@newdx
209 :     val t = E.Tensor(tid, tshape)
210 :     val (splitvar, body) = (case originalb
211 :     of E.Sum(sx, E.Probe _) => (true, multiPs(Ps, sx@newsx,t))
212 :     | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, setProd[eps0, multiPs(Ps, newsx, t)]))
213 :     | _ => (true, multiPs(Ps, newsx, t))
214 :     (* end case *))
215 : jhr 3550
216 : cchiw 3569 val ein0 = mkEin(params, index, body)
217 : jhr 3550 in
218 : cchiw 3569 (splitvar, ein0, sizes, dx, alpha')
219 : jhr 3550 end
220 :    
221 : jhr 3551 fun liftProbe (printStrings, (y, DstIL.EINAPP(e, args)), p, sx) = let
222 : cchiw 3569
223 :     val params = Ein.params e
224 :     val index = Ein.index e
225 :     val fid = length(params)
226 :     val nid = fid+1
227 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p
228 :     val nshift = length(dx)
229 :     val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
230 :     val freshIndex = getsumshift(sx, length(index))
231 : jhr 3550
232 :     (*transform T*P*P..Ps*)
233 : cchiw 3569 val (splitvar, ein0, sizes, dx, alpha') = createEinApp(Ein.body e, alpha, index, freshIndex, dim, dx, sx)
234 : jhr 3550 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
235 : cchiw 3569 val einApp0 = mkEinApp(ein0, [PArg,FArg])
236 :     val rtn0 = (case splitvar
237 :     of false => [(y, mkEinApp(ein0, [PArg,FArg]))]
238 : jhr 3550 | _ => let
239 : cchiw 3569 val bind3 = (y, DstIL.EINAPP(SummationEin.main ein0, [PArg, FArg]))
240 :     in
241 :     Split.splitEinApp bind3
242 :     end
243 : jhr 3550 (* end case *))
244 :    
245 :     (*lifted probe*)
246 : cchiw 3569 val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]
247 : jhr 3550 val freshIndex'= length(sizes)
248 : cchiw 3569 val body' = createBody(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)
249 :     val ein1=mkEin(params', sizes, body')
250 :     val einApp1=mkEinApp(ein1, args')
251 :     val rtn1=(FArg, einApp1)
252 : jhr 3550 in
253 : cchiw 3569 code@[rtn1]@rtn0
254 : jhr 3550 end
255 :    
256 :    
257 :     (* expandEinOp: code-> code list
258 :     * A this point we only have simple ein ops
259 :     * Looks to see if the expression has a probe. If so, replaces it.
260 :     * Note how we keeps eps expressions so only generate pieces that are used
261 :     *)
262 : jhr 3551 fun expandEinOp (e as (y, DstIL.EINAPP(ein, args)), fieldset) = let
263 : cchiw 3569 fun rewriteBody b=(case b
264 : jhr 3550 of (E.Probe(E.Conv(_,_,_,[]),_))
265 :     => replaceProbe(0,e,b,[])
266 :     | (E.Probe(E.Conv (_,alpha,_,dx),_))
267 : cchiw 3569 => liftProbe (0,e,b,[]) (*scans dx for contant*)
268 : jhr 3550 | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
269 :     => replaceProbe(0,e,p, sx) (*no dx*)
270 :     | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
271 : cchiw 3569 => liftProbe (0,e,p,sx) (*scalar field*)
272 : jhr 3550 | (E.Sum(sx,E.Probe p))
273 :     => replaceProbe(0,e,E.Probe p, sx)
274 :     | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
275 :     => replaceProbe(0,e,E.Probe p,sx)
276 :     | _ => [e]
277 :     (* end case *))
278 :     val (fieldset,var) = (case valnumflag
279 :     of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
280 :     | _ => (fieldset,NONE)
281 :     (* end case *))
282 :    
283 :     fun matchField b=(case b
284 :     of E.Probe _ => 1
285 : cchiw 3569 | E.Sum (_, E.Probe _) => 1
286 :     | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _])) => 1
287 :     | _ => 0
288 : jhr 3550 (* end case *))
289 : cchiw 3569 val b = Ein.body ein
290 : jhr 3550
291 :     in (case var
292 : cchiw 3569 of NONE => ((rewriteBody(Ein.body ein), fieldset, matchField(Ein.body ein), 0))
293 :     | SOME v => (("\n mapp_replacing"^(P.printerE ein)^":");([(y,DstIL.VAR v)], fieldset, matchField(Ein.body ein), 1))
294 : jhr 3550 (* end case *))
295 :     end
296 :    
297 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0