Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3138 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     *)
18 : cchiw 2838
19 :     structure Split = struct
20 :    
21 :     local
22 :    
23 :     structure E = Ein
24 :     structure DstIL = MidIL
25 :     structure DstTy = MidILTypes
26 :     structure DstV = DstIL.Var
27 :     structure P=Printer
28 :     structure cleanP=cleanParams
29 :     structure cleanI=cleanIndex
30 :    
31 : cchiw 3033
32 : cchiw 2838 in
33 :    
34 : cchiw 3033 val testing=0
35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
38 :     fun setEinZero y= (y,einappzero)
39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e
40 :     fun cleanIndex e =cleanI.cleanIndex e
41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e
42 : cchiw 2838 fun itos i =Int.toString i
43 : cchiw 2923 fun filterSca e=Filter.filterSca e
44 : cchiw 2838 fun err str=raise Fail str
45 :     val cnt = ref 0
46 :     fun genName prefix = let
47 :     val n = !cnt
48 :     in
49 :     cnt := n+1;
50 :     String.concat[prefix, "_", Int.toString n]
51 :     end
52 :     fun testp n=(case testing
53 :     of 0=> 1
54 :     | _ =>(print(String.concat n);1)
55 :     (*end case*))
56 :    
57 :    
58 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
59 :     *lifts expression and returns replacement tensor
60 : cchiw 2843 * cleans the index and params of subexpression
61 :     *creates new param and replacement tensor for the original ein_exp
62 : cchiw 2838 *)
63 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let
64 : cchiw 2867
65 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
66 : cchiw 2843 val id=length(params)
67 :     val Rparams=params@[E.TEN(1,sizes)]
68 :     val Re=E.Tensor(id,tshape)
69 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
70 : cchiw 2843 val Rargs=args@[M]
71 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
72 :    
73 : cchiw 2838 in
74 : cchiw 2843 (Re,Rparams,Rargs,[einapp])
75 : cchiw 2838 end
76 :    
77 :     (* isOp: ein->int
78 :     * checks to see if this sub-expression is pulled out or split form original
79 :     * 0-becomes zero,1-remains the same, 2-operator
80 :     *)
81 :     fun isOp e =(case e
82 :     of E.Field _ => 0
83 :     | E.Conv _ => 0
84 :     | E.Apply _ => 0
85 :     | E.Lift _ => 0
86 :     | E.Neg _ => 1
87 : cchiw 2870 | E.Sqrt _ => 1
88 : cchiw 3138 | E.Cosine _ => 1
89 :     | E.ArcCosine _ => 1
90 :     | E.Sine _ => 1
91 :     | E.ArcSine _ => 1
92 :     | E.PowInt _ => 1
93 :     | E.PowReal _ => 1
94 : cchiw 2838 | E.Add _ => 1
95 :     | E.Sub _ => 1
96 :     | E.Prod _ => 1
97 :     | E.Div _ => 1
98 :     | E.Sum _ => 1
99 :     | E.Probe _ => 1
100 :     | E.Partial _ => err(" Partial used after normalize")
101 :     | E.Krn _ => err("Krn used before expand")
102 :     | E.Value _ => err("Value used before expand")
103 :     | E.Img _ => err("Probe used before expand")
104 :     | _ => 2
105 :     (*end case*))
106 :    
107 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
108 :     * If e1 an op then call lift() to replace it
109 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same
110 :     *)
111 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1)
112 : cchiw 2838 of 0 => (E.Const 0,params,args,[])
113 :     | 2 => (e1,params,args,[])
114 : cchiw 2867 | _ => lift(name,e1,params,index,sx,args)
115 : cchiw 2838 (*end*))
116 :    
117 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
118 :     -> ein_exp list*params*args*code
119 :     * calls rewriteOp on ein_exp list
120 : cchiw 2838 *)
121 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let
122 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code)
123 :     | m(e1::es,rest,params,args,code)=let
124 : cchiw 3017
125 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
126 : cchiw 3033 (*val _ =testp["rewriteOP:\n",P.printbody e1,"\n\t=>", P.printbody e1']*)
127 : cchiw 2838 in
128 :     m(es,rest@[e1'],params',args',code@code')
129 :     end
130 :     in
131 :     m(list1,[],params,args,[])
132 :     end
133 : cchiw 2843
134 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
135 : cchiw 2843 When the operation is zero then we return a real.
136 : cchiw 2845 -Moved is Zero to before split.
137 : cchiw 2843 *)
138 : cchiw 3033 fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args)
139 : cchiw 2838
140 :     (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
141 : cchiw 2843 * calls rewriteOp() lift on ein_exp
142 : cchiw 2838 *)
143 :     fun handleNeg(y,e1,params,index,args)=let
144 : cchiw 2922 val (e1',params',args',code)= rewriteOp("neg", e1,params,index,[],args)
145 : cchiw 2838 val body =E.Neg e1'
146 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
147 :     in
148 :     (einapp,code)
149 :     end
150 : cchiw 2838
151 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
152 :     * calls rewriteOp() lift on ein_exp
153 :     *)
154 :     fun handleSqrt(y,e1,params,index,args)=let
155 : cchiw 2922 val (e1',params',args',code)= rewriteOp("sqrt", e1,params,index,[],args)
156 : cchiw 2867 val body =E.Sqrt e1'
157 :     val einapp= rewriteOrig(y,body,params',index,[],args')
158 :     in
159 :     (einapp,code)
160 :     end
161 :    
162 :    
163 : cchiw 3138 (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code
164 :     * calls rewriteOp() lift on ein_exp
165 :     *)
166 :     fun handleCosine(y,e1,params,index,args)=let
167 :     val (e1',params',args',code)= rewriteOp("cosine", e1,params,index,[],args)
168 :     val body =E.Cosine e1'
169 :     val einapp= rewriteOrig(y,body,params',index,[],args')
170 :     in
171 :     (einapp,code)
172 :     end
173 :    
174 :     (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code
175 :     * calls rewriteOp() lift on ein_exp
176 :     *)
177 :     fun handleArcCosine(y,e1,params,index,args)=let
178 :     val (e1',params',args',code)= rewriteOp("ArcCosine", e1,params,index,[],args)
179 :     val body =E.ArcCosine e1'
180 :     val einapp= rewriteOrig(y,body,params',index,[],args')
181 :     in
182 :     (einapp,code)
183 :     end
184 :    
185 :     (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
186 :     * calls rewriteOp() lift on ein_exp
187 :     *)
188 :     fun handleSine(y,e1,params,index,args)=let
189 :     val (e1',params',args',code)= rewriteOp("sine", e1,params,index,[],args)
190 :     val body =E.Sine e1'
191 :     val einapp= rewriteOrig(y,body,params',index,[],args')
192 :     in
193 :     (einapp,code)
194 :     end
195 :    
196 :     (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
197 :     * calls rewriteOp() lift on ein_exp
198 :     *)
199 :     fun handleArcSine(y,e1,params,index,args)=let
200 :     val (e1',params',args',code)= rewriteOp("ArcSine", e1,params,index,[],args)
201 :     val body =E.ArcSine e1'
202 :     val einapp= rewriteOrig(y,body,params',index,[],args')
203 :     in
204 :     (einapp,code)
205 :     end
206 :    
207 :    
208 : cchiw 3033 (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code
209 :     * calls rewriteOp() lift on ein_exp
210 :     *)
211 :     fun handlePowInt(y,(e1,n1),params,index,args)=let
212 :     val (e1',params',args',code)= rewriteOp("powint", e1,params,index,[],args)
213 :     val body =E.PowInt(e1',n1)
214 :     val einapp= rewriteOrig(y,body,params',index,[],args')
215 :     in
216 :     (einapp,code)
217 :     end
218 : cchiw 2870
219 :    
220 : cchiw 2922 (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code
221 :     * calls rewriteOp() lift on ein_exp
222 :     *)
223 :     fun handlePowReal(y,(e1,n1),params,index,args)=let
224 :     val (e1',params',args',code)= rewriteOp("powreal", e1,params,index,[],args)
225 :     val body =E.PowReal(e1',n1)
226 :     val einapp= rewriteOrig(y,body,params',index,[],args')
227 :     in
228 :     (einapp,code)
229 :     end
230 : cchiw 2870
231 :    
232 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
233 : cchiw 2843 * calls rewriteOps() lift on ein_exp
234 : cchiw 2838 *)
235 :     fun handleSub(y,e1,e2,params,index,args)=let
236 : cchiw 2922 val ([e1',e2'],params',args',code)= rewriteOps("subt",[e1,e2],params,index,[],args)
237 : cchiw 2838 val body =E.Sub(e1',e2')
238 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
239 :     in
240 :     (einapp,code)
241 :     end
242 : cchiw 2838
243 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
244 : cchiw 2843 * calls rewriteOp() lift on ein_exp
245 : cchiw 2838 *)
246 :     fun handleDiv(y,e1,e2,params,index,args)=let
247 : cchiw 2923 val (e1',params1',args1',code1')=rewriteOp("div-num",e1,params,index,[],args)
248 :     val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',index,[],args1')
249 :     (*val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',[],[],args1')*)
250 : cchiw 2838 val body =E.Div(e1',e2')
251 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2')
252 :     in
253 :     (einapp,code1'@code2')
254 :     end
255 : cchiw 2838
256 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
257 : cchiw 2843 * calls rewriteOps() lift on ein_exp
258 : cchiw 2838 *)
259 :     fun handleAdd(y,e1,params,index,args)=let
260 : cchiw 3030
261 : cchiw 2922 val (e1',params',args',code)= rewriteOps("add",e1,params,index,[],args)
262 : cchiw 2838 val body =E.Add e1'
263 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
264 :     in
265 :     (einapp,code)
266 :     end
267 : cchiw 2838
268 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
269 : cchiw 2843 * calls rewriteOps() lift on ein_exp
270 : cchiw 2838 *)
271 :     fun handleProd(y,e1,params,index,args)=let
272 : cchiw 2922 val (e1',params',args',code)= rewriteOps("prod",e1,params,index,[],args)
273 : cchiw 2838 val body =E.Prod e1'
274 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
275 :     in
276 :     (einapp,code)
277 :     end
278 : cchiw 2838
279 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
280 : cchiw 2843 * calls rewriteOps() lift on ein_exp
281 : cchiw 2838 *)
282 :     fun handleSumProd(y,e1,params,index,sx,args)=let
283 : cchiw 2922 val (e1',params',args',code)= rewriteOps("sumprod",e1,params,index,sx,args)
284 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
285 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args')
286 :     in
287 :     (einapp,code)
288 :     end
289 : cchiw 2838
290 :     (* split:var*ein_app-> (var*einap)*code
291 :     * split ein expression into smaller pieces
292 : cchiw 2843 note we leave summation around probe exp
293 : cchiw 2838 *)
294 :     fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
295 : cchiw 2843 val zero= (setEinZero y,[])
296 : cchiw 2838 val default=((y,einapp),[])
297 :     val sumIndex=ref []
298 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
299 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body]
300 :     fun rewrite b=(case b
301 : cchiw 2847 of E.Probe (E.Conv _,_) => default
302 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str
303 : cchiw 2847 | E.Probe _ => raise Fail str
304 : cchiw 2838 | E.Conv _ => zero
305 :     | E.Field _ => zero
306 :     | E.Apply _ => zero
307 :     | E.Lift e => zero
308 :     | E.Delta _ => default
309 :     | E.Epsilon _ => default
310 : cchiw 2843 | E.Eps2 _ => default
311 : cchiw 2838 | E.Tensor _ => default
312 :     | E.Const _ => default
313 : cchiw 2923 | E.ConstR _ => default
314 : cchiw 2838 | E.Neg e1 => handleNeg(y,e1,params,index,args)
315 : cchiw 2867 | E.Sqrt e1 => handleSqrt(y,e1,params,index,args)
316 : cchiw 3138 | E.Cosine e1 => handleCosine(y,e1,params,index,args)
317 :     | E.ArcCosine e1 => handleArcCosine(y,e1,params,index,args)
318 :     | E.Sine e1 => handleSine(y,e1,params,index,args)
319 :     | E.ArcSine e1 => handleArcSine(y,e1,params,index,args)
320 : cchiw 2923 | E.PowInt e1 => handlePowInt(y,e1,params,index,args)
321 :     | E.PowReal e1 => handlePowReal(y,e1,params,index,args)
322 : cchiw 2838 | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args)
323 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args)
324 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default
325 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default
326 :     | E.Sum(_,E.Probe(E.Conv _,_)) => default
327 : cchiw 2838 | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args)
328 : cchiw 3138 | E.Sum(sx,E.Delta d) => handleSumProd(y,[E.Delta d],params,index,sx,args)
329 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str)
330 : cchiw 2838 | E.Add e1 => handleAdd(y,e1,params,index,args)
331 :     | E.Prod e1 => handleProd(y,e1,params,index,args)
332 :     | E.Partial _ => err(" Partial used after normalize")
333 :     | E.Krn _ => err("Krn used before expand")
334 :     | E.Value _ => err("Value used before expand")
335 :     | E.Img _ => err("Probe used before expand")
336 :     (*end case *))
337 :     val (einapp2,newbies) =rewrite body
338 :     in
339 :     (einapp2,newbies)
340 :     end
341 :     |split(y,app) =((y,app),[])
342 : cchiw 2923
343 : cchiw 2838 (* iterMultiple:code*code=> (code*code)
344 :     * recursively split ein expression into smaller pieces
345 :     *)
346 :     fun iterMultiple(einapp2,newbies2)=let
347 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code)
348 :     | itercode(e1::newbies,rest,code,cnt)=let
349 :     val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"]
350 : cchiw 3033 val (einapp3,code3) = split e1
351 : cchiw 3017 val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))]
352 :     val (rest4,code4)=itercode(code3,[],[],cnt+1)
353 :     in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)
354 : cchiw 2838 end
355 : cchiw 3017 val(rest,code)= itercode(newbies2,[],[],1)
356 : cchiw 2838 in
357 :     (einapp2,code@rest)
358 :     end
359 :    
360 : cchiw 3017
361 : cchiw 2838 end; (* local *)
362 :    
363 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0