Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3448 - (view) (download)

1 : cchiw 2843 (* Currently under construction
2 : cchiw 2838 *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : cchiw 2838 * All rights reserved.
7 :     *)
8 : cchiw 2843
9 :     (*
10 :     During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
11 :    
12 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
13 :    
14 :     (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
15 :    
16 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
17 :    
18 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
19 :     *)
20 : cchiw 2838
21 :     structure Split = struct
22 :    
23 :     local
24 :    
25 :     structure E = Ein
26 :     structure DstIL = MidIL
27 :     structure DstTy = MidILTypes
28 :     structure DstV = DstIL.Var
29 : cchiw 3174
30 : cchiw 2838 structure P=Printer
31 :     structure cleanP=cleanParams
32 :     structure cleanI=cleanIndex
33 :    
34 : cchiw 3033
35 : cchiw 2838 in
36 :    
37 : cchiw 3174 val numFlag=1 (*remove common subexpression*)
38 : cchiw 3033 val testing=0
39 : cchiw 3260 fun mkEin e = E.mkEin e
40 : cchiw 3448 val einappzero= DstIL.EINAPP(mkEin([],[],E.B(E.Const 0)),[])
41 : cchiw 2843 fun setEinZero y= (y,einappzero)
42 : cchiw 3260 fun cleanParams e = cleanP.cleanParams e
43 :     fun cleanIndex e = cleanI.cleanIndex e
44 :     fun toStringBind e= MidToString.toStringBind e
45 :     fun itos i = Int.toString i
46 :     fun err str = raise Fail str
47 : cchiw 2838 val cnt = ref 0
48 : cchiw 3166 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
49 : cchiw 2838 fun genName prefix = let
50 :     val n = !cnt
51 :     in
52 :     cnt := n+1;
53 :     String.concat[prefix, "_", Int.toString n]
54 :     end
55 :     fun testp n=(case testing
56 :     of 0=> 1
57 :     | _ =>(print(String.concat n);1)
58 :     (*end case*))
59 :    
60 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
61 :     *lifts expression and returns replacement tensor
62 : cchiw 2843 * cleans the index and params of subexpression
63 :     *creates new param and replacement tensor for the original ein_exp
64 : cchiw 2838 *)
65 : cchiw 3166 fun lift(name,e,params,index,sx,args,fieldset,flag)=let
66 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx)
67 : cchiw 2843 val id=length(params)
68 :     val Rparams=params@[E.TEN(1,sizes)]
69 :     val Re=E.Tensor(id,tshape)
70 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
71 : cchiw 2843 val Rargs=args@[M]
72 :     val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
73 : cchiw 3166 val (_,einapp0)=einapp
74 :     val (Rargs,newbies,fieldset) =(case flag
75 :     of 1=> let
76 :     val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0)
77 :     in (case var
78 :     of NONE=> (args@[M],[einapp],fieldset)
79 :     | SOME v=> (incUse v ;(args@[v],[],fieldset))
80 :     (*end case*))
81 :     end
82 :     | _=>(args@[M],[einapp],fieldset)
83 :     (*end case*))
84 :     in
85 :     (Re,Rparams,Rargs,newbies,fieldset)
86 :     end
87 : cchiw 2838
88 :     (* isOp: ein->int
89 :     * checks to see if this sub-expression is pulled out or split form original
90 :     * 0-becomes zero,1-remains the same, 2-operator
91 :     *)
92 :     fun isOp e =(case e
93 :     of E.Field _ => 0
94 :     | E.Conv _ => 0
95 :     | E.Apply _ => 0
96 :     | E.Lift _ => 0
97 : cchiw 3448 | E.Op1 _ => 1
98 :     | E.Op2 _ => 1
99 :     | E.Opn _ => 1
100 : cchiw 2838 | E.Sum _ => 1
101 :     | E.Probe _ => 1
102 :     | E.Partial _ => err(" Partial used after normalize")
103 :     | E.Krn _ => err("Krn used before expand")
104 :     | E.Value _ => err("Value used before expand")
105 :     | E.Img _ => err("Probe used before expand")
106 :     | _ => 2
107 :     (*end case*))
108 :    
109 : cchiw 3444 (* *************************************** helpers ******************************** *)
110 : cchiw 3166 fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1)
111 : cchiw 3448 of 0 => (E.B(E.Const 0),params,args,[],fieldset)
112 : cchiw 3166 | 1 => lift(name,e1,params,index,sx,args,fieldset,flag)
113 : cchiw 3448 | _ => (e1,params,args,[],fieldset) (*not lifted*)
114 : cchiw 2838 (*end*))
115 :    
116 : cchiw 3444 fun unaryOp(name,sx,e1,x)=let
117 :     val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
118 :     val params=Ein.params ein
119 :     val index=Ein.index ein
120 :     in
121 :     rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
122 :     end
123 :    
124 :     fun multOp(name,sx,list1,x)=let
125 :     val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
126 :     val params=Ein.params ein
127 :     val index=Ein.index ein
128 : cchiw 3166 fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset)
129 : cchiw 3444 | m(e1::es,rest,params,args,code,fieldset)=let
130 : cchiw 3166 val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
131 : cchiw 2838 in
132 : cchiw 3166 m(es,rest@[e1'],params',args',code@code',fieldset)
133 : cchiw 2838 end
134 :     in
135 : cchiw 3444 m(list1,[],params,args,[],fieldset)
136 : cchiw 2838 end
137 : cchiw 3166
138 : cchiw 3444 (*clean params*)
139 :     fun cleanOrig(body,params,args,x) =let
140 : cchiw 3166 val ((y,DstIL.EINAPP(ein,_)),_,_)=x
141 :     val index=Ein.index ein
142 :     in cleanParams(y,body,params,index,args)
143 : cchiw 3444 end
144 : cchiw 3166
145 : cchiw 3444 (* *************************************** general handle Ops ******************************** *)
146 :     fun handleUnaryOp(name,opp,x,e1)=let
147 :     val (e1',params',args',code,fieldset)= unaryOp(name,[],e1,x)
148 : cchiw 3448 val body' =E.Op1(opp, e1')
149 : cchiw 3444 val einapp= cleanOrig(body',params',args',x)
150 : cchiw 2845 in
151 : cchiw 3166 (einapp,code,fieldset)
152 : cchiw 2845 end
153 : cchiw 3444 fun handleBinaryOp(name,opp,x,es)=let
154 :     val ([e1',e2'],params',args',code,fieldset)= multOp(name,[],es,x)
155 : cchiw 3448 val body' =E.Op2(opp,e1',e2')
156 : cchiw 3444 val einapp= cleanOrig(body',params',args',x)
157 : cchiw 3138 in
158 : cchiw 3166 (einapp,code,fieldset)
159 : cchiw 3444 end
160 :     fun handleMultOp(name,opp,x,es)= let
161 :     val (e1',params',args',code,fieldset)= multOp(name,[],es,x)
162 : cchiw 3448 val body =E.Opn(opp ,e1')
163 : cchiw 3444 val einapp= cleanOrig(body,params',args',x)
164 : cchiw 3138 in
165 : cchiw 3166 (einapp,code,fieldset)
166 : cchiw 3444 end
167 :     (* ***************************************specific handle Ops ******************************** *)
168 :     fun handleDiv(e1,e2,x)=let
169 :     val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
170 :     val params=Ein.params ein
171 :     val index=Ein.index ein
172 : cchiw 3166 val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag)
173 :     val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag)
174 : cchiw 3448 val body' =E.Op2(E.Div,e1',e2')
175 : cchiw 3444 val einapp= cleanOrig(body',params2',args2',x)
176 : cchiw 2845 in
177 : cchiw 3166 (einapp,code1'@code2',fieldset)
178 : cchiw 2845 end
179 : cchiw 3444 fun handleSumProd(e1,sx,x)=let
180 :     val (e1',params',args',code,fieldset)= multOp("sumprod",sx,e1,x)
181 : cchiw 3448 val body'= E.Sum(sx,E.Opn(E.Prod, e1'))
182 : cchiw 3444 val einapp= cleanOrig(body',params',args',x)
183 : cchiw 2845 in
184 : cchiw 3166 (einapp,code,fieldset)
185 : cchiw 2845 end
186 : cchiw 2838
187 : cchiw 3444 (* *************************************** Split ******************************** *)
188 : cchiw 2838
189 :     (* split:var*ein_app-> (var*einap)*code
190 :     * split ein expression into smaller pieces
191 : cchiw 2843 note we leave summation around probe exp
192 : cchiw 2838 *)
193 : cchiw 3166 fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let
194 :     val x= ((y,einapp),fieldset,flag)
195 :     val zero= (setEinZero y,[],fieldset)
196 :     val default=((y,einapp),[],fieldset)
197 : cchiw 2838 val sumIndex=ref []
198 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
199 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body]
200 :     fun rewrite b=(case b
201 : cchiw 3448 of E.B _ => default
202 :     | E.Tensor _ => default
203 :     | E.G _ => default
204 :     | E.Field _ => raise Fail "should have been swept"
205 :     | E.Lift e => raise Fail "should have been swept"
206 :     | E.Conv _ => raise Fail "should have been swept"
207 :     | E.Partial _ => err(" Partial used after normalize")
208 :     | E.Apply _ => raise Fail "should have been swept"
209 :     | E.Probe(E.Conv _,_) => default
210 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str
211 : cchiw 2847 | E.Probe _ => raise Fail str
212 : cchiw 3448 | E.Value _ => err("Value used before expand")
213 :     | E.Img _ => err("Probe used before expand")
214 :     | E.Krn _ => err("Krn used before expand")
215 : cchiw 3444 | E.Sum(_,E.Probe(E.Conv _,_)) => default
216 :     | E.Sum(sx,E.Tensor _) => default
217 : cchiw 3448
218 :     (* | E.Sum(_,E.Opn(E.Prod,[E.Eps2 _, E.Probe(E.Conv _,_)])) => default
219 :     | E.Sum(_,E.Opn(E.Prod,[E.Epsilon _, E.Probe(E.Conv _,_) ])) => default*)
220 :     | E.Sum(sx,E.Opn(E.Prod, e1)) => handleSumProd(e1,sx,x)
221 :     | E.Sum(sx,E.G(E.Delta d)) => handleSumProd([E.G(E.Delta d)],sx,x)
222 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str)
223 : cchiw 3448 | E.Op1(op1,e1) =>
224 :     (case op1
225 :     of E.Neg => handleUnaryOp("neg",op1,x,e1)
226 :     | E.Sqrt => handleUnaryOp("sqrt",op1,x,e1)
227 :     | E.Exp => handleUnaryOp("exp",op1,x,e1)
228 :     | E.PowInt n1 => handleUnaryOp("PowInt",op1,x,e1)
229 :     | _ => handleUnaryOp("Trig",op1,x,e1)
230 :     (*end case *))
231 :     | E.Op2(E.Sub,e1,e2) => handleBinaryOp("subtract",E.Sub,x,[e1,e2])
232 :     | E.Op2(E.Div,e1,e2) => handleDiv(e1,e2,x)
233 :     | E.Opn(E.Add,es) => handleMultOp("add",E.Add,x,es)
234 :     | E.Opn(Prod,[E.Tensor(id0,[]),E.Tensor(id1,[i]),E.Tensor(id2,[])])=>
235 :     rewrite (E.Opn(E.Prod,[
236 :     E.Opn(E.Prod,[E.Tensor(id0,[]),E.Tensor(id2,[])]),E.Tensor(id1,[i])]))
237 :     | E.Opn(E.Prod,es) => handleMultOp("prod",E.Prod,x,es)
238 :    
239 : cchiw 2838 (*end case *))
240 : cchiw 3166 val (einapp2,newbies,fieldset) =rewrite body
241 : cchiw 2838 in
242 : cchiw 3166 ((einapp2,newbies),fieldset)
243 : cchiw 2838 end
244 : cchiw 3166 |split((y,app),fieldset,_) =(((y,app),[]),fieldset)
245 : cchiw 2923
246 : cchiw 3166
247 : cchiw 3444 (* *************************************** main ******************************** *)
248 :     fun limitSplit(einapp2,fields2,splitlimit)=let
249 :     val fieldset= einSet.EinSet.empty
250 :     val _ =print ("\nSPLit with limit"^(Int.toString(splitlimit)))
251 :     fun itercode([],rest,code,cnt)=(("\n Empty-SplitCount: "^Int.toString(cnt));(rest,code))
252 :     | itercode(e1::newbies,rest,code,cnt)=let
253 :     val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
254 :     val (rest4,code4)=itercode(code3,[],[],cnt+1)
255 :     val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))]
256 :     in
257 :     if (length(rest@newbies@code) > splitlimit) then let
258 :     val _ =("\n SplitCount: "^Int.toString(cnt))
259 :     val code5=code4@rest4@code
260 :     val rest5=rest@[einapp3]
261 :     in
262 :     (rest5,code5@newbies)(*tab4*)
263 :     end
264 :     else itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)
265 :     end
266 :     val(rest,code)= itercode([einapp2],[],[],0)
267 : cchiw 3166 in
268 : cchiw 3444
269 :     fields2@code@rest (*B*)
270 : cchiw 3166 end
271 :    
272 : cchiw 3444 fun splitEinApp einapp0 =let
273 :     val fieldset= einSet.EinSet.empty
274 :     val einapp2=[einapp0]
275 : cchiw 3166 fun itercode([],rest,code,_)=(rest,code)
276 : cchiw 3017 | itercode(e1::newbies,rest,code,cnt)=let
277 : cchiw 3166 val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
278 : cchiw 3017 val (rest4,code4)=itercode(code3,[],[],cnt+1)
279 : cchiw 3260 val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))]
280 : cchiw 3166 in
281 :     itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
282 : cchiw 2838 end
283 : cchiw 3166 val(rest,code)= itercode(einapp2,[],[],0)
284 : cchiw 2838 in
285 : cchiw 3166 (code@rest)
286 : cchiw 2838 end
287 :    
288 :     end; (* local *)
289 :    
290 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0