Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2922 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     *)
18 : cchiw 2838
19 :     structure Split = struct
20 :    
21 :     local
22 :    
23 :     structure E = Ein
24 :     structure DstIL = MidIL
25 :     structure DstTy = MidILTypes
26 :     structure DstV = DstIL.Var
27 :     structure P=Printer
28 :     structure cleanP=cleanParams
29 :     structure cleanI=cleanIndex
30 : cchiw 2845 structure handleE=handleEin
31 : cchiw 2838
32 :     in
33 :    
34 : cchiw 2847 val testing=0
35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
38 :     fun setEinZero y= (y,einappzero)
39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e
40 :     fun cleanIndex e =cleanI.cleanIndex e
41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e
42 :     fun isZero e=handleE.isZero e
43 : cchiw 2870 fun sweep e=handleE.sweep e
44 : cchiw 2838 fun itos i =Int.toString i
45 :     fun err str=raise Fail str
46 :     val cnt = ref 0
47 :     fun genName prefix = let
48 :     val n = !cnt
49 :     in
50 :     cnt := n+1;
51 :     String.concat[prefix, "_", Int.toString n]
52 :     end
53 :     fun testp n=(case testing
54 :     of 0=> 1
55 :     | _ =>(print(String.concat n);1)
56 :     (*end case*))
57 :    
58 :    
59 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
60 :     *lifts expression and returns replacement tensor
61 : cchiw 2843 * cleans the index and params of subexpression
62 :     *creates new param and replacement tensor for the original ein_exp
63 : cchiw 2838 *)
64 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let
65 : cchiw 2867
66 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
67 : cchiw 2843 val id=length(params)
68 :     val Rparams=params@[E.TEN(1,sizes)]
69 :     val Re=E.Tensor(id,tshape)
70 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
71 : cchiw 2843 val Rargs=args@[M]
72 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
73 :    
74 : cchiw 2838 in
75 : cchiw 2843 (Re,Rparams,Rargs,[einapp])
76 : cchiw 2838 end
77 :    
78 :     (* isOp: ein->int
79 :     * checks to see if this sub-expression is pulled out or split form original
80 :     * 0-becomes zero,1-remains the same, 2-operator
81 :     *)
82 :     fun isOp e =(case e
83 :     of E.Field _ => 0
84 :     | E.Conv _ => 0
85 :     | E.Apply _ => 0
86 :     | E.Lift _ => 0
87 :     | E.Neg _ => 1
88 : cchiw 2870 | E.Sqrt _ => 1
89 :     | E.PowInt _ => 1
90 :     | E.PowReal _ => 1
91 : cchiw 2838 | E.Add _ => 1
92 :     | E.Sub _ => 1
93 :     | E.Prod _ => 1
94 :     | E.Div _ => 1
95 :     | E.Sum _ => 1
96 :     | E.Probe _ => 1
97 :     | E.Partial _ => err(" Partial used after normalize")
98 :     | E.Krn _ => err("Krn used before expand")
99 :     | E.Value _ => err("Value used before expand")
100 :     | E.Img _ => err("Probe used before expand")
101 :     | _ => 2
102 :     (*end case*))
103 :    
104 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
105 :     * If e1 an op then call lift() to replace it
106 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same
107 :     *)
108 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1)
109 : cchiw 2838 of 0 => (E.Const 0,params,args,[])
110 :     | 2 => (e1,params,args,[])
111 : cchiw 2867 | _ => lift(name,e1,params,index,sx,args)
112 : cchiw 2838 (*end*))
113 :    
114 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
115 :     -> ein_exp list*params*args*code
116 :     * calls rewriteOp on ein_exp list
117 : cchiw 2838 *)
118 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let
119 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code)
120 :     | m(e1::es,rest,params,args,code)=let
121 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
122 : cchiw 2838 in
123 :     m(es,rest@[e1'],params',args',code@code')
124 :     end
125 :     in
126 :     m(list1,[],params,args,[])
127 :     end
128 : cchiw 2843
129 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
130 : cchiw 2843 When the operation is zero then we return a real.
131 : cchiw 2845 -Moved is Zero to before split.
132 : cchiw 2843 *)
133 : cchiw 2845 fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body)
134 :     of 1=> setEinZero y
135 : cchiw 2843 | _ => cleanParams(y,body,params,index,args)
136 : cchiw 2845 (*end case*))
137 : cchiw 2838
138 :     (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
139 : cchiw 2843 * calls rewriteOp() lift on ein_exp
140 : cchiw 2838 *)
141 :     fun handleNeg(y,e1,params,index,args)=let
142 : cchiw 2922 val (e1',params',args',code)= rewriteOp("neg", e1,params,index,[],args)
143 : cchiw 2838 val body =E.Neg e1'
144 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
145 :     in
146 :     (einapp,code)
147 :     end
148 : cchiw 2838
149 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
150 :     * calls rewriteOp() lift on ein_exp
151 :     *)
152 :     fun handleSqrt(y,e1,params,index,args)=let
153 : cchiw 2922 val (e1',params',args',code)= rewriteOp("sqrt", e1,params,index,[],args)
154 : cchiw 2867 val body =E.Sqrt e1'
155 :     val einapp= rewriteOrig(y,body,params',index,[],args')
156 :     in
157 :     (einapp,code)
158 :     end
159 :    
160 :    
161 : cchiw 2870 (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code
162 :     * calls rewriteOp() lift on ein_exp
163 :     *)
164 :     fun handlePowInt(y,(e1,n1),params,index,args)=let
165 : cchiw 2922 val (e1',params',args',code)= rewriteOp("powint", e1,params,index,[],args)
166 : cchiw 2870 val body =E.PowInt(e1',n1)
167 :     val einapp= rewriteOrig(y,body,params',index,[],args')
168 :     in
169 :     (einapp,code)
170 :     end
171 :    
172 :    
173 : cchiw 2922 (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code
174 :     * calls rewriteOp() lift on ein_exp
175 :     *)
176 :     fun handlePowReal(y,(e1,n1),params,index,args)=let
177 :     val (e1',params',args',code)= rewriteOp("powreal", e1,params,index,[],args)
178 :     val body =E.PowReal(e1',n1)
179 :     val einapp= rewriteOrig(y,body,params',index,[],args')
180 :     in
181 :     (einapp,code)
182 :     end
183 : cchiw 2870
184 :    
185 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
186 : cchiw 2843 * calls rewriteOps() lift on ein_exp
187 : cchiw 2838 *)
188 :     fun handleSub(y,e1,e2,params,index,args)=let
189 : cchiw 2922 val ([e1',e2'],params',args',code)= rewriteOps("subt",[e1,e2],params,index,[],args)
190 : cchiw 2838 val body =E.Sub(e1',e2')
191 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
192 :     in
193 :     (einapp,code)
194 :     end
195 : cchiw 2838
196 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
197 : cchiw 2843 * calls rewriteOp() lift on ein_exp
198 : cchiw 2838 *)
199 :     fun handleDiv(y,e1,e2,params,index,args)=let
200 : cchiw 2845 val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args)
201 : cchiw 2922 val (e2',params2',args2',code2')=rewriteOp("div",e2,params1',[],[],args1')
202 : cchiw 2838 val body =E.Div(e1',e2')
203 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2')
204 :     in
205 :     (einapp,code1'@code2')
206 :     end
207 : cchiw 2838
208 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
209 : cchiw 2843 * calls rewriteOps() lift on ein_exp
210 : cchiw 2838 *)
211 :     fun handleAdd(y,e1,params,index,args)=let
212 : cchiw 2922 val (e1',params',args',code)= rewriteOps("add",e1,params,index,[],args)
213 : cchiw 2838 val body =E.Add e1'
214 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
215 :     in
216 :     (einapp,code)
217 :     end
218 : cchiw 2838
219 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
220 : cchiw 2843 * calls rewriteOps() lift on ein_exp
221 : cchiw 2838 *)
222 :     fun handleProd(y,e1,params,index,args)=let
223 : cchiw 2922 val (e1',params',args',code)= rewriteOps("prod",e1,params,index,[],args)
224 : cchiw 2838 val body =E.Prod e1'
225 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
226 :     in
227 :     (einapp,code)
228 :     end
229 : cchiw 2838
230 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
231 : cchiw 2843 * calls rewriteOps() lift on ein_exp
232 : cchiw 2838 *)
233 :     fun handleSumProd(y,e1,params,index,sx,args)=let
234 : cchiw 2922 val (e1',params',args',code)= rewriteOps("sumprod",e1,params,index,sx,args)
235 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
236 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args')
237 :     in
238 :     (einapp,code)
239 :     end
240 : cchiw 2838
241 :     (* split:var*ein_app-> (var*einap)*code
242 :     * split ein expression into smaller pieces
243 : cchiw 2843 note we leave summation around probe exp
244 : cchiw 2838 *)
245 :     fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
246 : cchiw 2870
247 : cchiw 2843 val zero= (setEinZero y,[])
248 : cchiw 2838 val default=((y,einapp),[])
249 :     val sumIndex=ref []
250 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
251 : cchiw 2838 fun rewrite b=(case b
252 : cchiw 2847 of E.Probe (E.Conv _,_) => default
253 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str
254 : cchiw 2847 | E.Probe _ => raise Fail str
255 : cchiw 2838 | E.Conv _ => zero
256 :     | E.Field _ => zero
257 :     | E.Apply _ => zero
258 :     | E.Lift e => zero
259 :     | E.Delta _ => default
260 :     | E.Epsilon _ => default
261 : cchiw 2843 | E.Eps2 _ => default
262 : cchiw 2838 | E.Tensor _ => default
263 :     | E.Const _ => default
264 : cchiw 2870 | E.ConstR _ => default
265 : cchiw 2838 | E.Neg e1 => handleNeg(y,e1,params,index,args)
266 : cchiw 2867 | E.Sqrt e1 => handleSqrt(y,e1,params,index,args)
267 : cchiw 2870 | E.PowInt e1 => handlePowInt(y,e1,params,index,args)
268 :     | E.PowReal e1 => handlePowReal(y,e1,params,index,args)
269 : cchiw 2838 | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args)
270 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args)
271 : cchiw 2922 | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n)))
272 :     | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
273 :     | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
274 :     | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))
275 :     | E.Sum(sx,E.Lift e ) => rewrite (E.Lift(E.Sum(sx,e)))
276 :     | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1))
277 :     | E.Sum(sx,E.Sqrt e) => rewrite (E.Sqrt(E.Sum(sx,e)))
278 :     | E.Sum(c1,E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
279 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default
280 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default
281 :     | E.Sum(_,E.Probe(E.Conv _,_)) => default
282 : cchiw 2838 | E.Sum(_,E.Conv _) => zero
283 :     | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args)
284 :     | E.Sum(sx,_) => default
285 :     | E.Add e1 => handleAdd(y,e1,params,index,args)
286 :     | E.Prod e1 => handleProd(y,e1,params,index,args)
287 :     | E.Partial _ => err(" Partial used after normalize")
288 :     | E.Krn _ => err("Krn used before expand")
289 :     | E.Value _ => err("Value used before expand")
290 :     | E.Img _ => err("Probe used before expand")
291 :     (*end case *))
292 :     val (einapp2,newbies) =rewrite body
293 :     in
294 :     (einapp2,newbies)
295 :     end
296 :     |split(y,app) =((y,app),[])
297 :    
298 :     (* iterMultiple:code*code=> (code*code)
299 :     * recursively split ein expression into smaller pieces
300 :     *)
301 :     fun iterMultiple(einapp2,newbies2)=let
302 :     fun itercode([],rest,code)=(rest,code)
303 :     | itercode(e1::newbies,rest,code)=let
304 :     val (einapp3,code3) =split(e1)
305 :     val (rest4,code4)=itercode(code3,[],[])
306 :     in itercode(newbies,rest@[einapp3],code4@rest4@code)
307 :     end
308 :     val(rest,code)= itercode(newbies2,[],[])
309 :     in
310 :     (einapp2,code@rest)
311 :     end
312 :    
313 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
314 : cchiw 2870 (*)val _ =print(String.concat["\n******\n",printEINAPP (y,einapp)])*)
315 :     val b=handleE.sweep body
316 :     val einapp2=assignEinApp(y,params,index,b,args)
317 :     val (einapp3,newbies2)=split einapp2
318 : cchiw 2838 in
319 : cchiw 2870 iterMultiple(einapp3,newbies2)
320 : cchiw 2838 end
321 :    
322 :     (* gettest:code*code=> (code*code)
323 :     * print results for splitting einapp
324 :     *)
325 : cchiw 2845 fun gettest einapp=(case testing
326 : cchiw 2838 of 0=>iterSplit(einapp)
327 :     | _=>let
328 : cchiw 2845 val star="\n************* SPLIT********\n"
329 :     val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp])
330 : cchiw 2838 val (einapp2,newbies)=iterSplit(einapp)
331 : cchiw 2870 (*val a=printEINAPP einapp2
332 : cchiw 2838 val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
333 : cchiw 2870 val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])*)
334 : cchiw 2838 in
335 :     (einapp2,newbies)
336 :     end
337 :     (*end case*))
338 :    
339 :     end; (* local *)
340 :    
341 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0