Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3552 - (view) (download)

1 : jhr 3552 (* split.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     (*
10 :     During the transition from high-IR to mid-IR, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels.
11 :    
12 :     (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
13 :    
14 :     (1a.) When a subexpression is a field expression $\circledast, \nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
15 :    
16 :     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
17 :    
18 :     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
19 :     *)
20 :    
21 :     structure Split : sig
22 :    
23 :     end = struct
24 :    
25 :     structure E = Ein
26 :     structure DstIR = MidIR
27 :     structure DstTy = MidTypes
28 :     structure DstV = DstIR.Var
29 :    
30 :     structure cleanP = cleanParams
31 :     structure cleanI = cleanIndex
32 :    
33 :    
34 :     in
35 :    
36 :     val numFlag = true (*remove common subexpression*)
37 :     fun mkEin e = E.mkEin e
38 :     val einappzero = DstIR.EINAPP(mkEin([], [], E.Const 0), [])
39 :     fun setEinZero y = (y, einappzero)
40 :     fun cleanParams e = cleanP.cleanParams e
41 :     fun cleanIndex e = cleanI.cleanIndex e
42 :     fun toStringBind e = MidToString.toStringBind e
43 :    
44 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
45 :     * floats an inner subexpression out to be its own ein expression and
46 :     * returns the replacement tensor.
47 :     * cleans the index and params of subexpression
48 :     * creates new param and replacement tensor for the original ein_exp
49 :     *)
50 :     fun float (name, e, params, index, sx, args, fieldset, flag) = let
51 :     val (tshape, sizes, body) = cleanIndex(e, index, sx)
52 :     val id = length params
53 :     val Rparams = params@[E.TEN(true, sizes)]
54 :     val Re = E.Tensor(id, tshape)
55 :     val M = DstV.new (concat[name, "_l_", Int.toString id], DstTy.TensorTy sizes)
56 :     val Rargs = args@[M]
57 :     val einapp as (_, einapp0) = cleanParams (M, body, Rparams, sizes, Rargs)
58 :     val (Rargs, newbies, fieldset) = if flag
59 :     then let
60 :     val (fieldset, var) = einSet.rtnVar(fieldset, M, einapp0)
61 :     in
62 :     case var
63 :     of NONE => (Rargs], [einapp], fieldset)
64 :     | SOME v => (args@[v], [], fieldset))
65 :     (* end case *))
66 :     end
67 :     else (args@[M], [einapp], fieldset)
68 :     in
69 :     (Re, Rparams, Rargs, newbies, fieldset)
70 :     end
71 :    
72 :     (* some Ein expressions get replaced by zero (e.g., fields), others get floated to
73 :     * top level, and the rest remain the same.
74 :     *)
75 :     datatype op_replace = ZERO | FLOAT | SAME
76 :    
77 :     (* checks to see if this sub-expression is floated out or split from original *)
78 :     fun shouldFloat e = (case e
79 :     of E.Field _ => ZERO
80 :     | E.Conv _ => ZERO
81 :     | E.Apply _ => ZERO
82 :     | E.Lift _ => ZERO
83 :     | E.Op1 _ => FLOAT
84 :     | E.Op2 _ => FLOAT
85 :     | E.Opn _ => FLOAT
86 :     | E.Sum _ => FLOAT
87 :     | E.Probe _ => FLOAT
88 :     | E.Partial _ => err "Partial used after normalize"
89 :     | E.Krn _ => err "Krn used before expand"
90 :     | E.Value _ => err "Value used before expand"
91 :     | E.Img _ => err "Probe used before expand"
92 :     | _ => SAME
93 :     (* end case *))
94 :    
95 :     (* *************************************** helpers ******************************** *)
96 :    
97 :     fun rewriteOp (name, e1, params, index, sx, args, fieldset, flag) = (case shouldFloat e1
98 :     of ZERO => (E.Const 0, params, args, [], fieldset)
99 :     | FLOAT => lift(name, e1, params, index, sx, args, fieldset, flag)
100 :     | SAME => (e1, params, args, [], fieldset)
101 :     (* end case *))
102 :    
103 :     fun unaryOp(name, sx, e1, x) = let
104 :     val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
105 :     val params = Ein.params ein
106 :     val index = Ein.index ein
107 :     in
108 :     rewriteOp (name, e1, params, index, sx, args, fieldset, flag)
109 :     end
110 :    
111 :     fun multOp (name, sx, list1, x) = let
112 :     val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
113 :     val params = Ein.params ein
114 :     val index = Ein.index ein
115 :     fun m ([], rest, params, args, code, fieldset) = (rest, params, args, code, fieldset)
116 :     | m (e1::es, rest, params, args, code, fieldset) = let
117 :     val (e1', params', args', code', fieldset) =
118 :     rewriteOp (name, e1, params, index, sx, args, fieldset, flag)
119 :     in
120 :     m(es, rest@[e1'], params', args', code@code', fieldset)
121 :     end
122 :     in
123 :     m( list1, [], params, args, [], fieldset)
124 :     end
125 :    
126 :     (*clean params*)
127 :     fun cleanOrig (body, params, args, x) = let
128 :     val ((y, DstIR.EINAPP(ein, _)), _, _) = x
129 :     val index = Ein.index ein
130 :     in
131 :     cleanParams (y, body, params, index, args)
132 :     end
133 :    
134 :     (* *************************************** general handle Ops ******************************** *)
135 :     fun handleUnaryOp (name, opp, x, e1) = let
136 :     val (e1', params', args', code, fieldset) = unaryOp(name, [], e1, x)
137 :     val body' = E.Op1(opp, e1')
138 :     val einapp = cleanOrig(body', params', args', x)
139 :     in
140 :     (einapp, code, fieldset)
141 :     end
142 :     fun handleBinaryOp (name, opp, x, es) = let
143 :     val ([e1', e2'], params', args', code, fieldset) = multOp(name, [], es, x)
144 :     val body' =E.Op2(opp, e1', e2')
145 :     val einapp= cleanOrig(body', params', args', x)
146 :     in
147 :     (einapp, code, fieldset)
148 :     end
149 :     fun handleMultOp (name, opp, x, es)= let
150 :     val (e1', params', args', code, fieldset) = multOp(name, [], es, x)
151 :     val body = E.Opn(opp , e1')
152 :     val einapp = cleanOrig(body, params', args', x)
153 :     in
154 :     (einapp, code, fieldset)
155 :     end
156 :    
157 :     (* ***************************************specific handle Ops ******************************** *)
158 :     fun handleDiv (e1, e2, x) = let
159 :     val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
160 :     val params = Ein.params ein
161 :     val index = Ein.index ein
162 :     val (e1', params1', args1', code1', fieldset) = rewriteOp("div-num", e1, params, index, [], args, fieldset, flag)
163 :     val (e2', params2', args2', code2', fieldset) = rewriteOp("div-denom", e2, params1', index, [], args1', fieldset, flag)
164 :     val body' = E.Op2(E.Div, e1', e2')
165 :     val einapp = cleanOrig(body', params2', args2', x)
166 :     in
167 :     (einapp, code1'@code2', fieldset)
168 :     end
169 :    
170 :     fun handleSumProd (e1, sx, x) = let
171 :     val (e1', params', args', code, fieldset)= multOp("sumprod", sx, e1, x)
172 :     val body'= E.Sum(sx, E.Opn(E.Prod, e1'))
173 :     val einapp= cleanOrig(body', params', args', x)
174 :     in
175 :     (einapp, code, fieldset)
176 :     end
177 :    
178 :     (* *************************************** Split ******************************** *)
179 :    
180 :     (* split:var*ein_app-> (var*einap)*code
181 :     * split ein expression into smaller pieces
182 :     note we leave summation around probe exp
183 :     *)
184 :     fun split((y, einapp as DstIR.EINAPP(Ein.EIN{params, index, body}, args)), fieldset, flag) =let
185 :     val x = ((y, einapp), fieldset, flag)
186 :     val zero = (setEinZero y, [], fieldset)
187 :     val default = ((y, einapp), [], fieldset)
188 :     val sumIndex = ref []
189 :     fun error () = raise Fail("Poorly formed EIN operator: " ^ EinPP.expToString body)
190 :     fun rewrite b = (case b
191 :     of E.Const _ => default
192 :     | E.ConstR _ => default
193 :     | E.Tensor _ => default
194 :     | E.Delta _ => default
195 :     | E.Epsilon _ => default
196 :     | E.Eps2 _ => default
197 :     | E.Field _ => raise Fail "should have been swept"
198 :     | E.Lift e => raise Fail "should have been swept"
199 :     | E.Conv _ => raise Fail "should have been swept"
200 :     | E.Partial _ => raise Fail "Partial used after normalize"
201 :     | E.Apply _ => raise Fail "should have been swept"
202 :     | E.Probe(E.Conv _, _) => default
203 :     | E.Probe(E.Field _, _) => error()
204 :     | E.Probe _ => error()
205 :     | E.Value _ => raise Fail "Value used before expand"
206 :     | E.Img _ => raise Fail "Probe used before expand"
207 :     | E.Krn _ => raise Fail "Krn used before expand"
208 :     | E.Sum(_, E.Probe(E.Conv _, _)) => default
209 :     | E.Sum(sx, E.Tensor _) => default
210 :     | E.Sum(sx, E.Opn(E.Prod, e1)) => handleSumProd (e1, sx, x)
211 :     | E.Sum(sx, E.Delta d) => handleSumProd ([E.Delta d], sx, x)
212 :     | E.Sum(sx, _) => raise Fail "summation not distributed:"^str)
213 :     | E.Op1(op1, e1) => (case op1
214 :     of E.Neg => handleUnaryOp ("neg", op1, x, e1)
215 :     | E.Sqrt => handleUnaryOp ("sqrt", op1, x, e1)
216 :     | E.Exp => handleUnaryOp ("exp", op1, x, e1)
217 :     | E.PowInt n1 => handleUnaryOp ("PowInt", op1, x, e1)
218 :     | _ => handleUnaryOp ("Trig", op1, x, e1)
219 :     (*end case *))
220 :     | E.Op2(E.Sub, e1, e2) => handleBinaryOp ("subtract", E.Sub, x, [e1, e2])
221 :     | E.Op2(E.Div, e1, e2) => handleDiv (e1, e2, x)
222 :     | E.Opn(E.Add, es) => handleMultOp ("add", E.Add, x, es)
223 :     | E.Opn(Prod, [E.Tensor(id0, []), E.Tensor(id1, [i]), E.Tensor(id2, [])]) =>
224 :     rewrite (E.Opn(E.Prod, [
225 :     E.Opn(E.Prod, [E.Tensor(id0, []), E.Tensor(id2, [])]), E.Tensor(id1, [i])
226 :     ]))
227 :     | E.Opn(E.Prod, es) => handleMultOp("prod", E.Prod, x, es)
228 :     (* end case *))
229 :     val (einapp2, newbies, fieldset) = rewrite body
230 :     in
231 :     ((einapp2, newbies), fieldset)
232 :     end
233 :     | split ((y, app), fieldset, _) = (((y, app), []), fieldset)
234 :    
235 :    
236 :     (* *************************************** main ******************************** *)
237 :     fun limitSplit(einapp2, fields2, splitlimit) = let
238 :     val fieldset= einSet.EinSet.empty
239 :     val _ =print ("\nSPLit with limit"^(Int.toString(splitlimit)))
240 :     fun itercode([], rest, code, cnt) = (("\n Empty-SplitCount: "^Int.toString(cnt));(rest, code))
241 :     | itercode(e1::newbies, rest, code, cnt) = let
242 :     val ((einapp3, code3), _) = split(e1, fieldset, numFlag)
243 :     val (rest4, code4) = itercode(code3, [], [], cnt+1)
244 :     val _ =testp [toStringBind(e1), "\n\t===>\n", toStringBind(einapp3), "\nand\n", (String.concatWith", \n\t"(List.map toStringBind (code4@rest4)))]
245 :     in
246 :     if (length(rest@newbies@code) > splitlimit) then let
247 :     val _ =("\n SplitCount: "^Int.toString(cnt))
248 :     val code5 = code4@rest4@code
249 :     val rest5 = rest@[einapp3]
250 :     in
251 :     (rest5, code5@newbies)(*tab4*)
252 :     end
253 :     else itercode(newbies, rest@[einapp3], code4@rest4@code, cnt+2)
254 :     end
255 :     val(rest, code)= itercode([einapp2], [], [], 0)
256 :     in
257 :    
258 :     fields2@code@rest (*B*)
259 :     end
260 :    
261 :     fun splitEinApp einapp0 =let
262 :     val fieldset= einSet.EinSet.empty
263 :     val einapp2 = [einapp0]
264 :     fun itercode([], rest, code, _) = (rest, code)
265 :     | itercode(e1::newbies, rest, code, cnt) = let
266 :     val ((einapp3, code3), _) = split(e1, fieldset, numFlag)
267 :     val (rest4, code4) = itercode(code3, [], [], cnt+1)
268 :     val _ =testp [toStringBind(e1), "\n\t===>\n", toStringBind(einapp3), "\nand\n", (String.concatWith", \n\t"(List.map toStringBind (code4@rest4)))]
269 :     in
270 :     itercode(newbies, rest@[einapp3], code4@( rest4)@code, cnt+2)
271 :     end
272 :     val(rest, code)= itercode(einapp2, [], [], 0)
273 :     in
274 :     (code@rest)
275 :     end
276 :    
277 :     end; (* local *)
278 :    
279 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0