Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis12/src/compiler/simplify/simplify.sml
ViewVC logotype

View of /branches/vis12/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2823 - (download) (annotate)
Sun Nov 9 03:57:35 2014 UTC (4 years, 11 months ago) by jhr
File size: 19770 byte(s)
  new unreachable code pruning; checks for missing strands/initially (the parser prevents
  these errors currently, but future language evolution may change that)
(* simplify.sml
 *
 * COPYRIGHT (c) 2014 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * Simplify the AST representation.
 *)

structure Simplify : sig

    val transform : Error.err_stream * AST.program -> Simple.program

  end = struct

    structure TU = TypeUtil
    structure S = Simple
    structure VMap = Var.Map
    structure InP = Inputs

    val cvtTy = SimpleTypes.simplify

    fun newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)

  (* convert an AST variable to a Simple variable *)
    fun cvtVar (env, x as Var.V{name, kind, ty=([], ty), ...}) = let
          val x' = SimpleVar.new (name, kind, cvtTy ty)
          in
            (x', VMap.insert(env, x, x'))
          end

    fun cvtVars (env, xs) = List.foldr
          (fn (x, (xs, env)) => let
            val (x', env) = cvtVar(env, x)
            in
              (x'::xs, env)
            end) ([], env) xs

    fun lookupVar (env, x) = (case VMap.find (env, x)
           of SOME x' => x'
            | NONE => raise Fail(concat["lookupVar(", Var.uniqueNameOf x, ")"])
          (* end case *))

  (* make a block out of a list of statements that are in reverse order *)
    fun mkBlock stms = S.Block(List.rev stms)

    fun inputImage (nrrd, dim, shape) = (
          case ImageInfo.fromNrrd(NrrdInfo.getInfo nrrd, dim, shape)
           of NONE => raise Fail(concat["nrrd file \"", nrrd, "\" does not have expected type"])
            | SOME info => InP.Proxy(nrrd, info)
          (* end case *))

    datatype 'a ctl_flow_info
      = EXIT			(* stm sequence always exits; no pruning so far *)
      | PRUNE of 'a		(* stm sequence always exits at last stm in argument, which
				 * is either a block of stm list *)
      | CONT			(* stm sequence falls through *)
      | EDIT of 'a		(* pruned code that has non-exiting paths *)

    fun pruneUnreachableCode (blk as S.Block stms) = let
	  fun isExit S.S_Die = true
	    | isExit S.S_Stabilize = true
	    | isExit (S.S_Return _) = true
	    | isExit _ = false
	  fun pruneStms [] = CONT
	    | pruneStms [S.S_IfThenElse(x, blk1, blk2)] = (
		case pruneIf(x, blk1, blk2)
		 of EXIT => EXIT
		  | PRUNE stm => PRUNE[stm]
		  | CONT => CONT
		  | EDIT stm => EDIT[stm]
		(* end case *))
	    | pruneStms [stm] = if isExit stm then EXIT else CONT
	    | pruneStms ((stm as S.S_IfThenElse(x, blk1, blk2))::stms) = (
		case pruneIf(x, blk1, blk2)
		 of EXIT => PRUNE[stm]
		  | PRUNE stm => PRUNE[stm]
		  | CONT => (case pruneStms stms
		       of PRUNE stms => PRUNE(stm::stms)
			| EDIT stms => EDIT(stm::stms)
			| EXIT => EXIT (* different instances of ctl_flow_info *)
			| CONT => CONT
		      (* end case *))
		  | EDIT stm => (case pruneStms stms
		       of PRUNE stms => PRUNE(stm::stms)
			| EDIT stms => EDIT(stm::stms)
			| _ => EDIT(stm::stms)
		      (* end case *))
		(* end case *))
	    | pruneStms (stm::stms) = if isExit stm
		then PRUNE[stm]
		else (case pruneStms stms
		   of PRUNE stms => PRUNE(stm::stms)
		    | EDIT stms => EDIT(stm::stms)
		    | info => info
		  (* end case *))
	  and pruneIf (x, blk1, blk2) = (case (pruneBlk blk1, pruneBlk blk2)
		 of (EXIT,       EXIT      ) => EXIT
		  | (CONT,       CONT      ) => CONT
		  | (CONT,       EXIT      ) => CONT
		  | (EXIT,       CONT      ) => CONT
		  | (CONT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (EDIT blk1,  CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (CONT,       PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (PRUNE blk1, CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (EXIT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (EDIT blk1,  EXIT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (EDIT blk1,  EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (EDIT blk1,  PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (PRUNE blk1, EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))
		  | (EXIT,       PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))
		  | (PRUNE blk1, EXIT      ) => PRUNE(S.S_IfThenElse(x, blk1, blk2))
		  | (PRUNE blk1, PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))
		(* end case *))
	  and pruneBlk (S.Block stms) = (case pruneStms stms
		 of PRUNE stms => PRUNE(S.Block stms)
		  | EDIT stms => EDIT(S.Block stms)
		  | EXIT => EXIT (* different instances of ctl_flow_info *)
		  | CONT => CONT
		(* end case *))
	  in
	    case pruneBlk blk
	     of PRUNE blk => blk
	      | EDIT blk => blk
	      | _=> blk
	    (* end case *)
	  end

    fun simplifyProgram (AST.Program{props, decls}) = let
          val inputs = ref []
          val inputInit = ref []
          val globals = ref []
          val globalInit = ref []
          val funcs = ref []
          val initially = ref NONE
          val strands = ref []
          fun setInitially init = (case !initially
                 of NONE => initially := SOME init
(* FIXME: the check for multiple initially decls should happen in type checking *)
                  | SOME _ => raise Fail "multiple initially declarations"
                (* end case *))
          fun simplifyDecl (dcl, env) = (case dcl
                 of AST.D_Input(x, desc, NONE) => let
                      val (x', env) = cvtVar(env, x)
                      val (ty, init) = (case SimpleVar.typeOf x'
                             of ty as SimpleTypes.T_Image{dim, shape} => let
                                  val info = ImageInfo.mkInfo(dim, shape)
                                  in
                                    (ty, SOME(InP.Image info))
                                  end
                              | ty => (ty, NONE)
                            (* end case *))
                      val inp = InP.INP{
                              ty = ty,
                              name = SimpleVar.nameOf x',
                              desc = desc,
                              init = init
                            }
                      in
                        inputs := (x', inp) :: !inputs;
                        env
                      end
                  | AST.D_Input(x, desc, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))) => let
                      val (x', env) = cvtVar(env, x)
                    (* load the nrrd proxy here *)
                      val info = NrrdInfo.getInfo nrrd
                      val (ty, init) = (case SimpleVar.typeOf x'
                             of ty as SimpleTypes.T_DynSequence _ => (ty, InP.DynSeq nrrd)
                              | ty as SimpleTypes.T_Image{dim, shape} => (ty, inputImage(nrrd, dim, shape))
                              | _ => raise Fail "impossible"
                            (* end case *))
                      val inp = InP.INP{
                              ty = ty,
                              name = SimpleVar.nameOf x',
                              desc = desc,
                              init = SOME init
                            }
                      in
                        inputs := (x', inp) :: !inputs;
                        env
                      end
                  | AST.D_Input(x, desc, SOME e) => let
                      val (x', env) = cvtVar(env, x)
                      val (stms, e') = simplifyExp (env, e, [])
                      val inp = InP.INP{
                              ty = SimpleVar.typeOf x',
                              name = SimpleVar.nameOf x',
                              desc = desc,
                              init = NONE
                            }
                      in
                        inputs := (x', inp) :: !inputs;
                        inputInit := S.S_Assign(x', e') :: (stms @ !inputInit);
                        env
                      end
                  | AST.D_Var(AST.VD_Decl(x, e)) => let
                      val (x', env) = cvtVar(env, x)
                      val (stms, e') = simplifyExp (env, e, [])
                      in
                        globals := x' :: !globals;
                        globalInit := S.S_Assign(x', e') :: (stms @ !globalInit);
                        env
                      end
                  | AST.D_Func(f, params, body) => let
                      val (f', env) = cvtVar(env, f)
                      val (params', env) = cvtVars (env, params)
                      val body' = pruneUnreachableCode (simplifyBlock(env, body))
                      in
                        funcs := S.Func{f=f', params=params', body=body'} :: !funcs;
                        env
                      end
                  | AST.D_Strand info => (
                      strands := simplifyStrand(env, info) :: !strands;
                      env)
                  | AST.D_InitialArray(creat, iters) => (
                      setInitially (simplifyInit(env, true, creat, iters));
                      env)
                  | AST.D_InitialCollection(creat, iters) => (
                      setInitially (simplifyInit(env, false, creat, iters));
                      env)
                (* end case *))
          val env = List.foldl simplifyDecl VMap.empty decls
          in
            S.Program{
                props = props,
                inputDefaults = mkBlock (!inputInit),
                inputs = List.rev(!inputs),
                globals = List.rev(!globals),
                globalInit = mkBlock (!globalInit),
                funcs = List.rev(!funcs),
                init = (case !initially
(* FIXME: the check for the initially block should really happen in typechecking *)
                   of NONE => raise Fail "missing initially declaration"
                    | SOME blk => blk
                  (* end case *)),
                strands = List.rev(!strands)
              }
          end

    and simplifyInit (env, isArray, AST.C_Create(strand, exps), iters) = let
          fun simplifyIter (AST.I_Range(x, e1, e2), (env, iters, stms)) = let
                val (stms, lo) = simplifyExpToVar (env, e1, stms)
                val (stms, hi) = simplifyExpToVar (env, e2, stms)
                val (x', env) = cvtVar (env, x)
                in
                  (env, {param=x', lo=lo, hi=hi}::iters, stms)
                end
          val (env, iters, iterStms) = List.foldl simplifyIter (env, [], []) iters
          val (stms, xs) = simplifyExpsToVars (env, exps, [])
          val creat = S.C_Create{
                  argInit = mkBlock stms,
                  name = strand,
                  args = xs
                }
          in
            S.Initially{
                isArray = isArray,
                rangeInit = mkBlock iterStms,
                iters = List.rev iters,
                create = creat
              }
          end

    and simplifyStrand (env, AST.Strand{name, params, state, methods}) = let
          val (params', env) = cvtVars (env, params)
          fun simplifyState (env, [], xs, stms) = (List.rev xs, mkBlock stms, env)
            | simplifyState (env, AST.VD_Decl(x, e) :: r, xs, stms) = let
                val (stms, e') = simplifyExp (env, e, stms)
                val (x', env) = cvtVar(env, x)
                in
                  simplifyState (env, r, x'::xs, S.S_Assign(x', e') :: stms)
                end
          val (xs, stm, env) = simplifyState (env, state, [], [])
          in
            S.Strand{
                name = name,
                params = params',
                state = xs, stateInit = stm,
                methods = List.map (simplifyMethod env) methods
              }
          end

    and simplifyMethod env (AST.M_Method(name, body)) =
          S.Method(name, pruneUnreachableCode (simplifyBlock(env, body)))

  (* simplify a statement into a single statement (i.e., a block if it expands
   * into more than one new statement).
   *)
    and simplifyBlock (env, stm) = mkBlock (#1 (simplifyStmt (env, stm, [])))

  (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
   * statements.  This function returns a reverse-order list of simplified statements.
   * Note that error reporting is done in the typechecker, but it does not prune unreachable
   * code.
   *)
    and simplifyStmt (env, stm, stms) = (case stm
           of AST.S_Block body => let
                fun simplify (_, [], stms) = stms
                  | simplify (env', stm::r, stms) = let
                      val (stms, env') = simplifyStmt (env', stm, stms)
                      in
                        simplify (env', r, stms)
                      end
                in
                  (simplify (env, body, stms), env)
                end
            | AST.S_Decl(AST.VD_Decl(x, e)) => let
                val (stms, e') = simplifyExp (env, e, stms)
                val (x', env) = cvtVar(env, x)
                in
                  (S.S_Assign(x', e') :: stms, env)
                end
            | AST.S_IfThenElse(e, s1, s2) => let
                val (stms, x) = simplifyExpToVar (env, e, stms)
                val s1 = simplifyBlock (env, s1)
                val s2 = simplifyBlock (env, s2)
                in
                  (S.S_IfThenElse(x, s1, s2) :: stms, env)
                end
            | AST.S_Assign(x, e) => let
                val (stms, e') = simplifyExp (env, e, stms)
                in
                  (S.S_Assign(lookupVar(env, x), e') :: stms, env)
                end
            | AST.S_New(name, args) => let
                val (stms, xs) = simplifyExpsToVars (env, args, stms)
                in
                  (S.S_New(name, xs) :: stms, env)
                end
            | AST.S_Die => (S.S_Die :: stms, env)
            | AST.S_Stabilize => (S.S_Stabilize :: stms, env)
            | AST.S_Return e => let
                val (stms, x) = simplifyExpToVar (env, e, stms)
                in
                  (S.S_Return x :: stms, env)
                end
            | AST.S_Print args => let
                val (stms, xs) = simplifyExpsToVars (env, args, stms)
                in
                  (S.S_Print xs :: stms, env)
                end
          (* end case *))

    and simplifyExp (env, exp, stms) = (
          case exp
           of AST.E_Var x => (case Var.kindOf x
                 of Var.BasisVar => let
                      val ty = cvtTy(Var.monoTypeOf x)
                      val x' = newTemp ty
                      val stm = S.S_Assign(x', S.E_Prim(x, [], [], ty))
                      in
                        (stm::stms, S.E_Var x')
                      end
                  | _ => (stms, S.E_Var(lookupVar(env, x)))
                (* end case *))
            | AST.E_Lit lit => (stms, S.E_Lit lit)
            | AST.E_Tuple es => raise Fail "E_Tuple not yet implemented"
            | AST.E_Apply(f, tyArgs, args, ty) => let
                val (stms, xs) = simplifyExpsToVars (env, args, stms)
                in
                  case Var.kindOf f
                   of S.FunVar => (stms, S.E_Apply(lookupVar(env, f), xs, cvtTy ty))
                    | S.BasisVar => let
                        fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
                          | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
                          | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
                          | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
                        val tyArgs = List.map cvtTyArg tyArgs
                        in
                          (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
                        end
                    | _ => raise Fail "bogus application"
                  (* end case *)
                end
            | AST.E_Cons es => let
                val (stms, xs) = simplifyExpsToVars (env, es, stms)
                in
                  (stms, S.E_Cons xs)
                end
            | AST.E_Seq(es, ty) => let
                val (stms, xs) = simplifyExpsToVars (env, es, stms)
                in
                  (stms, S.E_Seq(xs, cvtTy ty))
                end
            | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
                val (stms, x) = simplifyExpToVar (env, e, stms)
                fun f ([], ys, stms) = (stms, List.rev ys)
                  | f (NONE::es, ys, stms) = f (es, NONE::ys, stms)
                  | f (SOME e::es, ys, stms) = let
                      val (stms, y) = simplifyExpToVar (env, e, stms)
                      in
                        f (es, SOME y::ys, stms)
                      end
                val (stms, indices) = f (indices, [], stms)
                in
                  (stms, S.E_Slice(x, indices, cvtTy ty))
                end
            | AST.E_Cond(e1, e2, e3, ty) => let
              (* a conditional expression gets turned into an if-then-else statememt *)
                val result = newTemp(cvtTy ty)
                val (stms, x) = simplifyExpToVar (env, e1, S.S_Var result :: stms)
                fun simplifyBranch e = let
                      val (stms, e) = simplifyExp (env, e, [])
                      in
                        mkBlock (S.S_Assign(result, e)::stms)
                      end
                val s1 = simplifyBranch e2
                val s2 = simplifyBranch e3
                in
                  (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
                end
            | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
                 of ty as SimpleTypes.T_DynSequence _ => (stms, S.E_LoadSeq(ty, nrrd))
                  | ty as SimpleTypes.T_Image{dim, shape} => (
                      case ImageInfo.fromNrrd(NrrdInfo.getInfo nrrd, dim, shape)
                       of NONE => raise Fail(concat[
                              "nrrd file \"", nrrd, "\" does not have expected type"
                            ])
                        | SOME info => (stms, S.E_LoadImage(ty, nrrd, info))
                      (* end case *))
                  | _ => raise Fail "bogus type for E_LoadNrrd"
                (* end case *))
            | AST.E_Coerce{srcTy, dstTy, e} => let
                val (stms, x) = simplifyExpToVar (env, e, stms)
                val dstTy = cvtTy dstTy
                val result = newTemp dstTy
                val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
                in
                  (S.S_Assign(result, rhs)::stms, S.E_Var result)
                end
          (* end case *))

    and simplifyExpToVar (env, exp, stms) = let
          val (stms, e) = simplifyExp (env, exp, stms)
          in
            case e
             of S.E_Var x => (stms, x)
              | _ => let
                  val x = newTemp (S.typeOf e)
                  in
                    (S.S_Assign(x, e)::stms, x)
                  end
            (* end case *)
          end

    and simplifyExpsToVars (env, exps, stms) = let
          fun f ([], xs, stms) = (stms, List.rev xs)
            | f (e::es, xs, stms) = let
                val (stms, x) = simplifyExpToVar (env, e, stms)
                in
                  f (es, x::xs, stms)
                end
          in
            f (exps, [], stms)
          end

    fun transform (errStrm, ast) = let
          val simple = simplifyProgram ast
          val _ = SimplePP.output (Log.logFile(), "simplify", simple)   (* DEBUG *)
          val simple = Inliner.transform simple
          val _ = SimplePP.output (Log.logFile(), "inlining", simple)   (* DEBUG *)
(*
          val simple = Lift.transform simple
                handle Eval.Error msg => (Error.error(errStrm, msg); simple)
          val _ = SimplePP.output (Log.logFile(), "lifting", simple)   (* DEBUG *)
*)
          in
            simple
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0