Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/simplify/simplify.sml
 [diderot] / branches / vis15 / src / compiler / simplify / simplify.sml

# Diff of /branches/vis15/src/compiler/simplify/simplify.sml

revision 4393, Tue Aug 9 22:00:05 2016 UTC revision 4394, Wed Aug 10 01:03:33 2016 UTC
# Line 114  Line 114
114    (* make a variable definition *)    (* make a variable definition *)
115      fun mkDef (x, e) = S.S_Var(x, SOME e)      fun mkDef (x, e) = S.S_Var(x, SOME e)
116
117        fun mkRDiv (res, a, b) = mkDef (res, S.E_Prim(BV.div_rr, [], [a, b], STy.realTy))
118        fun mkToReal (res, a) =
119              mkDef (res, S.E_Coerce{srcTy = STy.T_Int, dstTy = STy.realTy, x = a})
120        fun mkLength (res, elemTy, xs) =
121              mkDef (res, S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int))
122
123    (* simplify a statement into a single statement (i.e., a block if it expands    (* simplify a statement into a single statement (i.e., a block if it expands
124     * into more than one new statement).     * into more than one new statement).
125     *)     *)
# Line 297  Line 303
303                                    val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)                                    val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
304                                    val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)                                    val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
305                                    val stms =                                    val stms =
306                                          mkDef(mean,                                          mkRDiv (mean, resultV, rNum) ::
307                                            S.E_Prim(BV.div_rr, [], [resultV, rNum], STy.realTy)) ::                                          mkToReal (rNum, num) ::
308                                          mkDef(rNum, S.E_Coerce{                                          mkLength (num, elemTy, xs) ::
srcTy = STy.T_Int, dstTy = STy.realTy, x = num
}) ::
mkDef(num,
S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int)) ::
309                                          stms                                          stms
310                                    in                                    in
311                                      (stms, S.E_Var mean)                                      (stms, S.E_Var mean)
# Line 316  Line 318
318                      | AST.E_ParallelMap(e', x, xs, _) =>                      | AST.E_ParallelMap(e', x, xs, _) =>
319                          if Basis.isReductionOp rator                          if Basis.isReductionOp rator
320                            then let                            then let
321                              val (result, stm) = simplifyReduction (cxt, rator, e', x, xs, ty)                              val (result, stms) = simplifyReduction (cxt, rator, e', x, xs, ty, stms)
322                              in                              in
323                                (stm :: stms, S.E_Var result)                                (stms, S.E_Var result)
324                              end                              end
325                            else raise Fail "unsupported operation on parallel map"                            else raise Fail "unsupported operation on parallel map"
326                      | _ => doPrimApply (rator, tyArgs, args, ty)                      | _ => doPrimApply (rator, tyArgs, args, ty)
# Line 452  Line 454
454            end            end
455
456    (* simplify a parallel map-reduce *)    (* simplify a parallel map-reduce *)
457      and simplifyReduction (cxt, rator, e, x, xs, resTy) = let      and simplifyReduction (cxt, rator, e, x, xs, resTy, stms) = let
458              val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)              val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
(* convert the reduction operator from a variable to a Reductions.t value *)
val rator' = if Var.same(BV.red_all, rator) then Reductions.ALL
else if Var.same(BV.red_exists, rator) then Reductions.EXISTS
else if Var.same(BV.red_max, rator) then Reductions.MAX
(* use SUM and divide by number of strands *)
else if Var.same(BV.red_mean, rator) then raise Fail "FIXME: mean reduction"
else if Var.same(BV.red_min, rator) then Reductions.MIN
else if Var.same(BV.red_product, rator) then Reductions.PRODUCT
else if Var.same(BV.red_sum, rator) then Reductions.SUM
(* two passes *)
else if Var.same(BV.red_variance, rator) then raise Fail "FIXME: variance reduction"
else raise Fail "impossible: not a reduction"
459              val x' = cvtVar x              val x' = cvtVar x
460              val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])              val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
461            (* convert the domain from a variable to a StrandSets.t value *)            (* convert the domain from a variable to a StrandSets.t value *)
# Line 476  Line 466
466              val (func, args) = Util.makeFunction(              val (func, args) = Util.makeFunction(
467                    Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),                    Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
468                    SimpleVar.typeOf bodyResult)                    SimpleVar.typeOf bodyResult)
469
470                in
471                  case Util.identifyReduction rator
472                   of Util.MEAN => let
473                        val mapReduceStm = S.S_MapReduce[
474                                S.MapReduce{
475                                    result = result, reduction = Reductions.SUM, mapf = func,
476                                    args = args, source = x', domain = domain
477                                  }]
478                        val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
479                        val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
480                        val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
481                        val numStrandsOp = (case domain
482                               of StrandSets.ACTIVE => BV.numActive
483                                | StrandSets.ALL => BV.numStrands
484                                | StrandSets.STABLE => BV.numStable
485                              (* end case *))
486                        val stms =
487                              mkRDiv (mean, result, rNum) ::
488                              mkToReal (rNum, num) ::
489                              mkDef (num, S.E_Prim(numStrandsOp, [], [], STy.T_Int)) ::
490                              mapReduceStm ::
491                              stms
492                        in
493                          (mean, stms)
494                        end
495                    | Util.VARIANCE => raise Fail "FIXME: variance reduction"
496                    | Util.RED rator' => let
497              val mapReduceStm = S.S_MapReduce[              val mapReduceStm = S.S_MapReduce[
498                      S.MapReduce{                      S.MapReduce{
499                          result = result, reduction = rator', mapf = func, args = args,                          result = result, reduction = rator', mapf = func, args = args,
500                          source = x', domain = domain                          source = x', domain = domain
501                        }]                        }]
502              in              in
503                (result, mapReduceStm)                        (result, mapReduceStm :: stms)
504                        end
505                  (* end case *)
506              end              end
507
508    (* simplify a block and then prune unreachable and dead code *)    (* simplify a block and then prune unreachable and dead code *)

Legend:
 Removed from v.4393 changed lines Added in v.4394