Fri Jul 1 14:46:08 2016 UTC (3 years, 3 months ago) by cchiw
File size: 4750 byte(s)
`epsilon rewrite incorrectly wrapped summation expression`
```(* normalize.sml
*
* This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* COPYRIGHT (c) 2015 The University of Chicago
*)

structure Normalize : sig

val transform : HighIR.program -> HighIR.program

end = struct

structure IR = HighIR
structure Op = HighOps
structure V = IR.Var
structure ST = Stats

(********** Counters for statistics **********)
val cntUnused               = ST.newCounter "high-opt:unused"

structure UnusedElim = UnusedElimFn (
structure IR = IR
val cntUnused = cntUnused)

fun useCount (IR.V{useCnt, ...}) = !useCnt

(* adjust a variable's use count *)
fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
fun decUse (IR.V{useCnt, ...}) = (useCnt := !useCnt - 1)
fun use x = (incUse x; x)

(*** OLD VERSION
fun getEinApp x = (case V.getDef x
of IR.EINAPP(e, arg) => SOME(e, arg)
| _ => NONE
(* end case *))
****)
(* get the EIN application that "x" is bound to (if any).  Note that we are conservative
* on following globals so as to avoid duplicating computation.
*)
fun getEinApp x = let
fun getEinRHS (IR.EINAPP app) = SOME app
| getEinRHS _ = NONE
in
case V.ty x
of HighTypes.KernelTy => getEinRHS(V.getDef x)
| HighTypes.FieldTy => getEinRHS(V.getDef x)
| _ => getEinRHS(V.getLocalDef x)
(* end case *)
end

(* doNormalize : EIN -> EIN
* Orders EIN, normalizes it, then cleans the summation
*)
fun doNormalize e' = let
val _ =print(String.concat["\n\n\n do normalize:", EinPP.toString(e')])
val ordered = Reorder.transform e'
in
case NormalizeEin.transform ordered
of NONE => ordered
| SOME e => (print(String.concat["\n transform :=>", EinPP.toString(e)]) ; EinSums.clean e)
(* end case *)
end

fun nameCnt e = String.concat[V.toString e, "(", Int.toString(useCount e), ")"]
fun nameCnts e = String.concatWithMap "" nameCnt e

(* rewriteEin : EIN.Params*int*int*EIN*HighIR Vars* rhs list* EIN  *HighIR Var
*	-> int*EIN*int*rhs list
* Orders EIN, normalizes it, then cleans the summation orig-original EIN
*)
fun rewriteEin (params, place, changed, newE, newArgs, done, newEinApp, orig, lhs) = (
case List.nth(params, place)
of Ein.TEN(false, _) => (
incUse lhs; incUse newEinApp;
(changed, orig, place+1, done@[newEinApp]))
| _ => let
val rtnArgs = done @ newArgs
val (c, subst) = Apply.apply(orig, place, newE)
val _ =print(String.concat["\n\n\n after substition:", EinPP.toString(subst)])
in
if c
then (true, subst, place + length newArgs, rtnArgs)
else (
incUse lhs; List.app incUse newArgs; decUse newEinApp;
(true, subst, place + length newArgs, rtnArgs))
end
(* end case *))

(* FIXME: it would be much more efficient to do all of the substitutions in one go,
* instead of repeatedly rewriting the term for each argument.
*)
(* doRHS: HighIR.var * rhs -> (var * rhs) list option
* Looks at each argument to the original EINAPP.
* If it is another EIN APP calls rewriteEin to do application
* "place"-The Param-id for the EIN operator.
* Keeps track of the place of the argument in substitution.
*)
fun doRHS (lhs, IR.EINAPP(ein, args)) = let
val _ =print(String.concat["\n\n\n doRHS:", EinPP.toString(ein)])
fun rewrite (false, _, _, [], _) = NONE
| rewrite (true, orig, _, [], args') =
SOME[(lhs, IR.EINAPP(doNormalize orig, args'))]
| rewrite (changed, orig, place, e::es, args') = (case getEinApp e
of NONE => rewrite(changed, orig, place+1, es, args'@[e])
| SOME(newE, newA) => let
val Ein.EIN{params, index, body} = orig
val (changed, e', place', done') =
rewriteEin (params, place, changed, newE, newA, args', e, orig, lhs)
in
rewrite(changed, e', place', es, done')
end
(* end case *))
in
rewrite (false, ein, 0, args, [])
end
| doRHS _ = NONE

structure Rewrite = RewriteFn (
struct
structure IR = IR
val doAssign = doRHS

fun doMAssign _ = NONE
val elimUnusedVars = UnusedElim.reduce
end)

structure Promote = PromoteFn (IR)

val transform = Promote.transform o Rewrite.transform

(*DEBUG*
fun transform prog = let
val prog = Rewrite.transform prog
val _ = HighPP.output(Log.logFile(), "AFTER REWRITE", prog)
val prog = Promote.transform prog
val _ = HighPP.output(Log.logFile(), "AFTER PROMOTE", prog)
in
prog
end
*DEBUG*)

end
```