Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/high-il/normalize.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/high-il/normalize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2182 - (view) (download)

1 : jhr 1232 (* normalize.sml
2 :     *
3 :     * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure Normalize : sig
8 :    
9 :     val transform : HighIL.program -> HighIL.program
10 :    
11 :     end = struct
12 :    
13 :     structure IL = HighIL
14 :     structure Op = HighOps
15 :     structure V = IL.Var
16 : jhr 2165 structure Ty = HighILTypes
17 : jhr 1232 structure ST = Stats
18 :    
19 :     (********** Counters for statistics **********)
20 : jhr 2173 val cntInsideScale = ST.newCounter "high-opt:inside-scale"
21 :     val cntInsideOffset = ST.newCounter "high-opt:inside-offset"
22 :     val cntInsideNeg = ST.newCounter "high-opt:inside-meg"
23 :     val cntInsideCurl = ST.newCounter "high-opt:inside-curl"
24 :     val cntInsideDiff = ST.newCounter "high-opt:inside-diff"
25 : jhr 1232 val cntProbeAdd = ST.newCounter "high-opt:probe-add"
26 :     val cntProbeSub = ST.newCounter "high-opt:probe-sub"
27 :     val cntProbeScale = ST.newCounter "high-opt:probe-scale"
28 : jhr 2173 val cntProbeOffset = ST.newCounter "high-opt:probe-offset"
29 : jhr 1232 val cntProbeNeg = ST.newCounter "high-opt:probe-neg"
30 : jhr 2173 val cntProbeCurl = ST.newCounter "high-opt:probe-curl"
31 : jhr 1232 val cntDiffField = ST.newCounter "high-opt:diff-field"
32 :     val cntDiffAdd = ST.newCounter "high-opt:diff-add"
33 :     val cntDiffScale = ST.newCounter "high-opt:diff-scale"
34 : jhr 2173 val cntDiffOffset = ST.newCounter "high-opt:diff-offset"
35 : jhr 1232 val cntDiffNeg = ST.newCounter "high-opt:diff-neg"
36 :     val cntUnused = ST.newCounter "high-opt:unused"
37 :     val firstCounter = cntProbeAdd
38 :     val lastCounter = cntUnused
39 :    
40 :     structure UnusedElim = UnusedElimFn (
41 :     structure IL = IL
42 :     val cntUnused = cntUnused)
43 :    
44 :     fun useCount (IL.V{useCnt, ...}) = !useCnt
45 :    
46 :     (* adjust a variable's use count *)
47 :     fun incUse (IL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
48 :     fun decUse (IL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
49 : jhr 2165 fun use x = (incUse x; x)
50 : jhr 1232
51 :     fun getRHS x = (case V.binding x
52 :     of IL.VB_RHS(IL.OP arg) => SOME arg
53 :     | IL.VB_RHS(IL.VAR x') => getRHS x'
54 :     | _ => NONE
55 :     (* end case *))
56 :    
57 : jhr 2165 (* get the binding of a kernel variable *)
58 :     fun getKernelRHS h = (case getRHS h
59 :     of SOME(Op.Kernel(kernel, k), []) => (kernel, k)
60 :     | _ => raise Fail(concat[
61 :     "bogus kernel binding ", V.toString h, " = ", IL.vbToString(V.binding h)
62 :     ])
63 :     (* end case *))
64 :    
65 : jhr 1232 (* optimize the rhs of an assignment, returning NONE if there is no change *)
66 :     fun doRHS (lhs, IL.OP rhs) = (case rhs
67 : jhr 2173 of (Op.Inside dim, [pos, f]) => (case getRHS f
68 :     of SOME(Op.Field _, _) => NONE (* direct inside test does not need rewrite *)
69 :     | SOME(Op.AddField, [f', g']) => raise Fail "inside(f+g)"
70 :     | SOME(Op.SubField, [f', g']) => raise Fail "inside(f-g)"
71 :     | SOME(Op.ScaleField, [_, f']) => (
72 :     ST.tick cntInsideScale;
73 :     decUse f;
74 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
75 :     | SOME(Op.OffsetField, [f', _]) => (
76 :     ST.tick cntInsideOffset;
77 :     decUse f;
78 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
79 :     | SOME(Op.NegField, [f']) => (
80 :     ST.tick cntInsideNeg;
81 :     decUse f;
82 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
83 :     | SOME(Op.CurlField _, [f']) => (
84 :     ST.tick cntInsideCurl;
85 :     decUse f;
86 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
87 :     | SOME(Op.DiffField, [f']) => (
88 :     ST.tick cntInsideDiff;
89 :     decUse f;
90 :     SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
91 :     | _ => raise Fail(concat[
92 :     "inside: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
93 :     ])
94 :     (* end case *))
95 :     | (Op.Probe(domTy, rngTy), [f, pos]) => (case getRHS f
96 : jhr 1232 of SOME(Op.Field _, _) => NONE (* direct probe does not need rewrite *)
97 :     | SOME(Op.AddField, [f', g']) => let
98 :     (* rewrite to (f@pos) + (g@pos) *)
99 :     val lhs1 = IL.Var.copy lhs
100 :     val lhs2 = IL.Var.copy lhs
101 :     in
102 :     ST.tick cntProbeAdd;
103 :     decUse f;
104 :     incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
105 :     SOME[
106 :     (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
107 :     (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
108 :     (lhs, IL.OP(Op.Add rngTy, [lhs1, lhs2]))
109 :     ]
110 :     end
111 :     | SOME(Op.SubField, [f', g']) => let
112 :     (* rewrite to (f@pos) - (g@pos) *)
113 :     val lhs1 = IL.Var.copy lhs
114 :     val lhs2 = IL.Var.copy lhs
115 :     in
116 :     ST.tick cntProbeSub;
117 :     decUse f;
118 :     incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
119 :     SOME[
120 :     (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
121 :     (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
122 :     (lhs, IL.OP(Op.Sub rngTy, [lhs1, lhs2]))
123 :     ]
124 :     end
125 :     | SOME(Op.ScaleField, [s, f']) => let
126 :     (* rewrite to s*(f'@pos) *)
127 :     val lhs' = IL.Var.copy lhs
128 : jhr 2182 val scaleOp = (case rngTy
129 :     of Ty.TensorTy[] => Op.Mul rngTy
130 :     | _ => Op.Scale rngTy
131 :     (* end case *))
132 : jhr 1232 in
133 :     ST.tick cntProbeScale;
134 :     decUse f;
135 :     SOME[
136 : jhr 2173 (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
137 : jhr 2182 (lhs, IL.OP(scaleOp, [use s, use lhs']))
138 : jhr 1232 ]
139 :     end
140 : jhr 2173 | SOME(Op.OffsetField, [f', s]) => let
141 :     (* rewrite to (f'@pos) + s *)
142 :     val lhs' = IL.Var.copy lhs
143 :     in
144 :     ST.tick cntProbeOffset;
145 :     decUse f;
146 :     SOME[
147 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
148 :     (lhs, IL.OP(Op.Add rngTy, [use lhs', use s]))
149 :     ]
150 :     end
151 : jhr 1232 | SOME(Op.NegField, [f']) => let
152 :     (* rewrite to -(f'@pos) *)
153 :     val lhs' = IL.Var.copy lhs
154 :     in
155 :     ST.tick cntProbeNeg;
156 :     decUse f;
157 :     incUse lhs'; incUse f';
158 :     SOME[
159 :     (lhs', IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
160 :     (lhs, IL.OP(Op.Neg rngTy, [lhs']))
161 :     ]
162 :     end
163 : jhr 2165 | SOME(Op.CurlField 2, [f']) => (case getRHS f'
164 :     of SOME(Op.Field dim, [v, h]) => let
165 :     (* rewrite to (D f')@pos[1,0] - (D f')@pos[0,1] *)
166 :     val (kernel, k) = getKernelRHS h
167 :     val h' = IL.Var.copy h
168 :     val f'' = IL.Var.copy f'
169 :     val mat22 = Ty.TensorTy[2,2]
170 :     val m = IL.Var.new("m", mat22)
171 :     val zero = IL.Var.new("zero", Ty.intTy)
172 :     val one = IL.Var.new("one", Ty.intTy)
173 :     val m10 = IL.Var.new("m_10", Ty.realTy)
174 :     val m01 = IL.Var.new("m_01", Ty.realTy)
175 :     in
176 : jhr 2173 ST.tick cntProbeCurl;
177 : jhr 2165 decUse f;
178 :     SOME[
179 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
180 :     (f'', IL.OP(Op.Field dim, [use v, use h'])),
181 :     (m, IL.OP(Op.Probe(domTy, mat22), [use f'', pos])),
182 :     (zero, IL.LIT(Literal.Int 0)),
183 :     (one, IL.LIT(Literal.Int 1)),
184 :     (m10, IL.OP(Op.TensorSub mat22, [use m, use one, use zero])),
185 :     (m01, IL.OP(Op.TensorSub mat22, [use m, use zero, use one])),
186 :     (lhs, IL.OP(Op.Sub Ty.realTy, [use m10, use m01]))
187 :     ]
188 :     end
189 :     | _ => raise Fail(concat[
190 :     "bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
191 :     ])
192 :     (* end case *))
193 :     | SOME(Op.CurlField 3, [f']) => (case getRHS f'
194 :     of SOME(Op.Field dim, [v, h]) => let
195 :     (* rewrite to
196 :     * [ (D f')@pos[2,1] - (D f')@pos[1,2] ]
197 :     * [ (D f')@pos[0,2] - (D f')@pos[2,0] ]
198 :     * [ (D f')@pos[1,0] - (D f')@pos[0,1] ]
199 :     *)
200 :     val (kernel, k) = getKernelRHS h
201 :     val h' = IL.Var.copy h
202 :     val f'' = IL.Var.copy f'
203 :     val mat33 = Ty.TensorTy[3,3]
204 :     val m = IL.Var.new("m", mat33)
205 :     val zero = IL.Var.new("zero", Ty.intTy)
206 :     val one = IL.Var.new("one", Ty.intTy)
207 :     val two = IL.Var.new("two", Ty.intTy)
208 :     val m21 = IL.Var.new("m_21", Ty.realTy)
209 :     val m12 = IL.Var.new("m_12", Ty.realTy)
210 :     val m02 = IL.Var.new("m_02", Ty.realTy)
211 :     val m20 = IL.Var.new("m_20", Ty.realTy)
212 :     val m10 = IL.Var.new("m_10", Ty.realTy)
213 :     val m01 = IL.Var.new("m_01", Ty.realTy)
214 :     val lhs0 = IL.Var.new("lhs0", Ty.realTy)
215 :     val lhs1 = IL.Var.new("lhs1", Ty.realTy)
216 :     val lhs2 = IL.Var.new("lhs2", Ty.realTy)
217 :     in
218 : jhr 2173 ST.tick cntProbeCurl;
219 : jhr 2165 decUse f;
220 :     SOME[
221 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
222 :     (f'', IL.OP(Op.Field dim, [use v, use h'])),
223 :     (m, IL.OP(Op.Probe(domTy, mat33), [use f'', pos])),
224 :     (zero, IL.LIT(Literal.Int 0)),
225 :     (one, IL.LIT(Literal.Int 1)),
226 :     (two, IL.LIT(Literal.Int 2)),
227 :     (m21, IL.OP(Op.TensorSub mat33, [use m, use two, use one])),
228 :     (m12, IL.OP(Op.TensorSub mat33, [use m, use one, use two])),
229 :     (lhs0, IL.OP(Op.Sub Ty.realTy, [use m21, use m12])),
230 :     (m02, IL.OP(Op.TensorSub mat33, [use m, use zero, use two])),
231 :     (m20, IL.OP(Op.TensorSub mat33, [use m, use two, use zero])),
232 :     (lhs1, IL.OP(Op.Sub Ty.realTy, [use m02, use m20])),
233 :     (m10, IL.OP(Op.TensorSub mat33, [use m, use one, use zero])),
234 :     (m01, IL.OP(Op.TensorSub mat33, [use m, use zero, use one])),
235 :     (lhs2, IL.OP(Op.Sub Ty.realTy, [use m10, use m01])),
236 :     (lhs, IL.CONS(Ty.TensorTy[3], [lhs0, lhs1, lhs2]))
237 :     ]
238 :     end
239 :     | _ => raise Fail(concat[
240 : jhr 2173 "curl: bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
241 : jhr 2165 ])
242 :     (* end case *))
243 : jhr 2173 | SOME(Op.DiffField, _) => NONE (* need further rewriting *)
244 : jhr 1232 | _ => raise Fail(concat[
245 : jhr 2173 "probe: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
246 : jhr 1232 ])
247 :     (* end case *))
248 :     | (Op.DiffField, [f]) => (case (getRHS f)
249 : jhr 2165 of SOME(Op.Field dim, [v, h]) => let
250 :     val (kernel, k) = getKernelRHS h
251 :     val h' = IL.Var.copy h
252 :     in
253 :     ST.tick cntDiffField;
254 :     decUse f;
255 :     incUse h'; incUse v;
256 :     SOME[
257 :     (h', IL.OP(Op.Kernel(kernel, k+1), [])),
258 :     (lhs, IL.OP(Op.Field dim, [v, h']))
259 :     ]
260 :     end
261 : jhr 1232 | SOME(Op.AddField, [f, g]) => raise Fail "Diff(f+g)"
262 :     | SOME(Op.SubField, [f, g]) => raise Fail "Diff(f-g)"
263 :     | SOME(Op.ScaleField, [s, f']) => let
264 :     (* rewrite to s*(D f) *)
265 :     val lhs' = IL.Var.copy lhs
266 :     in
267 :     ST.tick cntDiffScale;
268 :     decUse f;
269 :     incUse lhs'; incUse f'; incUse s;
270 :     SOME[
271 :     (lhs', IL.OP(Op.DiffField, [f'])),
272 :     (lhs, IL.OP(Op.ScaleField, [s, lhs']))
273 :     ]
274 :     end
275 : jhr 2173 | SOME(Op.OffsetField, [f', s]) => (
276 :     (* rewrite to (D f) *)
277 :     ST.tick cntDiffOffset;
278 :     decUse f;
279 :     SOME[(lhs, IL.OP(Op.DiffField, [use f']))])
280 : jhr 1232 | SOME(Op.NegField, [f']) => let
281 :     (* rewrite to -(D f') *)
282 :     val lhs' = IL.Var.copy lhs
283 :     in
284 :     ST.tick cntDiffNeg;
285 :     decUse f;
286 :     incUse lhs'; incUse f';
287 :     SOME[
288 :     (lhs', IL.OP(Op.DiffField, [f'])),
289 :     (lhs, IL.OP(Op.NegField, [lhs']))
290 :     ]
291 :     end
292 :     | _ => NONE
293 :     (* end case *))
294 : jhr 2165 | (Op.CurlField _, [f]) => (case (getRHS f)
295 :     (* FIXME: the following is just the constant 0 field, but we don't have
296 :     * a representation of constant fields
297 :     *)
298 :     of SOME(Op.AddField, [f, g]) => raise Fail "curl(f+g)"
299 :     | SOME(Op.SubField, [f, g]) => raise Fail "curl(f-g)"
300 :     | SOME(Op.ScaleField, [s, f']) => raise Fail "curl(s*f)"
301 :     | SOME(Op.NegField, [f']) => raise Fail "curl(-f)"
302 :     | SOME(Op.DiffField, _) => raise Fail "curl of del"
303 :     | _ => NONE
304 :     (* end case *))
305 : jhr 1232 | _ => NONE
306 :     (* end case *))
307 :     | doRHS _ = NONE
308 :    
309 :     (* simplify expressions *)
310 :     fun simplify (nd as IL.ND{kind=IL.ASSIGN{stm=(y, rhs), ...}, ...}) =
311 :     if (useCount y = 0)
312 :     then () (* skip unused assignments *)
313 :     else (case doRHS(y, rhs)
314 :     of SOME[] => IL.CFG.deleteNode nd
315 : jhr 1640 | SOME assigns => let
316 :     val assigns = List.map
317 :     (fn (y, rhs) => (V.setBinding(y, IL.VB_RHS rhs); IL.ASSGN(y, rhs)))
318 :     assigns
319 :     in
320 :     IL.CFG.replaceNodeWithCFG (nd, IL.CFG.mkBlock assigns)
321 :     end
322 : jhr 1232 | NONE => ()
323 :     (* end case *))
324 :     | simplify _ = ()
325 :    
326 :     fun loopToFixPt f = let
327 :     fun loop n = let
328 :     val () = f ()
329 :     val n' = Stats.sum{from=firstCounter, to=lastCounter}
330 :     in
331 :     if (n = n') then () else loop n'
332 :     end
333 :     in
334 :     loop (Stats.sum{from=firstCounter, to=lastCounter})
335 :     end
336 :    
337 : jhr 1640 fun transform (prog as IL.Program{props, globalInit, initially, strands}) = let
338 : jhr 1232 fun doCFG cfg = (
339 :     loopToFixPt (fn () => IL.CFG.apply simplify cfg);
340 :     loopToFixPt (fn () => ignore(UnusedElim.reduce cfg)))
341 :     fun doMethod (IL.Method{body, ...}) = doCFG body
342 :     fun doStrand (IL.Strand{stateInit, methods, ...}) = (
343 :     doCFG stateInit;
344 :     List.app doMethod methods)
345 :     fun optPass () = (
346 :     doCFG globalInit;
347 :     List.app doStrand strands)
348 :     in
349 :     loopToFixPt optPass;
350 :     (* FIXME: after optimization, we should filter out any globals that are now unused *)
351 :     IL.Program{
352 : jhr 1640 props = props,
353 : jhr 1232 globalInit = globalInit,
354 :     initially = initially, (* FIXME: we should optimize this code *)
355 :     strands = strands
356 :     }
357 :     end
358 :    
359 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0