Home My Page Projects Code Snippets Project Openings diderot

SCM Repository

[diderot] View of /branches/lamont/src/compiler/simplify/reduction_lift.sml
 [diderot] / branches / lamont / src / compiler / simplify / reduction_lift.sml View of /branches/lamont/src/compiler/simplify/reduction_lift.sml

Sun Dec 9 19:54:32 2012 UTC (9 years, 1 month ago) by lamonts
File size: 3503 byte(s)
(* reduction_lift.sml
*
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* Lift reduction operations to global scope.
*
* NOTE: this process can be streamlined as follows:
*	1) Find the reduction expressions
*	2) create new global variables for each reduction expression and place the computation in the global reduction block.
*)

structure ReductionLift : sig

val transform : Error.err_stream * Simple.program -> Simple.program

end = struct

structure S = Simple

local
val cnt = ref 0
fun genName prefix = let
val n = !cnt
in
cnt := n+1;
String.concat[prefix, "_", Int.toString n]
end
in
fun newGlobalR (r,ty) = Var.new (Atom.atom (genName(String.concat["r_",r])), AST.GlobalVar,ty)
end

fun reductionToString(r) = (case r
of S.R_Max => "max"
| S.R_Min => "min"
| S.R_Or => "or"
| S.R_And => "and"
| S.R_Xor => "xor"
| S.R_Product => "product"
| S.R_Sum => "sum")

fun transform (errStrm,prog as S.Program{globals,globalInit as S.Block(globalInitStms), init, strands,inputs,...}) = let

val globalReduc = ref []
val extraGlobals = ref []
val extraGlobalInit = ref []

fun reduceStrand(S.Strand{name, params, state,stateInit,methods}) =
S.Strand{
name = name,
params = params,
state = state, stateInit = stateInit,
methods = List.map reduceMethod methods
}

and reduceMethod (S.Method(name, body)) =
S.Method(name, reduceBlock body)

and reduceBlock (S.Block stms) = S.Block(List.rev (reduceStmts(stms,[])))

and reduceStmts ((stm::stms),stms') = (case stm
of S.S_Foreach(v,e,b) => reduceStmts(stms,S.S_Foreach(v,reduceExp e, reduceBlock b)::stms')
| S.S_IfThenElse(v,b1,b2) => reduceStmts(stms,S.S_IfThenElse(v,reduceBlock b1, reduceBlock b2)::stms')
| S.S_Assign (v,e) => reduceStmts(stms,S.S_Assign(v,reduceExp e)::stms')
| _ => reduceStmts(stms,stm::stms'))
| reduceStmts ([],stms') = stms'

and reduceExp(e) = (case e
of S.E_Reduction(r,sv,stms,x'',ty) => let
fun initReduction(r) = (case r
of S.R_And => S.E_Lit(Literal.Bool(true))
| S.R_Or => S.E_Lit(Literal.Bool(false))
| S.R_Xor => S.E_Lit(Literal.Bool(false))
| S.R_Min => S.E_Lit(Literal.Float(FloatLit.posInf))
| S.R_Max => S.E_Lit(Literal.Float(FloatLit.negInf))
| S.R_Product => S.E_Lit(Literal.Float(FloatLit.fromInt 1))
| S.R_Sum => S.E_Lit(Literal.Float (FloatLit.fromInt 0))
(* end case *))

val x' = newGlobalR(reductionToString(r),ty)
val stm = S.S_Assign(x',S.E_Reduction(r,sv,stms,x'',ty))
in
extraGlobalInit := S.S_Assign(x',initReduction(r)) :: (!extraGlobalInit);
extraGlobals := x'::(!extraGlobals);
globalReduc := (stms @ [stm])@(!globalReduc);
S.E_Var(x')
end
| _ => e)

val prog' =
S.Program{
strands = List.map reduceStrand strands ,
globals = globals @ (!extraGlobals),
globalInit = S.Block(globalInitStms @(!extraGlobalInit)),
globalReduc = S.Block(!globalReduc),
init = init,
inputs = inputs
}
(* val _ = SimplePP.output (Log.logFile(), prog)*)	(* DEBUG *)
in
prog'
end

end