Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis12/src/compiler/high-il/normalize.sml
ViewVC logotype

View of /branches/vis12/src/compiler/high-il/normalize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2196 - (download) (annotate)
Sun Feb 24 13:44:48 2013 UTC (6 years, 6 months ago) by jhr
File size: 15200 byte(s)
  Added some more rewriting for curl.
(* normalize.sml
 *
 * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure Normalize : sig

    val transform : HighIL.program -> HighIL.program

  end = struct

    structure IL = HighIL
    structure Op = HighOps
    structure V = IL.Var
    structure Ty = HighILTypes
    structure ST = Stats

  (********** Counters for statistics **********)
    val cntInsideScale		= ST.newCounter "high-opt:inside-scale"
    val cntInsideOffset		= ST.newCounter "high-opt:inside-offset"
    val cntInsideNeg		= ST.newCounter "high-opt:inside-meg"
    val cntInsideCurl		= ST.newCounter "high-opt:inside-curl"
    val cntInsideDiff		= ST.newCounter "high-opt:inside-diff"
    val cntProbeAdd		= ST.newCounter "high-opt:probe-add"
    val cntProbeSub		= ST.newCounter "high-opt:probe-sub"
    val cntProbeScale		= ST.newCounter "high-opt:probe-scale"
    val cntProbeOffset		= ST.newCounter "high-opt:probe-offset"
    val cntProbeNeg		= ST.newCounter "high-opt:probe-neg"
    val cntProbeCurl		= ST.newCounter "high-opt:probe-curl"
    val cntDiffField		= ST.newCounter "high-opt:diff-field"
    val cntDiffAdd		= ST.newCounter "high-opt:diff-add"
    val cntDiffScale		= ST.newCounter "high-opt:diff-scale"
    val cntDiffOffset		= ST.newCounter "high-opt:diff-offset"
    val cntDiffNeg		= ST.newCounter "high-opt:diff-neg"
    val cntCurlScale		= ST.newCounter "high-opt:curl-scale"
    val cntCurlNeg		= ST.newCounter "high-opt:curl-neg"
    val cntUnused		= ST.newCounter "high-opt:unused"
    val firstCounter            = cntInsideScale
    val lastCounter             = cntUnused

    structure UnusedElim = UnusedElimFn (
	structure IL = IL
	val cntUnused = cntUnused)

    fun useCount (IL.V{useCnt, ...}) = !useCnt

  (* adjust a variable's use count *)
    fun incUse (IL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun decUse (IL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
    fun use x = (incUse x; x)

    fun getRHS x = (case V.binding x
	   of IL.VB_RHS(IL.OP arg) => SOME arg
	    | IL.VB_RHS(IL.VAR x') => getRHS x'
	    | _ => NONE
	  (* end case *))

  (* get the binding of a kernel variable *)
    fun getKernelRHS h = (case getRHS h
           of SOME(Op.Kernel(kernel, k), []) => (kernel, k)
            | _ => raise Fail(concat[
                  "bogus kernel binding ", V.toString h, " = ", IL.vbToString(V.binding h)
                ])
          (* end case *))

  (* optimize the rhs of an assignment, returning NONE if there is no change *)
    fun doRHS (lhs, IL.OP rhs) = (case rhs
	   of (Op.Inside dim, [pos, f]) => (case getRHS f
		 of SOME(Op.Field _, _) => NONE (* direct inside test does not need rewrite *)
		  | SOME(Op.AddField, [f', g']) => raise Fail "inside(f+g)"
		  | SOME(Op.SubField, [f', g']) => raise Fail "inside(f-g)"
		  | SOME(Op.ScaleField, [_, f']) => (
		      ST.tick cntInsideScale;
		      decUse f;
		      SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
		  | SOME(Op.OffsetField, [f', _]) => (
		      ST.tick cntInsideOffset;
		      decUse f;
		      SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
		  | SOME(Op.NegField, [f']) => (
		      ST.tick cntInsideNeg;
		      decUse f;
		      SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
                  | SOME(Op.CurlField _, [f']) => (
		      ST.tick cntInsideCurl;
		      decUse f;
		      SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
		  | SOME(Op.DiffField, [f']) => (
		      ST.tick cntInsideDiff;
		      decUse f;
		      SOME[(lhs, IL.OP(Op.Inside dim, [pos, use f']))])
		  | _ => raise Fail(concat[
			"inside: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
		      ])
		(* end case *))
	    | (Op.Probe(domTy, rngTy), [f, pos]) => (case getRHS f
		 of SOME(Op.Field _, _) => NONE (* direct probe does not need rewrite *)
		  | SOME(Op.AddField, [f', g']) => let
		    (* rewrite to (f@pos) + (g@pos) *)
		      val lhs1 = IL.Var.copy lhs
		      val lhs2 = IL.Var.copy lhs
		      in
			ST.tick cntProbeAdd;
			decUse f;
			incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
			SOME[
			    (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
			    (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
			    (lhs, IL.OP(Op.Add rngTy, [lhs1, lhs2]))
			  ]
		      end
		  | SOME(Op.SubField, [f', g']) => let
		    (* rewrite to (f@pos) - (g@pos) *)
		      val lhs1 = IL.Var.copy lhs
		      val lhs2 = IL.Var.copy lhs
		      in
			ST.tick cntProbeSub;
			decUse f;
			incUse lhs1; incUse f'; incUse lhs2; incUse g'; incUse pos;
			SOME[
			    (lhs1, IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
			    (lhs2, IL.OP(Op.Probe(domTy, rngTy), [g', pos])),
			    (lhs, IL.OP(Op.Sub rngTy, [lhs1, lhs2]))
			  ]
		      end
		  | SOME(Op.ScaleField, [s, f']) => let
		    (* rewrite to s*(f'@pos) *)
		      val lhs' = IL.Var.copy lhs
		      val scaleOp = (case rngTy
			     of Ty.TensorTy[] => Op.Mul rngTy
			      | _ => Op.Scale rngTy
			    (* end case *))
		      in
			ST.tick cntProbeScale;
			decUse f;
			SOME[
			    (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
			    (lhs, IL.OP(scaleOp, [use s, use lhs']))
			  ]
		      end
		  | SOME(Op.OffsetField, [f', s]) => let
		    (* rewrite to (f'@pos) + s *)
		      val lhs' = IL.Var.copy lhs
		      in
			ST.tick cntProbeOffset;
			decUse f;
			SOME[
			    (lhs', IL.OP(Op.Probe(domTy, rngTy), [use f', pos])),
			    (lhs, IL.OP(Op.Add rngTy, [use lhs', use s]))
			  ]
		      end
		  | SOME(Op.NegField, [f']) => let
		    (* rewrite to -(f'@pos) *)
		      val lhs' = IL.Var.copy lhs
		      in
			ST.tick cntProbeNeg;
			decUse f;
			incUse lhs'; incUse f';
			SOME[
			    (lhs', IL.OP(Op.Probe(domTy, rngTy), [f', pos])),
			    (lhs, IL.OP(Op.Neg rngTy, [lhs']))
			  ]
		      end
                  | SOME(Op.CurlField 2, [f']) => (case getRHS f'
                       of SOME(Op.Field dim, [v, h]) => let
                          (* rewrite to (D f')@pos[1,0] - (D f')@pos[0,1] *)
                            val (kernel, k) = getKernelRHS h
                            val h' = IL.Var.copy h
                            val f'' = IL.Var.copy f'
                            val mat22 = Ty.TensorTy[2,2]
                            val m = IL.Var.new("m", mat22)
                            val zero = IL.Var.new("zero", Ty.intTy)
                            val one = IL.Var.new("one", Ty.intTy)
                            val m10 = IL.Var.new("m_10", Ty.realTy)
                            val m01 = IL.Var.new("m_01", Ty.realTy)
                            in
			      ST.tick cntProbeCurl;
                              decUse f;
                              SOME[
                                  (h', IL.OP(Op.Kernel(kernel, k+1), [])),
                                  (f'', IL.OP(Op.Field dim, [use v, use h'])),
                                  (m, IL.OP(Op.Probe(domTy, mat22), [use f'', pos])),
                                  (zero, IL.LIT(Literal.Int 0)),
                                  (one, IL.LIT(Literal.Int 1)),
                                  (m10, IL.OP(Op.TensorSub mat22, [use m, use one, use zero])),
                                  (m01, IL.OP(Op.TensorSub mat22, [use m, use zero, use one])),
                                  (lhs, IL.OP(Op.Sub Ty.realTy, [use m10, use m01]))
                                ]
                            end
                        | _ => raise Fail(concat[
                              "bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
                            ])
                      (* end case *))
                  | SOME(Op.CurlField 3, [f']) => (case getRHS f'
                       of SOME(Op.Field dim, [v, h]) => let
                          (* rewrite to
                           *  [ (D f')@pos[2,1] - (D f')@pos[1,2] ]
                           *  [ (D f')@pos[0,2] - (D f')@pos[2,0] ]
                           *  [ (D f')@pos[1,0] - (D f')@pos[0,1] ]
                           *)
                            val (kernel, k) = getKernelRHS h
                            val h' = IL.Var.copy h
                            val f'' = IL.Var.copy f'
                            val mat33 = Ty.TensorTy[3,3]
                            val m = IL.Var.new("m", mat33)
                            val zero = IL.Var.new("zero", Ty.intTy)
                            val one = IL.Var.new("one", Ty.intTy)
                            val two = IL.Var.new("two", Ty.intTy)
                            val m21 = IL.Var.new("m_21", Ty.realTy)
                            val m12 = IL.Var.new("m_12", Ty.realTy)
                            val m02 = IL.Var.new("m_02", Ty.realTy)
                            val m20 = IL.Var.new("m_20", Ty.realTy)
                            val m10 = IL.Var.new("m_10", Ty.realTy)
                            val m01 = IL.Var.new("m_01", Ty.realTy)
                            val lhs0 = IL.Var.new("lhs0", Ty.realTy)
                            val lhs1 = IL.Var.new("lhs1", Ty.realTy)
                            val lhs2 = IL.Var.new("lhs2", Ty.realTy)
                            in
			      ST.tick cntProbeCurl;
                              decUse f;
                              SOME[
                                  (h', IL.OP(Op.Kernel(kernel, k+1), [])),
                                  (f'', IL.OP(Op.Field dim, [use v, use h'])),
                                  (m, IL.OP(Op.Probe(domTy, mat33), [use f'', pos])),
                                  (zero, IL.LIT(Literal.Int 0)),
                                  (one, IL.LIT(Literal.Int 1)),
                                  (two, IL.LIT(Literal.Int 2)),
                                  (m21, IL.OP(Op.TensorSub mat33, [use m, use two, use one])),
                                  (m12, IL.OP(Op.TensorSub mat33, [use m, use one, use two])),
                                  (lhs0, IL.OP(Op.Sub Ty.realTy, [use m21, use m12])),
                                  (m02, IL.OP(Op.TensorSub mat33, [use m, use zero, use two])),
                                  (m20, IL.OP(Op.TensorSub mat33, [use m, use two, use zero])),
                                  (lhs1, IL.OP(Op.Sub Ty.realTy, [use m02, use m20])),
                                  (m10, IL.OP(Op.TensorSub mat33, [use m, use one, use zero])),
                                  (m01, IL.OP(Op.TensorSub mat33, [use m, use zero, use one])),
                                  (lhs2, IL.OP(Op.Sub Ty.realTy, [use m10, use m01])),
                                  (lhs, IL.CONS(Ty.TensorTy[3], [lhs0, lhs1, lhs2]))
                                ]
                            end
                        | _ => raise Fail(concat[
                              "curl: bogus field binding ", V.toString f', " = ", IL.vbToString(V.binding f')
                            ])
                      (* end case *))
		  | SOME(Op.DiffField, _) => NONE (* need further rewriting *)
		  | _ => raise Fail(concat[
			"probe: bogus field binding ", V.toString f, " = ", IL.vbToString(V.binding f)
		      ])
		(* end case *))
	    | (Op.DiffField, [f]) => (case (getRHS f)
		 of SOME(Op.Field dim, [v, h]) => let
                      val (kernel, k) = getKernelRHS h
                      val h' = IL.Var.copy h
                      in
                        ST.tick cntDiffField;
                        decUse f;
                        incUse h'; incUse v;
                        SOME[
                            (h', IL.OP(Op.Kernel(kernel, k+1), [])),
                            (lhs, IL.OP(Op.Field dim, [v, h']))
                          ]
                      end
		  | SOME(Op.AddField, [f, g]) => raise Fail "Diff(f+g)"
		  | SOME(Op.SubField, [f, g]) => raise Fail "Diff(f-g)"
		  | SOME(Op.ScaleField, [s, f']) => let
		    (* rewrite to s*(D f) *)
		      val lhs' = IL.Var.copy lhs
		      in
			ST.tick cntDiffScale;
			decUse f;
			SOME[
			    (lhs', IL.OP(Op.DiffField, [use f'])),
			    (lhs, IL.OP(Op.ScaleField, [use s, use lhs']))
			  ]
		      end
		  | SOME(Op.OffsetField, [f', s]) => (
		    (* rewrite to (D f) *)
		      ST.tick cntDiffOffset;
		      decUse f;
		      SOME[(lhs, IL.OP(Op.DiffField, [use f']))])
		  | SOME(Op.NegField, [f']) => let
		    (* rewrite to -(D f') *)
		      val lhs' = IL.Var.copy lhs
		      in
			ST.tick cntDiffNeg;
			decUse f;
			incUse lhs'; incUse f';
			SOME[
			    (lhs', IL.OP(Op.DiffField, [f'])),
			    (lhs, IL.OP(Op.NegField, [lhs']))
			  ]
		      end
		  | _ => NONE
		(* end case *))
            | (Op.CurlField dim, [f]) => (case (getRHS f)
		 of SOME(Op.AddField, [f, g]) => raise Fail "curl(f+g)"
		  | SOME(Op.SubField, [f, g]) => raise Fail "curl(f-g)"
		  | SOME(Op.ScaleField, [s, f']) => let
		    (* rewrite to s*curl(f) *)
		    val f'' = IL.Var.copy f'
		    in
		      ST.tick cntCurlScale;
		      decUse f;
		      SOME[
			  (f'', IL.OP(Op.CurlField dim, [use f'])),
			  (lhs, IL.OP(Op.ScaleField, [use s, use f'']))
			]
		    end
		  | SOME(Op.NegField, [f']) => let
		    (* rewrite to -curl(f) *)
		    val f'' = IL.Var.copy f'
		    in
		      ST.tick cntCurlNeg;
		      decUse f;
		      SOME[
			  (f'', IL.OP(Op.CurlField dim, [use f'])),
			  (lhs, IL.OP(Op.NegField, [use f'']))
			]
		    end
		(* FIXME: the following is just the constant 0 field, but we don't have
		 * a representation of constant fields
		 *)
                  | SOME(Op.DiffField, _) => raise Fail "curl of del"
                  | _ => NONE
                (* end case *))
	    | _ => NONE
	  (* end case *))
      | doRHS _ = NONE

  (* simplify expressions *)
    fun simplify (nd as IL.ND{kind=IL.ASSIGN{stm=(y, rhs), ...}, ...}) =
	  if (useCount y = 0)
	    then () (* skip unused assignments *)
	    else (case doRHS(y, rhs)
	       of SOME[] => IL.CFG.deleteNode nd
		| SOME assigns => let
                    val assigns = List.map
                          (fn (y, rhs) => (V.setBinding(y, IL.VB_RHS rhs); IL.ASSGN(y, rhs)))
                            assigns
                    in
                      IL.CFG.replaceNodeWithCFG (nd, IL.CFG.mkBlock assigns)
                    end
		| NONE => ()
	      (* end case *))
      | simplify _ = ()

    fun loopToFixPt f = let
	  fun loop n = let
		val () = f ()
		val n' = Stats.sum{from=firstCounter, to=lastCounter}
		in
		  if (n = n') then () else loop n'
		end
	  in
	    loop (Stats.sum{from=firstCounter, to=lastCounter})
	  end

    fun transform (prog as IL.Program{props, globalInit, initially, strands}) = let
	  fun doCFG cfg = (
		loopToFixPt (fn () => IL.CFG.apply simplify cfg);
		loopToFixPt (fn () => ignore(UnusedElim.reduce cfg)))
	  fun doMethod (IL.Method{body, ...}) = doCFG body
	  fun doStrand (IL.Strand{stateInit, methods, ...}) = (
		doCFG stateInit;
		List.app doMethod methods)
	  fun optPass () = (
		doCFG globalInit;
		List.app doStrand strands)
	  in
	    loopToFixPt optPass;
(* FIXME: after optimization, we should filter out any globals that are now unused *)
	    IL.Program{
		props = props,
		globalInit = globalInit,
		initially = initially,	(* FIXME: we should optimize this code *)
		strands = strands
	      }
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0