Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/compiler/high-il/normalize.sml
ViewVC logotype

Annotation of /trunk/src/compiler/high-il/normalize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3349 - (view) (download)

1 : jhr 1232 (* normalize.sml
2 :     *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : jhr 1232 * All rights reserved.
7 :     *)
8 :    
9 :     structure Normalize : sig
10 :    
11 :     val transform : HighIL.program -> HighIL.program
12 :    
13 :     end = struct
14 :    
15 :     structure IL = HighIL
16 :     structure Op = HighOps
17 :     structure V = IL.Var
18 : jhr 2356 structure Ty = HighILTypes
19 : jhr 1232 structure ST = Stats
20 :    
21 :     (********** Counters for statistics **********)
22 : jhr 2356 val cntInsideScale = ST.newCounter "high-opt:inside-scale"
23 :     val cntInsideOffset = ST.newCounter "high-opt:inside-offset"
24 :     val cntInsideNeg = ST.newCounter "high-opt:inside-meg"
25 :     val cntInsideCurl = ST.newCounter "high-opt:inside-curl"
26 :     val cntInsideDiff = ST.newCounter "high-opt:inside-diff"
27 :     val cntProbeAdd = ST.newCounter "high-opt:probe-add"
28 :     val cntProbeSub = ST.newCounter "high-opt:probe-sub"
29 :     val cntProbeScale = ST.newCounter "high-opt:probe-scale"
30 :     val cntProbeOffset = ST.newCounter "high-opt:probe-offset"
31 :     val cntProbeNeg = ST.newCounter "high-opt:probe-neg"
32 :     val cntProbeCurl = ST.newCounter "high-opt:probe-curl"
33 :     val cntDiffField = ST.newCounter "high-opt:diff-field"
34 :     val cntDiffAdd = ST.newCounter "high-opt:diff-add"
35 :     val cntDiffScale = ST.newCounter "high-opt:diff-scale"
36 :     val cntDiffOffset = ST.newCounter "high-opt:diff-offset"
37 :     val cntDiffNeg = ST.newCounter "high-opt:diff-neg"
38 :     val cntCurlScale = ST.newCounter "high-opt:curl-scale"
39 :     val cntCurlNeg = ST.newCounter "high-opt:curl-neg"
40 :     val cntUnused = ST.newCounter "high-opt:unused"
41 :     val firstCounter = cntInsideScale
42 : jhr 1232 val lastCounter = cntUnused
43 :    
44 :     structure UnusedElim = UnusedElimFn (
45 : jhr 2356 structure IL = IL
46 :     val cntUnused = cntUnused)
47 : jhr 1232
48 :     fun useCount (IL.V{useCnt, ...}) = !useCnt
49 :    
50 :     (* adjust a variable's use count *)
51 :     fun incUse (IL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
52 :     fun decUse (IL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
53 : jhr 2356 fun use x = (incUse x; x)
54 : jhr 1232
55 :     fun getRHS x = (case V.binding x
56 : jhr 2356 of IL.VB_RHS(IL.OP arg) => SOME arg
57 :     | IL.VB_RHS(IL.VAR x') => getRHS x'
58 :     | _ => NONE
59 :     (* end case *))
60 : jhr 1232
61 : jhr 2356 (* get the binding of a kernel variable *)
62 :     fun getKernelRHS h = (case getRHS h
63 :     of SOME(Op.Kernel(kernel, k), []) => (kernel, k)
64 :     | _ => raise Fail(concat[
65 :     "bogus kernel binding ", V.toString h, " = ", IL.vbToString(V.binding h)
66 :     ])
67 :     (* end case *))
68 :    
69 : jhr 1232 (* optimize the rhs of an assignment, returning NONE if there is no change *)
70 :     fun doRHS (lhs, IL.OP rhs) = (case rhs
71 : jhr 2356 of (Op.Inside dim, [pos, f]) => (case getRHS f
72 :     of SOME(Op.Field _, _) => NONE (* direct inside test does not need rewrite *)
73 :     | SOME(Op.AddField, [f', g']) => raise Fail "inside(f+g)"
74 :     | SOME(Op.SubField, [f', g']) => raise Fail "inside(f-g)"
75 :     | SOME(Op.ScaleField, [_, f']) => (
76 :     ST.tick cntInsideScale;
77 :     decUse f;
78 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
79 :     | SOME(Op.OffsetField, [f', _]) => (
80 :     ST.tick cntInsideOffset;
81 :     decUse f;
82 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
83 :     | SOME(Op.NegField, [f']) => (
84 :     ST.tick cntInsideNeg;
85 :     decUse f;
86 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
87 :     | SOME(Op.CurlField _, [f']) => (
88 :     ST.tick cntInsideCurl;
89 :     decUse f;
90 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
91 :     | SOME(Op.DiffField, [f']) => (
92 :     ST.tick cntInsideDiff;
93 :     decUse f;
94 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
95 :     | _ => raise Fail(concat[
96 :     "inside: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
97 :     ])
98 :     (* end case *))
99 :     | (Op.Probe(domTy, rngTy), [f, pos]) => (case getRHS f
100 :     of SOME(Op.Field _, _) => NONE (* direct probe does not need rewrite *)
101 :     | SOME(Op.AddField, [f', g']) => let
102 :     (* rewrite to (f@pos) + (g@pos) *)
103 :     val lhs1 = IL.Var.copy lhs
104 :     val lhs2 = IL.Var.copy lhs
105 :     in
106 :     ST.tick cntProbeAdd;
107 :     decUse f;
108 :     incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
109 :     SOME[
110 :     (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
111 :     (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
112 :     (lhs, IL.OP(Op.Add rngTy, [lhs1, lhs2]))
113 :     ]
114 :     end
115 :     | SOME(Op.SubField, [f', g']) => let
116 :     (* rewrite to (f@pos) - (g@pos) *)
117 :     val lhs1 = IL.Var.copy lhs
118 :     val lhs2 = IL.Var.copy lhs
119 :     in
120 :     ST.tick cntProbeSub;
121 :     decUse f;
122 :     incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
123 :     SOME[
124 :     (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
125 :     (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
126 :     (lhs, IL.OP(Op.Sub rngTy, [lhs1, lhs2]))
127 :     ]
128 :     end
129 :     | SOME(Op.ScaleField, [s, f']) => let
130 :     (* rewrite to s*(f'@pos) *)
131 :     val lhs' = IL.Var.copy lhs
132 :     val scaleOp = (case rngTy
133 :     of Ty.TensorTy[] => Op.Mul rngTy
134 :     | _ => Op.Scale rngTy
135 :     (* end case *))
136 :     in
137 :     ST.tick cntProbeScale;
138 :     decUse f;
139 :     SOME[
140 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
141 :     (lhs, IL.OP(scaleOp, [use s, use lhs']))
142 :     ]
143 :     end
144 :     | SOME(Op.OffsetField, [f', s]) => let
145 :     (* rewrite to (f'@pos) + s *)
146 :     val lhs' = IL.Var.copy lhs
147 :     in
148 :     ST.tick cntProbeOffset;
149 :     decUse f;
150 :     SOME[
151 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
152 :     (lhs, IL.OP(Op.Add rngTy, [use lhs', use s]))
153 :     ]
154 :     end
155 :     | SOME(Op.NegField, [f']) => let
156 :     (* rewrite to -(f'@pos) *)
157 :     val lhs' = IL.Var.copy lhs
158 :     in
159 :     ST.tick cntProbeNeg;
160 :     decUse f;
161 :     incUse lhs'; incUse f';
162 :     SOME[
163 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
164 :     (lhs, IL.OP(Op.Neg rngTy, [lhs']))
165 :     ]
166 :     end
167 :     | SOME(Op.CurlField 2, [f']) => (case getRHS f'
168 :     of SOME(Op.Field dim, [v, h]) => let
169 :     (* rewrite to (D f')@pos[1,0] - (D f')@pos[0,1] *)
170 :     val (kernel, k) = getKernelRHS h
171 :     val h' = IL.Var.copy h
172 :     val f'' = IL.Var.copy f'
173 :     val mat22 = Ty.TensorTy[2,2]
174 :     val m = IL.Var.new("m", mat22)
175 :     val zero = IL.Var.new("zero", Ty.intTy)
176 :     val one = IL.Var.new("one", Ty.intTy)
177 :     val m10 = IL.Var.new("m_10", Ty.realTy)
178 :     val m01 = IL.Var.new("m_01", Ty.realTy)
179 :     in
180 :     ST.tick cntProbeCurl;
181 :     decUse f;
182 :     SOME[
183 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
184 :     (f'', IL.OP(Op.Field dim, [use v, use h'])),
185 :     (m, IL.OP(Op.Probe(domTy, mat22), [use f'', pos])),
186 :     (zero, IL.LIT(Literal.Int 0)),
187 :     (one, IL.LIT(Literal.Int 1)),
188 :     (m10, IL.OP(Op.TensorSub mat22, [use m, use one, use zero])),
189 :     (m01, IL.OP(Op.TensorSub mat22, [use m, use zero, use one])),
190 :     (lhs, IL.OP(Op.Sub Ty.realTy, [use m10, use m01]))
191 :     ]
192 :     end
193 :     | _ => raise Fail(concat[
194 :     "bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
195 :     ])
196 :     (* end case *))
197 :     | SOME(Op.CurlField 3, [f']) => (case getRHS f'
198 :     of SOME(Op.Field dim, [v, h]) => let
199 :     (* rewrite to
200 :     * [ (D f')@pos[2,1] - (D f')@pos[1,2] ]
201 :     * [ (D f')@pos[0,2] - (D f')@pos[2,0] ]
202 :     * [ (D f')@pos[1,0] - (D f')@pos[0,1] ]
203 :     *)
204 :     val (kernel, k) = getKernelRHS h
205 :     val h' = IL.Var.copy h
206 :     val f'' = IL.Var.copy f'
207 :     val mat33 = Ty.TensorTy[3,3]
208 :     val m = IL.Var.new("m", mat33)
209 :     val zero = IL.Var.new("zero", Ty.intTy)
210 :     val one = IL.Var.new("one", Ty.intTy)
211 :     val two = IL.Var.new("two", Ty.intTy)
212 :     val m21 = IL.Var.new("m_21", Ty.realTy)
213 :     val m12 = IL.Var.new("m_12", Ty.realTy)
214 :     val m02 = IL.Var.new("m_02", Ty.realTy)
215 :     val m20 = IL.Var.new("m_20", Ty.realTy)
216 :     val m10 = IL.Var.new("m_10", Ty.realTy)
217 :     val m01 = IL.Var.new("m_01", Ty.realTy)
218 :     val lhs0 = IL.Var.new("lhs0", Ty.realTy)
219 :     val lhs1 = IL.Var.new("lhs1", Ty.realTy)
220 :     val lhs2 = IL.Var.new("lhs2", Ty.realTy)
221 :     in
222 :     ST.tick cntProbeCurl;
223 :     decUse f;
224 :     SOME[
225 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
226 :     (f'', IL.OP(Op.Field dim, [use v, use h'])),
227 :     (m, IL.OP(Op.Probe(domTy, mat33), [use f'', pos])),
228 :     (zero, IL.LIT(Literal.Int 0)),
229 :     (one, IL.LIT(Literal.Int 1)),
230 :     (two, IL.LIT(Literal.Int 2)),
231 :     (m21, IL.OP(Op.TensorSub mat33, [use m, use two, use one])),
232 :     (m12, IL.OP(Op.TensorSub mat33, [use m, use one, use two])),
233 :     (lhs0, IL.OP(Op.Sub Ty.realTy, [use m21, use m12])),
234 :     (m02, IL.OP(Op.TensorSub mat33, [use m, use zero, use two])),
235 :     (m20, IL.OP(Op.TensorSub mat33, [use m, use two, use zero])),
236 :     (lhs1, IL.OP(Op.Sub Ty.realTy, [use m02, use m20])),
237 :     (m10, IL.OP(Op.TensorSub mat33, [use m, use one, use zero])),
238 :     (m01, IL.OP(Op.TensorSub mat33, [use m, use zero, use one])),
239 :     (lhs2, IL.OP(Op.Sub Ty.realTy, [use m10, use m01])),
240 : jhr 2660 (lhs, IL.CONS(Ty.TensorTy[3], [use lhs0, use lhs1, use lhs2]))
241 : jhr 2356 ]
242 :     end
243 :     | _ => raise Fail(concat[
244 :     "curl: bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
245 :     ])
246 :     (* end case *))
247 :     | SOME(Op.DiffField, _) => NONE (* need further rewriting *)
248 :     | _ => raise Fail(concat[
249 :     "probe: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
250 :     ])
251 :     (* end case *))
252 :     | (Op.DiffField, [f]) => (case (getRHS f)
253 :     of SOME(Op.Field dim, [v, h]) => let
254 :     val (kernel, k) = getKernelRHS h
255 :     val h' = IL.Var.copy h
256 :     in
257 :     ST.tick cntDiffField;
258 :     decUse f;
259 :     incUse h'; incUse v;
260 :     SOME[
261 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
262 :     (lhs, IL.OP(Op.Field dim, [v, h']))
263 :     ]
264 :     end
265 :     | SOME(Op.AddField, [f, g]) => raise Fail "Diff(f+g)"
266 :     | SOME(Op.SubField, [f, g]) => raise Fail "Diff(f-g)"
267 :     | SOME(Op.ScaleField, [s, f']) => let
268 :     (* rewrite to s*(D f) *)
269 :     val lhs' = IL.Var.copy lhs
270 :     in
271 :     ST.tick cntDiffScale;
272 :     decUse f;
273 :     SOME[
274 :     (lhs', IL.OP(Op.DiffField, [use f'])),
275 :     (lhs, IL.OP(Op.ScaleField, [use s, use lhs']))
276 :     ]
277 :     end
278 :     | SOME(Op.OffsetField, [f', s]) => (
279 :     (* rewrite to (D f) *)
280 :     ST.tick cntDiffOffset;
281 :     decUse f;
282 :     SOME[(lhs, IL.OP(Op.DiffField, [use f']))])
283 :     | SOME(Op.NegField, [f']) => let
284 :     (* rewrite to -(D f') *)
285 :     val lhs' = IL.Var.copy lhs
286 :     in
287 :     ST.tick cntDiffNeg;
288 :     decUse f;
289 :     incUse lhs'; incUse f';
290 :     SOME[
291 :     (lhs', IL.OP(Op.DiffField, [f'])),
292 :     (lhs, IL.OP(Op.NegField, [lhs']))
293 :     ]
294 :     end
295 :     | _ => NONE
296 :     (* end case *))
297 :     | (Op.CurlField dim, [f]) => (case (getRHS f)
298 :     of SOME(Op.AddField, [f, g]) => raise Fail "curl(f+g)"
299 :     | SOME(Op.SubField, [f, g]) => raise Fail "curl(f-g)"
300 :     | SOME(Op.ScaleField, [s, f']) => let
301 :     (* rewrite to s*curl(f) *)
302 :     val f'' = IL.Var.copy f'
303 : jhr 1640 in
304 : jhr 2356 ST.tick cntCurlScale;
305 :     decUse f;
306 :     SOME[
307 :     (f'', IL.OP(Op.CurlField dim, [use f'])),
308 :     (lhs, IL.OP(Op.ScaleField, [use s, use f'']))
309 :     ]
310 : jhr 1640 end
311 : jhr 2356 | SOME(Op.NegField, [f']) => let
312 :     (* rewrite to -curl(f) *)
313 :     val f'' = IL.Var.copy f'
314 :     in
315 :     ST.tick cntCurlNeg;
316 :     decUse f;
317 :     SOME[
318 :     (f'', IL.OP(Op.CurlField dim, [use f'])),
319 :     (lhs, IL.OP(Op.NegField, [use f'']))
320 :     ]
321 :     end
322 :     (* FIXME: the following is just the constant 0 field, but we don't have
323 :     * a representation of constant fields
324 :     *)
325 :     | SOME(Op.DiffField, _) => raise Fail "curl of del"
326 :     | _ => NONE
327 :     (* end case *))
328 :     | _ => NONE
329 :     (* end case *))
330 :     | doRHS _ = NONE
331 : jhr 1232
332 : jhr 2356 structure Rewrite = RewriteFn (
333 :     struct
334 :     structure IL = IL
335 :     val doAssign = doRHS
336 :     fun doMAssign _ = NONE
337 :     val elimUnusedVars = UnusedElim.reduce
338 :     end)
339 : jhr 1232
340 : jhr 2356 val transform = Rewrite.transform
341 : jhr 1232
342 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0