Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-opt/normalize-ein.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-opt/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3585 - (download) (annotate)
Thu Jan 14 14:08:46 2016 UTC (3 years, 6 months ago) by jhr
File size: 9620 byte(s)
debugging merge
(* normalize-ein.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure NormalizeEin : sig

  (* normalize an Ein function; if there are no changes, then NONE is returned. *)
    val transform : Ein.ein -> Ein.ein option

  end = struct

    structure E = Ein
    structure ST = Stats

  (********** Counters for statistics **********)
    val cntNullSum		= ST.newCounter "high-opt:null-sum"
    val cntSumRewrite		= ST.newCounter "high-opt:sum-rewrite"
    val cntProbe		= ST.newCounter "high-opt:normalize-probe"
    val cntFilter		= ST.newCounter "high-opt:filter"
    val cntApplyPartial		= ST.newCounter "high-opt:apply-partial"
    val cntNegElim		= ST.newCounter "high-opt:neg-elim"
    val cntSubElim		= ST.newCounter "high-opt:sub-elim"
    val cntDivElim		= ST.newCounter "high-opt:div-elim"
    val cntDivDiv		= ST.newCounter "high-opt:div-div"
    val cntAddRewrite		= ST.newCounter "high-opt:add-rewrite"
    val cntSqrtElim		= ST.newCounter "high-opt:sqrt-elim"
    val cntEpsElim		= ST.newCounter "high-opt:eps-elim"
    val cntEpsToDeltas		= ST.newCounter "high-opt:eps-to-deltas"
    val cntNegDelta		= ST.newCounter "high-opt:neg-delta"
    val cntReduceDelta		= ST.newCounter "high-opt:reduce-delta"
    val firstCounter		= cntNullSum
    val lastCounter		= cntReduceDelta
    val cntRounds		= ST.newCounter "high-opt:normalize-round"

    fun err str = raise Fail(String.concat["Ill-formed EIN Operator",str])

    val zero = E.Const 0

    fun mkProd exps = E.Opn(E.Prod, exps)
    fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2)

  (* build a normalized summation *)
    fun mkSum ([], b) = (ST.tick cntNullSum; b)
      | mkSum (sx, b) = let
	  fun return e = (ST.tick cntSumRewrite; e)
	  in
	    case b
	     of E.Lift e => return (E.Lift(E.Sum(sx, e)))
	      | E.Tensor(_, []) => return b
	      | E.Const _ => return b
	      | E.ConstR _ => return b
	      | E.Opn(E.Prod, es) => (case EinFilter.filterSca (sx, es)
		   of (true, e) => return e
		    | _ => E.Sum(sx, b)
		  (* end case *))
	      | _ => E.Sum(sx, b)
	    (* end case *)
	  end

  (* build a normalized probe operation *)
    fun mkProbe (fld, x) = let
	  fun return e = (ST.tick cntProbe; e)
	  in
	    case fld
	     of E.Tensor _         => err "Tensor without Lift"
	      | E.Lift e           => return e
	      | E.Partial _        => err "Probe Partial"
	      | E.Probe _          => err "Probe of a Probe"
	      | E.Value _          => err "Value used before expand"
	      | E.Img _            => err "Probe used before expand"
	      | E.Krn _            => err "Krn used before expand"
	      | E.Sum(sx1, e)      => return (E.Sum(sx1, E.Probe(e, x)))
	      | E.Op1(op1, e)      => return (E.Op1(op1, E.Probe(e, x)))
	      | E.Op2(op2, e1, e2) => return (E.Op2(op2, E.Probe(e1, x), E.Probe(e2, x)))
	      | E.Opn(opn, [])     => err "Probe of empty operator"
	      | E.Opn(opn, es)     => return (E.Opn(opn, List.map (fn e => E.Probe(e, x)) es))
	      | _		   => E.Probe(fld, x)
	    (* end case *)
	  end

  (* rewrite body of EIN *)
    fun transform (ein as Ein.EIN{params, index, body}) = let
	  fun filterProd args = (case EinFilter.mkProd args
		 of SOME e => (ST.tick cntFilter; e)
		  | NONE => mkProd args
		(* end case *))
	  fun rewrite body = (case body
		 of E.Const _			  => body
		  | E.ConstR _			  => body
		  | E.Tensor _			  => body
		  | E.Delta _			  => body
		  | E.Epsilon _			  => body
		  | E.Eps2 _			  => body
                (************** Field Terms **************)
		  | E.Field _			  => body
		  | E.Lift e1			  => E.Lift(rewrite e1)
		  | E.Conv _			  => body
		  | E.Partial _			  => body
		  | E.Apply(E.Partial [], e1)	  => e1
		  | E.Apply(E.Partial d1, e1)	  => let
		      val e1 = rewrite e1
		      in
			case Derivative.mkApply(E.Partial d1, e1)
			 of SOME e => (ST.tick cntApplyPartial; e)
			  | NONE => E.Apply(E.Partial d1, e1)
			(* end case *)
		      end
		  | E.Apply _			  => err "Ill-formed Apply expression"
		  | E.Probe(e1, e2)		  => mkProbe(rewrite e1, rewrite e2)
                (************** Field Terms **************)
		  | E.Value _			  => err "Value before Expand"
		  | E.Img _			  => err "Img before Expand"
		  | E.Krn _			  => err "Krn before Expand"
                (************** Sum **************)
		  | E.Sum(sx, e)		  => mkSum (sx, rewrite e)
                (************* Algebraic Rewrites Op1 **************)
		  | E.Op1(E.Neg, E.Op1(E.Neg, e)) => (ST.tick cntNegElim; rewrite e)
		  | E.Op1(E.Neg, E.Const 0)       => (ST.tick cntNegElim; zero)
		  | E.Op1(op1, e1)		  => E.Op1(op1, rewrite e1)
                (************* Algebraic Rewrites Op2 **************)
		  | E.Op2(E.Sub, E.Const 0, e2)	  => (ST.tick cntSubElim; E.Op1(E.Neg, rewrite e2))
		  | E.Op2(E.Sub, e1, E.Const 0)	  => (ST.tick cntSubElim; rewrite e1)
		  | E.Op2(E.Div, E.Const 0, e2)	  => (ST.tick cntDivElim; zero)
		  | E.Op2(E.Div, E.Op2(E.Div, a, b), E.Op2(E.Div, c, d))
						  => (ST.tick cntDivDiv;
						      rewrite (mkDiv (mkProd[a, d], mkProd[b, c])))
		  | E.Op2(E.Div, E.Op2(E.Div, a, b), c)
						  => (ST.tick cntDivDiv;
						      rewrite (mkDiv (a, mkProd[b, c])))
		  | E.Op2(E.Div, a, E.Op2(E.Div, b, c))
						  => (ST.tick cntDivDiv;
						      rewrite (mkDiv (mkProd[a, c], b)))
		  | E.Op2(op1, e1, e2)		  => E.Op2(op1, rewrite e1, rewrite e2)
                (************* Algebraic Rewrites Opn **************)
		  | E.Opn(E.Add, es)		  => let
		      val es' = List.map rewrite es
		      in
			case EinFilter.mkAdd es'
			 of SOME body' => (ST.tick cntAddRewrite; body')
			  | NONE => E.Opn(E.Add, es')
		      end
                (************* Product **************)
		  | E.Opn(E.Prod, [])		  => err "missing elements in product"
		  | E.Opn(E.Prod, [e1])		  => rewrite e1
		  | E.Opn(E.Prod, [e1 as E.Op1(E.Sqrt, s1), e2 as E.Op1(E.Sqrt, s2)]) =>
		      if EinUtil.sameExp(s1, s2)
			then (ST.tick cntSqrtElim; s1)
			else filterProd [rewrite e1, rewrite e2]
                (************* Product EPS **************)
		  | E.Opn(E.Prod, (eps1 as E.Epsilon(i,j,k))::ps) => (case ps
		       of ((p1 as E.Apply(E.Partial d, e)) :: es) => (
			    case (EpsUtil.matchEps (d, [i,j,k]), es)
			     of (true, _) => (ST.tick cntEpsElim; zero)
			      | (_, []) => mkProd[eps1, rewrite p1]
			      | _ => filterProd [eps1, rewrite (mkProd (p1 :: es))]
			    (* end case *))
			| ((p1 as E.Conv(_, _, _, d)) :: es) => (
			    case (EpsUtil.matchEps (d, [i,j,k]), es)
			     of (true, _) => (ST.tick cntEpsElim; E.Lift zero)
			      | (_, []) => mkProd[eps1, p1]
			      | _ => filterProd [eps1, rewrite (mkProd(p1 :: es))]
			    (* end case *))
			| [E.Tensor(_, [E.V i1, E.V i2])] =>
			    if (j=i1 andalso k=i2)
			      then (ST.tick cntEpsElim; zero)
			      else body
			| _  => (case EpsUtil.epsToDels (eps1::ps)
			     of (true, e, [], _, _) => (ST.tick cntEpsToDeltas; e)(* Changed to Deltas*)
(* QUESTION: should we call rewrite on e? *)
			      | (true, e, sx, _, _) => (ST.tick cntEpsToDeltas; E.Sum(sx, e))
			      | (_, _, _, _, []) =>  body
			      | (_, _, _, epsAll, rest) =>
				  filterProd (epsAll @ [rewrite (mkProd rest)])
			    (* end case *))
		      (* end case *))
		  | E.Opn(E.Prod, (s1 as E.Sum(c1, e1)) :: (s2 as E.Sum(c2, e2)) :: es) => (
		      case (e1, e2, es)
		       of (
			  E.Opn(E.Prod, (e1 as E.Epsilon _) :: es1),
			  E.Opn(E.Prod, (e2 as E.Epsilon _) :: es2),
			  _) => (case EpsUtil.epsToDels(e1 :: e2 :: es1 @ es2 @ es)
			    of (true, e, sx, _, _) => (ST.tick cntEpsToDeltas; E.Sum(c1@c2@sx, e))
			     | _ => let
				  val eA = rewrite s1
				  val eB = rewrite (mkProd(s2 :: es))
				  in
				    filterProd [eA, eB]
				  end
			    (* end case *))
			| (_, _, []) => filterProd [rewrite(E.Sum(c1, e1)), rewrite(E.Sum(c2, e2))]
			| _ => let
			    val e' = rewrite (E.Sum(c1, e1))
			    val e2 = rewrite (E.Opn(E.Prod, E.Sum(c2, e2) :: es))
			    in
			      case e2
			       of E.Opn(E.Prod, p') => filterProd (e' :: p')
				| _ => filterProd [e', e2]
			      (* end case *)
			    end
		      (* end case *))
		  | E.Opn(E.Prod, E.Delta d::es) => (case es
		       of [E.Op1(E.Neg, e1)] => (
			    ST.tick cntNegDelta; E.Op1(E.Neg, mkProd[E.Delta d, e1]))
			| _ => let
			    val (pre', eps, dels, post) = EinFilter.partitionGreek(E.Delta d::es)
			    in
			      case EpsUtil.reduceDelta(eps, dels, post)
			       of (false, _) => mkProd [E.Delta d, rewrite(mkProd es)]
				| (_, E.Opn(E.Prod, p)) => (ST.tick cntReduceDelta; filterProd p)
				| (_, a) => (ST.tick cntReduceDelta; a)
			      (* end case *)
			    end
		      (* end case *))
		  | E.Opn(E.Prod, [e1, e2]) => filterProd [rewrite e1, rewrite e2]
		  | E.Opn(E.Prod, e1::es) => let
		      val e' = rewrite e1
		      val e2 = rewrite (mkProd es)
		      in
			case e2
			 of E.Opn(Prod, p') => filterProd (e' :: p')
			  | _ => filterProd [e',e2]
			(* end case *)
		      end
		(* end case *))
(*DEBUG*)val start = ST.count cntRounds
	  fun loop (body, total, changed) = let
		val body' = rewrite body
		val totalTicks = ST.sum{from = firstCounter, to = lastCounter}
		in
		  ST.tick cntRounds;
(*DEBUG*)if (ST.count cntRounds - start > 50) then raise Fail "too many steps" else ();
		  if (totalTicks > total)
		    then loop(body', totalTicks, true)
		  else if changed
		    then SOME(Ein.EIN{params=params, index=index, body=body'})
		    else NONE
		end
	  in
	    loop(body, ST.sum{from = firstCounter, to = lastCounter}, false)
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0