Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3030 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     *)
18 : cchiw 2838
19 :     structure Split = struct
20 :    
21 :     local
22 :    
23 :     structure E = Ein
24 :     structure DstIL = MidIL
25 :     structure DstTy = MidILTypes
26 :     structure DstV = DstIL.Var
27 :     structure P=Printer
28 :     structure cleanP=cleanParams
29 :     structure cleanI=cleanIndex
30 : cchiw 2845 structure handleE=handleEin
31 : cchiw 2838
32 :     in
33 :    
34 : cchiw 3030 val testing=1
35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
38 :     fun setEinZero y= (y,einappzero)
39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e
40 :     fun cleanIndex e =cleanI.cleanIndex e
41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e
42 :     fun isZero e=handleE.isZero e
43 : cchiw 2870 fun sweep e=handleE.sweep e
44 : cchiw 2838 fun itos i =Int.toString i
45 : cchiw 2923 fun filterSca e=Filter.filterSca e
46 : cchiw 2838 fun err str=raise Fail str
47 :     val cnt = ref 0
48 :     fun genName prefix = let
49 :     val n = !cnt
50 :     in
51 :     cnt := n+1;
52 :     String.concat[prefix, "_", Int.toString n]
53 :     end
54 :     fun testp n=(case testing
55 :     of 0=> 1
56 :     | _ =>(print(String.concat n);1)
57 :     (*end case*))
58 :    
59 :    
60 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
61 :     *lifts expression and returns replacement tensor
62 : cchiw 2843 * cleans the index and params of subexpression
63 :     *creates new param and replacement tensor for the original ein_exp
64 : cchiw 2838 *)
65 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let
66 : cchiw 2867
67 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
68 : cchiw 2843 val id=length(params)
69 :     val Rparams=params@[E.TEN(1,sizes)]
70 :     val Re=E.Tensor(id,tshape)
71 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
72 : cchiw 2843 val Rargs=args@[M]
73 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
74 :    
75 : cchiw 2838 in
76 : cchiw 2843 (Re,Rparams,Rargs,[einapp])
77 : cchiw 2838 end
78 :    
79 :     (* isOp: ein->int
80 :     * checks to see if this sub-expression is pulled out or split form original
81 :     * 0-becomes zero,1-remains the same, 2-operator
82 :     *)
83 :     fun isOp e =(case e
84 :     of E.Field _ => 0
85 :     | E.Conv _ => 0
86 :     | E.Apply _ => 0
87 :     | E.Lift _ => 0
88 :     | E.Neg _ => 1
89 : cchiw 2870 | E.Sqrt _ => 1
90 : cchiw 2923 | E.PowInt _ => 1
91 :     | E.PowReal _ => 1
92 : cchiw 2838 | E.Add _ => 1
93 :     | E.Sub _ => 1
94 :     | E.Prod _ => 1
95 :     | E.Div _ => 1
96 :     | E.Sum _ => 1
97 :     | E.Probe _ => 1
98 :     | E.Partial _ => err(" Partial used after normalize")
99 :     | E.Krn _ => err("Krn used before expand")
100 :     | E.Value _ => err("Value used before expand")
101 :     | E.Img _ => err("Probe used before expand")
102 :     | _ => 2
103 :     (*end case*))
104 :    
105 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
106 :     * If e1 an op then call lift() to replace it
107 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same
108 :     *)
109 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1)
110 : cchiw 2838 of 0 => (E.Const 0,params,args,[])
111 :     | 2 => (e1,params,args,[])
112 : cchiw 2867 | _ => lift(name,e1,params,index,sx,args)
113 : cchiw 2838 (*end*))
114 :    
115 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
116 :     -> ein_exp list*params*args*code
117 :     * calls rewriteOp on ein_exp list
118 : cchiw 2838 *)
119 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let
120 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code)
121 :     | m(e1::es,rest,params,args,code)=let
122 : cchiw 3017
123 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
124 : cchiw 3030 val _ =print("rewriteOP:\n"^P.printbody e1^"\n\t=>"^ P.printbody e1')
125 : cchiw 2838 in
126 :     m(es,rest@[e1'],params',args',code@code')
127 :     end
128 :     in
129 :     m(list1,[],params,args,[])
130 :     end
131 : cchiw 2843
132 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
133 : cchiw 2843 When the operation is zero then we return a real.
134 : cchiw 2845 -Moved is Zero to before split.
135 : cchiw 2843 *)
136 : cchiw 2845 fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body)
137 :     of 1=> setEinZero y
138 : cchiw 2843 | _ => cleanParams(y,body,params,index,args)
139 : cchiw 2845 (*end case*))
140 : cchiw 2838
141 :     (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
142 : cchiw 2843 * calls rewriteOp() lift on ein_exp
143 : cchiw 2838 *)
144 :     fun handleNeg(y,e1,params,index,args)=let
145 : cchiw 2922 val (e1',params',args',code)= rewriteOp("neg", e1,params,index,[],args)
146 : cchiw 2838 val body =E.Neg e1'
147 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
148 :     in
149 :     (einapp,code)
150 :     end
151 : cchiw 2838
152 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
153 :     * calls rewriteOp() lift on ein_exp
154 :     *)
155 :     fun handleSqrt(y,e1,params,index,args)=let
156 : cchiw 2922 val (e1',params',args',code)= rewriteOp("sqrt", e1,params,index,[],args)
157 : cchiw 2867 val body =E.Sqrt e1'
158 :     val einapp= rewriteOrig(y,body,params',index,[],args')
159 :     in
160 :     (einapp,code)
161 :     end
162 :    
163 :    
164 : cchiw 2870 (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code
165 :     * calls rewriteOp() lift on ein_exp
166 :     *)
167 :     fun handlePowInt(y,(e1,n1),params,index,args)=let
168 : cchiw 2922 val (e1',params',args',code)= rewriteOp("powint", e1,params,index,[],args)
169 : cchiw 2870 val body =E.PowInt(e1',n1)
170 :     val einapp= rewriteOrig(y,body,params',index,[],args')
171 :     in
172 :     (einapp,code)
173 :     end
174 :    
175 :    
176 : cchiw 2922 (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code
177 :     * calls rewriteOp() lift on ein_exp
178 :     *)
179 :     fun handlePowReal(y,(e1,n1),params,index,args)=let
180 :     val (e1',params',args',code)= rewriteOp("powreal", e1,params,index,[],args)
181 :     val body =E.PowReal(e1',n1)
182 :     val einapp= rewriteOrig(y,body,params',index,[],args')
183 :     in
184 :     (einapp,code)
185 :     end
186 : cchiw 2870
187 :    
188 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
189 : cchiw 2843 * calls rewriteOps() lift on ein_exp
190 : cchiw 2838 *)
191 :     fun handleSub(y,e1,e2,params,index,args)=let
192 : cchiw 2922 val ([e1',e2'],params',args',code)= rewriteOps("subt",[e1,e2],params,index,[],args)
193 : cchiw 2838 val body =E.Sub(e1',e2')
194 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
195 :     in
196 :     (einapp,code)
197 :     end
198 : cchiw 2838
199 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
200 : cchiw 2843 * calls rewriteOp() lift on ein_exp
201 : cchiw 2838 *)
202 :     fun handleDiv(y,e1,e2,params,index,args)=let
203 : cchiw 2923 val (e1',params1',args1',code1')=rewriteOp("div-num",e1,params,index,[],args)
204 :     val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',index,[],args1')
205 :     (*val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',[],[],args1')*)
206 : cchiw 2838 val body =E.Div(e1',e2')
207 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2')
208 :     in
209 :     (einapp,code1'@code2')
210 :     end
211 : cchiw 2838
212 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
213 : cchiw 2843 * calls rewriteOps() lift on ein_exp
214 : cchiw 2838 *)
215 :     fun handleAdd(y,e1,params,index,args)=let
216 : cchiw 3030
217 : cchiw 2922 val (e1',params',args',code)= rewriteOps("add",e1,params,index,[],args)
218 : cchiw 2838 val body =E.Add e1'
219 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
220 :     in
221 :     (einapp,code)
222 :     end
223 : cchiw 2838
224 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
225 : cchiw 2843 * calls rewriteOps() lift on ein_exp
226 : cchiw 2838 *)
227 :     fun handleProd(y,e1,params,index,args)=let
228 : cchiw 2922 val (e1',params',args',code)= rewriteOps("prod",e1,params,index,[],args)
229 : cchiw 2838 val body =E.Prod e1'
230 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
231 :     in
232 :     (einapp,code)
233 :     end
234 : cchiw 2838
235 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
236 : cchiw 2843 * calls rewriteOps() lift on ein_exp
237 : cchiw 2838 *)
238 :     fun handleSumProd(y,e1,params,index,sx,args)=let
239 : cchiw 2922 val (e1',params',args',code)= rewriteOps("sumprod",e1,params,index,sx,args)
240 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
241 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args')
242 :     in
243 :     (einapp,code)
244 :     end
245 : cchiw 2838
246 :     (* split:var*ein_app-> (var*einap)*code
247 :     * split ein expression into smaller pieces
248 : cchiw 2843 note we leave summation around probe exp
249 : cchiw 2838 *)
250 :     fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
251 : cchiw 2843 val zero= (setEinZero y,[])
252 : cchiw 2838 val default=((y,einapp),[])
253 :     val sumIndex=ref []
254 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
255 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body]
256 :     fun rewrite b=(case b
257 : cchiw 2847 of E.Probe (E.Conv _,_) => default
258 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str
259 : cchiw 2847 | E.Probe _ => raise Fail str
260 : cchiw 2838 | E.Conv _ => zero
261 :     | E.Field _ => zero
262 :     | E.Apply _ => zero
263 :     | E.Lift e => zero
264 :     | E.Delta _ => default
265 :     | E.Epsilon _ => default
266 : cchiw 2843 | E.Eps2 _ => default
267 : cchiw 2838 | E.Tensor _ => default
268 :     | E.Const _ => default
269 : cchiw 2923 | E.ConstR _ => default
270 : cchiw 2838 | E.Neg e1 => handleNeg(y,e1,params,index,args)
271 : cchiw 2867 | E.Sqrt e1 => handleSqrt(y,e1,params,index,args)
272 : cchiw 2923 | E.PowInt e1 => handlePowInt(y,e1,params,index,args)
273 :     | E.PowReal e1 => handlePowReal(y,e1,params,index,args)
274 : cchiw 2838 | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args)
275 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args)
276 : cchiw 2923 | E.Sum(sx,E.Tensor(id,[]))=> rewrite (E.Tensor(id,[]))
277 :     | E.Sum(sx,E.Const c) =>rewrite ( E.Const c )
278 :     | E.Sum(sx,E.ConstR r) => rewrite (E.ConstR r)
279 : cchiw 2922 | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n)))
280 :     | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
281 :     | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
282 : cchiw 3017 | E.Sum(sx,E.Div(e1,e2)) => rewrite (E.Sum(sx,E.Prod[e1,E.Div(E.Const 1,e2)]))
283 : cchiw 2922 | E.Sum(sx,E.Lift e ) => rewrite (E.Lift(E.Sum(sx,e)))
284 :     | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1))
285 :     | E.Sum(sx,E.Sqrt e) => rewrite (E.Sqrt(E.Sum(sx,e)))
286 :     | E.Sum(c1,E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
287 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default
288 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default
289 :     | E.Sum(_,E.Probe(E.Conv _,_)) => default
290 : cchiw 2838 | E.Sum(_,E.Conv _) => zero
291 :     | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args)
292 :     | E.Sum(sx,_) => default
293 :     | E.Add e1 => handleAdd(y,e1,params,index,args)
294 :     | E.Prod e1 => handleProd(y,e1,params,index,args)
295 :     | E.Partial _ => err(" Partial used after normalize")
296 :     | E.Krn _ => err("Krn used before expand")
297 :     | E.Value _ => err("Value used before expand")
298 :     | E.Img _ => err("Probe used before expand")
299 :     (*end case *))
300 :     val (einapp2,newbies) =rewrite body
301 :     in
302 :     (einapp2,newbies)
303 :     end
304 :     |split(y,app) =((y,app),[])
305 : cchiw 2923
306 :    
307 :     (*Distribute summation if needed*)
308 :     fun distributeSummation(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
309 : cchiw 3017 fun rewrite b=(case b
310 : cchiw 2923 of E.Sum(sx,E.Tensor(id,[])) => E.Tensor(id,[])
311 :     | E.Sum(sx,E.Const c) => E.Const c
312 :     | E.Sum(sx,E.ConstR r) => E.ConstR r
313 :     | E.Sum(sx,E.Neg n) => rewrite(E.Neg(E.Sum(sx,n)))
314 :     | E.Sum(sx,E.Add a) => rewrite(E.Add(List.map (fn e=> E.Sum(sx,e)) a))
315 :     | E.Sum(sx,E.Sub (e1,e2)) => rewrite(E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
316 : cchiw 3017 (* | E.Sum(sx,E.Div(e1,e2)) => rewrite( E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))*)
317 : cchiw 2923 | E.Sum(sx,E.Lift e ) => rewrite (E.Lift(E.Sum(sx,e)))
318 :     | E.Sum(sx,E.PowReal(e,n1)) => rewrite(E.PowReal(E.Sum(sx,e),n1))
319 :     | E.Sum(sx,E.Sqrt e) => rewrite(E.Sqrt(E.Sum(sx,e)))
320 :     | E.Sum(sx,E.Sum (c2,e)) => rewrite (E.Sum (sx@c2,e))
321 :     | E.Sum(sx,E.Prod p) => let
322 :     val (c,e)=filterSca(sx,p)
323 : cchiw 3017 in e end
324 : cchiw 2923 | E.Div(e1,e2) => E.Div(rewrite e1, rewrite e2)
325 : cchiw 3030 | E.Sub(e1,E.Const 0) => rewrite e1
326 : cchiw 2923 | E.Sub(e1,e2) => E.Sub(rewrite e1, rewrite e2)
327 :     | E.Add es => E.Add(List.map rewrite es)
328 :     | E.Prod es => E.Prod(List.map rewrite es)
329 :     | E.Neg e => E.Neg(rewrite e)
330 :     | E.Sqrt e => E.Sqrt(rewrite e)
331 :     | _ => b
332 :     (*end case*))
333 :     val body =rewrite body
334 : cchiw 3017 val _ =testp["\nAfter distributeSummation \n",P.printbody body]
335 : cchiw 2923 val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=body})
336 :     val einapp2= (y,DstIL.EINAPP(ein,args))
337 :     in
338 :     split(einapp2)
339 :     end
340 :     |distributeSummation(y,app) =((y,app),[])
341 :    
342 :    
343 :    
344 : cchiw 2838 (* iterMultiple:code*code=> (code*code)
345 :     * recursively split ein expression into smaller pieces
346 :     *)
347 :     fun iterMultiple(einapp2,newbies2)=let
348 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code)
349 :     | itercode(e1::newbies,rest,code,cnt)=let
350 :     val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"]
351 :     val (einapp3,code3) = distributeSummation e1
352 :     val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))]
353 :     val (rest4,code4)=itercode(code3,[],[],cnt+1)
354 :     in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)
355 : cchiw 2838 end
356 : cchiw 3017 val(rest,code)= itercode(newbies2,[],[],1)
357 : cchiw 2838 in
358 :     (einapp2,code@rest)
359 :     end
360 :    
361 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
362 : cchiw 2923 val bodysweep=handleE.sweep body
363 : cchiw 3017 val _=testp["\nPresweep\n",P.printbody body,"\n\n Sweep\n",P.printbody bodysweep,"\n"]
364 : cchiw 2923 val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=bodysweep})
365 :     val _=testp["\n\n Clean Summation\n",P.printbody(Ein.body ein),"\n"]
366 :     val einapp2=(y,DstIL.EINAPP(ein, args))
367 : cchiw 3017 val _ =testp["\n\n******* split term **",Int.toString (0)," ***** \n \t==>\n",printEINAPP(einapp2)]
368 : cchiw 2923 val (einapp3,newbies2)=distributeSummation einapp2
369 : cchiw 3017 val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies2))]
370 :    
371 : cchiw 2838 in
372 : cchiw 2870 iterMultiple(einapp3,newbies2)
373 : cchiw 2838 end
374 :    
375 :     (* gettest:code*code=> (code*code)
376 :     * print results for splitting einapp
377 :     *)
378 : cchiw 2845 fun gettest einapp=(case testing
379 : cchiw 2838 of 0=>iterSplit(einapp)
380 :     | _=>let
381 : cchiw 3017 val star="\n************* SPLIT INITIAL********\n"
382 : cchiw 2923 val _ =testp[star,"\n","start get test",printEINAPP einapp]
383 : cchiw 2838 val (einapp2,newbies)=iterSplit(einapp)
384 : cchiw 2923 val _ =testp["\n\n Returning \n\n =>",printEINAPP einapp2,
385 :     " newbies\n\t",String.concatWith",\n\t"(List.map printEINAPP newbies), "\n",star]
386 : cchiw 2838 in
387 :     (einapp2,newbies)
388 :     end
389 :     (*end case*))
390 :    
391 :     end; (* local *)
392 :    
393 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0