Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/staging/src/compiler/high-il/normalize.sml
ViewVC logotype

Annotation of /branches/staging/src/compiler/high-il/normalize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2205 - (view) (download)

1 : jhr 1232 (* normalize.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure Normalize : sig
8 :    
9 :     val transform : HighIL.program -> HighIL.program
10 :    
11 :     end = struct
12 :    
13 :     structure IL = HighIL
14 :     structure Op = HighOps
15 :     structure V = IL.Var
16 : jhr 2205 structure Ty = HighILTypes
17 : jhr 1232 structure ST = Stats
18 :    
19 :     (********** Counters for statistics **********)
20 : jhr 2205 val cntInsideScale = ST.newCounter "high-opt:inside-scale"
21 :     val cntInsideOffset = ST.newCounter "high-opt:inside-offset"
22 :     val cntInsideNeg = ST.newCounter "high-opt:inside-meg"
23 :     val cntInsideCurl = ST.newCounter "high-opt:inside-curl"
24 :     val cntInsideDiff = ST.newCounter "high-opt:inside-diff"
25 :     val cntProbeAdd = ST.newCounter "high-opt:probe-add"
26 :     val cntProbeSub = ST.newCounter "high-opt:probe-sub"
27 :     val cntProbeScale = ST.newCounter "high-opt:probe-scale"
28 :     val cntProbeOffset = ST.newCounter "high-opt:probe-offset"
29 :     val cntProbeNeg = ST.newCounter "high-opt:probe-neg"
30 :     val cntProbeCurl = ST.newCounter "high-opt:probe-curl"
31 :     val cntDiffField = ST.newCounter "high-opt:diff-field"
32 :     val cntDiffAdd = ST.newCounter "high-opt:diff-add"
33 :     val cntDiffScale = ST.newCounter "high-opt:diff-scale"
34 :     val cntDiffOffset = ST.newCounter "high-opt:diff-offset"
35 :     val cntDiffNeg = ST.newCounter "high-opt:diff-neg"
36 :     val cntCurlScale = ST.newCounter "high-opt:curl-scale"
37 :     val cntCurlNeg = ST.newCounter "high-opt:curl-neg"
38 :     val cntUnused = ST.newCounter "high-opt:unused"
39 :     val firstCounter = cntInsideScale
40 : jhr 1232 val lastCounter = cntUnused
41 :    
42 :     structure UnusedElim = UnusedElimFn (
43 : jhr 2205 structure IL = IL
44 :     val cntUnused = cntUnused)
45 : jhr 1232
46 :     fun useCount (IL.V{useCnt, ...}) = !useCnt
47 :    
48 :     (* adjust a variable's use count *)
49 :     fun incUse (IL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50 :     fun decUse (IL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51 : jhr 2205 fun use x = (incUse x; x)
52 : jhr 1232
53 :     fun getRHS x = (case V.binding x
54 : jhr 2205 of IL.VB_RHS(IL.OP arg) => SOME arg
55 :     | IL.VB_RHS(IL.VAR x') => getRHS x'
56 :     | _ => NONE
57 :     (* end case *))
58 : jhr 1232
59 : jhr 2205 (* get the binding of a kernel variable *)
60 :     fun getKernelRHS h = (case getRHS h
61 :     of SOME(Op.Kernel(kernel, k), []) => (kernel, k)
62 :     | _ => raise Fail(concat[
63 :     "bogus kernel binding ", V.toString h, " = ", IL.vbToString(V.binding h)
64 :     ])
65 :     (* end case *))
66 :    
67 : jhr 1232 (* optimize the rhs of an assignment, returning NONE if there is no change *)
68 :     fun doRHS (lhs, IL.OP rhs) = (case rhs
69 : jhr 2205 of (Op.Inside dim, [pos, f]) => (case getRHS f
70 :     of SOME(Op.Field _, _) => NONE (* direct inside test does not need rewrite *)
71 :     | SOME(Op.AddField, [f', g']) => raise Fail "inside(f+g)"
72 :     | SOME(Op.SubField, [f', g']) => raise Fail "inside(f-g)"
73 :     | SOME(Op.ScaleField, [_, f']) => (
74 :     ST.tick cntInsideScale;
75 :     decUse f;
76 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
77 :     | SOME(Op.OffsetField, [f', _]) => (
78 :     ST.tick cntInsideOffset;
79 :     decUse f;
80 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
81 :     | SOME(Op.NegField, [f']) => (
82 :     ST.tick cntInsideNeg;
83 :     decUse f;
84 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
85 :     | SOME(Op.CurlField _, [f']) => (
86 :     ST.tick cntInsideCurl;
87 :     decUse f;
88 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
89 :     | SOME(Op.DiffField, [f']) => (
90 :     ST.tick cntInsideDiff;
91 :     decUse f;
92 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
93 :     | _ => raise Fail(concat[
94 :     "inside: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
95 :     ])
96 :     (* end case *))
97 :     | (Op.Probe(domTy, rngTy), [f, pos]) => (case getRHS f
98 :     of SOME(Op.Field _, _) => NONE (* direct probe does not need rewrite *)
99 :     | SOME(Op.AddField, [f', g']) => let
100 :     (* rewrite to (f@pos) + (g@pos) *)
101 :     val lhs1 = IL.Var.copy lhs
102 :     val lhs2 = IL.Var.copy lhs
103 :     in
104 :     ST.tick cntProbeAdd;
105 :     decUse f;
106 :     incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
107 :     SOME[
108 :     (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
109 :     (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
110 :     (lhs, IL.OP(Op.Add rngTy, [lhs1, lhs2]))
111 :     ]
112 :     end
113 :     | SOME(Op.SubField, [f', g']) => let
114 :     (* rewrite to (f@pos) - (g@pos) *)
115 :     val lhs1 = IL.Var.copy lhs
116 :     val lhs2 = IL.Var.copy lhs
117 :     in
118 :     ST.tick cntProbeSub;
119 :     decUse f;
120 :     incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
121 :     SOME[
122 :     (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
123 :     (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
124 :     (lhs, IL.OP(Op.Sub rngTy, [lhs1, lhs2]))
125 :     ]
126 :     end
127 :     | SOME(Op.ScaleField, [s, f']) => let
128 :     (* rewrite to s*(f'@pos) *)
129 :     val lhs' = IL.Var.copy lhs
130 :     val scaleOp = (case rngTy
131 :     of Ty.TensorTy[] => Op.Mul rngTy
132 :     | _ => Op.Scale rngTy
133 :     (* end case *))
134 :     in
135 :     ST.tick cntProbeScale;
136 :     decUse f;
137 :     SOME[
138 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
139 :     (lhs, IL.OP(scaleOp, [use s, use lhs']))
140 :     ]
141 :     end
142 :     | SOME(Op.OffsetField, [f', s]) => let
143 :     (* rewrite to (f'@pos) + s *)
144 :     val lhs' = IL.Var.copy lhs
145 :     in
146 :     ST.tick cntProbeOffset;
147 :     decUse f;
148 :     SOME[
149 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
150 :     (lhs, IL.OP(Op.Add rngTy, [use lhs', use s]))
151 :     ]
152 :     end
153 :     | SOME(Op.NegField, [f']) => let
154 :     (* rewrite to -(f'@pos) *)
155 :     val lhs' = IL.Var.copy lhs
156 :     in
157 :     ST.tick cntProbeNeg;
158 :     decUse f;
159 :     incUse lhs'; incUse f';
160 :     SOME[
161 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
162 :     (lhs, IL.OP(Op.Neg rngTy, [lhs']))
163 :     ]
164 :     end
165 :     | SOME(Op.CurlField 2, [f']) => (case getRHS f'
166 :     of SOME(Op.Field dim, [v, h]) => let
167 :     (* rewrite to (D f')@pos[1,0] - (D f')@pos[0,1] *)
168 :     val (kernel, k) = getKernelRHS h
169 :     val h' = IL.Var.copy h
170 :     val f'' = IL.Var.copy f'
171 :     val mat22 = Ty.TensorTy[2,2]
172 :     val m = IL.Var.new("m", mat22)
173 :     val zero = IL.Var.new("zero", Ty.intTy)
174 :     val one = IL.Var.new("one", Ty.intTy)
175 :     val m10 = IL.Var.new("m_10", Ty.realTy)
176 :     val m01 = IL.Var.new("m_01", Ty.realTy)
177 :     in
178 :     ST.tick cntProbeCurl;
179 :     decUse f;
180 :     SOME[
181 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
182 :     (f'', IL.OP(Op.Field dim, [use v, use h'])),
183 :     (m, IL.OP(Op.Probe(domTy, mat22), [use f'', pos])),
184 :     (zero, IL.LIT(Literal.Int 0)),
185 :     (one, IL.LIT(Literal.Int 1)),
186 :     (m10, IL.OP(Op.TensorSub mat22, [use m, use one, use zero])),
187 :     (m01, IL.OP(Op.TensorSub mat22, [use m, use zero, use one])),
188 :     (lhs, IL.OP(Op.Sub Ty.realTy, [use m10, use m01]))
189 :     ]
190 :     end
191 :     | _ => raise Fail(concat[
192 :     "bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
193 :     ])
194 :     (* end case *))
195 :     | SOME(Op.CurlField 3, [f']) => (case getRHS f'
196 :     of SOME(Op.Field dim, [v, h]) => let
197 :     (* rewrite to
198 :     * [ (D f')@pos[2,1] - (D f')@pos[1,2] ]
199 :     * [ (D f')@pos[0,2] - (D f')@pos[2,0] ]
200 :     * [ (D f')@pos[1,0] - (D f')@pos[0,1] ]
201 :     *)
202 :     val (kernel, k) = getKernelRHS h
203 :     val h' = IL.Var.copy h
204 :     val f'' = IL.Var.copy f'
205 :     val mat33 = Ty.TensorTy[3,3]
206 :     val m = IL.Var.new("m", mat33)
207 :     val zero = IL.Var.new("zero", Ty.intTy)
208 :     val one = IL.Var.new("one", Ty.intTy)
209 :     val two = IL.Var.new("two", Ty.intTy)
210 :     val m21 = IL.Var.new("m_21", Ty.realTy)
211 :     val m12 = IL.Var.new("m_12", Ty.realTy)
212 :     val m02 = IL.Var.new("m_02", Ty.realTy)
213 :     val m20 = IL.Var.new("m_20", Ty.realTy)
214 :     val m10 = IL.Var.new("m_10", Ty.realTy)
215 :     val m01 = IL.Var.new("m_01", Ty.realTy)
216 :     val lhs0 = IL.Var.new("lhs0", Ty.realTy)
217 :     val lhs1 = IL.Var.new("lhs1", Ty.realTy)
218 :     val lhs2 = IL.Var.new("lhs2", Ty.realTy)
219 :     in
220 :     ST.tick cntProbeCurl;
221 :     decUse f;
222 :     SOME[
223 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
224 :     (f'', IL.OP(Op.Field dim, [use v, use h'])),
225 :     (m, IL.OP(Op.Probe(domTy, mat33), [use f'', pos])),
226 :     (zero, IL.LIT(Literal.Int 0)),
227 :     (one, IL.LIT(Literal.Int 1)),
228 :     (two, IL.LIT(Literal.Int 2)),
229 :     (m21, IL.OP(Op.TensorSub mat33, [use m, use two, use one])),
230 :     (m12, IL.OP(Op.TensorSub mat33, [use m, use one, use two])),
231 :     (lhs0, IL.OP(Op.Sub Ty.realTy, [use m21, use m12])),
232 :     (m02, IL.OP(Op.TensorSub mat33, [use m, use zero, use two])),
233 :     (m20, IL.OP(Op.TensorSub mat33, [use m, use two, use zero])),
234 :     (lhs1, IL.OP(Op.Sub Ty.realTy, [use m02, use m20])),
235 :     (m10, IL.OP(Op.TensorSub mat33, [use m, use one, use zero])),
236 :     (m01, IL.OP(Op.TensorSub mat33, [use m, use zero, use one])),
237 :     (lhs2, IL.OP(Op.Sub Ty.realTy, [use m10, use m01])),
238 :     (lhs, IL.CONS(Ty.TensorTy[3], [lhs0, lhs1, lhs2]))
239 :     ]
240 :     end
241 :     | _ => raise Fail(concat[
242 :     "curl: bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
243 :     ])
244 :     (* end case *))
245 :     | SOME(Op.DiffField, _) => NONE (* need further rewriting *)
246 :     | _ => raise Fail(concat[
247 :     "probe: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
248 :     ])
249 :     (* end case *))
250 :     | (Op.DiffField, [f]) => (case (getRHS f)
251 :     of SOME(Op.Field dim, [v, h]) => let
252 :     val (kernel, k) = getKernelRHS h
253 :     val h' = IL.Var.copy h
254 :     in
255 :     ST.tick cntDiffField;
256 :     decUse f;
257 :     incUse h'; incUse v;
258 :     SOME[
259 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
260 :     (lhs, IL.OP(Op.Field dim, [v, h']))
261 :     ]
262 :     end
263 :     | SOME(Op.AddField, [f, g]) => raise Fail "Diff(f+g)"
264 :     | SOME(Op.SubField, [f, g]) => raise Fail "Diff(f-g)"
265 :     | SOME(Op.ScaleField, [s, f']) => let
266 :     (* rewrite to s*(D f) *)
267 :     val lhs' = IL.Var.copy lhs
268 :     in
269 :     ST.tick cntDiffScale;
270 :     decUse f;
271 :     SOME[
272 :     (lhs', IL.OP(Op.DiffField, [use f'])),
273 :     (lhs, IL.OP(Op.ScaleField, [use s, use lhs']))
274 :     ]
275 :     end
276 :     | SOME(Op.OffsetField, [f', s]) => (
277 :     (* rewrite to (D f) *)
278 :     ST.tick cntDiffOffset;
279 :     decUse f;
280 :     SOME[(lhs, IL.OP(Op.DiffField, [use f']))])
281 :     | SOME(Op.NegField, [f']) => let
282 :     (* rewrite to -(D f') *)
283 :     val lhs' = IL.Var.copy lhs
284 :     in
285 :     ST.tick cntDiffNeg;
286 :     decUse f;
287 :     incUse lhs'; incUse f';
288 :     SOME[
289 :     (lhs', IL.OP(Op.DiffField, [f'])),
290 :     (lhs, IL.OP(Op.NegField, [lhs']))
291 :     ]
292 :     end
293 :     | _ => NONE
294 :     (* end case *))
295 :     | (Op.CurlField dim, [f]) => (case (getRHS f)
296 :     of SOME(Op.AddField, [f, g]) => raise Fail "curl(f+g)"
297 :     | SOME(Op.SubField, [f, g]) => raise Fail "curl(f-g)"
298 :     | SOME(Op.ScaleField, [s, f']) => let
299 :     (* rewrite to s*curl(f) *)
300 :     val f'' = IL.Var.copy f'
301 : jhr 1640 in
302 : jhr 2205 ST.tick cntCurlScale;
303 :     decUse f;
304 :     SOME[
305 :     (f'', IL.OP(Op.CurlField dim, [use f'])),
306 :     (lhs, IL.OP(Op.ScaleField, [use s, use f'']))
307 :     ]
308 : jhr 1640 end
309 : jhr 2205 | SOME(Op.NegField, [f']) => let
310 :     (* rewrite to -curl(f) *)
311 :     val f'' = IL.Var.copy f'
312 :     in
313 :     ST.tick cntCurlNeg;
314 :     decUse f;
315 :     SOME[
316 :     (f'', IL.OP(Op.CurlField dim, [use f'])),
317 :     (lhs, IL.OP(Op.NegField, [use f'']))
318 :     ]
319 :     end
320 :     (* FIXME: the following is just the constant 0 field, but we don't have
321 :     * a representation of constant fields
322 :     *)
323 :     | SOME(Op.DiffField, _) => raise Fail "curl of del"
324 :     | _ => NONE
325 :     (* end case *))
326 :     | _ => NONE
327 :     (* end case *))
328 :     | doRHS _ = NONE
329 : jhr 1232
330 : jhr 2205 structure Rewrite = RewriteFn (
331 :     struct
332 :     structure IL = IL
333 :     val doAssign = doRHS
334 :     fun doMAssign _ = NONE
335 :     val elimUnusedVars = UnusedElim.reduce
336 :     end)
337 : jhr 1232
338 : jhr 2205 val transform = Rewrite.transform
339 : jhr 1232
340 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0