Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2847 - (view) (download)
Original Path: branches/charisee/src/compiler/high-to-mid/split.sml

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 : cchiw 2843
7 :     (*
8 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
9 :    
10 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11 :    
12 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13 :    
14 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15 :    
16 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17 :     *)
18 : cchiw 2838
19 :     structure Split = struct
20 :    
21 :     local
22 :    
23 :     structure E = Ein
24 :     structure DstIL = MidIL
25 :     structure DstTy = MidILTypes
26 :     structure DstV = DstIL.Var
27 :     structure P=Printer
28 :     structure cleanP=cleanParams
29 :     structure cleanI=cleanIndex
30 : cchiw 2845 structure handleE=handleEin
31 : cchiw 2838
32 :     in
33 :    
34 : cchiw 2847 val testing=0
35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
38 :     fun setEinZero y= (y,einappzero)
39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e
40 :     fun cleanIndex e =cleanI.cleanIndex e
41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e
42 :     fun isZero e=handleE.isZero e
43 : cchiw 2838 fun itos i =Int.toString i
44 :     fun err str=raise Fail str
45 :     val cnt = ref 0
46 :     fun genName prefix = let
47 :     val n = !cnt
48 :     in
49 :     cnt := n+1;
50 :     String.concat[prefix, "_", Int.toString n]
51 :     end
52 :     fun testp n=(case testing
53 :     of 0=> 1
54 :     | _ =>(print(String.concat n);1)
55 :     (*end case*))
56 :    
57 :    
58 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
59 :     *lifts expression and returns replacement tensor
60 : cchiw 2843 * cleans the index and params of subexpression
61 :     *creates new param and replacement tensor for the original ein_exp
62 : cchiw 2838 *)
63 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let
64 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
65 : cchiw 2843 val id=length(params)
66 :     val Rparams=params@[E.TEN(1,sizes)]
67 :     val Re=E.Tensor(id,tshape)
68 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
69 : cchiw 2843 val Rargs=args@[M]
70 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
71 :    
72 : cchiw 2838 in
73 : cchiw 2843 (Re,Rparams,Rargs,[einapp])
74 : cchiw 2838 end
75 :    
76 :     (* isOp: ein->int
77 :     * checks to see if this sub-expression is pulled out or split form original
78 :     * 0-becomes zero,1-remains the same, 2-operator
79 :     *)
80 :     fun isOp e =(case e
81 :     of E.Field _ => 0
82 :     | E.Conv _ => 0
83 :     | E.Apply _ => 0
84 :     | E.Lift _ => 0
85 :     | E.Neg _ => 1
86 :     | E.Add _ => 1
87 :     | E.Sub _ => 1
88 :     | E.Prod _ => 1
89 :     | E.Div _ => 1
90 :     | E.Sum _ => 1
91 :     | E.Probe _ => 1
92 :     | E.Partial _ => err(" Partial used after normalize")
93 :     | E.Krn _ => err("Krn used before expand")
94 :     | E.Value _ => err("Value used before expand")
95 :     | E.Img _ => err("Probe used before expand")
96 :     | _ => 2
97 :     (*end case*))
98 :    
99 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
100 :     * If e1 an op then call lift() to replace it
101 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same
102 :     *)
103 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1)
104 : cchiw 2838 of 0 => (E.Const 0,params,args,[])
105 :     | 2 => (e1,params,args,[])
106 : cchiw 2845 | _ => lift(name,e1,params,index,sx,args)
107 : cchiw 2838 (*end*))
108 :    
109 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
110 :     -> ein_exp list*params*args*code
111 :     * calls rewriteOp on ein_exp list
112 : cchiw 2838 *)
113 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let
114 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code)
115 :     | m(e1::es,rest,params,args,code)=let
116 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
117 : cchiw 2838 in
118 :     m(es,rest@[e1'],params',args',code@code')
119 :     end
120 :     in
121 :     m(list1,[],params,args,[])
122 :     end
123 : cchiw 2843
124 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
125 : cchiw 2843 When the operation is zero then we return a real.
126 : cchiw 2845 -Moved is Zero to before split.
127 : cchiw 2843 *)
128 : cchiw 2845 fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body)
129 :     of 1=> setEinZero y
130 : cchiw 2843 | _ => cleanParams(y,body,params,index,args)
131 : cchiw 2845 (*end case*))
132 : cchiw 2838
133 :     (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
134 : cchiw 2843 * calls rewriteOp() lift on ein_exp
135 : cchiw 2838 *)
136 :     fun handleNeg(y,e1,params,index,args)=let
137 : cchiw 2845 val (e1',params',args',code)= rewriteOp(DstV.name y, e1,params,index,[],args)
138 : cchiw 2838 val body =E.Neg e1'
139 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
140 :     in
141 :     (einapp,code)
142 :     end
143 : cchiw 2838
144 :     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
145 : cchiw 2843 * calls rewriteOps() lift on ein_exp
146 : cchiw 2838 *)
147 :     fun handleSub(y,e1,e2,params,index,args)=let
148 : cchiw 2845 val ([e1',e2'],params',args',code)= rewriteOps(DstV.name y,[e1,e2],params,index,[],args)
149 : cchiw 2838 val body =E.Sub(e1',e2')
150 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
151 :     in
152 :     (einapp,code)
153 :     end
154 : cchiw 2838
155 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
156 : cchiw 2843 * calls rewriteOp() lift on ein_exp
157 : cchiw 2838 *)
158 :     fun handleDiv(y,e1,e2,params,index,args)=let
159 : cchiw 2845 val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args)
160 :     val (e2',params2',args2',code2')=rewriteOp(DstV.name y,e2,params1',[],[],args1')
161 : cchiw 2838 val body =E.Div(e1',e2')
162 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2')
163 :     in
164 :     (einapp,code1'@code2')
165 :     end
166 : cchiw 2838
167 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
168 : cchiw 2843 * calls rewriteOps() lift on ein_exp
169 : cchiw 2838 *)
170 :     fun handleAdd(y,e1,params,index,args)=let
171 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,[],args)
172 : cchiw 2838 val body =E.Add e1'
173 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
174 :     in
175 :     (einapp,code)
176 :     end
177 : cchiw 2838
178 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
179 : cchiw 2843 * calls rewriteOps() lift on ein_exp
180 : cchiw 2838 *)
181 :     fun handleProd(y,e1,params,index,args)=let
182 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,[],args)
183 : cchiw 2838 val body =E.Prod e1'
184 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args')
185 :     in
186 :     (einapp,code)
187 :     end
188 : cchiw 2838
189 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
190 : cchiw 2843 * calls rewriteOps() lift on ein_exp
191 : cchiw 2838 *)
192 :     fun handleSumProd(y,e1,params,index,sx,args)=let
193 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,sx,args)
194 : cchiw 2838 val body= E.Sum(sx,E.Prod e1')
195 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args')
196 :     in
197 :     (einapp,code)
198 :     end
199 : cchiw 2838
200 :     (* split:var*ein_app-> (var*einap)*code
201 :     * split ein expression into smaller pieces
202 : cchiw 2843 note we leave summation around probe exp
203 : cchiw 2838 *)
204 :     fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
205 : cchiw 2843 val zero= (setEinZero y,[])
206 : cchiw 2838 val default=((y,einapp),[])
207 :     val sumIndex=ref []
208 : cchiw 2847 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"
209 : cchiw 2838 fun rewrite b=(case b
210 : cchiw 2847 of E.Probe (E.Conv _,_) => default
211 :     | E.Probe _ => raise Fail str
212 : cchiw 2838 | E.Conv _ => zero
213 :     | E.Field _ => zero
214 :     | E.Apply _ => zero
215 :     | E.Lift e => zero
216 :     | E.Delta _ => default
217 :     | E.Epsilon _ => default
218 : cchiw 2843 | E.Eps2 _ => default
219 : cchiw 2838 | E.Tensor _ => default
220 :     | E.Const _ => default
221 :     | E.Neg e1 => handleNeg(y,e1,params,index,args)
222 :     | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args)
223 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args)
224 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default
225 :     | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default
226 :     | E.Sum(_,E.Probe(E.Conv _,_)) => default
227 : cchiw 2838 | E.Sum(_,E.Conv _) => zero
228 :     | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args)
229 :     | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n)))
230 :     | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
231 :     | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
232 :     | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))
233 :     | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
234 :     | E.Sum(sx,_) => default
235 :     | E.Add e1 => handleAdd(y,e1,params,index,args)
236 :     | E.Prod e1 => handleProd(y,e1,params,index,args)
237 :     | E.Partial _ => err(" Partial used after normalize")
238 :     | E.Krn _ => err("Krn used before expand")
239 :     | E.Value _ => err("Value used before expand")
240 :     | E.Img _ => err("Probe used before expand")
241 :     (*end case *))
242 :     val (einapp2,newbies) =rewrite body
243 :     in
244 :     (einapp2,newbies)
245 :     end
246 :     |split(y,app) =((y,app),[])
247 :    
248 :     (* iterMultiple:code*code=> (code*code)
249 :     * recursively split ein expression into smaller pieces
250 :     *)
251 :     fun iterMultiple(einapp2,newbies2)=let
252 :     fun itercode([],rest,code)=(rest,code)
253 :     | itercode(e1::newbies,rest,code)=let
254 :     val (einapp3,code3) =split(e1)
255 :     val (rest4,code4)=itercode(code3,[],[])
256 :     in itercode(newbies,rest@[einapp3],code4@rest4@code)
257 :     end
258 :     val(rest,code)= itercode(newbies2,[],[])
259 :     in
260 :     (einapp2,code@rest)
261 :     end
262 :    
263 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
264 : cchiw 2845 (*val (_,_,body')=cleanIndex(body,index,[])
265 : cchiw 2843 val einapp1= assignEinApp(y,params,index,body',args)
266 : cchiw 2845 *)
267 :     val (_,sizes,body')=cleanIndex(body,index,[])
268 :     val einapp1= assignEinApp(y,params,index,body',args)
269 :     val a=testp["\n rewriten einapp\n \t",printEINAPP einapp1]
270 : cchiw 2843 val (einapp2,newbies2)=split einapp1
271 : cchiw 2838 in
272 :     iterMultiple(einapp2,newbies2)
273 :     end
274 :    
275 :     (* gettest:code*code=> (code*code)
276 :     * print results for splitting einapp
277 :     *)
278 : cchiw 2845 fun gettest einapp=(case testing
279 : cchiw 2838 of 0=>iterSplit(einapp)
280 :     | _=>let
281 : cchiw 2845 val star="\n************* SPLIT********\n"
282 :     val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp])
283 : cchiw 2838 val (einapp2,newbies)=iterSplit(einapp)
284 :     val a=printEINAPP einapp2
285 :     val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
286 :     val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])
287 :     in
288 :     (einapp2,newbies)
289 :     end
290 :     (*end case*))
291 :    
292 :     end; (* local *)
293 :    
294 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0