Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2843 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     (2) All the lifted subexpressions in the original EIN operator are replaced with tensors and non-probed fields with zeros. Call isZero() to determine if the body is zero. If so, needs to return 0. Otherwise clean the EIN operator.
18 :    
19 :     *)
20 : cchiw 2838
21 :     structure Split = struct
22 :    
23 :     local
24 :    
25 :     structure E = Ein
26 :     structure mk= mkOperators
27 :     structure SrcIL = HighIL
28 :     structure SrcTy = HighILTypes
29 :     structure SrcOp = HighOps
30 :     structure SrcSV = SrcIL.StateVar
31 :     structure VTbl = SrcIL.Var.Tbl
32 :     structure DstIL = MidIL
33 :     structure DstTy = MidILTypes
34 :     structure DstOp = MidOps
35 :     structure DstV = DstIL.Var
36 :     structure SrcV = SrcIL.Var
37 :     structure P=Printer
38 :     structure F=Filter
39 :     structure T=TransformEin
40 :     structure Var = MidIL.Var
41 :     structure cleanP=cleanParams
42 :     structure cleanI=cleanIndex
43 :    
44 :     val testing=1
45 :     in
46 :    
47 :    
48 :     fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
49 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
50 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
51 :     fun setEinZero y= (y,einappzero)
52 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e
53 :     fun cleanIndex e =cleanI.cleanIndex e
54 :     fun itos i =Int.toString i
55 :     fun err str=raise Fail str
56 :     val cnt = ref 0
57 :     fun genName prefix = let
58 :     val n = !cnt
59 :     in
60 :     cnt := n+1;
61 :     String.concat[prefix, "_", Int.toString n]
62 :     end
63 :     fun testp n=(case testing
64 :     of 0=> 1
65 :     | _ =>(print(String.concat n);1)
66 :     (*end case*))
67 :    
68 :     fun printEINAPP(id, DstIL.EINAPP(rator, args))=let
69 :     val a=String.concatWith " , " (List.map Var.toString args)
70 :     in
71 :     String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> ==",P.printerE rator, a,"\n"])
72 :     end
73 :     | printEINAPP(id, DstIL.OP(rator, args))=let
74 :     val a=String.concatWith " , " (List.map Var.toString args)
75 :     in
76 :     String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> =",DstOp.toString rator,a,"\n"])
77 :     end
78 :    
79 :     | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])
80 :    
81 :    
82 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
83 :     *lifts expression and returns replacement tensor
84 : cchiw 2843 * cleans the index and params of subexpression
85 :     *creates new param and replacement tensor for the original ein_exp
86 : cchiw 2838 *)
87 :     fun lift(e,params,index,sx,args)=let
88 :     val (tshape,sizes,body)=cleanIndex(e,index,sx)
89 : cchiw 2843 val id=length(params)
90 :     val Rparams=params@[E.TEN(1,sizes)]
91 :     val Re=E.Tensor(id,tshape)
92 :     val M = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)
93 :     val Rargs=args@[M]
94 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
95 :    
96 : cchiw 2838 in
97 : cchiw 2843 (Re,Rparams,Rargs,[einapp])
98 : cchiw 2838 end
99 :    
100 :    
101 :     (* isOp: ein->int
102 :     * checks to see if this sub-expression is pulled out or split form original
103 :     * 0-becomes zero,1-remains the same, 2-operator
104 :     *)
105 :     fun isOp e =(case e
106 :     of E.Field _ => 0
107 :     | E.Conv _ => 0
108 :     | E.Apply _ => 0
109 :     | E.Lift _ => 0
110 :     | E.Neg _ => 1
111 :     | E.Add _ => 1
112 :     | E.Sub _ => 1
113 :     | E.Prod _ => 1
114 :     | E.Div _ => 1
115 :     | E.Sum _ => 1
116 :     | E.Probe _ => 1
117 :     | E.Partial _ => err(" Partial used after normalize")
118 :     | E.Krn _ => err("Krn used before expand")
119 :     | E.Value _ => err("Value used before expand")
120 :     | E.Img _ => err("Probe used before expand")
121 :     | _ => 2
122 :     (*end case*))
123 :    
124 :    
125 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
126 :     * If e1 an op then call lift() to replace it
127 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same
128 :     *)
129 : cchiw 2843 fun rewriteOp(e1,params,index,sx,args)=(case (isOp e1)
130 : cchiw 2838 of 0 => (E.Const 0,params,args,[])
131 :     | 2 => (e1,params,args,[])
132 : cchiw 2843 | _ => lift(e1,params,index,sx,args)
133 : cchiw 2838 (*end*))
134 :    
135 :    
136 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
137 :     -> ein_exp list*params*args*code
138 :     * calls rewriteOp on ein_exp list
139 : cchiw 2838 *)
140 : cchiw 2843 fun rewriteOps(list1,params,index,sx,args)=let
141 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code)
142 :     | m(e1::es,rest,params,args,code)=let
143 : cchiw 2843 val (e1',params',args',code')= rewriteOp(e1,params,index,sx,args)
144 : cchiw 2838 in
145 :     m(es,rest@[e1'],params',args',code@code')
146 :     end
147 :     in
148 :     m(list1,[],params,args,[])
149 :     end
150 : cchiw 2843
151 :     (*isZero: var* ein_exp* params*index list*mid-il vars
152 :     When the operation is zero then we return a real.
153 :     *)
154 :     fun isZero(y,body,params,index,sx,args) =(case (cleanP.isZero body)
155 :     of 1=> setEinZero y
156 :     | _ => cleanParams(y,body,params,index,args)
157 : cchiw 2838 (*end case*))
158 :    
159 :     (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
160 : cchiw 2843 * calls rewriteOp() lift on ein_exp
161 : cchiw 2838 *)
162 :     fun handleNeg(y,e1,params,index,args)=let
163 : cchiw 2843 val (e1',params',args',code)= rewriteOp(e1,params,index,[],args)
164 : cchiw 2838 val body =E.Neg e1'
165 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args')
166 : cchiw 2838 in
167 :     (einapp,code)
168 :     end
169 :    
170 :     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
171 : cchiw 2843 * calls rewriteOps() lift on ein_exp
172 : cchiw 2838 *)
173 :     fun handleSub(y,e1,e2,params,index,args)=let
174 : cchiw 2843 val ([e1',e2'],params',args',code)= rewriteOps([e1,e2],params,index,[],args)
175 : cchiw 2838 val body =E.Sub(e1',e2')
176 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args')
177 : cchiw 2838 in
178 :     (einapp,code)
179 :     end
180 :    
181 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
182 : cchiw 2843 * calls rewriteOp() lift on ein_exp
183 : cchiw 2838 *)
184 :     fun handleDiv(y,e1,e2,params,index,args)=let
185 : cchiw 2843 val (e1',params1',args1',code1')=rewriteOp(e1,params,index,[],args)
186 :     val (e2',params2',args2',code2')=rewriteOp(e2,params1',[],[],args1')
187 : cchiw 2838 val body =E.Div(e1',e2')
188 : cchiw 2843 val einapp= isZero(y,body,params2',index,[],args2')
189 : cchiw 2838 in
190 :     (einapp,code1'@code2')
191 :     end
192 :    
193 :    
194 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
195 : cchiw 2843 * calls rewriteOps() lift on ein_exp
196 : cchiw 2838 *)
197 :     fun handleAdd(y,e1,params,index,args)=let
198 : cchiw 2843 val (e1',params',args',code)= rewriteOps(e1,params,index,[],args)
199 : cchiw 2838 val body =E.Add e1'
200 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args')
201 : cchiw 2838 in
202 :     (einapp,code)
203 :     end
204 :    
205 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
206 : cchiw 2843 * calls rewriteOps() lift on ein_exp
207 : cchiw 2838 *)
208 :     fun handleProd(y,e1,params,index,args)=let
209 : cchiw 2843 val (e1',params',args',code)= rewriteOps(e1,params,index,[],args)
210 : cchiw 2838 val body =E.Prod e1'
211 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args')
212 : cchiw 2838 in
213 :     (einapp,code)
214 :     end
215 :    
216 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
217 : cchiw 2843 * calls rewriteOps() lift on ein_exp
218 : cchiw 2838 *)
219 :     fun handleSumProd(y,e1,params,index,sx,args)=let
220 : cchiw 2843 val (e1',params',args',code)= rewriteOps(e1,params,index,sx,args)
221 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
222 : cchiw 2843 val einapp= isZero(y,body,params',index,sx,args')
223 : cchiw 2838 in
224 :     (einapp,code)
225 :     end
226 :    
227 :     (* split:var*ein_app-> (var*einap)*code
228 :     * split ein expression into smaller pieces
229 : cchiw 2843 note we leave summation around probe exp
230 : cchiw 2838 *)
231 :     fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
232 : cchiw 2843 val zero= (setEinZero y,[])
233 : cchiw 2838 val default=((y,einapp),[])
234 :     val sumIndex=ref []
235 :     fun rewrite b=(case b
236 :     of E.Probe _ => default
237 :     | E.Conv _ => zero
238 :     | E.Field _ => zero
239 :     | E.Apply _ => zero
240 :     | E.Lift e => zero
241 :     | E.Delta _ => default
242 :     | E.Epsilon _ => default
243 : cchiw 2843 | E.Eps2 _ => default
244 : cchiw 2838 | E.Tensor _ => default
245 :     | E.Const _ => default
246 :     | E.Neg e1 => handleNeg(y,e1,params,index,args)
247 :     | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args)
248 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args)
249 : cchiw 2843 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe _ ]) => default
250 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe _ ]) => default
251 : cchiw 2838 | E.Sum(_,E.Probe _) => default
252 :     | E.Sum(_,E.Conv _) => zero
253 :     | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args)
254 :     | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n)))
255 :     | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
256 :     | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
257 :     | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))
258 :     | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
259 :     | E.Sum(sx,_) => default
260 :     | E.Add e1 => handleAdd(y,e1,params,index,args)
261 :     | E.Prod e1 => handleProd(y,e1,params,index,args)
262 :     | E.Partial _ => err(" Partial used after normalize")
263 :     | E.Krn _ => err("Krn used before expand")
264 :     | E.Value _ => err("Value used before expand")
265 :     | E.Img _ => err("Probe used before expand")
266 :     (*end case *))
267 :    
268 : cchiw 2843
269 : cchiw 2838 val (einapp2,newbies) =rewrite body
270 :     in
271 :     (einapp2,newbies)
272 :     end
273 :     |split(y,app) =((y,app),[])
274 :    
275 :    
276 :     (* iterMultiple:code*code=> (code*code)
277 :     * recursively split ein expression into smaller pieces
278 :     *)
279 :     fun iterMultiple(einapp2,newbies2)=let
280 :     fun itercode([],rest,code)=(rest,code)
281 :     | itercode(e1::newbies,rest,code)=let
282 :     val (einapp3,code3) =split(e1)
283 :     val (rest4,code4)=itercode(code3,[],[])
284 :     in itercode(newbies,rest@[einapp3],code4@rest4@code)
285 :     end
286 :     val(rest,code)= itercode(newbies2,[],[])
287 :     in
288 :     (einapp2,code@rest)
289 :     end
290 :    
291 :    
292 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
293 :     val (_,_,body')=cleanIndex(body,index,[])
294 :     val einapp1= assignEinApp(y,params,index,body',args)
295 :     val (einapp2,newbies2)=split einapp1
296 : cchiw 2838 in
297 :     iterMultiple(einapp2,newbies2)
298 :     end
299 :    
300 :    
301 :    
302 :     (* gettest:code*code=> (code*code)
303 :     * print results for splitting einapp
304 :     *)
305 :     fun gettest(einapp)=(case testing
306 :     of 0=>iterSplit(einapp)
307 :     | _=>let
308 :     val star="\n*************\n"
309 :     val _ =print(String.concat[star])
310 :     val (einapp2,newbies)=iterSplit(einapp)
311 :     val a=printEINAPP einapp2
312 :     val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
313 :     val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])
314 :     in
315 :     (einapp2,newbies)
316 :     end
317 :     (*end case*))
318 :    
319 :    
320 :     end; (* local *)
321 :    
322 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0