Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/staging/src/compiler/high-il/normalize.sml
ViewVC logotype

Diff of /branches/staging/src/compiler/high-il/normalize.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2204, Mon Feb 25 14:11:34 2013 UTC revision 2205, Mon Feb 25 14:38:35 2013 UTC
# Line 13  Line 13 
13      structure IL = HighIL      structure IL = HighIL
14      structure Op = HighOps      structure Op = HighOps
15      structure V = IL.Var      structure V = IL.Var
16        structure Ty = HighILTypes
17      structure ST = Stats      structure ST = Stats
18    
19    (********** Counters for statistics **********)    (********** Counters for statistics **********)
20        val cntInsideScale          = ST.newCounter "high-opt:inside-scale"
21        val cntInsideOffset         = ST.newCounter "high-opt:inside-offset"
22        val cntInsideNeg            = ST.newCounter "high-opt:inside-meg"
23        val cntInsideCurl           = ST.newCounter "high-opt:inside-curl"
24        val cntInsideDiff           = ST.newCounter "high-opt:inside-diff"
25      val cntProbeAdd             = ST.newCounter "high-opt:probe-add"      val cntProbeAdd             = ST.newCounter "high-opt:probe-add"
26      val cntProbeSub             = ST.newCounter "high-opt:probe-sub"      val cntProbeSub             = ST.newCounter "high-opt:probe-sub"
27      val cntProbeScale           = ST.newCounter "high-opt:probe-scale"      val cntProbeScale           = ST.newCounter "high-opt:probe-scale"
28        val cntProbeOffset          = ST.newCounter "high-opt:probe-offset"
29      val cntProbeNeg             = ST.newCounter "high-opt:probe-neg"      val cntProbeNeg             = ST.newCounter "high-opt:probe-neg"
30        val cntProbeCurl            = ST.newCounter "high-opt:probe-curl"
31      val cntDiffField            = ST.newCounter "high-opt:diff-field"      val cntDiffField            = ST.newCounter "high-opt:diff-field"
32      val cntDiffAdd              = ST.newCounter "high-opt:diff-add"      val cntDiffAdd              = ST.newCounter "high-opt:diff-add"
33      val cntDiffScale            = ST.newCounter "high-opt:diff-scale"      val cntDiffScale            = ST.newCounter "high-opt:diff-scale"
34        val cntDiffOffset           = ST.newCounter "high-opt:diff-offset"
35      val cntDiffNeg              = ST.newCounter "high-opt:diff-neg"      val cntDiffNeg              = ST.newCounter "high-opt:diff-neg"
36        val cntCurlScale            = ST.newCounter "high-opt:curl-scale"
37        val cntCurlNeg              = ST.newCounter "high-opt:curl-neg"
38      val cntUnused               = ST.newCounter "high-opt:unused"      val cntUnused               = ST.newCounter "high-opt:unused"
39      val firstCounter            = cntProbeAdd      val firstCounter            = cntInsideScale
40      val lastCounter             = cntUnused      val lastCounter             = cntUnused
41    
42      structure UnusedElim = UnusedElimFn (      structure UnusedElim = UnusedElimFn (
# Line 37  Line 48 
48    (* adjust a variable's use count *)    (* adjust a variable's use count *)
49      fun incUse (IL.V{useCnt, ...}) = (useCnt := !useCnt + 1)      fun incUse (IL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50      fun decUse (IL.V{useCnt, ...}) = (useCnt := !useCnt - 1)      fun decUse (IL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51        fun use x = (incUse x; x)
52    
53      fun getRHS x = (case V.binding x      fun getRHS x = (case V.binding x
54             of IL.VB_RHS(IL.OP arg) => SOME arg             of IL.VB_RHS(IL.OP arg) => SOME arg
# Line 44  Line 56 
56              | _ => NONE              | _ => NONE
57            (* end case *))            (* end case *))
58    
59      (* get the binding of a kernel variable *)
60        fun getKernelRHS h = (case getRHS h
61               of SOME(Op.Kernel(kernel, k), []) => (kernel, k)
62                | _ => raise Fail(concat[
63                      "bogus kernel binding ", V.toString h, " = ", IL.vbToString(V.binding h)
64                    ])
65              (* end case *))
66    
67    (* optimize the rhs of an assignment, returning NONE if there is no change *)    (* optimize the rhs of an assignment, returning NONE if there is no change *)
68      fun doRHS (lhs, IL.OP rhs) = (case rhs      fun doRHS (lhs, IL.OP rhs) = (case rhs
69             of (Op.Probe(domTy, rngTy), [f, pos]) => (case getRHS f             of (Op.Inside dim, [pos, f]) => (case getRHS f
70                     of SOME(Op.Field _, _) => NONE (* direct inside test does not need rewrite *)
71                      | SOME(Op.AddField, [f', g']) => raise Fail "inside(f+g)"
72                      | SOME(Op.SubField, [f', g']) => raise Fail "inside(f-g)"
73                      | SOME(Op.ScaleField, [_, f']) => (
74                          ST.tick cntInsideScale;
75                          decUse f;
76                          SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
77                      | SOME(Op.OffsetField, [f', _]) => (
78                          ST.tick cntInsideOffset;
79                          decUse f;
80                          SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
81                      | SOME(Op.NegField, [f']) => (
82                          ST.tick cntInsideNeg;
83                          decUse f;
84                          SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
85                      | SOME(Op.CurlField _, [f']) => (
86                          ST.tick cntInsideCurl;
87                          decUse f;
88                          SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
89                      | SOME(Op.DiffField, [f']) => (
90                          ST.tick cntInsideDiff;
91                          decUse f;
92                          SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
93                      | _ => raise Fail(concat[
94                            "inside: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
95                          ])
96                    (* end case *))
97                | (Op.Probe(domTy, rngTy), [f, pos]) => (case getRHS f
98                   of SOME(Op.Field _, _) => NONE (* direct probe does not need rewrite *)                   of SOME(Op.Field _, _) => NONE (* direct probe does not need rewrite *)
99                    | SOME(Op.AddField, [f', g']) => let                    | SOME(Op.AddField, [f', g']) => let
100                      (* rewrite to (f@pos) + (g@pos) *)                      (* rewrite to (f@pos) + (g@pos) *)
# Line 79  Line 127 
127                    | SOME(Op.ScaleField, [s, f']) => let                    | SOME(Op.ScaleField, [s, f']) => let
128                      (* rewrite to s*(f'@pos) *)                      (* rewrite to s*(f'@pos) *)
129                        val lhs' = IL.Var.copy lhs                        val lhs' = IL.Var.copy lhs
130                          val scaleOp = (case rngTy
131                                 of Ty.TensorTy[] => Op.Mul rngTy
132                                  | _ => Op.Scale rngTy
133                                (* end case *))
134                        in                        in
135                          ST.tick cntProbeScale;                          ST.tick cntProbeScale;
136                          decUse f;                          decUse f;
                         incUse lhs'; incUse f'; incUse s;  
137                          SOME[                          SOME[
138                              (lhs', IL.OP(Op.Probe(domTy, rngTy), [f', pos])),                              (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
139                              (lhs, IL.OP(Op.Scale rngTy, [s, lhs']))                              (lhs, IL.OP(scaleOp, [use s, use lhs']))
140                              ]
141                          end
142                      | SOME(Op.OffsetField, [f', s]) => let
143                        (* rewrite to (f'@pos) + s *)
144                          val lhs' = IL.Var.copy lhs
145                          in
146                            ST.tick cntProbeOffset;
147                            decUse f;
148                            SOME[
149                                (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
150                                (lhs, IL.OP(Op.Add rngTy, [use lhs', use s]))
151                            ]                            ]
152                        end                        end
153                    | SOME(Op.NegField, [f']) => let                    | SOME(Op.NegField, [f']) => let
# Line 100  Line 162 
162                              (lhs, IL.OP(Op.Neg rngTy, [lhs']))                              (lhs, IL.OP(Op.Neg rngTy, [lhs']))
163                            ]                            ]
164                        end                        end
165                      | SOME(Op.CurlField 2, [f']) => (case getRHS f'
166                           of SOME(Op.Field dim, [v, h]) => let
167                              (* rewrite to (D f')@pos[1,0] - (D f')@pos[0,1] *)
168                                val (kernel, k) = getKernelRHS h
169                                val h' = IL.Var.copy h
170                                val f'' = IL.Var.copy f'
171                                val mat22 = Ty.TensorTy[2,2]
172                                val m = IL.Var.new("m", mat22)
173                                val zero = IL.Var.new("zero", Ty.intTy)
174                                val one = IL.Var.new("one", Ty.intTy)
175                                val m10 = IL.Var.new("m_10", Ty.realTy)
176                                val m01 = IL.Var.new("m_01", Ty.realTy)
177                                in
178                                  ST.tick cntProbeCurl;
179                                  decUse f;
180                                  SOME[
181                                      (h', IL.OP(Op.Kernel(kernel, k+1), [])),
182                                      (f'', IL.OP(Op.Field dim, [use v, use h'])),
183                                      (m, IL.OP(Op.Probe(domTy, mat22), [use f'', pos])),
184                                      (zero, IL.LIT(Literal.Int 0)),
185                                      (one, IL.LIT(Literal.Int 1)),
186                                      (m10, IL.OP(Op.TensorSub mat22, [use m, use one, use zero])),
187                                      (m01, IL.OP(Op.TensorSub mat22, [use m, use zero, use one])),
188                                      (lhs, IL.OP(Op.Sub Ty.realTy, [use m10, use m01]))
189                                    ]
190                                end
191                            | _ => raise Fail(concat[
192                                  "bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
193                                ])
194                          (* end case *))
195                      | SOME(Op.CurlField 3, [f']) => (case getRHS f'
196                           of SOME(Op.Field dim, [v, h]) => let
197                              (* rewrite to
198                               *  [ (D f')@pos[2,1] - (D f')@pos[1,2] ]
199                               *  [ (D f')@pos[0,2] - (D f')@pos[2,0] ]
200                               *  [ (D f')@pos[1,0] - (D f')@pos[0,1] ]
201                               *)
202                                val (kernel, k) = getKernelRHS h
203                                val h' = IL.Var.copy h
204                                val f'' = IL.Var.copy f'
205                                val mat33 = Ty.TensorTy[3,3]
206                                val m = IL.Var.new("m", mat33)
207                                val zero = IL.Var.new("zero", Ty.intTy)
208                                val one = IL.Var.new("one", Ty.intTy)
209                                val two = IL.Var.new("two", Ty.intTy)
210                                val m21 = IL.Var.new("m_21", Ty.realTy)
211                                val m12 = IL.Var.new("m_12", Ty.realTy)
212                                val m02 = IL.Var.new("m_02", Ty.realTy)
213                                val m20 = IL.Var.new("m_20", Ty.realTy)
214                                val m10 = IL.Var.new("m_10", Ty.realTy)
215                                val m01 = IL.Var.new("m_01", Ty.realTy)
216                                val lhs0 = IL.Var.new("lhs0", Ty.realTy)
217                                val lhs1 = IL.Var.new("lhs1", Ty.realTy)
218                                val lhs2 = IL.Var.new("lhs2", Ty.realTy)
219                                in
220                                  ST.tick cntProbeCurl;
221                                  decUse f;
222                                  SOME[
223                                      (h', IL.OP(Op.Kernel(kernel, k+1), [])),
224                                      (f'', IL.OP(Op.Field dim, [use v, use h'])),
225                                      (m, IL.OP(Op.Probe(domTy, mat33), [use f'', pos])),
226                                      (zero, IL.LIT(Literal.Int 0)),
227                                      (one, IL.LIT(Literal.Int 1)),
228                                      (two, IL.LIT(Literal.Int 2)),
229                                      (m21, IL.OP(Op.TensorSub mat33, [use m, use two, use one])),
230                                      (m12, IL.OP(Op.TensorSub mat33, [use m, use one, use two])),
231                                      (lhs0, IL.OP(Op.Sub Ty.realTy, [use m21, use m12])),
232                                      (m02, IL.OP(Op.TensorSub mat33, [use m, use zero, use two])),
233                                      (m20, IL.OP(Op.TensorSub mat33, [use m, use two, use zero])),
234                                      (lhs1, IL.OP(Op.Sub Ty.realTy, [use m02, use m20])),
235                                      (m10, IL.OP(Op.TensorSub mat33, [use m, use one, use zero])),
236                                      (m01, IL.OP(Op.TensorSub mat33, [use m, use zero, use one])),
237                                      (lhs2, IL.OP(Op.Sub Ty.realTy, [use m10, use m01])),
238                                      (lhs, IL.CONS(Ty.TensorTy[3], [lhs0, lhs1, lhs2]))
239                                    ]
240                                end
241                            | _ => raise Fail(concat[
242                                  "curl: bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
243                                ])
244                          (* end case *))
245                      | SOME(Op.DiffField, _) => NONE (* need further rewriting *)
246                    | _ => raise Fail(concat[                    | _ => raise Fail(concat[
247                          "bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)                          "probe: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
248                        ])                        ])
249                  (* end case *))                  (* end case *))
250              | (Op.DiffField, [f]) => (case (getRHS f)              | (Op.DiffField, [f]) => (case (getRHS f)
251                   of SOME(Op.Field dim, [v, h]) => (case getRHS h                   of SOME(Op.Field dim, [v, h]) => let
252                         of SOME(Op.Kernel(kernel, k), []) => let                        val (kernel, k) = getKernelRHS h
253                              val h' = IL.Var.copy h                              val h' = IL.Var.copy h
254                              in                              in
255                                ST.tick cntDiffField;                                ST.tick cntDiffField;
# Line 117  Line 260 
260                                    (lhs, IL.OP(Op.Field dim, [v, h']))                                    (lhs, IL.OP(Op.Field dim, [v, h']))
261                                  ]                                  ]
262                              end                              end
                         | _ => raise Fail(concat[  
                               "bogus kernel binding ", V.toString h, " = ", IL.vbToString(V.binding h)  
                             ])  
                       (* end case *))  
263                    | SOME(Op.AddField, [f, g]) => raise Fail "Diff(f+g)"                    | SOME(Op.AddField, [f, g]) => raise Fail "Diff(f+g)"
264                    | SOME(Op.SubField, [f, g]) => raise Fail "Diff(f-g)"                    | SOME(Op.SubField, [f, g]) => raise Fail "Diff(f-g)"
265                    | SOME(Op.ScaleField, [s, f']) => let                    | SOME(Op.ScaleField, [s, f']) => let
# Line 129  Line 268 
268                        in                        in
269                          ST.tick cntDiffScale;                          ST.tick cntDiffScale;
270                          decUse f;                          decUse f;
                         incUse lhs'; incUse f'; incUse s;  
271                          SOME[                          SOME[
272                              (lhs', IL.OP(Op.DiffField, [f'])),                              (lhs', IL.OP(Op.DiffField, [use f'])),
273                              (lhs, IL.OP(Op.ScaleField, [s, lhs']))                              (lhs, IL.OP(Op.ScaleField, [use s, use lhs']))
274                            ]                            ]
275                        end                        end
276                      | SOME(Op.OffsetField, [f', s]) => (
277                        (* rewrite to (D f) *)
278                          ST.tick cntDiffOffset;
279                          decUse f;
280                          SOME[(lhs, IL.OP(Op.DiffField, [use f']))])
281                    | SOME(Op.NegField, [f']) => let                    | SOME(Op.NegField, [f']) => let
282                      (* rewrite to -(D f') *)                      (* rewrite to -(D f') *)
283                        val lhs' = IL.Var.copy lhs                        val lhs' = IL.Var.copy lhs
# Line 149  Line 292 
292                        end                        end
293                    | _ => NONE                    | _ => NONE
294                  (* end case *))                  (* end case *))
295                | (Op.CurlField dim, [f]) => (case (getRHS f)
296                     of SOME(Op.AddField, [f, g]) => raise Fail "curl(f+g)"
297                      | SOME(Op.SubField, [f, g]) => raise Fail "curl(f-g)"
298                      | SOME(Op.ScaleField, [s, f']) => let
299                        (* rewrite to s*curl(f) *)
300                        val f'' = IL.Var.copy f'
301                        in
302                          ST.tick cntCurlScale;
303                          decUse f;
304                          SOME[
305                              (f'', IL.OP(Op.CurlField dim, [use f'])),
306                              (lhs, IL.OP(Op.ScaleField, [use s, use f'']))
307                            ]
308                        end
309                      | SOME(Op.NegField, [f']) => let
310                        (* rewrite to -curl(f) *)
311                        val f'' = IL.Var.copy f'
312                        in
313                          ST.tick cntCurlNeg;
314                          decUse f;
315                          SOME[
316                              (f'', IL.OP(Op.CurlField dim, [use f'])),
317                              (lhs, IL.OP(Op.NegField, [use f'']))
318                            ]
319                        end
320                    (* FIXME: the following is just the constant 0 field, but we don't have
321                     * a representation of constant fields
322                     *)
323                      | SOME(Op.DiffField, _) => raise Fail "curl of del"
324                      | _ => NONE
325                    (* end case *))
326              | _ => NONE              | _ => NONE
327            (* end case *))            (* end case *))
328        | doRHS _ = NONE        | doRHS _ = NONE
329    
330    (* simplify expressions *)      structure Rewrite = RewriteFn (
331      fun simplify (nd as IL.ND{kind=IL.ASSIGN{stm=(y, rhs), ...}, ...}) =        struct
332            if (useCount y = 0)          structure IL = IL
333              then () (* skip unused assignments *)          val doAssign = doRHS
334              else (case doRHS(y, rhs)          fun doMAssign _ = NONE
335                 of SOME[] => IL.CFG.deleteNode nd          val elimUnusedVars = UnusedElim.reduce
336                  | SOME assigns => let        end)
337                      val assigns = List.map  
338                            (fn (y, rhs) => (V.setBinding(y, IL.VB_RHS rhs); IL.ASSGN(y, rhs)))      val transform = Rewrite.transform
                             assigns  
                     in  
                       IL.CFG.replaceNodeWithCFG (nd, IL.CFG.mkBlock assigns)  
                     end  
                 | NONE => ()  
               (* end case *))  
       | simplify _ = ()  
   
     fun loopToFixPt f = let  
           fun loop n = let  
                 val () = f ()  
                 val n' = Stats.sum{from=firstCounter, to=lastCounter}  
                 in  
                   if (n = n') then () else loop n'  
                 end  
           in  
             loop (Stats.sum{from=firstCounter, to=lastCounter})  
           end  
   
     fun transform (prog as IL.Program{props, globalInit, initially, strands}) = let  
           fun doCFG cfg = (  
                 loopToFixPt (fn () => IL.CFG.apply simplify cfg);  
                 loopToFixPt (fn () => ignore(UnusedElim.reduce cfg)))  
           fun doMethod (IL.Method{body, ...}) = doCFG body  
           fun doStrand (IL.Strand{stateInit, methods, ...}) = (  
                 doCFG stateInit;  
                 List.app doMethod methods)  
           fun optPass () = (  
                 doCFG globalInit;  
                 List.app doStrand strands)  
           in  
             loopToFixPt optPass;  
 (* FIXME: after optimization, we should filter out any globals that are now unused *)  
             IL.Program{  
                 props = props,  
                 globalInit = globalInit,  
                 initially = initially,  (* FIXME: we should optimize this code *)  
                 strands = strands  
               }  
           end  
339    
340    end    end

Legend:
Removed from v.2204  
changed lines
  Added in v.2205

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0