Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/lamont/src/compiler/simplify/simplify.sml
ViewVC logotype

View of /branches/lamont/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2871 - (download) (annotate)
Thu Feb 26 14:49:09 2015 UTC (4 years, 7 months ago) by lamonts
File size: 21308 byte(s)
Added variance back to compiler
(* simplify.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * Simplify the AST representation.
 *)

structure Simplify : sig

    val transform : Error.err_stream * AST.program -> Simple.program

  end = struct

    structure BV = BasisVars
    structure Ty = Types
    structure S = Simple
    structure InP = Inputs
    structure RLU = ReductionLiftUtil

  (* the SimpleAST and AST currently use the same type representation, but
   * we prune out meta variables.
   *)
    val cvtTy = TypeUtil.prune

    local 
      val tempName = Atom.atom "_t"
      val cnt = ref 0
      fun genName prefix = let
            val n = !cnt
            in
              cnt := n+1;
              String.concat[prefix, "_", Int.toString n]
            end
    in
    fun newTemp ty = Var.new (tempName, AST.LocalVar, cvtTy ty)
    end

  (* make a block out of a list of statements that are in reverse order *)
    fun mkBlock stms = S.Block(List.rev stms)

  (* convert an AST expression to an input initialization.  Note that the Diderot grammar
   * limits the forms of expression that we might encounter in this context.
   *)
    fun expToInit (ty, exp) = (case exp
           of AST.E_Lit(Literal.Int n) => InP.Int n
            | AST.E_Lit(Literal.Float f) => InP.Real f
            | AST.E_Lit(Literal.String s) => InP.String s
            | AST.E_Lit(Literal.Bool b) => InP.Bool b
            | AST.E_Tuple es => raise Fail "E_Tuple not yet implemented"
            | AST.E_Cons es => let
                val shp = (case ty
                       of Ty.T_Tensor(Ty.Shape shp) => List.map (fn (Ty.DimConst d) => d) shp
                        | _ => raise Fail "not tensor type"
                      (* end case *))
                fun flatten (AST.E_Lit(Literal.Int n), l) = FloatLit.fromInt n :: l
                  | flatten (AST.E_Lit(Literal.Float f), l) = f :: l
                  | flatten (AST.E_Coerce{e, ...}, l) = flatten(e, l)
                  | flatten (AST.E_Cons es, l) = flattenList (es, l)
                  | flatten _ = raise Fail "impossible"
                and flattenList ([], l) = l
                  | flattenList (x::xs, l) = flatten(x, flattenList(xs, l))
                in
                  InP.Tensor(shp, Vector.fromList(flattenList (es, [])))
                end
(*
            | AST.E_Seq es => ??
            | AST.E_Coerce{srcTy, dstTy, e} => ??
*)
            | _ => raise Fail "impossible initialization expression"
          (* end case *))

    fun inputImage (nrrd, dim, shape) = let
          val dim = TypeUtil.monoDim dim
          val shp = TypeUtil.monoShape shape
          in
            case ImageInfo.fromNrrd(NrrdInfo.getInfo nrrd, dim, shp)
             of NONE => raise Fail(concat["nrrd file \"", nrrd, "\" does not have expected type"])
              | SOME info => InP.Proxy(nrrd, info)
            (* end case *)
          end

  (* is the given statement's continuation the syntactically following statement? *)
    fun contIsNext (AST.S_Block stms) = List.all contIsNext stms
      | contIsNext (AST.S_IfThenElse(_, s1, s2)) = contIsNext s1 orelse contIsNext s2
      | contIsNext AST.S_Die = false
      | contIsNext AST.S_Stabilize = false
      | contIsNext (AST.S_Return _) = false
      | contIsNext _ = true

    fun simplifyProgram (AST.Program{props, decls}) = let
          val inputs = ref []
          val globals = ref []
          val globalBlock = ref []
          val globalInit = ref []
          val funcs = ref []
          val reductionBlock = ref []
          val initially = ref NONE
          val strands = ref []
          fun setInitially init = (case !initially
                 of NONE => initially := SOME init
(* FIXME: the check for multiple initially decls should happen in type checking *)
                  | SOME _ => raise Fail "multiple initially declarations"
                (* end case *))
          fun simplifyDecl dcl = (case dcl
                 of AST.D_Input(x, desc, NONE) => let
                      val (ty, init) = (case Var.monoTypeOf x
                             of ty as Ty.T_Image{dim, shape} => let
                                  val info = ImageInfo.mkInfo(TypeUtil.monoDim dim, TypeUtil.monoShape shape)
                                  in
                                    (ty, SOME(InP.Image info))
                                  end
                              | ty => (ty, NONE)
                            (* end case *))
                      val inp = InP.INP{
                              ty = ty,
                              name = Var.nameOf x,
                              desc = desc,
                              init = init
                            }
                      in
                        inputs := (x, inp) :: !inputs
                      end
                  | AST.D_Input(x, desc, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))) => let
                    (* load the nrrd proxy here *)
                      val info = NrrdInfo.getInfo nrrd
                      val (ty, init) = (case Var.monoTypeOf x
                             of ty as Ty.T_DynSequence _ => (ty, InP.DynSeq nrrd)
                              | ty as Ty.T_Image{dim, shape} => (ty, inputImage(nrrd, dim, shape))
                              | _ => raise Fail "impossible"
                            (* end case *))
                      val inp = InP.INP{
                              ty = ty,
                              name = Var.nameOf x,
                              desc = desc,
                              init = SOME init
                            }
                      in
                        inputs := (x, inp) :: !inputs
                      end
                  | AST.D_Input(x, desc, SOME e) => let
                      val ty = Var.monoTypeOf x
                      val inp = InP.INP{
                              ty = ty,
                              name = Var.nameOf x,
                              desc = desc,
                              init = SOME(expToInit(ty, e))
                            }
                      in
                        inputs := (x, inp) :: !inputs
                      end
                  | AST.D_Var(AST.VD_Decl(x, e)) => let
                      val (stms, e') = simplifyExp (e, [])
                      in
                        globals := x :: !globals;
                        globalInit := S.S_Assign(x, e') :: (stms @ !globalInit) 
                      end
                  | AST.D_Func(f, params, body) =>
                      funcs := S.Func{f=f, params=params, body=simplifyBlock body} :: !funcs
                  | AST.D_Strand info => strands := simplifyStrand info :: !strands
                  | AST.D_InitialArray(creat, iters) =>
                      setInitially (simplifyInit(true, creat, iters))
                  | AST.D_InitialCollection(creat, iters) =>
                      setInitially (simplifyInit(false, creat, iters))
                  | AST.D_Global (AST.S_Block body)  => let 
                      fun reduceCheck(stm,stms) =  let 
                          val (found,stm') = RLU.lift(stm) 
                        in 
                          if found 
                          then let 
                              val rBlock =  List.foldl (fn(stmt',stmts') => simplifyStmt(stmt',stmts')) (!reductionBlock) (RLU.getReduceStmts())
                              val _  = reductionBlock := rBlock 
                              in 
                                 simplifyStmt (stm', stms)
                              end 
                         else simplifyStmt (stm, stms) 
                        end                      
                      fun simplifyGlobal (stm, stms) = (case stm
                        (* this case lifts locally defined variables in the global block to "global" status *) 
                        of AST.S_Decl(AST.VD_Decl(x, e)) => let
                          val x' = Var.updateKind(x, AST.GlobalVar)
                          val (found,initExp) = RLU.initExp(e)
                          val (stms', e') = simplifyExp (initExp, [])
                          in
                            globals := x' :: !globals;
                            globalInit := S.S_Assign(x, e')::(stms' @ !globalInit);
                            if found 
                            then (print("<1>\n"); reduceCheck(AST.S_Assign(x',e),stms))
                            else (print("<2>\n"); stms) 
                          end
                        |  _ => reduceCheck(stm,stms))(* end case *) 
                      fun simplify ([], stms) = stms
                        | simplify (stm::r, stms) = if contIsNext stm
                                            then simplify (r, simplifyGlobal(stm, stms))
                                            else simplifyStmt (stm, stms)  (* prune unreachable statements *)
                      in 
                          globalBlock := simplify (body, []) 
                      end          
                (* end case *))
          in
            List.app simplifyDecl decls;
            
            S.Program{
		        props = props, 
                inputs = List.rev(!inputs),
                globals = List.rev( RLU.getReduceGlobals() @ !globals),
                globalInit = mkBlock (RLU.getReduceInit() @ !globalInit),
                funcs = List.rev(!funcs),
                globalBlock = mkBlock(!globalBlock),
                reductionBlock = (mkBlock(!reductionBlock),RLU.getPhases()), 
                init = (case !initially
(* FIXME: the check for the initially block should really happen in typechecking *)
                   of NONE => raise Fail "missing initially declaration"
                    | SOME blk => blk
                  (* end case *)),
                strands = List.rev(!strands)
              }
          end

    and simplifyInit (isArray, AST.C_Create(strand, exps), iters) = let
          val (stms, xs) = simplifyExpsToVars (exps, [])
          val creat = S.C_Create{
                  argInit = mkBlock stms,
                  name = strand,
                  args = xs
                }
          fun simplifyIter (AST.I_Range(x, e1, e2), (iters, stms)) = let
                val (stms, lo) = simplifyExpToVar (e1, stms)
                val (stms, hi) = simplifyExpToVar (e2, stms)
                in
                  ({param=x, lo=lo, hi=hi}::iters, stms)
                end
          val (iters, stms) = List.foldl simplifyIter ([], []) iters
          in
            S.Initially{
                isArray = isArray,
                rangeInit = mkBlock stms,
                iters = List.rev iters,
                create = creat
              }
          end

    and simplifyStrand (AST.Strand{name, params, state, methods}) = let
          fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
            | simplifyState (AST.VD_Decl(x, e) :: r, xs, stms) = let
                val (stms, e') = simplifyExp (e, stms)
                in
                  simplifyState (r, x::xs, S.S_Assign(x, e') :: stms)
                end
          val (xs, stm) = simplifyState (state, [], [])
          in
            S.Strand{
                name = name,
                params = params,
                state = xs, stateInit = stm,
                methods = List.map simplifyMethod methods
              }
          end

    and simplifyMethod (AST.M_Method(name, body)) =
          S.Method(name, simplifyBlock body)

  (* simplify a statement into a single statement (i.e., a block if it expands
   * into more than one new statement).
   *)
    and simplifyBlock stm = mkBlock (simplifyStmt (stm, []))

  (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
   * statements.  This function returns a reverse-order list of simplified statements.
   * Note that error reporting is done in the typechecker, but it does not prune unreachable
   * code.
   *)
    and simplifyStmt (stm, stms) = (case stm
           of AST.S_Block body => let
                fun simplify ([], stms) = stms
                  | simplify (stm::r, stms) = if contIsNext stm
                      then simplify (r, simplifyStmt (stm, stms))
                      else simplifyStmt (stm, stms)  (* prune unreachable statements *)
                in
                  simplify (body, stms)
                end
            | AST.S_Decl(AST.VD_Decl(x, e)) => let
                val (stms, e') = simplifyExp (e, stms)
                in
                  S.S_Assign(x, e') :: stms
                end
            | AST.S_Foreach(x, ty, e, s) => let
                val (stms, e') = simplifyExp(e,stms)
                val iteratorVar = newTemp Ty.T_Int
                val dynSeqVar = newTemp (Ty.T_DynSequence ty)
                val addOneVar = newTemp Ty.T_Int 
                val dynSeqSub = S.E_Apply(BV.dynSubscript, [Ty.TYPE(MetaVar.newFromType ty)], [dynSeqVar, iteratorVar], ty)
                val iterExp = S.E_Lit(Literal.Int ~1)
                val iterOneExp = S.E_Lit(Literal.Int 1)
                val iterIncExp = S.E_Apply(BV.add_ii, [], [iteratorVar,addOneVar], Ty.T_Int)
                val blk as S.Block(t) = simplifyBlock(s)
                in 
                  S.S_Foreach(
                      dynSeqVar, e',
                      S.Block(
                        S.S_Assign(iteratorVar, iterIncExp) ::
                        S.S_Assign(x,dynSeqSub)::t)
                    ) ::
                  S.S_Assign(dynSeqVar,e') ::
                  S.S_Assign(iteratorVar,iterExp) ::
                  S.S_Assign(addOneVar,iterOneExp) :: stms
                end 
            | AST.S_IfThenElse(e, s1, s2) => let
                val (stms, x) = simplifyExpToVar (e, stms)
                val s1 = simplifyBlock s1
                val s2 = simplifyBlock s2
                in
                  S.S_IfThenElse(x, s1, s2) :: stms
                end
            | AST.S_Assign(x, e) => let
                  val (stms, e') = simplifyExp (e, stms)
                  in
                    S.S_Assign(x, e') :: stms
                  end
            | AST.S_New(name, args) => let
                val (stms, xs) = simplifyExpsToVars (args, stms)
                in
                  S.S_New(name, xs) :: stms
                end
            | AST.S_Die => S.S_Die :: stms
            | AST.S_Stabilize => S.S_Stabilize :: stms
            | AST.S_Return e => let
                val (stms, x) = simplifyExpToVar (e, stms)
                in
                  S.S_Return x :: stms
                end
            | AST.S_Print args => let
                val (stms, xs) = simplifyExpsToVars (args, stms)
                in
                  S.S_Print xs :: stms
                end
          (* end case *))
    and simplifyExp (exp, stms) = (
          case exp
           of AST.E_Var x => (case Var.kindOf x
                 of Var.BasisVar => let
                      val ty = Var.monoTypeOf x
                      val x' = newTemp ty
                      val stm = S.S_Assign(x', S.E_Apply(x, [], [], ty))
                      in
                        (stm::stms, S.E_Var x')
                      end
                  | _ => (stms, S.E_Var x)
                (* end case *))
            | AST.E_Lit lit => (stms, S.E_Lit lit)
            | AST.E_Selector(x, f, ty) => let 
                val (stms,x') = simplifyExpToVar(x, stms) 
                in 
                  (stms, S.E_Selector(x', f, ty))
                end 
            | AST.E_Reduction(rVar, e, setVar,kind)  => let 
                val (stms,kind') = simplifyExpToVar(AST.E_Lit(Literal.String (Atom.toString kind)), S.S_Assign(setVar,S.E_Lit(Literal.Int 0))::stms)
                val rVarTy as Ty.T_Fun(argTy::args,rRetTy) =  Var.monoTypeOf rVar      
                val (argsTy,_) = Var.typeOf rVar
                fun isVariance() = case (Var.nameOf rVar) 
                  of "variance" =>  let 
                    val meanBasisVar = newTemp (Var.monoTypeOf rVar) 
                    val meanLocalVar = newTemp rRetTy 
                    val (stms', e') = simplifyExpToVar(e,[])
                    val meanStm = S.S_Assign(meanLocalVar,S.E_Apply(BasisVars.fn_rMean, argsTy, [e',setVar,kind'], rRetTy))
                    val subVar = BasisVars.sub_tt
                    val (stms, e') = simplifyExpToVar(AST.E_Apply(subVar,argsTy,[e,AST.E_Var(meanLocalVar)],rRetTy),(meanStm::stms')@ stms) 
                  in
                    (stms,S.E_Apply(BasisVars.fn_rMean, argsTy, [e',setVar,kind'], rRetTy))
                  end
                   | _ => let
                      val (stms, e') = simplifyExpToVar(e,stms)
                   in
                      (stms,S.E_Apply(rVar, argsTy, [e',setVar,kind'], rRetTy))
                   end                
                in 
                   isVariance()
                end  
            | AST.E_Tuple es => raise Fail "E_Tuple not yet implemented"
            | AST.E_Apply(f, tyArgs, args, ty) => let
                val (stms, xs) = simplifyExpsToVars (args, stms)
                in
                  (stms, S.E_Apply(f, tyArgs, xs, ty))
                end
            | AST.E_Cons es => let
                val (stms, xs) = simplifyExpsToVars (es, stms)
                in
                  (stms, S.E_Cons xs)
                end
            | AST.E_Seq es => let
                val (stms, xs) = simplifyExpsToVars (es, stms)
                in
                  (stms, S.E_Seq xs)
                end
            | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
                val (stms, x) = simplifyExpToVar (e, stms)
                fun f ([], ys, stms) = (stms, List.rev ys)
                  | f (NONE::es, ys, stms) = f (es, NONE::ys, stms)
                  | f (SOME e::es, ys, stms) = let
                      val (stms, y) = simplifyExpToVar (e, stms)
                      in
                        f (es, SOME y::ys, stms)
                      end
                val (stms, indices) = f (indices, [], stms)
                in
                  (stms, S.E_Slice(x, indices, ty))
                end
            | AST.E_Cond(e1, e2, e3, ty) => let
              (* a conditional expression gets turned into an if-then-else statememt *)
                val result = newTemp ty
                val (stms, x) = simplifyExpToVar (e1, S.S_Var result :: stms)
                fun simplifyBranch e = let
                      val (stms, e) = simplifyExp (e, [])
                      in
                        mkBlock (S.S_Assign(result, e)::stms)
                      end
                val s1 = simplifyBranch e2
                val s2 = simplifyBranch e3
                in
                  (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
                end
            | AST.E_LoadNrrd(_, nrrd, ty) => (case TypeUtil.prune ty
                 of ty as Ty.T_DynSequence _ => (stms, S.E_LoadSeq(ty, nrrd))
                  | ty as Ty.T_Image{dim, shape} => let
                      val dim = TypeUtil.monoDim dim
                      val shp = TypeUtil.monoShape shape
                      in
                        case ImageInfo.fromNrrd(NrrdInfo.getInfo nrrd, dim, shp)
                         of NONE => raise Fail(concat[
                                "nrrd file \"", nrrd, "\" does not have expected type"
                              ])
                          | SOME info => (stms, S.E_LoadImage(ty, nrrd, info))
                        (* end case *)
                      end
                  | _ => raise Fail "bogus type for E_LoadNrrd"
                (* end case *))
            | AST.E_Coerce{srcTy, dstTy, e} => let
                val (stms, x) = simplifyExpToVar (e, stms)
                val result = newTemp dstTy
                val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = cvtTy dstTy, x = x}
                in
                  (S.S_Assign(result, rhs)::stms, S.E_Var result)
                end
          (* end case *))

    and simplifyExpToVar (exp, stms) = let
          val (stms, e) = simplifyExp (exp, stms)
          in
            case e
             of S.E_Var x => (stms, x)
              | _ => let
                  val x = newTemp (S.typeOf e)
                  in
                    (S.S_Assign(x, e)::stms, x)
                  end
            (* end case *)
          end

    and simplifyExpsToVars (exps, stms) = let
          fun f ([], xs, stms) = (stms, List.rev xs)
            | f (e::es, xs, stms) = let
                val (stms, x) = simplifyExpToVar (e, stms)
                in
                  f (es, x::xs, stms)
                end
          in
            f (exps, [], stms)
          end

    fun transform (errStrm, ast) = let
          val simple = simplifyProgram ast
          val _ = SimplePP.output (Log.logFile(), "simplify", simple)   (* DEBUG *)
          val simple = Inliner.transform simple
          val _ = SimplePP.output (Log.logFile(), "inlining", simple)   (* DEBUG *)
(*
          val simple = Lift.transform simple
                handle Eval.Error msg => (Error.error(errStrm, msg); simple)
*)
          in
            simple
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0