1 : |
cchiw |
2843 |
(* Currently under construction
|
2 : |
|
|
*
|
3 : |
|
|
* COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
|
4 : |
|
|
* All rights reserved.
|
5 : |
|
|
*)
|
6 : |
|
|
|
7 : |
|
|
structure ProbeEin = struct
|
8 : |
|
|
|
9 : |
|
|
local
|
10 : |
|
|
|
11 : |
|
|
structure E = Ein
|
12 : |
|
|
structure mk= mkOperators
|
13 : |
|
|
structure SrcIL = HighIL
|
14 : |
|
|
structure SrcTy = HighILTypes
|
15 : |
|
|
structure SrcOp = HighOps
|
16 : |
|
|
structure SrcSV = SrcIL.StateVar
|
17 : |
|
|
structure VTbl = SrcIL.Var.Tbl
|
18 : |
|
|
structure DstIL = MidIL
|
19 : |
|
|
structure DstTy = MidILTypes
|
20 : |
|
|
structure DstOp = MidOps
|
21 : |
|
|
structure DstV = DstIL.Var
|
22 : |
|
|
structure SrcV = SrcIL.Var
|
23 : |
|
|
structure P=Printer
|
24 : |
|
|
structure F=Filter
|
25 : |
|
|
structure T=TransformEin
|
26 : |
|
|
structure split=Split
|
27 : |
|
|
structure cleanI=cleanIndex
|
28 : |
|
|
|
29 : |
|
|
|
30 : |
|
|
val testing=1
|
31 : |
|
|
|
32 : |
|
|
|
33 : |
|
|
in
|
34 : |
|
|
|
35 : |
|
|
|
36 : |
|
|
(* This file expands probed fields
|
37 : |
|
|
*Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
|
38 : |
|
|
* Param_ids are used to note the placement of the argument in the midIL.var list
|
39 : |
|
|
* Index_ids bind the shape of an Image or differentiation.
|
40 : |
|
|
* Generally, we will refer to the following
|
41 : |
|
|
*dim:dimension of field V
|
42 : |
|
|
* s: support of kernel H
|
43 : |
|
|
* alpha: The alpha in <V_alpha * H^(deltas)>
|
44 : |
|
|
* deltas: The deltas in <V_alpha * H^(deltas)>
|
45 : |
|
|
* Vid:param_id for V
|
46 : |
|
|
* hid:param_id for H
|
47 : |
|
|
* nid: integer position param_id
|
48 : |
|
|
* fid :fractional position param_id
|
49 : |
|
|
*img-imginfo about V
|
50 : |
|
|
*)
|
51 : |
|
|
|
52 : |
|
|
|
53 : |
|
|
val cnt = ref 0
|
54 : |
|
|
fun genName prefix = let
|
55 : |
|
|
val n = !cnt
|
56 : |
|
|
in
|
57 : |
|
|
cnt := n+1;
|
58 : |
|
|
String.concat[prefix, "_", Int.toString n]
|
59 : |
|
|
end
|
60 : |
|
|
|
61 : |
|
|
|
62 : |
|
|
fun iterSx e=F.iterSx e
|
63 : |
|
|
fun transformToIndexSpace e=T.transformToIndexSpace e
|
64 : |
|
|
fun transformToImgSpace e=T.transformToImgSpace e
|
65 : |
|
|
fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
|
66 : |
|
|
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
|
67 : |
|
|
fun testp n=(case testing
|
68 : |
|
|
of 0=> 1
|
69 : |
|
|
| _ =>(print(String.concat n);1)
|
70 : |
|
|
(*end case*))
|
71 : |
|
|
fun getRHSDst x = (case DstIL.Var.binding x
|
72 : |
|
|
of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
|
73 : |
|
|
| DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
|
74 : |
|
|
| vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
|
75 : |
|
|
(* end case *))
|
76 : |
|
|
|
77 : |
|
|
|
78 : |
|
|
(* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
|
79 : |
|
|
uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments
|
80 : |
|
|
returns the support of ther kernel, and image
|
81 : |
|
|
*)
|
82 : |
|
|
fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
|
83 : |
|
|
of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
|
84 : |
|
|
in
|
85 : |
|
|
((Kernel.support h) ,img,ImageInfo.dim img)
|
86 : |
|
|
end
|
87 : |
|
|
| _ => raise Fail "Expected Image and kernel arguments"
|
88 : |
|
|
(*end case*))
|
89 : |
|
|
|
90 : |
|
|
|
91 : |
|
|
(*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var
|
92 : |
|
|
* uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each
|
93 : |
|
|
*Transforms the position to index space
|
94 : |
|
|
*P-mid-il var for the (transformation matrix)transpose
|
95 : |
|
|
*)
|
96 : |
|
|
fun handleArgs(Vid,hid,tid,args)=let
|
97 : |
|
|
val imgArg=List.nth(args,Vid)
|
98 : |
|
|
val hArg=List.nth(args,hid)
|
99 : |
|
|
val newposArg=List.nth(args,tid)
|
100 : |
|
|
val (s,img,dim) =getArgsDst(hArg,imgArg,args)
|
101 : |
|
|
val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
|
102 : |
|
|
in (dim,args@argsT,code, s,P)
|
103 : |
|
|
end
|
104 : |
|
|
|
105 : |
|
|
|
106 : |
|
|
(*createBody:int*int*int, index_id list, param_id, param_id, param_id, param_id
|
107 : |
|
|
* expands the body for the probed field
|
108 : |
|
|
*)
|
109 : |
|
|
fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
|
110 : |
|
|
|
111 : |
|
|
(*1-d fields*)
|
112 : |
|
|
fun createKRND1 ()=let
|
113 : |
|
|
val sum=sx
|
114 : |
|
|
val dels=List.map (fn e=>(E.C 0,e)) deltas
|
115 : |
|
|
val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
|
116 : |
|
|
val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
|
117 : |
|
|
in
|
118 : |
|
|
E.Prod [E.Img(Vid,alpha,pos),rest]
|
119 : |
|
|
|
120 : |
|
|
end
|
121 : |
|
|
(*createKRN Image field and kernels *)
|
122 : |
|
|
fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
|
123 : |
|
|
| createKRN(dim,imgpos,rest)=let
|
124 : |
|
|
val dim'=dim-1
|
125 : |
|
|
val sum=sx+dim'
|
126 : |
|
|
val dels=List.map (fn e=>(E.C dim',e)) deltas
|
127 : |
|
|
val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
|
128 : |
|
|
val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
|
129 : |
|
|
in
|
130 : |
|
|
createKRN(dim',pos@imgpos,[rest']@rest)
|
131 : |
|
|
end
|
132 : |
|
|
val exp=(case dim
|
133 : |
|
|
of 1 => createKRND1()
|
134 : |
|
|
| _=> createKRN(dim, [],[])
|
135 : |
|
|
(*end case*))
|
136 : |
|
|
|
137 : |
|
|
(*sumIndex creating summaiton Index for body*)
|
138 : |
|
|
val slb=1-s
|
139 : |
|
|
val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
|
140 : |
|
|
in
|
141 : |
|
|
E.Sum(esum, exp)
|
142 : |
|
|
end
|
143 : |
|
|
|
144 : |
|
|
(*getsumshift:sum_index_id list* index_id list-> int
|
145 : |
|
|
*get fresh/unused index_id, returns int
|
146 : |
|
|
*)
|
147 : |
|
|
fun getsumshift(sx,index) =let
|
148 : |
|
|
val nsumshift= (case sx
|
149 : |
|
|
of []=> length(index)
|
150 : |
|
|
| _=>let
|
151 : |
|
|
val (E.V v,_,_)=List.hd(List.rev sx)
|
152 : |
|
|
in v+1
|
153 : |
|
|
end
|
154 : |
|
|
(* end case *))
|
155 : |
|
|
val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
|
156 : |
|
|
val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
|
157 : |
|
|
in
|
158 : |
|
|
nsumshift
|
159 : |
|
|
end
|
160 : |
|
|
|
161 : |
|
|
(*formBody:ein_exp->ein_exp
|
162 : |
|
|
*just does a quick rewrite
|
163 : |
|
|
*)
|
164 : |
|
|
fun formBody(E.Sum([],e))=formBody e
|
165 : |
|
|
| formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
|
166 : |
|
|
| formBody(E.Prod [e])=e
|
167 : |
|
|
| formBody e=e
|
168 : |
|
|
|
169 : |
|
|
|
170 : |
|
|
(* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
|
171 : |
|
|
* Transforms position to world space
|
172 : |
|
|
* transforms result back to index_space
|
173 : |
|
|
* rewrites body
|
174 : |
|
|
* replace probe with expanded version
|
175 : |
|
|
*)
|
176 : |
|
|
fun replaceProbe(b,params,args,index, sx)=let
|
177 : |
|
|
|
178 : |
|
|
val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
|
179 : |
|
|
val fid=length(params)
|
180 : |
|
|
val nid=fid+1
|
181 : |
|
|
val Pid=nid+1
|
182 : |
|
|
val nshift=length(dx)
|
183 : |
|
|
val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
|
184 : |
|
|
val freshIndex=getsumshift(sx,index)
|
185 : |
|
|
val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
|
186 : |
|
|
val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
|
187 : |
|
|
val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
|
188 : |
|
|
val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
|
189 : |
|
|
val args'=argsA@[PArg]
|
190 : |
|
|
in
|
191 : |
|
|
(body',params',args' ,code)
|
192 : |
|
|
end
|
193 : |
|
|
|
194 : |
|
|
|
195 : |
|
|
|
196 : |
|
|
(* liftedProbe:e:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
|
197 : |
|
|
* Same as above except it does not transforms result back to index_space
|
198 : |
|
|
* Also returns P arg.
|
199 : |
|
|
*)
|
200 : |
|
|
fun liftedProbe(b,params,args,index, sumIndex)=let
|
201 : |
|
|
val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
|
202 : |
|
|
val fid=length(params)
|
203 : |
|
|
val nid=fid+1
|
204 : |
|
|
val nshift=length(dx)
|
205 : |
|
|
val freshIndex = getsumshift(sumIndex,index)
|
206 : |
|
|
val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args)
|
207 : |
|
|
val params=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
|
208 : |
|
|
val body = createBody(dim, s, freshIndex+nshift,alpha,dx,V, h, nid, fid)
|
209 : |
|
|
in
|
210 : |
|
|
(body,params,argsA ,PArg,code)
|
211 : |
|
|
end
|
212 : |
|
|
|
213 : |
|
|
(* expandEinOp: code-> code list
|
214 : |
|
|
*Looks to see if the expression has a probe. If so, replaces it.
|
215 : |
|
|
* Note how we keeps eps type expressions so we have less time in mid-to-low-il stage
|
216 : |
|
|
*)
|
217 : |
|
|
fun expandEinOp99( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
|
218 : |
|
|
fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
|
219 : |
|
|
(String.concatWith",\t"(List.map split.printEINAPP code))]
|
220 : |
|
|
|
221 : |
|
|
fun rewriteBody b=(case b
|
222 : |
|
|
of E.Probe e =>let
|
223 : |
|
|
val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
|
224 : |
|
|
val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
|
225 : |
|
|
val code=newbies@[einapp]
|
226 : |
|
|
in
|
227 : |
|
|
code
|
228 : |
|
|
end
|
229 : |
|
|
| E.Sum(sx,E.Probe e) =>let
|
230 : |
|
|
val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
|
231 : |
|
|
val body'=E.Sum(sx,body')
|
232 : |
|
|
val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
|
233 : |
|
|
val code=newbies@[einapp]
|
234 : |
|
|
in
|
235 : |
|
|
code
|
236 : |
|
|
end
|
237 : |
|
|
| E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
|
238 : |
|
|
val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
|
239 : |
|
|
val body'=E.Sum(sx,E.Prod[eps,body'])
|
240 : |
|
|
val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
|
241 : |
|
|
val code=newbies@[einapp]
|
242 : |
|
|
in
|
243 : |
|
|
code
|
244 : |
|
|
end
|
245 : |
|
|
| _=> [e]
|
246 : |
|
|
(* end case *))
|
247 : |
|
|
in
|
248 : |
|
|
rewriteBody body
|
249 : |
|
|
end
|
250 : |
|
|
|
251 : |
|
|
|
252 : |
|
|
(* expandEinOp: code-> code list* Arg List
|
253 : |
|
|
*same as above but uses lifted probe()
|
254 : |
|
|
*)
|
255 : |
|
|
|
256 : |
|
|
fun expandEinOp2( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
|
257 : |
|
|
fun rewriteBody b=(case b
|
258 : |
|
|
of E.Probe e =>let
|
259 : |
|
|
val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, [])
|
260 : |
|
|
val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
|
261 : |
|
|
val code=newbies@[einapp]
|
262 : |
|
|
|
263 : |
|
|
|
264 : |
|
|
in
|
265 : |
|
|
(code,[PArg])
|
266 : |
|
|
end
|
267 : |
|
|
| E.Sum(sx,E.Probe e) =>let
|
268 : |
|
|
val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, sx)
|
269 : |
|
|
val body'=E.Sum(sx,body')
|
270 : |
|
|
val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
|
271 : |
|
|
val code=newbies@[einapp]
|
272 : |
|
|
in
|
273 : |
|
|
(code,[PArg])
|
274 : |
|
|
end
|
275 : |
|
|
| E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
|
276 : |
|
|
val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, sx)
|
277 : |
|
|
val body'=E.Sum(sx,E.Prod[eps,body'])
|
278 : |
|
|
val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
|
279 : |
|
|
val code=newbies@[einapp]
|
280 : |
|
|
in
|
281 : |
|
|
(code,[PArg])
|
282 : |
|
|
end
|
283 : |
|
|
| _=> ([e],[])
|
284 : |
|
|
(* end case *))
|
285 : |
|
|
in
|
286 : |
|
|
rewriteBody body
|
287 : |
|
|
end
|
288 : |
|
|
|
289 : |
|
|
(* mkBoody:index_id list, ein_exp: sum_id list, ein_exp
|
290 : |
|
|
*rewrite ein_exp. replaces dx with new delta
|
291 : |
|
|
* Was sx bound to the probed field's alpha or delta?
|
292 : |
|
|
* If it was delta then it needs to be moved to the original ein_app, not in the subexpression
|
293 : |
|
|
*)
|
294 : |
|
|
fun mkbody(beta ,body)=(case body
|
295 : |
|
|
of E.Probe(E.Conv (v,alpha,h,_), E.Tensor t)=>([], E.Probe(E.Conv (v,alpha,h,beta), E.Tensor t))
|
296 : |
|
|
| E.Sum(sx,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t))=>let
|
297 : |
|
|
val(pre,post)=F.iterSx(sx,dx)
|
298 : |
|
|
val body=E.Sum(post,E.Probe(E.Conv (v,alpha,h,beta), E.Tensor t))
|
299 : |
|
|
in (pre,body)
|
300 : |
|
|
end
|
301 : |
|
|
| E.Sum(sx,E.Prod[eps,E.Probe(E.Conv(v,alpha,h,dx), E.Tensor t)])=>let
|
302 : |
|
|
val(pre,post)=F.iterSx(sx,dx)
|
303 : |
|
|
val body=E.Sum(post,E.Prod[eps,E.Probe(E.Conv(v,alpha,h,beta), E.Tensor t)])
|
304 : |
|
|
in (pre,body)
|
305 : |
|
|
end
|
306 : |
|
|
|
307 : |
|
|
(*end case*))
|
308 : |
|
|
|
309 : |
|
|
(*getT:sum_index_id list* int list* int*index_id list*ein_exp*Param_ids*mid_il var list*mid_id var
|
310 : |
|
|
*This goal of this function is to create simple EINAPPs
|
311 : |
|
|
*When differentiation is involved the deltas in a probed field are rewritten
|
312 : |
|
|
*and there is a multiplication by P for each index.
|
313 : |
|
|
*This function lifts the probed field out and multiplies it's replacment tensor with the Ps
|
314 : |
|
|
*The probed field is rewritten with the new indices then it is cleaned with "split.lift".
|
315 : |
|
|
*The result is two einapps that hopefully produce less loops for mid-to-low.sml
|
316 : |
|
|
*)
|
317 : |
|
|
fun geT(sx,index,dim,dx,body,params,args,y)=let
|
318 : |
|
|
val Pid=length(args)+1
|
319 : |
|
|
val freshIndex=getsumshift(sx,index)
|
320 : |
|
|
val (newdx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
|
321 : |
|
|
val (newsx2,e')= mkbody(newdx,body)
|
322 : |
|
|
val newsx=newsx1@newsx2
|
323 : |
|
|
val (Re,Rparams,Rargs,[einappN])=split.lift(e',params,index,newsx,args)
|
324 : |
|
|
val (einappN,Parg)=expandEinOp2 einappN
|
325 : |
|
|
|
326 : |
|
|
val body'=E.Sum(newsx,E.Prod(Ps@[Re]))
|
327 : |
|
|
val einappO=(y,DstIL.EINAPP(Ein.EIN{params=Rparams@[E.TEN(1,[dim,dim])], index=index, body=body'},Rargs@Parg))
|
328 : |
|
|
val code=einappN@[einappO]
|
329 : |
|
|
in code
|
330 : |
|
|
end
|
331 : |
|
|
|
332 : |
|
|
|
333 : |
|
|
(*liftTransform: code->code
|
334 : |
|
|
this function is called when we are testing lifting transformations
|
335 : |
|
|
analyzes body of ein_exp and sets up the arguments to next function
|
336 : |
|
|
Note, I hardcoded the dimension to be 2.
|
337 : |
|
|
*)
|
338 : |
|
|
fun liftTransform( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
|
339 : |
|
|
fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
|
340 : |
|
|
(String.concatWith",\t"(List.map split.printEINAPP code))]
|
341 : |
|
|
fun rewriteBody b=(case b
|
342 : |
|
|
of E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t) =>let
|
343 : |
|
|
val code=geT([],index,2,dx,b,params,args,y)
|
344 : |
|
|
val _ = printResult code
|
345 : |
|
|
in
|
346 : |
|
|
code
|
347 : |
|
|
end
|
348 : |
|
|
| E.Sum(sx,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t)) =>let
|
349 : |
|
|
val code=geT(sx,index,2,dx,b,params,args,y)
|
350 : |
|
|
in
|
351 : |
|
|
code
|
352 : |
|
|
end
|
353 : |
|
|
| E.Sum(sx,E.Prod[eps,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t)]) =>let
|
354 : |
|
|
val code=geT(sx,index,2,dx,b,params,args,y)
|
355 : |
|
|
in
|
356 : |
|
|
code
|
357 : |
|
|
end
|
358 : |
|
|
| _=> [e]
|
359 : |
|
|
(* end case *))
|
360 : |
|
|
in
|
361 : |
|
|
rewriteBody body
|
362 : |
|
|
end
|
363 : |
|
|
|
364 : |
|
|
val testlift=0
|
365 : |
|
|
fun expandEinOp e=(case testlift
|
366 : |
|
|
of 1=>liftTransform e
|
367 : |
|
|
| _ =>let
|
368 : |
|
|
val code= expandEinOp99 e
|
369 : |
|
|
in code
|
370 : |
|
|
end
|
371 : |
|
|
(*end case*))
|
372 : |
|
|
|
373 : |
|
|
|
374 : |
|
|
|
375 : |
|
|
end; (* local *)
|
376 : |
|
|
|
377 : |
|
|
end (* local *)
|