Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] View of /branches/vis12/src/compiler/simplify/simplify.sml
 [diderot] / branches / vis12 / src / compiler / simplify / simplify.sml # View of /branches/vis12/src/compiler/simplify/simplify.sml

Mon Oct 19 19:39:09 2015 UTC (5 years ago) by jhr
File size: 20082 byte(s)
```  Fixed bug042.  This bug was a side effect of the change in precedence, which
required a different treatment of negative literals.  We now constant fold
them during simplification.
```
```(* simplify.sml
*
* This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* COPYRIGHT (c) 2015 The University of Chicago
*
* Simplify the AST representation.
*)

structure Simplify : sig

val transform : Error.err_stream * AST.program -> Simple.program

end = struct

structure TU = TypeUtil
structure S = Simple
structure VMap = Var.Map
structure InP = Inputs

val cvtTy = SimpleTypes.simplify

fun newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)

(* convert an AST variable to a Simple variable *)
fun cvtVar (env, x as Var.V{name, kind, ty=([], ty), ...}) = let
val x' = SimpleVar.new (name, kind, cvtTy ty)
in
(x', VMap.insert(env, x, x'))
end

fun cvtVars (env, xs) = List.foldr
(fn (x, (xs, env)) => let
val (x', env) = cvtVar(env, x)
in
(x'::xs, env)
end) ([], env) xs

fun lookupVar (env, x) = (case VMap.find (env, x)
of SOME x' => x'
| NONE => raise Fail(concat["lookupVar(", Var.uniqueNameOf x, ")"])
(* end case *))

(* make a block out of a list of statements that are in reverse order *)
fun mkBlock stms = S.Block(List.rev stms)

fun inputImage (nrrd, dim, shape) = (
case ImageInfo.fromNrrd(NrrdInfo.getInfo nrrd, dim, shape)
of NONE => raise Fail(concat["nrrd file \"", nrrd, "\" does not have expected type"])
| SOME info => InP.Proxy(nrrd, info)
(* end case *))

datatype 'a ctl_flow_info
= EXIT                    (* stm sequence always exits; no pruning so far *)
| PRUNE of 'a             (* stm sequence always exits at last stm in argument, which
* is either a block of stm list *)
| CONT                    (* stm sequence falls through *)
| EDIT of 'a              (* pruned code that has non-exiting paths *)

fun pruneUnreachableCode (blk as S.Block stms) = let
fun isExit S.S_Die = true
| isExit S.S_Stabilize = true
| isExit (S.S_Return _) = true
| isExit _ = false
fun pruneStms [] = CONT
| pruneStms [S.S_IfThenElse(x, blk1, blk2)] = (
case pruneIf(x, blk1, blk2)
of EXIT => EXIT
| PRUNE stm => PRUNE[stm]
| CONT => CONT
| EDIT stm => EDIT[stm]
(* end case *))
| pruneStms [stm] = if isExit stm then EXIT else CONT
| pruneStms ((stm as S.S_IfThenElse(x, blk1, blk2))::stms) = (
case pruneIf(x, blk1, blk2)
of EXIT => PRUNE[stm]
| PRUNE stm => PRUNE[stm]
| CONT => (case pruneStms stms
of PRUNE stms => PRUNE(stm::stms)
| EDIT stms => EDIT(stm::stms)
| EXIT => EXIT (* different instances of ctl_flow_info *)
| CONT => CONT
(* end case *))
| EDIT stm => (case pruneStms stms
of PRUNE stms => PRUNE(stm::stms)
| EDIT stms => EDIT(stm::stms)
| _ => EDIT(stm::stms)
(* end case *))
(* end case *))
| pruneStms (stm::stms) = if isExit stm
then PRUNE[stm]
else (case pruneStms stms
of PRUNE stms => PRUNE(stm::stms)
| EDIT stms => EDIT(stm::stms)
| info => info
(* end case *))
and pruneIf (x, blk1, blk2) = (case (pruneBlk blk1, pruneBlk blk2)
of (EXIT,       EXIT      ) => EXIT
| (CONT,       CONT      ) => CONT
| (CONT,       EXIT      ) => CONT
| (EXIT,       CONT      ) => CONT
| (CONT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (EDIT blk1,  CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (CONT,       PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (PRUNE blk1, CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (EXIT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (EDIT blk1,  EXIT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (EDIT blk1,  EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (EDIT blk1,  PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (PRUNE blk1, EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
| (EXIT,       PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))
| (PRUNE blk1, EXIT      ) => PRUNE(S.S_IfThenElse(x, blk1, blk2))
| (PRUNE blk1, PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))
(* end case *))
and pruneBlk (S.Block stms) = (case pruneStms stms
of PRUNE stms => PRUNE(S.Block stms)
| EDIT stms => EDIT(S.Block stms)
| EXIT => EXIT (* different instances of ctl_flow_info *)
| CONT => CONT
(* end case *))
in
case pruneBlk blk
of PRUNE blk => blk
| EDIT blk => blk
| _=> blk
(* end case *)
end

fun simplifyProgram (AST.Program{props, decls}) = let
val inputs = ref []
val inputInit = ref []
val globals = ref []
val globalInit = ref []
val funcs = ref []
val initially = ref NONE
val strands = ref []
fun setInitially init = (case !initially
of NONE => initially := SOME init
(* FIXME: the check for multiple initially decls should happen in type checking *)
| SOME _ => raise Fail "multiple initially declarations"
(* end case *))
fun simplifyDecl (dcl, env) = (case dcl
of AST.D_Input(x, desc, NONE) => let
val (x', env) = cvtVar(env, x)
val (ty, init) = (case SimpleVar.typeOf x'
of ty as SimpleTypes.T_Image{dim, shape} => let
val info = ImageInfo.mkInfo(dim, shape)
in
(ty, SOME(InP.Image info))
end
| ty => (ty, NONE)
(* end case *))
val inp = InP.INP{
ty = ty,
name = SimpleVar.nameOf x',
desc = desc,
init = init
}
in
inputs := (x', inp) :: !inputs;
env
end
| AST.D_Input(x, desc, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))) => let
val (x', env) = cvtVar(env, x)
(* load the nrrd proxy here *)
val info = NrrdInfo.getInfo nrrd
val (ty, init) = (case SimpleVar.typeOf x'
of ty as SimpleTypes.T_DynSequence _ => (ty, InP.DynSeq nrrd)
| ty as SimpleTypes.T_Image{dim, shape} => (ty, inputImage(nrrd, dim, shape))
| _ => raise Fail "impossible"
(* end case *))
val inp = InP.INP{
ty = ty,
name = SimpleVar.nameOf x',
desc = desc,
init = SOME init
}
in
inputs := (x', inp) :: !inputs;
env
end
| AST.D_Input(x, desc, SOME e) => let
val (x', env) = cvtVar(env, x)
val (stms, e') = simplifyExp (env, e, [])
val inp = InP.INP{
ty = SimpleVar.typeOf x',
name = SimpleVar.nameOf x',
desc = desc,
init = NONE
}
in
inputs := (x', inp) :: !inputs;
inputInit := S.S_Assign(x', e') :: (stms @ !inputInit);
env
end
| AST.D_Var(AST.VD_Decl(x, e)) => let
val (x', env) = cvtVar(env, x)
val (stms, e') = simplifyExp (env, e, [])
in
globals := x' :: !globals;
globalInit := S.S_Assign(x', e') :: (stms @ !globalInit);
env
end
| AST.D_Func(f, params, body) => let
val (f', env) = cvtVar(env, f)
val (params', env) = cvtVars (env, params)
val body' = pruneUnreachableCode (simplifyBlock(env, body))
in
funcs := S.Func{f=f', params=params', body=body'} :: !funcs;
env
end
| AST.D_Strand info => (
strands := simplifyStrand(env, info) :: !strands;
env)
| AST.D_InitialArray(creat, iters) => (
setInitially (simplifyInit(env, true, creat, iters));
env)
| AST.D_InitialCollection(creat, iters) => (
setInitially (simplifyInit(env, false, creat, iters));
env)
(* end case *))
val env = List.foldl simplifyDecl VMap.empty decls
in
S.Program{
props = props,
inputDefaults = mkBlock (!inputInit),
inputs = List.rev(!inputs),
globals = List.rev(!globals),
globalInit = mkBlock (!globalInit),
funcs = List.rev(!funcs),
init = (case !initially
(* FIXME: the check for the initially block should really happen in typechecking *)
of NONE => raise Fail "missing initially declaration"
| SOME blk => blk
(* end case *)),
strands = List.rev(!strands)
}
end

and simplifyInit (env, isArray, AST.C_Create(strand, exps), iters) = let
fun simplifyIter (AST.I_Range(x, e1, e2), (env, iters, stms)) = let
val (stms, lo) = simplifyExpToVar (env, e1, stms)
val (stms, hi) = simplifyExpToVar (env, e2, stms)
val (x', env) = cvtVar (env, x)
in
(env, {param=x', lo=lo, hi=hi}::iters, stms)
end
val (env, iters, iterStms) = List.foldl simplifyIter (env, [], []) iters
val (stms, xs) = simplifyExpsToVars (env, exps, [])
val creat = S.C_Create{
argInit = mkBlock stms,
name = strand,
args = xs
}
in
S.Initially{
isArray = isArray,
rangeInit = mkBlock iterStms,
iters = List.rev iters,
create = creat
}
end

and simplifyStrand (env, AST.Strand{name, params, state, methods}) = let
val (params', env) = cvtVars (env, params)
fun simplifyState (env, [], xs, stms) = (List.rev xs, mkBlock stms, env)
| simplifyState (env, AST.VD_Decl(x, e) :: r, xs, stms) = let
val (stms, e') = simplifyExp (env, e, stms)
val (x', env) = cvtVar(env, x)
in
simplifyState (env, r, x'::xs, S.S_Assign(x', e') :: stms)
end
val (xs, stm, env) = simplifyState (env, state, [], [])
in
S.Strand{
name = name,
params = params',
state = xs, stateInit = stm,
methods = List.map (simplifyMethod env) methods
}
end

and simplifyMethod env (AST.M_Method(name, body)) =
S.Method(name, pruneUnreachableCode (simplifyBlock(env, body)))

(* simplify a statement into a single statement (i.e., a block if it expands
* into more than one new statement).
*)
and simplifyBlock (env, stm) = mkBlock (#1 (simplifyStmt (env, stm, [])))

(* simplify the statement stm where stms is a reverse-order list of preceeding simplified
* statements.  This function returns a reverse-order list of simplified statements.
* Note that error reporting is done in the typechecker, but it does not prune unreachable
* code.
*)
and simplifyStmt (env, stm, stms) = (case stm
of AST.S_Block body => let
fun simplify (_, [], stms) = stms
| simplify (env', stm::r, stms) = let
val (stms, env') = simplifyStmt (env', stm, stms)
in
simplify (env', r, stms)
end
in
(simplify (env, body, stms), env)
end
| AST.S_Decl(AST.VD_Decl(x, e)) => let
val (stms, e') = simplifyExp (env, e, stms)
val (x', env) = cvtVar(env, x)
in
(S.S_Assign(x', e') :: stms, env)
end
| AST.S_IfThenElse(e, s1, s2) => let
val (stms, x) = simplifyExpToVar (env, e, stms)
val s1 = simplifyBlock (env, s1)
val s2 = simplifyBlock (env, s2)
in
(S.S_IfThenElse(x, s1, s2) :: stms, env)
end
| AST.S_Assign(x, e) => let
val (stms, e') = simplifyExp (env, e, stms)
in
(S.S_Assign(lookupVar(env, x), e') :: stms, env)
end
| AST.S_New(name, args) => let
val (stms, xs) = simplifyExpsToVars (env, args, stms)
in
(S.S_New(name, xs) :: stms, env)
end
| AST.S_Continue => (S.S_Continue :: stms, env)
| AST.S_Die => (S.S_Die :: stms, env)
| AST.S_Stabilize => (S.S_Stabilize :: stms, env)
| AST.S_Return e => let
val (stms, x) = simplifyExpToVar (env, e, stms)
in
(S.S_Return x :: stms, env)
end
| AST.S_Print args => let
val (stms, xs) = simplifyExpsToVars (env, args, stms)
in
(S.S_Print xs :: stms, env)
end
(* end case *))

and simplifyExp (env, exp, stms) = let
fun doApply (f, tyArgs, args, ty) = let
val (stms, xs) = simplifyExpsToVars (env, args, stms)
in
case Var.kindOf f
of S.FunVar => (stms, S.E_Apply(lookupVar(env, f), xs, cvtTy ty))
| S.BasisVar => let
fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
| cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
| cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
| cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
val tyArgs = List.map cvtTyArg tyArgs
in
(stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
end
| _ => raise Fail "bogus application"
(* end case *)
end
in
case exp
of AST.E_Var x => (case Var.kindOf x
of Var.BasisVar => let
val ty = cvtTy(Var.monoTypeOf x)
val x' = newTemp ty
val stm = S.S_Assign(x', S.E_Prim(x, [], [], ty))
in
(stm::stms, S.E_Var x')
end
| _ => (stms, S.E_Var(lookupVar(env, x)))
(* end case *))
| AST.E_Lit lit => (stms, S.E_Lit lit)
| AST.E_Tuple es => raise Fail "E_Tuple not yet implemented"
| AST.E_Apply(rator, tyArgs, args as [AST.E_Lit(Literal.Int n)], ty) =>
(* constant-fold negation of integer literals *)
if Var.same(BasisVars.neg_i, rator)
then (stms, S.E_Lit(Literal.Int(~n)))
else doApply (rator, tyArgs, args, ty)
| AST.E_Apply(rator, tyArgs, args as [AST.E_Lit(Literal.Float f)], ty as Types.T_Tensor sh) =>
(* constant-fold negation of real literals *)
if Var.same(BasisVars.neg_i, rator) andalso List.null(TU.monoShape sh)
then (stms, S.E_Lit(Literal.Float(FloatLit.negate f)))
else doApply (rator, tyArgs, args, ty)
| AST.E_Apply(f, tyArgs, args, ty) => doApply (f, tyArgs, args, ty)
| AST.E_Cons es => let
val (stms, xs) = simplifyExpsToVars (env, es, stms)
in
(stms, S.E_Cons xs)
end
| AST.E_Seq(es, ty) => let
val (stms, xs) = simplifyExpsToVars (env, es, stms)
in
(stms, S.E_Seq(xs, cvtTy ty))
end
| AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
val (stms, x) = simplifyExpToVar (env, e, stms)
fun f ([], ys, stms) = (stms, List.rev ys)
| f (NONE::es, ys, stms) = f (es, NONE::ys, stms)
| f (SOME e::es, ys, stms) = let
val (stms, y) = simplifyExpToVar (env, e, stms)
in
f (es, SOME y::ys, stms)
end
val (stms, indices) = f (indices, [], stms)
in
(stms, S.E_Slice(x, indices, cvtTy ty))
end
| AST.E_Cond(e1, e2, e3, ty) => let
(* a conditional expression gets turned into an if-then-else statememt *)
val result = newTemp(cvtTy ty)
val (stms, x) = simplifyExpToVar (env, e1, S.S_Var result :: stms)
fun simplifyBranch e = let
val (stms, e) = simplifyExp (env, e, [])
in
mkBlock (S.S_Assign(result, e)::stms)
end
val s1 = simplifyBranch e2
val s2 = simplifyBranch e3
in
(S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
end
| AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
of ty as SimpleTypes.T_DynSequence _ => (stms, S.E_LoadSeq(ty, nrrd))
| ty as SimpleTypes.T_Image{dim, shape} => (
case ImageInfo.fromNrrd(NrrdInfo.getInfo nrrd, dim, shape)
of NONE => raise Fail(concat[
"nrrd file \"", nrrd, "\" does not have expected type"
])
| SOME info => (stms, S.E_LoadImage(ty, nrrd, info))
(* end case *))
| _ => raise Fail "bogus type for E_LoadNrrd"
(* end case *))
| AST.E_Coerce{srcTy, dstTy, e} => let
val (stms, x) = simplifyExpToVar (env, e, stms)
val dstTy = cvtTy dstTy
val result = newTemp dstTy
val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
in
(S.S_Assign(result, rhs)::stms, S.E_Var result)
end
(* end case *)
end

and simplifyExpToVar (env, exp, stms) = let
val (stms, e) = simplifyExp (env, exp, stms)
in
case e
of S.E_Var x => (stms, x)
| _ => let
val x = newTemp (S.typeOf e)
in
(S.S_Assign(x, e)::stms, x)
end
(* end case *)
end

and simplifyExpsToVars (env, exps, stms) = let
fun f ([], xs, stms) = (stms, List.rev xs)
| f (e::es, xs, stms) = let
val (stms, x) = simplifyExpToVar (env, e, stms)
in
f (es, x::xs, stms)
end
in
f (exps, [], stms)
end

fun transform (errStrm, ast) = let
val simple = simplifyProgram ast
val _ = SimplePP.output (Log.logFile(), "simplify", simple)   (* DEBUG *)
val simple = Inliner.transform simple
val _ = SimplePP.output (Log.logFile(), "inlining", simple)   (* DEBUG *)
in
simple
end

end
```