Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3174 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure ProbeEin = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure DstIL = MidIL
13 :     structure DstOp = MidOps
14 : jhr 3060 structure P = Printer
15 :     structure T = TransformEin
16 :     structure MidToS = MidToString
17 : cchiw 2976 structure DstV = DstIL.Var
18 :     structure DstTy = MidILTypes
19 :    
20 : cchiw 2606 in
21 :    
22 : cchiw 2870 (* This file expands probed fields
23 :     * Take a look at ProbeEin tex file for examples
24 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
25 :     * Param_ids are used to note the placement of the argument in the midIL.var list
26 :     * Index_ids keep track of the shape of an Image or differentiation.
27 :     * Mu bind Index_id
28 :     * Generally, we will refer to the following
29 :     *dim:dimension of field V
30 :     * s: support of kernel H
31 :     * alpha: The alpha in <V_alpha * H^(deltas)>
32 :     * deltas: The deltas in <V_alpha * H^(deltas)>
33 :     * Vid:param_id for V
34 :     * hid:param_id for H
35 :     * nid: integer position param_id
36 :     * fid :fractional position param_id
37 : cchiw 3166 * img-imginfo about V
38 : cchiw 2870 *)
39 : cchiw 3033
40 : cchiw 2923 val testing=0
41 : cchiw 3094 val testlift=0
42 : cchiw 2845 val cnt = ref 0
43 : cchiw 2606
44 : cchiw 3030 fun printEINAPP e=MidToString.printEINAPP e
45 : cchiw 2845 fun transformToIndexSpace e=T.transformToIndexSpace e
46 :     fun transformToImgSpace e=T.transformToImgSpace e
47 : cchiw 3174 fun printEINAPP e=MidToString.printEINAPP e
48 : cchiw 3033
49 : cchiw 3048 fun transitionToString(testreplace,a,b)=(case testreplace
50 :     of 0=> 1
51 : cchiw 3094 | 2 => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)
52 : cchiw 3166 |_ =>(print(String.concat["\nReplaced:",P.printbody a]);1)
53 : cchiw 3048 (*end case*))
54 : cchiw 3033 fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
55 :     fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
56 :     fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body
57 :     fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=
58 :     (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))
59 : cchiw 3048
60 : cchiw 2845 fun testp n=(case testing
61 :     of 0=> 1
62 :     | _ =>(print(String.concat n);1)
63 :     (*end case*))
64 : cchiw 3048 fun einapptostring (body,a,b)=(case testlift
65 :     of 0=>1
66 : cchiw 3094 | 2=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a, "&\n\t", printEINAPP b]);1)
67 : cchiw 3166 |_ =>(print(String.concat["\nLifted",P.printbody body]);1)
68 : cchiw 3048 (*end case*))
69 :    
70 :    
71 : cchiw 2845 fun getRHSDst x = (case DstIL.Var.binding x
72 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
73 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
74 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
75 :     (* end case *))
76 : cchiw 2838
77 : cchiw 2606
78 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
79 :     uses the Param_ids for the image, kernel,
80 :     and position tensor to get the Mid-IL arguments
81 :     returns the support of ther kernel, and image
82 :     *)
83 : jhr 3060 fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
84 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
85 : cchiw 2845 in
86 : jhr 3060 ((Kernel.support h) ,img,ImageInfo.dim img)
87 : cchiw 2845 end
88 :     | _ => raise Fail "Expected Image and kernel arguments"
89 :     (*end case*))
90 : cchiw 2606
91 :    
92 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
93 :     ->int*Mid.ILVars list* code*int* low-il-var
94 :     * uses the Param_ids for the image, kernel, and tensor
95 :     * and gets the mid-IL vars for each.
96 :     *Transforms the position to index space
97 :     *P is the mid-il var for the (transformation matrix)transpose
98 :     *)
99 :     fun handleArgs(Vid,hid,tid,args)=let
100 :     val imgArg=List.nth(args,Vid)
101 :     val hArg=List.nth(args,hid)
102 :     val newposArg=List.nth(args,tid)
103 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
104 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
105 : cchiw 2606 in
106 : cchiw 2845 (dim,args@argsT,code, s,P)
107 : cchiw 2606 end
108 : cchiw 2838
109 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
110 :     * expands the body for the probed field
111 :     *)
112 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
113 :     (*1-d fields*)
114 :     fun createKRND1 ()=let
115 :     val sum=sx
116 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
117 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
118 :     val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
119 :     in
120 :     E.Prod [E.Img(Vid,alpha,pos),rest]
121 : cchiw 2843 end
122 : cchiw 2845 (*createKRN Image field and kernels *)
123 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
124 :     | createKRN(dim,imgpos,rest)=let
125 :     val dim'=dim-1
126 :     val sum=sx+dim'
127 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
128 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
129 :     val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
130 :     in
131 :     createKRN(dim',pos@imgpos,[rest']@rest)
132 :     end
133 :     val exp=(case dim
134 :     of 1 => createKRND1()
135 :     | _=> createKRN(dim, [],[])
136 :     (*end case*))
137 :     (*sumIndex creating summaiton Index for body*)
138 :     val slb=1-s
139 :     val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
140 : cchiw 2843 in
141 : cchiw 2845 E.Sum(esum, exp)
142 : cchiw 2606 end
143 :    
144 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
145 :     *get fresh/unused index_id, returns int
146 :     *)
147 :     fun getsumshift(sx,index) =let
148 :     val nsumshift= (case sx
149 :     of []=> length(index)
150 :     | _=>let
151 :     val (E.V v,_,_)=List.hd(List.rev sx)
152 :     in v+1
153 :     end
154 :     (* end case *))
155 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
156 :     val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
157 :     "\nThink nshift is ", Int.toString nsumshift]
158 :     in
159 :     nsumshift
160 :     end
161 : cchiw 2611
162 : cchiw 2845 (*formBody:ein_exp->ein_exp
163 :     *just does a quick rewrite
164 :     *)
165 :     fun formBody(E.Sum([],e))=formBody e
166 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
167 :     | formBody(E.Prod [e])=e
168 :     | formBody e=e
169 : cchiw 2606
170 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
171 :     fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
172 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
173 :    
174 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
175 :     -> ein_exp* *code
176 :     * Transforms position to world space
177 :     * transforms result back to index_space
178 :     * rewrites body
179 :     * replace probe with expanded version
180 :     *)
181 : cchiw 3048 (* fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
182 :    
183 :     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
184 :     =let
185 :     val originalb=Ein.body e
186 :     val params=Ein.params e
187 :     val index=Ein.index e
188 : cchiw 3174 (*val _=print("\n"^P.printbody originalb)*)
189 : cchiw 3048
190 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
191 : cchiw 2845 val fid=length(params)
192 :     val nid=fid+1
193 :     val Pid=nid+1
194 :     val nshift=length(dx)
195 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
196 :     val freshIndex=getsumshift(sx,index)
197 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
198 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
199 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
200 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
201 : cchiw 3033
202 :     val body'=(case originalb
203 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
204 :     | E.Sum(sx,E.Prod[eps0,E.Probe _ ]) => E.Sum(sx,E.Prod[eps0,body'])
205 :     | _ => body'
206 :     (*end case*))
207 : cchiw 3048 val _=transitionToString(testN,originalb,body')
208 : cchiw 3033
209 : cchiw 2845 val args'=argsA@[PArg]
210 : cchiw 3033 val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
211 :     in
212 :     code@[einapp]
213 : cchiw 2845 end
214 : cchiw 2976
215 :    
216 : cchiw 3048 fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
217 : cchiw 2976 val Pid=0
218 :     val tid=1
219 :    
220 : cchiw 3048 val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
221 :    
222 :     (*need to rewrite dx*)
223 :     val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx
224 :     of []=> ([],index,E.Conv(9,alpha,7,newdx))
225 :     | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
226 :     (*end case*))
227 :    
228 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
229 : cchiw 3033 val tshape=alpha@newdx
230 :     val t=E.Tensor(tid,tshape)
231 : cchiw 3048 val exp = multiPs(Ps,newsx,t)
232 :     val body=(case originalb
233 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,exp)
234 :     | E.Sum(sx,E.Prod[eps0,E.Probe _ ]) => E.Sum(sx,E.Prod[eps0,exp])
235 :     | _ => exp
236 :     (*end case*))
237 :    
238 :     val ein0=mkEin(params,index,body)
239 : cchiw 2976 in
240 : cchiw 3048 (ein0,sizes,dx)
241 : cchiw 2976 end
242 : cchiw 3048
243 :     fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let
244 :     val originalb=Ein.body e
245 :     val params=Ein.params e
246 :     val index=Ein.index e
247 : cchiw 3174 (*val _=print("\n"^P.printbody originalb)*)
248 : cchiw 2976
249 : cchiw 3048 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
250 : cchiw 2976 val fid=length(params)
251 :     val nid=fid+1
252 :     val nshift=length(dx)
253 : cchiw 3048 val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
254 : cchiw 2976 val freshIndex=getsumshift(sx,index)
255 :    
256 :    
257 :     (*transform T*P*P..Ps*)
258 : cchiw 3048 val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
259 :     val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
260 :     val einApp0=mkEinApp(ein0,[PArg,FArg])
261 :     val rtn0=(y,einApp0)
262 : cchiw 2976
263 :     (*lifted probe*)
264 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
265 : cchiw 3048 val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
266 : cchiw 3033 val ein1=mkEin(params',sizes,body')
267 : cchiw 2976 val einApp1=mkEinApp(ein1,args')
268 : cchiw 3048 val rtn1=(FArg,einApp1)
269 :     val rtn=code@[rtn1,rtn0]
270 : cchiw 3166 val _= einapptostring (originalb,rtn1,rtn0)
271 : cchiw 2976 in
272 :     rtn
273 :     end
274 :    
275 :    
276 : cchiw 2845 (* expandEinOp: code-> code list
277 : cchiw 3048 *A this point we only have simple ein ops
278 : cchiw 2845 *Looks to see if the expression has a probe. If so, replaces it.
279 :     * Note how we keeps eps expressions so only generate pieces that are used
280 :     *)
281 : cchiw 3174 (* fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let*)
282 :     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
283 :    
284 :    
285 : cchiw 3066 fun checkConst ([],a) = liftProbe a
286 : cchiw 3174 | checkConst ((E.C _::_),a) =(("\n \n constant field"^(printEINAPP e));replaceProbe a)
287 : cchiw 3166 | checkConst ((_ ::es),a)= checkConst(es,a)
288 : cchiw 2845 fun rewriteBody b=(case b
289 : cchiw 3048 of E.Probe(E.Conv(_,_,_,[]),_)
290 : cchiw 3094 => replaceProbe(0,e,b, [])
291 : cchiw 3066 | E.Probe(E.Conv (_,alpha,_,dx),_)
292 :     => checkConst(alpha@dx,(0,e,b,[]))
293 : cchiw 3048 | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
294 : cchiw 3094 => replaceProbe(0,e,p, sx)
295 : cchiw 3066 | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
296 :     => checkConst(dx,(0,e,p,sx))
297 : cchiw 3166 (*| E.Sum(sx as [(v,_,_)],p as (E.Probe((E.Conv(_,alpha,_,dx),_))))=>(case
298 :     (List.find (fn x => x = v) dx)
299 :     of NONE=> checkConst(alpha@dx,(1,e,p,sx))
300 :     (*need to push summation to lifted exp rather than transform exp.*)
301 :     | SOME _=> replaceProbe(1,e, p, sx)
302 :     (*end case*))*)
303 : cchiw 3048 | E.Sum(sx,E.Probe p)
304 : cchiw 3094 => replaceProbe(0,e,E.Probe p, sx)
305 : cchiw 3048 | E.Sum(sx,E.Prod[eps,E.Probe p])
306 : cchiw 3094 => replaceProbe(0,e,E.Probe p,sx)
307 : cchiw 3048 | _ => [e]
308 : cchiw 2845 (* end case *))
309 : cchiw 3174
310 :     val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
311 :     in (case var
312 :     of NONE=> (("\n \n not replacing"^(printEINAPP e));(rewriteBody(Ein.body ein),fieldset))
313 :     | SOME v=> (("\n replacing"^(P.printerE ein));( [(y,DstIL.VAR v)] , fieldset))
314 :     (*end case*))
315 :    
316 :     (*val code=rewriteBody(Ein.body ein)
317 :     in (code,fieldset)*)
318 : cchiw 2845 end
319 : cchiw 2843
320 : cchiw 2606 end; (* local *)
321 :    
322 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0