Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-opt/normalize-ein.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-opt/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3516 - (download) (annotate)
Sat Dec 19 03:54:23 2015 UTC (3 years, 9 months ago) by jhr
File size: 8625 byte(s)
working on merge
(* normalize-ein.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure NormalizeEin : sig

  (* normalize an Ein function; if there are no changes, then NONE is returned. *)
    val transform : Ein.ein -> Ein.ein option

  end = struct

    structure E = Ein

    fun err str = raise Fail(String.concat["Ill-formed EIN Operator",str])

    val zero = E.Const 0

    fun mkProd exps = E.Opn(E.Prod, exps)
    fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2)

    (*mkSum:sum_indexid list * ein_exp->int *ein_exp
    *distribute summation expression 
    *)
    fun mkSum (sx1, b) = (case b
           of E.Lift e          => SOME(E.Lift(E.Sum(sx1, e)))
	    | E.Tensor(_, [])   => SOME b
	    | E.Const _         => SOME b
	    | E.ConstR _        => SOME b
(* QUESTION: what if filterSca returns false for its first result? *)
	    | E.Opn(E.Prod, es) => SOME(#2 (Filter.filterSca (sx1, es)))
	    | _                 => NONE
	  (* end case *))

    (*mkProbe:ein_exp* ein_exp-> int ein_exp
    *rewritten probe
    *)
    fun mkProbe (b, x) = (case b
	   of E.Const _          => NONE
	    | E.ConstR _         => NONE
            | E.Tensor _         => err "Tensor without Lift"
            | E.Delta _          => NONE
            | E.Epsilon _        => NONE
            | E.Eps2 _           => NONE
            | E.Field _          => NONE
            | E.Lift e           => SOME e
            | E.Conv _           => NONE
            | E.Partial _        => err "Probe Partial"
            | E.Apply _          => NONE
            | E.Probe _          => err "Probe of a Probe"
            | E.Value _          => err "Value used before expand"
            | E.Img _            => err "Probe used before expand"
            | E.Krn _            => err "Krn used before expand"
            | E.Sum(sx1, e)      => SOME(E.Sum(sx1, E.Probe(e, x)))
            | E.Op1(op1, e)      => SOME(E.Op1(op1, E.Probe(e, x)))
            | E.Op2(op2, e1, e2) => SOME(E.Op2(op2, E.Probe(e1, x), E.Probe(e2, x)))
            | E.Opn(opn, [])     => err "Probe of empty operator"
            | E.Opn(opn, es)     => SOME(E.Opn(opn, List.map (fn e => E.Probe(e, x)) es))
	  (* end case *))

  (* rewrite body of EIN *)
    fun transform (Ein.EIN{params, index, body}) = let
	  val changed = ref false
	  fun filterProd args = (case Filter.mkProd args
		 of SOME e => (changed := true; e)
		  | NONE => mkProd args
		(* end case *))
	  fun rewrite body = (case body
		 of E.Const _			  => body
		  | E.ConstR _			  => body
		  | E.Tensor _			  => body
		  | E.Delta _			  => body
		  | E.Epsilon _			  => body
		  | E.Eps2 _			  => body
                (************** Field Terms **************)
		  | E.Field _			  => body
		  | E.Lift e1			  => E.Lift(rewrite e1)
		  | E.Conv _			  => body
		  | E.Partial _			  => body
		  | E.Apply(E.Partial [], e1)	  => e1
		  | E.Apply(E.Partial d1, e1)	  => let
		      val e1 = rewrite e1
		      in
			case Derivative.mkApply(E.Partial d1, e1)
			 of SOME e => (changed := true; e)
			  | NONE => E.Apply(E.Partial d1, e1)
			(* end case *)
		      end
		  | E.Apply _			  => err "Ill-formed Apply expression"
		  | E.Probe(e1, e2)		  => (
		      case mkProbe(rewrite e1, rewrite e2)
		       of NONE => body
			| SOME body' => (changed := true; body')
		      (* end case *))
                (************** Field Terms **************)
		  | E.Value _			  => err "Value before Expand"
		  | E.Img _			  => err "Img before Expand"
		  | E.Krn _			  => err "Krn before Expand"
                (************** Sum **************)
		  | E.Sum([], e1)		  => (changed:=true; rewrite e1)
		  | E.Sum(sx, e)		  => let
		      val e = rewrite e
		      in
		        case mkSum(sx, e)
			 of SOME e' => (changed := true; e')
			  | NONE => E.Sum(sx, e)
			(* end case *)
		      end
                (************* Algebraic Rewrites Op1 **************)
		  | E.Op1(E.Neg, E.Op1(E.Neg, e)) => (changed := true; rewrite e)
		  | E.Op1(E.Neg, E.Const 0)       => (changed := true; zero)
		  | E.Op1(op1, e1)		  => E.Op1(op1, rewrite e1)
                (************* Algebraic Rewrites Op2 **************)
		  | E.Op2(E.Sub, E.Const 0, e2)	  => (changed := true; E.Op1(E.Neg, rewrite e2))
		  | E.Op2(E.Sub, e1, E.Const 0)	  => (changed := true; rewrite e1)
		  | E.Op2(E.Div, E.Const 0, e2)	  => (changed := true; zero)
		  | E.Op2(E.Div, E.Op2(E.Div,a,b), E.Op2(E.Div,c,d))
						  => rewrite (mkDiv (mkProd[a,d], mkProd[b,c]))
		  | E.Op2(E.Div, E.Op2(E.Div, a, b), c)
						  => rewrite (mkDiv (a, mkProd[b, c]))
		  | E.Op2(E.Div, a, E.Op2(E.Div, b, c))
						  => rewrite (mkDiv (mkProd[a, c], b))
		  | E.Op2(op1, e1, e2)		  => E.Op2(op1, rewrite e1, rewrite e2)
                (************* Algebraic Rewrites Opn **************)
		  | E.Opn(E.Add, es)		  => (
		      case Filter.mkAdd (List.map rewrite es)
		       of SOME body' => (changed := true; body')
			| NONE => body
		      (* end case *))
                (************* Product **************)
		  | E.Opn(E.Prod, [])		  => err "missing elements in product"
		  | E.Opn(E.Prod, [e1])		  => rewrite e1
		  | E.Opn(E.Prod, [e1 as E.Op1(E.Sqrt, s1), e2 as E.Op1(E.Sqrt, s2)]) =>
		      if EinUtil.sameExp(s1, s2)
			then (changed := true; s1)
			else filterProd [rewrite e1, rewrite e2]
                (************* Product EPS **************)
		  | E.Opn(E.Prod, (eps1 as E.Epsilon(i,j,k))::ps) => (case ps
		       of ((p1 as E.Apply(E.Partial d, e)) :: es) => (
			    case (EpsUtil.matchEps (d, [i,j,k]), es)
			     of (true, _) => (changed := true; zero)
			      | (_, []) => mkProd[eps1, rewrite p1]
			      | _ => filterProd [eps1, rewrite (mkProd (p1 :: es))]
			    (* end case *))
			| ((p1 as E.Conv(_, _, _, d)) :: es) => (
			    case (EpsUtil.matchEps (d, [i,j,k]), es)
			     of (true, _) => (changed := true; E.Lift zero)
			      | (_, []) => mkProd[eps1, p1]
			      | _ => filterProd [eps1, rewrite (mkProd(p1 :: es))]
			    (* end case *))
			| [E.Tensor(_, [E.V i1, E.V i2])] =>
			    if (j=i1 andalso k=i2)
			      then (changed :=true; zero)
			      else body
			| _  => (case EpsUtil.epsToDels (eps1::ps)
			     of (true, e, [], _, _) => (changed := true; e)(* Changed to Deltas*)
(* QUESTION: should we call rewrite on e? *)
			      | (true, e, sx, _, _) => (changed := true; E.Sum(sx, e))
			      | (_, _, _, _, []) =>  body
			      | (_, _, _, epsAll, rest) =>
				  filterProd (epsAll @ [rewrite (mkProd rest)])
			    (* end case *))
		      (* end case *))
		  | E.Opn(E.Prod, (s1 as E.Sum(c1, e1)) :: (s2 as E.Sum(c2, e2)) :: es) => (
		      case (e1, e2, es)
		       of (
			  E.Opn(E.Prod, (e1 as E.Epsilon _) :: es1),
			  E.Opn(E.Prod, (e2 as E.Epsilon _) :: es2),
			  _) => (case EpsUtil.epsToDels(e1 :: e2 :: es1 @ es2 @ es)
			    of (true, e, sx, _, _) => (changed:=true; E.Sum(c1@c2@sx, e))
			     | _ => let
				  val eA = rewrite s1
				  val eB = rewrite (mkProd(s2 :: es))
				  in
				    filterProd [eA, eB]
				  end
			    (* end case *))
			| (_, _, []) => filterProd [rewrite(E.Sum(c1, e1)), rewrite(E.Sum(c2, e2))]
			| _ => let
			    val e' = rewrite (E.Sum(c1, e1))
			    val e2 = rewrite (E.Opn(E.Prod, E.Sum(c2, e2) :: es))
			    in
			      case e2
			       of E.Opn(E.Prod, p') => filterProd ([e']@p')
				| _ => filterProd [e', e2]
			      (* end case *)
			    end
		      (* end case *))
		  | E.Opn(E.Prod, E.Delta d::es) => (case es
		       of [E.Op1(E.Neg, e1)] => (
			    changed:=true; E.Op1(E.Neg, mkProd[E.Delta d, e1]))
			| _ => let
			    val (pre', eps, dels, post) = Filter.partitionGreek(E.Delta d::es)
			    in
			      case EpsUtil.reduceDelta(eps, dels, post)
			       of (false, _) => mkProd [E.Delta d, rewrite(mkProd es)]
				| (_, E.Opn(E.Prod, p)) => (changed := true; filterProd p)
				| (_, a) => (changed := true; a)
			      (* end case *)
			    end
		      (* end case *))
		  | E.Opn(E.Prod, [e1, e2]) => filterProd [rewrite e1, rewrite e2]
		  | E.Opn(E.Prod, e1::es) => let
		      val e' = rewrite e1
		      val e2 = rewrite (mkProd es)
		      in
			case e2
			 of E.Opn(Prod, p') => filterProd (e' :: p')
			  | _ => filterProd [e',e2]
			(* end case *)
		      end
		(* end case *))
	  fun loop (body ,count) = let
		val body' = rewrite body
		in
		  if !changed
		    then (changed := false; loop(body', count+1))
		    else (body', count)
		end
	  val (b, count) = loop(body, 0)
	  in
	    if count = 0 then NONE else SOME(Ein.EIN{params=params, index=index, body=b})
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0