Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3166 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     *)
18 : cchiw 2838
19 :     structure Split = struct
20 :    
21 :     local
22 :    
23 :     structure E = Ein
24 :     structure DstIL = MidIL
25 :     structure DstTy = MidILTypes
26 :     structure DstV = DstIL.Var
27 :     structure P=Printer
28 :     structure cleanP=cleanParams
29 :     structure cleanI=cleanIndex
30 :    
31 : cchiw 3033
32 : cchiw 2838 in
33 :    
34 : cchiw 3166 val numFlag=1 (*remove common subexpression*)
35 : cchiw 3033 val testing=0
36 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
37 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
38 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
39 :     fun setEinZero y= (y,einappzero)
40 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e
41 :     fun cleanIndex e =cleanI.cleanIndex e
42 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e
43 : cchiw 2838 fun itos i =Int.toString i
44 : cchiw 2923 fun filterSca e=Filter.filterSca e
45 : cchiw 2838 fun err str=raise Fail str
46 :     val cnt = ref 0
47 : cchiw 3166 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
48 : cchiw 2838 fun genName prefix = let
49 :     val n = !cnt
50 :     in
51 :     cnt := n+1;
52 :     String.concat[prefix, "_", Int.toString n]
53 :     end
54 :     fun testp n=(case testing
55 :     of 0=> 1
56 :     | _ =>(print(String.concat n);1)
57 :     (*end case*))
58 :    
59 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
60 :     *lifts expression and returns replacement tensor
61 : cchiw 2843 * cleans the index and params of subexpression
62 :     *creates new param and replacement tensor for the original ein_exp
63 : cchiw 2838 *)
64 : cchiw 3166 fun lift(name,e,params,index,sx,args,fieldset,flag)=let
65 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
66 : cchiw 2843 val id=length(params)
67 :     val Rparams=params@[E.TEN(1,sizes)]
68 :     val Re=E.Tensor(id,tshape)
69 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
70 : cchiw 2843 val Rargs=args@[M]
71 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
72 : cchiw 3166 val (_,einapp0)=einapp
73 :     val (Rargs,newbies,fieldset) =(case flag
74 :     of 1=> let
75 :     val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0)
76 :     in (case var
77 :     of NONE=> (args@[M],[einapp],fieldset)
78 :     | SOME v=> (incUse v ;(args@[v],[],fieldset))
79 :     (*end case*))
80 :     end
81 :     | _=>(args@[M],[einapp],fieldset)
82 :     (*end case*))
83 :     in
84 :     (Re,Rparams,Rargs,newbies,fieldset)
85 :     end
86 : cchiw 2838
87 : cchiw 3166
88 : cchiw 2838 (* isOp: ein->int
89 :     * checks to see if this sub-expression is pulled out or split form original
90 :     * 0-becomes zero,1-remains the same, 2-operator
91 :     *)
92 :     fun isOp e =(case e
93 :     of E.Field _ => 0
94 :     | E.Conv _ => 0
95 :     | E.Apply _ => 0
96 :     | E.Lift _ => 0
97 :     | E.Neg _ => 1
98 : cchiw 2870 | E.Sqrt _ => 1
99 : cchiw 3138 | E.Cosine _ => 1
100 :     | E.ArcCosine _ => 1
101 :     | E.Sine _ => 1
102 : cchiw 3166 | E.ArcSine _ => 1
103 : cchiw 3138 | E.PowInt _ => 1
104 :     | E.PowReal _ => 1
105 : cchiw 2838 | E.Add _ => 1
106 :     | E.Sub _ => 1
107 :     | E.Prod _ => 1
108 :     | E.Div _ => 1
109 :     | E.Sum _ => 1
110 :     | E.Probe _ => 1
111 :     | E.Partial _ => err(" Partial used after normalize")
112 :     | E.Krn _ => err("Krn used before expand")
113 :     | E.Value _ => err("Value used before expand")
114 :     | E.Img _ => err("Probe used before expand")
115 :     | _ => 2
116 :     (*end case*))
117 :    
118 : cchiw 3166
119 :    
120 :     fun rewriteOp3(name,sx,e1,x)=let
121 :     val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
122 :     val params=Ein.params ein
123 :     val index=Ein.index ein
124 :     in (case (isOp e1)
125 :     of 0 => (E.Const 0,params,args,[],fieldset)
126 :     | 1 => lift(name,e1,params,index,sx,args,fieldset,flag)
127 :     | 2 => (e1,params,args,[],fieldset)
128 :     (*end*))
129 :     end
130 :    
131 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
132 : cchiw 3166 * If e1 an op then call lift() to replace it
133 :     *)
134 :     fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1)
135 :     of 0 => (E.Const 0,params,args,[],fieldset)
136 :     | 1 => lift(name,e1,params,index,sx,args,fieldset,flag)
137 :     | 2 => (e1,params,args,[],fieldset) (*not lifted*)
138 : cchiw 2838 (*end*))
139 :    
140 : cchiw 3166
141 : cchiw 3017
142 : cchiw 3166
143 :     fun rewriteOps(name,list1,params,index,sx,args,fieldset0,flag)=let
144 :     fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset)
145 :     | m(e1::es,rest,params,args,code,fieldset)=let
146 :    
147 :     val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
148 : cchiw 2838 in
149 : cchiw 3166 m(es,rest@[e1'],params',args',code@code',fieldset)
150 : cchiw 2838 end
151 :     in
152 : cchiw 3166 m(list1,[],params,args,[],fieldset0)
153 : cchiw 2838 end
154 : cchiw 3166
155 :    
156 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
157 : cchiw 2843 When the operation is zero then we return a real.
158 : cchiw 2845 -Moved is Zero to before split.
159 : cchiw 2843 *)
160 : cchiw 3033 fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args)
161 : cchiw 2838
162 : cchiw 3166 fun rewriteOrig3(sx,body,params,args,x) =let
163 :     val ((y,DstIL.EINAPP(ein,_)),_,_)=x
164 :     val index=Ein.index ein
165 :     in cleanParams(y,body,params,index,args)
166 :     end
167 :    
168 : cchiw 2838 (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
169 : cchiw 2843 * calls rewriteOp() lift on ein_exp
170 : cchiw 2838 *)
171 : cchiw 3166 fun handleNeg(e1,x)=let
172 :     val (e1',params',args',code,fieldset)= rewriteOp3("neg",[],e1,x)
173 :     val body' =E.Neg e1'
174 :     val einapp= rewriteOrig3([],body',params',args',x)
175 : cchiw 2845 in
176 : cchiw 3166 (einapp,code,fieldset)
177 : cchiw 2845 end
178 : cchiw 2838
179 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
180 :     * calls rewriteOp() lift on ein_exp
181 :     *)
182 : cchiw 3166 fun handleSqrt(y,e1,params,index,args,fieldset,flag)=let
183 :     val (e1',params',args',code,fieldset)= rewriteOp("sqrt", e1,params,index,[],args,fieldset,flag)
184 : cchiw 2867 val body =E.Sqrt e1'
185 :     val einapp= rewriteOrig(y,body,params',index,[],args')
186 :     in
187 : cchiw 3166 (einapp,code,fieldset)
188 : cchiw 2867 end
189 :    
190 :    
191 : cchiw 3138 (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code
192 :     * calls rewriteOp() lift on ein_exp
193 :     *)
194 : cchiw 3166 fun handleCosine(y,e1,params,index,args,fieldset,flag)=let
195 :     val (e1',params',args',code,fieldset)= rewriteOp("cosine", e1,params,index,[],args,fieldset,flag)
196 : cchiw 3138 val body =E.Cosine e1'
197 :     val einapp= rewriteOrig(y,body,params',index,[],args')
198 :     in
199 : cchiw 3166 (einapp,code,fieldset)
200 : cchiw 3138 end
201 :    
202 :     (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code
203 :     * calls rewriteOp() lift on ein_exp
204 :     *)
205 : cchiw 3166 fun handleArcCosine(y,e1,params,index,args,fieldset,flag)=let
206 :     val (e1',params',args',code,fieldset)= rewriteOp("ArcCosine", e1,params,index,[],args,fieldset,flag)
207 : cchiw 3138 val body =E.ArcCosine e1'
208 :     val einapp= rewriteOrig(y,body,params',index,[],args')
209 :     in
210 : cchiw 3166 (einapp,code,fieldset)
211 : cchiw 3138 end
212 :    
213 :     (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
214 :     * calls rewriteOp() lift on ein_exp
215 :     *)
216 : cchiw 3166 fun handleSine(y,e1,params,index,args,fieldset,flag)=let
217 :     val (e1',params',args',code,fieldset)= rewriteOp("sine", e1,params,index,[],args,fieldset,flag)
218 : cchiw 3138 val body =E.Sine e1'
219 :     val einapp= rewriteOrig(y,body,params',index,[],args')
220 :     in
221 : cchiw 3166 (einapp,code,fieldset)
222 : cchiw 3138 end
223 :    
224 :     (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
225 :     * calls rewriteOp() lift on ein_exp
226 :     *)
227 : cchiw 3166 fun handleArcSine(y,e1,params,index,args,fieldset,flag)=let
228 :     val (e1',params',args',code,fieldset)= rewriteOp("ArcSine", e1,params,index,[],args,fieldset,flag)
229 :     val body =E.ArcSine e1'
230 : cchiw 3033 val einapp= rewriteOrig(y,body,params',index,[],args')
231 :     in
232 : cchiw 3166 (einapp,code,fieldset)
233 : cchiw 3033 end
234 : cchiw 2870
235 :    
236 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
237 : cchiw 2843 * calls rewriteOps() lift on ein_exp
238 : cchiw 2838 *)
239 : cchiw 3166 fun handleSub(y,e1,e2,params,index,args,fieldset,flag)=let
240 :     val ([e1',e2'],params',args',code,fieldset)= rewriteOps("subt",[e1,e2],params,index,[],args,fieldset,flag)
241 : cchiw 2838 val body =E.Sub(e1',e2')
242 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
243 :     in
244 : cchiw 3166 (einapp,code,fieldset)
245 : cchiw 2845 end
246 : cchiw 2838
247 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
248 : cchiw 2843 * calls rewriteOp() lift on ein_exp
249 : cchiw 2838 *)
250 : cchiw 3166 fun handleDiv(y,e1,e2,params,index,args,fieldset,flag)=let
251 :     val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag)
252 :     val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag)
253 : cchiw 2838 val body =E.Div(e1',e2')
254 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2')
255 :     in
256 : cchiw 3166 (einapp,code1'@code2',fieldset)
257 : cchiw 2845 end
258 : cchiw 2838
259 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
260 : cchiw 2843 * calls rewriteOps() lift on ein_exp
261 : cchiw 2838 *)
262 : cchiw 3166 fun handleAdd(y,e1,params,index,args,fieldset,flag)=let
263 : cchiw 3030
264 : cchiw 3166 val (e1',params',args',code,fieldset)= rewriteOps("add",e1,params,index,[],args,fieldset,flag)
265 : cchiw 2838 val body =E.Add e1'
266 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
267 :     in
268 : cchiw 3166 (einapp,code,fieldset)
269 : cchiw 2845 end
270 : cchiw 2838
271 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
272 : cchiw 2843 * calls rewriteOps() lift on ein_exp
273 : cchiw 2838 *)
274 : cchiw 3166 fun handleProd(y,e1,params,index,args,fieldset,flag)=let
275 :     val (e1',params',args',code,fieldset)= rewriteOps("prod",e1,params,index,[],args,fieldset,flag)
276 : cchiw 2838 val body =E.Prod e1'
277 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
278 :     in
279 : cchiw 3166 (einapp,code,fieldset)
280 : cchiw 2845 end
281 : cchiw 2838
282 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
283 : cchiw 2843 * calls rewriteOps() lift on ein_exp
284 : cchiw 2838 *)
285 : cchiw 3166 fun handleSumProd(y,e1,params,index,sx,args,fieldset,flag)=let
286 :     val (e1',params',args',code,fieldset)= rewriteOps("sumprod",e1,params,index,sx,args,fieldset,flag)
287 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
288 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args')
289 :     in
290 : cchiw 3166 (einapp,code,fieldset)
291 : cchiw 2845 end
292 : cchiw 2838
293 :     (* split:var*ein_app-> (var*einap)*code
294 :     * split ein expression into smaller pieces
295 : cchiw 2843 note we leave summation around probe exp
296 : cchiw 2838 *)
297 : cchiw 3166 fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let
298 :     val x= ((y,einapp),fieldset,flag)
299 :     val zero= (setEinZero y,[],fieldset)
300 :     val default=((y,einapp),[],fieldset)
301 : cchiw 2838 val sumIndex=ref []
302 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
303 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body]
304 :     fun rewrite b=(case b
305 : cchiw 2847 of E.Probe (E.Conv _,_) => default
306 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str
307 : cchiw 2847 | E.Probe _ => raise Fail str
308 : cchiw 2838 | E.Conv _ => zero
309 :     | E.Field _ => zero
310 :     | E.Apply _ => zero
311 :     | E.Lift e => zero
312 :     | E.Delta _ => default
313 :     | E.Epsilon _ => default
314 : cchiw 2843 | E.Eps2 _ => default
315 : cchiw 2838 | E.Tensor _ => default
316 :     | E.Const _ => default
317 : cchiw 2923 | E.ConstR _ => default
318 : cchiw 3166 | E.Neg e1 => handleNeg(e1,x)
319 :     | E.Sqrt e1 => handleSqrt(y,e1,params,index,args,fieldset,flag)
320 :     | E.Cosine e1 => handleCosine(y,e1,params,index,args,fieldset,flag)
321 :     | E.ArcCosine e1 => handleArcCosine(y,e1,params,index,args,fieldset,flag)
322 :     | E.Sine e1 => handleSine(y,e1,params,index,args,fieldset,flag)
323 :     | E.ArcSine e1 => handleArcSine(y,e1,params,index,args,fieldset,flag)
324 :     | E.PowInt e1 => err(" PowInt unsupported")
325 :     | E.PowReal e1 => err(" PowReal unsupported")
326 :     | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args,fieldset,flag)
327 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args,fieldset,flag)
328 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default
329 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default
330 :     | E.Sum(_,E.Probe(E.Conv _,_)) => default
331 : cchiw 3166 | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args,fieldset,flag)
332 :     | E.Sum(sx,E.Delta d) => handleSumProd(y,[E.Delta d],params,index,sx,args,fieldset,flag)
333 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str)
334 : cchiw 3166 | E.Add e1 => handleAdd(y,e1,params,index,args,fieldset,flag)
335 :     | E.Prod e1 => handleProd(y,e1,params,index,args,fieldset,flag)
336 : cchiw 2838 | E.Partial _ => err(" Partial used after normalize")
337 :     | E.Krn _ => err("Krn used before expand")
338 :     | E.Value _ => err("Value used before expand")
339 :     | E.Img _ => err("Probe used before expand")
340 :     (*end case *))
341 : cchiw 3166 val (einapp2,newbies,fieldset) =rewrite body
342 : cchiw 2838 in
343 : cchiw 3166 ((einapp2,newbies),fieldset)
344 : cchiw 2838 end
345 : cchiw 3166 |split((y,app),fieldset,_) =(((y,app),[]),fieldset)
346 : cchiw 2923
347 : cchiw 3166
348 :     fun iterMultiple(einapp2,newbies2,fieldset)=let
349 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code)
350 : cchiw 3166 | itercode(e1::newbies,rest,code,cnt)=let
351 :     val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
352 :     val (rest4,code4)=itercode(code3,[],[],cnt+1)
353 :     in
354 :     itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
355 :     end
356 :     val(rest,code)= itercode(newbies2,[],[],1)
357 :     in
358 :     ((code)@rest@[einapp2])
359 :     end
360 :    
361 :    
362 :     fun iterAll(einapp2,fieldset)=let
363 :     fun itercode([],rest,code,_)=(rest,code)
364 : cchiw 3017 | itercode(e1::newbies,rest,code,cnt)=let
365 : cchiw 3166 val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
366 : cchiw 3017 val (rest4,code4)=itercode(code3,[],[],cnt+1)
367 : cchiw 3166 in
368 :     itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
369 : cchiw 2838 end
370 : cchiw 3166 val(rest,code)= itercode(einapp2,[],[],0)
371 : cchiw 2838 in
372 : cchiw 3166 (code@rest)
373 : cchiw 2838 end
374 :    
375 : cchiw 3166 fun splitEinApp einapp3= let
376 :     val fieldset= einSet.EinSet.empty
377 : cchiw 3017
378 : cchiw 3166 (* **** split in parts **** *)
379 :     (*
380 :     val ((einapp4,newbies4),fieldset)=split(einapp3,fieldset,0)
381 :     val _ =testp["\n\t===>\n",printEINAPP(einapp4),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies4))]
382 :     val (newbies5)= iterMultiple(einapp4,newbies4,fieldset)
383 :     *)
384 :    
385 :     (* **** split all at once **** *)
386 :     val (newbies5)= iterAll([einapp3],fieldset)
387 :    
388 :     in
389 :     newbies5
390 :     end
391 :    
392 :    
393 : cchiw 2838 end; (* local *)
394 :    
395 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0