Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

# SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-opt/normalize-ein.sml
 [diderot] / branches / vis15 / src / compiler / high-opt / normalize-ein.sml

# View of /branches/vis15/src/compiler/high-opt/normalize-ein.sml

Revision 3520 - (download) (annotate)
Sat Dec 19 15:49:06 2015 UTC (4 years, 11 months ago) by jhr
File size: 8637 byte(s)
`working on merge`
```(* normalize-ein.sml
*
* This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* COPYRIGHT (c) 2015 The University of Chicago
* All rights reserved.
*)

structure NormalizeEin : sig

(* normalize an Ein function; if there are no changes, then NONE is returned. *)
val transform : Ein.ein -> Ein.ein option

end = struct

structure E = Ein

fun err str = raise Fail(String.concat["Ill-formed EIN Operator",str])

val zero = E.Const 0

fun mkProd exps = E.Opn(E.Prod, exps)
fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2)

(*mkSum:sum_indexid list * ein_exp->int *ein_exp
*distribute summation expression
*)
fun mkSum (sx1, b) = (case b
of E.Lift e          => SOME(E.Lift(E.Sum(sx1, e)))
| E.Tensor(_, [])   => SOME b
| E.Const _         => SOME b
| E.ConstR _        => SOME b
(* QUESTION: what if filterSca returns false for its first result? *)
| E.Opn(E.Prod, es) => SOME(#2 (EinFilter.filterSca (sx1, es)))
| _                 => NONE
(* end case *))

(*mkProbe:ein_exp* ein_exp-> int ein_exp
*rewritten probe
*)
fun mkProbe (b, x) = (case b
of E.Const _          => NONE
| E.ConstR _         => NONE
| E.Tensor _         => err "Tensor without Lift"
| E.Delta _          => NONE
| E.Epsilon _        => NONE
| E.Eps2 _           => NONE
| E.Field _          => NONE
| E.Lift e           => SOME e
| E.Conv _           => NONE
| E.Partial _        => err "Probe Partial"
| E.Apply _          => NONE
| E.Probe _          => err "Probe of a Probe"
| E.Value _          => err "Value used before expand"
| E.Img _            => err "Probe used before expand"
| E.Krn _            => err "Krn used before expand"
| E.Sum(sx1, e)      => SOME(E.Sum(sx1, E.Probe(e, x)))
| E.Op1(op1, e)      => SOME(E.Op1(op1, E.Probe(e, x)))
| E.Op2(op2, e1, e2) => SOME(E.Op2(op2, E.Probe(e1, x), E.Probe(e2, x)))
| E.Opn(opn, [])     => err "Probe of empty operator"
| E.Opn(opn, es)     => SOME(E.Opn(opn, List.map (fn e => E.Probe(e, x)) es))
(* end case *))

(* rewrite body of EIN *)
fun transform (Ein.EIN{params, index, body}) = let
val changed = ref false
fun filterProd args = (case EinFilter.mkProd args
of SOME e => (changed := true; e)
| NONE => mkProd args
(* end case *))
fun rewrite body = (case body
of E.Const _			  => body
| E.ConstR _			  => body
| E.Tensor _			  => body
| E.Delta _			  => body
| E.Epsilon _			  => body
| E.Eps2 _			  => body
(************** Field Terms **************)
| E.Field _			  => body
| E.Lift e1			  => E.Lift(rewrite e1)
| E.Conv _			  => body
| E.Partial _			  => body
| E.Apply(E.Partial [], e1)	  => e1
| E.Apply(E.Partial d1, e1)	  => let
val e1 = rewrite e1
in
case Derivative.mkApply(E.Partial d1, e1)
of SOME e => (changed := true; e)
| NONE => E.Apply(E.Partial d1, e1)
(* end case *)
end
| E.Apply _			  => err "Ill-formed Apply expression"
| E.Probe(e1, e2)		  => (
case mkProbe(rewrite e1, rewrite e2)
of NONE => body
| SOME body' => (changed := true; body')
(* end case *))
(************** Field Terms **************)
| E.Value _			  => err "Value before Expand"
| E.Img _			  => err "Img before Expand"
| E.Krn _			  => err "Krn before Expand"
(************** Sum **************)
| E.Sum([], e1)		  => (changed:=true; rewrite e1)
| E.Sum(sx, e)		  => let
val e = rewrite e
in
case mkSum(sx, e)
of SOME e' => (changed := true; e')
| NONE => E.Sum(sx, e)
(* end case *)
end
(************* Algebraic Rewrites Op1 **************)
| E.Op1(E.Neg, E.Op1(E.Neg, e)) => (changed := true; rewrite e)
| E.Op1(E.Neg, E.Const 0)       => (changed := true; zero)
| E.Op1(op1, e1)		  => E.Op1(op1, rewrite e1)
(************* Algebraic Rewrites Op2 **************)
| E.Op2(E.Sub, E.Const 0, e2)	  => (changed := true; E.Op1(E.Neg, rewrite e2))
| E.Op2(E.Sub, e1, E.Const 0)	  => (changed := true; rewrite e1)
| E.Op2(E.Div, E.Const 0, e2)	  => (changed := true; zero)
| E.Op2(E.Div, E.Op2(E.Div,a,b), E.Op2(E.Div,c,d))
=> rewrite (mkDiv (mkProd[a,d], mkProd[b,c]))
| E.Op2(E.Div, E.Op2(E.Div, a, b), c)
=> rewrite (mkDiv (a, mkProd[b, c]))
| E.Op2(E.Div, a, E.Op2(E.Div, b, c))
=> rewrite (mkDiv (mkProd[a, c], b))
| E.Op2(op1, e1, e2)		  => E.Op2(op1, rewrite e1, rewrite e2)
(************* Algebraic Rewrites Opn **************)
| E.Opn(E.Add, es)		  => (
case EinFilter.mkAdd (List.map rewrite es)
of SOME body' => (changed := true; body')
| NONE => body
(* end case *))
(************* Product **************)
| E.Opn(E.Prod, [])		  => err "missing elements in product"
| E.Opn(E.Prod, [e1])		  => rewrite e1
| E.Opn(E.Prod, [e1 as E.Op1(E.Sqrt, s1), e2 as E.Op1(E.Sqrt, s2)]) =>
if EinUtil.sameExp(s1, s2)
then (changed := true; s1)
else filterProd [rewrite e1, rewrite e2]
(************* Product EPS **************)
| E.Opn(E.Prod, (eps1 as E.Epsilon(i,j,k))::ps) => (case ps
of ((p1 as E.Apply(E.Partial d, e)) :: es) => (
case (EpsUtil.matchEps (d, [i,j,k]), es)
of (true, _) => (changed := true; zero)
| (_, []) => mkProd[eps1, rewrite p1]
| _ => filterProd [eps1, rewrite (mkProd (p1 :: es))]
(* end case *))
| ((p1 as E.Conv(_, _, _, d)) :: es) => (
case (EpsUtil.matchEps (d, [i,j,k]), es)
of (true, _) => (changed := true; E.Lift zero)
| (_, []) => mkProd[eps1, p1]
| _ => filterProd [eps1, rewrite (mkProd(p1 :: es))]
(* end case *))
| [E.Tensor(_, [E.V i1, E.V i2])] =>
if (j=i1 andalso k=i2)
then (changed :=true; zero)
else body
| _  => (case EpsUtil.epsToDels (eps1::ps)
of (true, e, [], _, _) => (changed := true; e)(* Changed to Deltas*)
(* QUESTION: should we call rewrite on e? *)
| (true, e, sx, _, _) => (changed := true; E.Sum(sx, e))
| (_, _, _, _, []) =>  body
| (_, _, _, epsAll, rest) =>
filterProd (epsAll @ [rewrite (mkProd rest)])
(* end case *))
(* end case *))
| E.Opn(E.Prod, (s1 as E.Sum(c1, e1)) :: (s2 as E.Sum(c2, e2)) :: es) => (
case (e1, e2, es)
of (
E.Opn(E.Prod, (e1 as E.Epsilon _) :: es1),
E.Opn(E.Prod, (e2 as E.Epsilon _) :: es2),
_) => (case EpsUtil.epsToDels(e1 :: e2 :: es1 @ es2 @ es)
of (true, e, sx, _, _) => (changed:=true; E.Sum(c1@c2@sx, e))
| _ => let
val eA = rewrite s1
val eB = rewrite (mkProd(s2 :: es))
in
filterProd [eA, eB]
end
(* end case *))
| (_, _, []) => filterProd [rewrite(E.Sum(c1, e1)), rewrite(E.Sum(c2, e2))]
| _ => let
val e' = rewrite (E.Sum(c1, e1))
val e2 = rewrite (E.Opn(E.Prod, E.Sum(c2, e2) :: es))
in
case e2
of E.Opn(E.Prod, p') => filterProd ([e']@p')
| _ => filterProd [e', e2]
(* end case *)
end
(* end case *))
| E.Opn(E.Prod, E.Delta d::es) => (case es
of [E.Op1(E.Neg, e1)] => (
changed:=true; E.Op1(E.Neg, mkProd[E.Delta d, e1]))
| _ => let
val (pre', eps, dels, post) = EinFilter.partitionGreek(E.Delta d::es)
in
case EpsUtil.reduceDelta(eps, dels, post)
of (false, _) => mkProd [E.Delta d, rewrite(mkProd es)]
| (_, E.Opn(E.Prod, p)) => (changed := true; filterProd p)
| (_, a) => (changed := true; a)
(* end case *)
end
(* end case *))
| E.Opn(E.Prod, [e1, e2]) => filterProd [rewrite e1, rewrite e2]
| E.Opn(E.Prod, e1::es) => let
val e' = rewrite e1
val e2 = rewrite (mkProd es)
in
case e2
of E.Opn(Prod, p') => filterProd (e' :: p')
| _ => filterProd [e',e2]
(* end case *)
end
(* end case *))
fun loop (body ,count) = let
val body' = rewrite body
in
if !changed
then (changed := false; loop(body', count+1))
else (body', count)
end
val (b, count) = loop(body, 0)
in
if count = 0 then NONE else SOME(Ein.EIN{params=params, index=index, body=b})
end

end
```

 root@smlnj-gforge.cs.uchicago.edu ViewVC Help Powered by ViewVC 1.0.0