Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3316 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     *)
18 : cchiw 2838
19 :     structure Split = struct
20 :    
21 :     local
22 :    
23 :     structure E = Ein
24 :     structure DstIL = MidIL
25 :     structure DstTy = MidILTypes
26 :     structure DstV = DstIL.Var
27 : cchiw 3174
28 : cchiw 2838 structure P=Printer
29 :     structure cleanP=cleanParams
30 :     structure cleanI=cleanIndex
31 :    
32 : cchiw 3033
33 : cchiw 2838 in
34 :    
35 : cchiw 3316 val numFlag=1 (*remove common subexpression*)
36 : cchiw 3033 val testing=0
37 : cchiw 3260 fun mkEin e = E.mkEin e
38 :     val einappzero= DstIL.EINAPP(mkEin([],[],E.Const 0),[])
39 : cchiw 2843 fun setEinZero y= (y,einappzero)
40 : cchiw 3260 fun cleanParams e = cleanP.cleanParams e
41 :     fun cleanIndex e = cleanI.cleanIndex e
42 :     fun toStringBind e= MidToString.toStringBind e
43 :     fun itos i = Int.toString i
44 :     fun err str = raise Fail str
45 : cchiw 2838 val cnt = ref 0
46 : cchiw 3166 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
47 : cchiw 2838 fun genName prefix = let
48 :     val n = !cnt
49 :     in
50 :     cnt := n+1;
51 :     String.concat[prefix, "_", Int.toString n]
52 :     end
53 :     fun testp n=(case testing
54 :     of 0=> 1
55 :     | _ =>(print(String.concat n);1)
56 :     (*end case*))
57 :    
58 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
59 :     *lifts expression and returns replacement tensor
60 : cchiw 2843 * cleans the index and params of subexpression
61 :     *creates new param and replacement tensor for the original ein_exp
62 : cchiw 2838 *)
63 : cchiw 3166 fun lift(name,e,params,index,sx,args,fieldset,flag)=let
64 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
65 : cchiw 2843 val id=length(params)
66 :     val Rparams=params@[E.TEN(1,sizes)]
67 :     val Re=E.Tensor(id,tshape)
68 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
69 : cchiw 2843 val Rargs=args@[M]
70 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
71 : cchiw 3166 val (_,einapp0)=einapp
72 :     val (Rargs,newbies,fieldset) =(case flag
73 :     of 1=> let
74 :     val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0)
75 :     in (case var
76 :     of NONE=> (args@[M],[einapp],fieldset)
77 :     | SOME v=> (incUse v ;(args@[v],[],fieldset))
78 :     (*end case*))
79 :     end
80 :     | _=>(args@[M],[einapp],fieldset)
81 :     (*end case*))
82 :     in
83 :     (Re,Rparams,Rargs,newbies,fieldset)
84 :     end
85 : cchiw 2838
86 : cchiw 3166
87 : cchiw 2838 (* isOp: ein->int
88 :     * checks to see if this sub-expression is pulled out or split form original
89 :     * 0-becomes zero,1-remains the same, 2-operator
90 :     *)
91 :     fun isOp e =(case e
92 :     of E.Field _ => 0
93 :     | E.Conv _ => 0
94 :     | E.Apply _ => 0
95 :     | E.Lift _ => 0
96 :     | E.Neg _ => 1
97 : cchiw 2870 | E.Sqrt _ => 1
98 : cchiw 3138 | E.Cosine _ => 1
99 :     | E.ArcCosine _ => 1
100 :     | E.Sine _ => 1
101 : cchiw 3166 | E.ArcSine _ => 1
102 : cchiw 3138 | E.PowInt _ => 1
103 :     | E.PowReal _ => 1
104 : cchiw 2838 | E.Add _ => 1
105 :     | E.Sub _ => 1
106 :     | E.Prod _ => 1
107 :     | E.Div _ => 1
108 :     | E.Sum _ => 1
109 :     | E.Probe _ => 1
110 :     | E.Partial _ => err(" Partial used after normalize")
111 :     | E.Krn _ => err("Krn used before expand")
112 :     | E.Value _ => err("Value used before expand")
113 :     | E.Img _ => err("Probe used before expand")
114 :     | _ => 2
115 :     (*end case*))
116 :    
117 : cchiw 3166 fun rewriteOp3(name,sx,e1,x)=let
118 :     val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
119 :     val params=Ein.params ein
120 :     val index=Ein.index ein
121 :     in (case (isOp e1)
122 :     of 0 => (E.Const 0,params,args,[],fieldset)
123 :     | 1 => lift(name,e1,params,index,sx,args,fieldset,flag)
124 :     | 2 => (e1,params,args,[],fieldset)
125 :     (*end*))
126 :     end
127 :    
128 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
129 : cchiw 3166 * If e1 an op then call lift() to replace it
130 :     *)
131 :     fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1)
132 :     of 0 => (E.Const 0,params,args,[],fieldset)
133 :     | 1 => lift(name,e1,params,index,sx,args,fieldset,flag)
134 :     | 2 => (e1,params,args,[],fieldset) (*not lifted*)
135 : cchiw 2838 (*end*))
136 :    
137 : cchiw 3166 fun rewriteOps(name,list1,params,index,sx,args,fieldset0,flag)=let
138 :     fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset)
139 :     | m(e1::es,rest,params,args,code,fieldset)=let
140 :    
141 :     val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
142 : cchiw 2838 in
143 : cchiw 3166 m(es,rest@[e1'],params',args',code@code',fieldset)
144 : cchiw 2838 end
145 :     in
146 : cchiw 3166 m(list1,[],params,args,[],fieldset0)
147 : cchiw 2838 end
148 : cchiw 3166
149 :    
150 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
151 : cchiw 2843 When the operation is zero then we return a real.
152 : cchiw 2845 -Moved is Zero to before split.
153 : cchiw 2843 *)
154 : cchiw 3033 fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args)
155 : cchiw 2838
156 : cchiw 3166 fun rewriteOrig3(sx,body,params,args,x) =let
157 :     val ((y,DstIL.EINAPP(ein,_)),_,_)=x
158 :     val index=Ein.index ein
159 :     in cleanParams(y,body,params,index,args)
160 :     end
161 :    
162 : cchiw 2838 (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
163 : cchiw 2843 * calls rewriteOp() lift on ein_exp
164 : cchiw 2838 *)
165 : cchiw 3166 fun handleNeg(e1,x)=let
166 :     val (e1',params',args',code,fieldset)= rewriteOp3("neg",[],e1,x)
167 :     val body' =E.Neg e1'
168 :     val einapp= rewriteOrig3([],body',params',args',x)
169 : cchiw 2845 in
170 : cchiw 3166 (einapp,code,fieldset)
171 : cchiw 2845 end
172 : cchiw 2838
173 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
174 :     * calls rewriteOp() lift on ein_exp
175 :     *)
176 : cchiw 3166 fun handleSqrt(y,e1,params,index,args,fieldset,flag)=let
177 :     val (e1',params',args',code,fieldset)= rewriteOp("sqrt", e1,params,index,[],args,fieldset,flag)
178 : cchiw 2867 val body =E.Sqrt e1'
179 :     val einapp= rewriteOrig(y,body,params',index,[],args')
180 :     in
181 : cchiw 3166 (einapp,code,fieldset)
182 : cchiw 2867 end
183 :    
184 :    
185 : cchiw 3138 (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code
186 :     * calls rewriteOp() lift on ein_exp
187 :     *)
188 : cchiw 3166 fun handleCosine(y,e1,params,index,args,fieldset,flag)=let
189 :     val (e1',params',args',code,fieldset)= rewriteOp("cosine", e1,params,index,[],args,fieldset,flag)
190 : cchiw 3138 val body =E.Cosine e1'
191 :     val einapp= rewriteOrig(y,body,params',index,[],args')
192 :     in
193 : cchiw 3166 (einapp,code,fieldset)
194 : cchiw 3138 end
195 :    
196 :     (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code
197 :     * calls rewriteOp() lift on ein_exp
198 :     *)
199 : cchiw 3166 fun handleArcCosine(y,e1,params,index,args,fieldset,flag)=let
200 :     val (e1',params',args',code,fieldset)= rewriteOp("ArcCosine", e1,params,index,[],args,fieldset,flag)
201 : cchiw 3138 val body =E.ArcCosine e1'
202 :     val einapp= rewriteOrig(y,body,params',index,[],args')
203 :     in
204 : cchiw 3166 (einapp,code,fieldset)
205 : cchiw 3138 end
206 :    
207 :     (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
208 :     * calls rewriteOp() lift on ein_exp
209 :     *)
210 : cchiw 3166 fun handleSine(y,e1,params,index,args,fieldset,flag)=let
211 :     val (e1',params',args',code,fieldset)= rewriteOp("sine", e1,params,index,[],args,fieldset,flag)
212 : cchiw 3138 val body =E.Sine e1'
213 :     val einapp= rewriteOrig(y,body,params',index,[],args')
214 :     in
215 : cchiw 3166 (einapp,code,fieldset)
216 : cchiw 3138 end
217 :    
218 :     (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
219 :     * calls rewriteOp() lift on ein_exp
220 :     *)
221 : cchiw 3166 fun handleArcSine(y,e1,params,index,args,fieldset,flag)=let
222 :     val (e1',params',args',code,fieldset)= rewriteOp("ArcSine", e1,params,index,[],args,fieldset,flag)
223 :     val body =E.ArcSine e1'
224 : cchiw 3033 val einapp= rewriteOrig(y,body,params',index,[],args')
225 :     in
226 : cchiw 3166 (einapp,code,fieldset)
227 : cchiw 3033 end
228 : cchiw 2870
229 :    
230 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
231 : cchiw 2843 * calls rewriteOps() lift on ein_exp
232 : cchiw 2838 *)
233 : cchiw 3166 fun handleSub(y,e1,e2,params,index,args,fieldset,flag)=let
234 :     val ([e1',e2'],params',args',code,fieldset)= rewriteOps("subt",[e1,e2],params,index,[],args,fieldset,flag)
235 : cchiw 2838 val body =E.Sub(e1',e2')
236 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
237 :     in
238 : cchiw 3166 (einapp,code,fieldset)
239 : cchiw 2845 end
240 : cchiw 2838
241 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
242 : cchiw 2843 * calls rewriteOp() lift on ein_exp
243 : cchiw 2838 *)
244 : cchiw 3166 fun handleDiv(y,e1,e2,params,index,args,fieldset,flag)=let
245 :     val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag)
246 :     val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag)
247 : cchiw 2838 val body =E.Div(e1',e2')
248 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2')
249 :     in
250 : cchiw 3166 (einapp,code1'@code2',fieldset)
251 : cchiw 2845 end
252 : cchiw 2838
253 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
254 : cchiw 2843 * calls rewriteOps() lift on ein_exp
255 : cchiw 2838 *)
256 : cchiw 3193 fun handleAdd(y,e1 as [_,_,_,_],params,index,args,fieldset,flag)=let
257 : cchiw 3030
258 : cchiw 3193 val (e1',params',args',code,fieldset)= rewriteOps("add",e1,params,index,[],args,fieldset,flag)
259 :     fun pb es=String.concatWith "\n\n\t-*-" (List.map P.printbody es)
260 : cchiw 3194 (*)val _ =print("\n****Inside Add:"^Int.toString(length index)^"\n -"^ pb e1 ^"----- newbies\n-"^ pb e1')*)
261 : cchiw 3193
262 :     val body =E.Add e1'
263 :     val einapp= rewriteOrig(y,body,params',index,[],args')
264 :     in
265 :     (einapp,code,fieldset)
266 :     end
267 :     | handleAdd(y,e1,params,index,args,fieldset,flag)=let
268 :    
269 : cchiw 3166 val (e1',params',args',code,fieldset)= rewriteOps("add",e1,params,index,[],args,fieldset,flag)
270 : cchiw 3193 fun pb es=String.concatWith "\n-" (List.map P.printbody es)
271 :    
272 :    
273 : cchiw 2838 val body =E.Add e1'
274 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
275 :     in
276 : cchiw 3166 (einapp,code,fieldset)
277 : cchiw 2845 end
278 : cchiw 2838
279 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
280 : cchiw 2843 * calls rewriteOps() lift on ein_exp
281 : cchiw 2838 *)
282 : cchiw 3166 fun handleProd(y,e1,params,index,args,fieldset,flag)=let
283 :     val (e1',params',args',code,fieldset)= rewriteOps("prod",e1,params,index,[],args,fieldset,flag)
284 : cchiw 2838 val body =E.Prod e1'
285 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
286 :     in
287 : cchiw 3166 (einapp,code,fieldset)
288 : cchiw 2845 end
289 : cchiw 2838
290 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
291 : cchiw 2843 * calls rewriteOps() lift on ein_exp
292 : cchiw 2838 *)
293 : cchiw 3166 fun handleSumProd(y,e1,params,index,sx,args,fieldset,flag)=let
294 :     val (e1',params',args',code,fieldset)= rewriteOps("sumprod",e1,params,index,sx,args,fieldset,flag)
295 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
296 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args')
297 :     in
298 : cchiw 3166 (einapp,code,fieldset)
299 : cchiw 2845 end
300 : cchiw 2838
301 :     (* split:var*ein_app-> (var*einap)*code
302 :     * split ein expression into smaller pieces
303 : cchiw 2843 note we leave summation around probe exp
304 : cchiw 2838 *)
305 : cchiw 3166 fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let
306 :     val x= ((y,einapp),fieldset,flag)
307 :     val zero= (setEinZero y,[],fieldset)
308 :     val default=((y,einapp),[],fieldset)
309 : cchiw 2838 val sumIndex=ref []
310 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
311 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body]
312 :     fun rewrite b=(case b
313 : cchiw 3229 of E.Probe(E.Conv _,_) => default
314 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str
315 : cchiw 2847 | E.Probe _ => raise Fail str
316 : cchiw 3229 | E.Conv _ => raise Fail "should have been swept"
317 :     | E.Field _ => raise Fail "should have been swept"
318 :     | E.Apply _ => raise Fail "should have been swept"
319 :     | E.Lift e => raise Fail "should have been swept"
320 : cchiw 2838 | E.Delta _ => default
321 :     | E.Epsilon _ => default
322 : cchiw 2843 | E.Eps2 _ => default
323 : cchiw 2838 | E.Tensor _ => default
324 :     | E.Const _ => default
325 : cchiw 2923 | E.ConstR _ => default
326 : cchiw 3166 | E.Neg e1 => handleNeg(e1,x)
327 :     | E.Sqrt e1 => handleSqrt(y,e1,params,index,args,fieldset,flag)
328 :     | E.Cosine e1 => handleCosine(y,e1,params,index,args,fieldset,flag)
329 :     | E.ArcCosine e1 => handleArcCosine(y,e1,params,index,args,fieldset,flag)
330 :     | E.Sine e1 => handleSine(y,e1,params,index,args,fieldset,flag)
331 :     | E.ArcSine e1 => handleArcSine(y,e1,params,index,args,fieldset,flag)
332 :     | E.PowInt e1 => err(" PowInt unsupported")
333 :     | E.PowReal e1 => err(" PowReal unsupported")
334 :     | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args,fieldset,flag)
335 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args,fieldset,flag)
336 : cchiw 3189 (*
337 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default
338 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default
339 : cchiw 3189 *)
340 : cchiw 2847 | E.Sum(_,E.Probe(E.Conv _,_)) => default
341 : cchiw 3166 | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args,fieldset,flag)
342 :     | E.Sum(sx,E.Delta d) => handleSumProd(y,[E.Delta d],params,index,sx,args,fieldset,flag)
343 : cchiw 3194 | E.Sum(sx,E.Tensor _) => default
344 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str)
345 : cchiw 3166 | E.Add e1 => handleAdd(y,e1,params,index,args,fieldset,flag)
346 : cchiw 3276 | E.Prod[E.Tensor(id0,[]),E.Tensor(id1,[i]),E.Tensor(id2,[])]=>
347 :     rewrite (E.Prod[E.Prod[E.Tensor(id0,[]),E.Tensor(id2,[])],E.Tensor(id1,[i])])
348 : cchiw 3166 | E.Prod e1 => handleProd(y,e1,params,index,args,fieldset,flag)
349 : cchiw 2838 | E.Partial _ => err(" Partial used after normalize")
350 :     | E.Krn _ => err("Krn used before expand")
351 :     | E.Value _ => err("Value used before expand")
352 :     | E.Img _ => err("Probe used before expand")
353 :     (*end case *))
354 : cchiw 3166 val (einapp2,newbies,fieldset) =rewrite body
355 : cchiw 2838 in
356 : cchiw 3166 ((einapp2,newbies),fieldset)
357 : cchiw 2838 end
358 : cchiw 3166 |split((y,app),fieldset,_) =(((y,app),[]),fieldset)
359 : cchiw 2923
360 : cchiw 3166
361 :     fun iterMultiple(einapp2,newbies2,fieldset)=let
362 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code)
363 : cchiw 3166 | itercode(e1::newbies,rest,code,cnt)=let
364 :     val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
365 :     val (rest4,code4)=itercode(code3,[],[],cnt+1)
366 :     in
367 :     itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
368 :     end
369 :     val(rest,code)= itercode(newbies2,[],[],1)
370 :     in
371 :     ((code)@rest@[einapp2])
372 :     end
373 :    
374 :    
375 :     fun iterAll(einapp2,fieldset)=let
376 :     fun itercode([],rest,code,_)=(rest,code)
377 : cchiw 3017 | itercode(e1::newbies,rest,code,cnt)=let
378 : cchiw 3166 val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
379 : cchiw 3017 val (rest4,code4)=itercode(code3,[],[],cnt+1)
380 : cchiw 3260 val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))]
381 : cchiw 3166 in
382 :     itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
383 : cchiw 2838 end
384 : cchiw 3166 val(rest,code)= itercode(einapp2,[],[],0)
385 : cchiw 2838 in
386 : cchiw 3166 (code@rest)
387 : cchiw 2838 end
388 :    
389 : cchiw 3166 fun splitEinApp einapp3= let
390 :     val fieldset= einSet.EinSet.empty
391 : cchiw 3017
392 : cchiw 3166 (* **** split in parts **** *)
393 :     (*
394 :     val ((einapp4,newbies4),fieldset)=split(einapp3,fieldset,0)
395 : cchiw 3260 val _ =testp["\n\t===>\n",toStringBind(einapp4),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind newbies4))]
396 : cchiw 3166 val (newbies5)= iterMultiple(einapp4,newbies4,fieldset)
397 :     *)
398 :    
399 :     (* **** split all at once **** *)
400 :     val (newbies5)= iterAll([einapp3],fieldset)
401 :    
402 :     in
403 :     newbies5
404 :     end
405 :    
406 :    
407 : cchiw 2838 end; (* local *)
408 :    
409 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0