Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3465, Sun Nov 29 20:04:16 2015 UTC revision 4130, Fri Jul 1 12:11:34 2016 UTC
# Line 5  Line 5 
5   * COPYRIGHT (c) 2015 The University of Chicago   * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *   *
8   * Simplify the AST representation.   * Simplify the AST representation.  This phase involves the following transformations:
9     *
10     *      - types are simplified by removing meta variables (which will have been resolved)
11     *
12     *      - expressions are simplified to involve a single operation on variables
13     *
14     *      - global reductions are converted to MapReduce statements
15     *
16     *      - other comprehensions and reductions are converted to foreach loops
17     *
18     *      - unreachable code is pruned
19     *
20     *      - negation of literal integers and reals are constant folded
21   *)   *)
22    
23  structure Simplify : sig  structure Simplify : sig
# Line 31  Line 43 
43              | Ty.T_String => STy.T_String              | Ty.T_String => STy.T_String
44              | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)              | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)
45              | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))              | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))
46              | Ty.T_Named id => STy.T_Named id              | Ty.T_Strand id => STy.T_Strand id
47              | Ty.T_Kernel n => STy.T_Kernel(TU.monoDiff n)              | Ty.T_Kernel n => STy.T_Kernel(TU.monoDiff n)
48              | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)              | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)
49              | Ty.T_Image{dim, shape} => STy.T_Image{              | Ty.T_Image{dim, shape} => STy.T_Image{
# Line 47  Line 59 
59              | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"              | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"
60            (* end case *))            (* end case *))
61    
62      fun newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)      fun apiTypeOf x = let
63              fun cvtTy STy.T_Bool = APITypes.BoolTy
64                | cvtTy STy.T_Int = APITypes.IntTy
65                | cvtTy STy.T_String = APITypes.StringTy
66                | cvtTy (STy.T_Sequence(ty, len)) = APITypes.SeqTy(cvtTy ty, len)
67                | cvtTy (STy.T_Tensor shape) = APITypes.TensorTy shape
68                | cvtTy (STy.T_Image{dim, shape}) = APITypes.ImageTy(dim, shape)
69                | cvtTy ty = raise Fail "bogus API type"
70              in
71                cvtTy (SimpleVar.typeOf x)
72              end
73    
74        fun newTemp (ty as STy.T_Image _) = SimpleVar.new ("img", SimpleVar.LocalVar, ty)
75          | newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)
76    
77    (* a property to map AST variables to SimpleAST variables *)    (* a property to map AST variables to SimpleAST variables *)
78      local      local
# Line 59  Line 84 
84      fun cvtVars xs = List.map cvtVar xs      fun cvtVars xs = List.map cvtVar xs
85    
86    (* make a block out of a list of statements that are in reverse order *)    (* make a block out of a list of statements that are in reverse order *)
87      fun mkBlock stms = S.Block(List.rev stms)      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
   
     fun inputImage (errStrm, nrrd, dim, shape) = (  
           case ImageInfo.fromNrrd(NrrdInfo.getInfo(errStrm, nrrd), dim, shape)  
            of NONE => raise Fail(concat["nrrd file \"", nrrd, "\" does not have expected type"])  
             | SOME info => S.Proxy(nrrd, info)  
           (* end case *))  
   
     datatype 'a ctl_flow_info  
       = EXIT                    (* stm sequence always exits; no pruning so far *)  
       | PRUNE of 'a             (* stm sequence always exits at last stm in argument, which  
                                  * is either a block or stm list *)  
       | CONT                    (* stm sequence falls through *)  
       | EDIT of 'a              (* pruned code that has non-exiting paths *)  
   
     fun pruneUnreachableCode (blk as S.Block stms) = let  
           fun isExit S.S_Die = true  
             | isExit S.S_Stabilize = true  
             | isExit (S.S_Return _) = true  
             | isExit _ = false  
           fun pruneStms [] = CONT  
             | pruneStms [S.S_IfThenElse(x, blk1, blk2)] = (  
                 case pruneIf(x, blk1, blk2)  
                  of EXIT => EXIT  
                   | PRUNE stm => PRUNE[stm]  
                   | CONT => CONT  
                   | EDIT stm => EDIT[stm]  
                 (* end case *))  
             | pruneStms [stm] = if isExit stm then EXIT else CONT  
             | pruneStms ((stm as S.S_IfThenElse(x, blk1, blk2))::stms) = (  
                 case pruneIf(x, blk1, blk2)  
                  of EXIT => PRUNE[stm]  
                   | PRUNE stm => PRUNE[stm]  
                   | CONT => (case pruneStms stms  
                        of PRUNE stms => PRUNE(stm::stms)  
                         | EDIT stms => EDIT(stm::stms)  
                         | EXIT => EXIT (* different instances of ctl_flow_info *)  
                         | CONT => CONT  
                       (* end case *))  
                   | EDIT stm => (case pruneStms stms  
                        of PRUNE stms => PRUNE(stm::stms)  
                         | EDIT stms => EDIT(stm::stms)  
                         | _ => EDIT(stm::stms)  
                       (* end case *))  
                 (* end case *))  
             | pruneStms (stm::stms) = if isExit stm  
                 then PRUNE[stm]  
                 else (case pruneStms stms  
                    of PRUNE stms => PRUNE(stm::stms)  
                     | EDIT stms => EDIT(stm::stms)  
                     | info => info  
                   (* end case *))  
           and pruneIf (x, blk1, blk2) = (case (pruneBlk blk1, pruneBlk blk2)  
                  of (EXIT,       EXIT      ) => EXIT  
                   | (CONT,       CONT      ) => CONT  
                   | (CONT,       EXIT      ) => CONT  
                   | (EXIT,       CONT      ) => CONT  
                   | (CONT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (CONT,       PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EXIT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  EXIT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EXIT,       PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, EXIT      ) => PRUNE(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))  
                 (* end case *))  
           and pruneBlk (S.Block stms) = (case pruneStms stms  
                  of PRUNE stms => PRUNE(S.Block stms)  
                   | EDIT stms => EDIT(S.Block stms)  
                   | EXIT => EXIT (* different instances of ctl_flow_info *)  
                   | CONT => CONT  
                 (* end case *))  
           in  
             case pruneBlk blk  
              of PRUNE blk => blk  
               | EDIT blk => blk  
               | _=> blk  
             (* end case *)  
           end  
88    
89    (* simplify a statement into a single statement (i.e., a block if it expands    (* simplify a statement into a single statement (i.e., a block if it expands
90     * into more than one new statement).     * into more than one new statement).
91     *)     *)
92      fun simplifyBlock errStrm stm = mkBlock (simplifyStmt (errStrm, stm, []))      fun simplifyBlock (errStrm, stm) = mkBlock (simplifyStmt (errStrm, stm, []))
93    
94    (* simplify the statement stm where stms is a reverse-order list of preceeding simplified    (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
95     * statements.  This function returns a reverse-order list of simplified statements.     * statements.  This function returns a reverse-order list of simplified statements.
# Line 173  Line 116 
116                  end                  end
117              | AST.S_IfThenElse(e, s1, s2) => let              | AST.S_IfThenElse(e, s1, s2) => let
118                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)
119                  val s1 = simplifyBlock errStrm s1                  val s1 = simplifyBlock (errStrm, s1)
120                  val s2 = simplifyBlock errStrm s2                  val s2 = simplifyBlock (errStrm, s2)
121                  in                  in
122                    S.S_IfThenElse(x, s1, s2) :: stms                    S.S_IfThenElse(x, s1, s2) :: stms
123                  end                  end
124              | AST.S_Foreach((x, e), body) => let              | AST.S_Foreach((x, e), body) => let
125                  val (stms, xs') = simplifyExpToVar (errStrm, e, stms)                  val (stms, xs') = simplifyExpToVar (errStrm, e, stms)
126                  val body' = simplifyBlock errStrm body                  val body' = simplifyBlock (errStrm, body)
127                  in                  in
128                    S.S_Foreach(cvtVar x, xs', body') :: stms                    S.S_Foreach(cvtVar x, xs', body') :: stms
129                  end                  end
# Line 239  Line 182 
182                      | _ => (stms, S.E_Var(cvtVar x))                      | _ => (stms, S.E_Var(cvtVar x))
183                    (* end case *))                    (* end case *))
184                | AST.E_Lit lit => (stms, S.E_Lit lit)                | AST.E_Lit lit => (stms, S.E_Lit lit)
185                  | AST.E_Kernel h => (stms, S.E_Kernel h)
186                | AST.E_Select(e, (fld, _)) => let                | AST.E_Select(e, (fld, _)) => let
187                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)
188                    in                    in
# Line 311  Line 255 
255                    in                    in
256                      (foreachStm :: initStm :: stms, S.E_Var acc)                      (foreachStm :: initStm :: stms, S.E_Var acc)
257                    end                    end
258                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME"                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"
259                | AST.E_Tensor(es, ty) => let                | AST.E_Tensor(es, ty) => let
260                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)
261                    in                    in
# Line 324  Line 268 
268                    end                    end
269                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
270                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)
271                    fun f ([], ys, stms) = (stms, List.rev ys)                    fun f NONE = NONE
272                      | f (NONE::es, ys, stms) = f (es, NONE::ys, stms)                      | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
273                      | f (SOME e::es, ys, stms) = let                      | f _ = raise Fail "expected integer literal in slice"
274                          val (stms, y) = simplifyExpToVar (errStrm, e, stms)                    val indices = List.map f indices
                         in  
                           f (es, SOME y::ys, stms)  
                         end  
                   val (stms, indices) = f (indices, [], stms)  
275                    in                    in
276                      (stms, S.E_Slice(x, indices, cvtTy ty))                      (stms, S.E_Slice(x, indices, cvtTy ty))
277                    end                    end
# Line 352  Line 292 
292                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
293                     of ty as SimpleTypes.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))                     of ty as SimpleTypes.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))
294                      | ty as SimpleTypes.T_Image{dim, shape} => (                      | ty as SimpleTypes.T_Image{dim, shape} => (
295                          case ImageInfo.fromNrrd(NrrdInfo.getInfo(errStrm, nrrd), dim, shape)                          case NrrdInfo.getInfo (errStrm, nrrd)
296                           of NONE => raise Fail(concat[                           of SOME info => (case ImageInfo.fromNrrd(info, dim, shape)
297                                   of NONE => (
298                                        Error.error (errStrm, [
299                                  "nrrd file \"", nrrd, "\" does not have expected type"                                  "nrrd file \"", nrrd, "\" does not have expected type"
300                                ])                                        ]);
301                                        (stms, S.E_LoadImage(ty, nrrd, ImageInfo.mkInfo(dim, shape))))
302                            | SOME info => (stms, S.E_LoadImage(ty, nrrd, info))                            | SOME info => (stms, S.E_LoadImage(ty, nrrd, info))
303                          (* end case *))                          (* end case *))
304                              | NONE => (
305                                  Error.warning (errStrm, [
306                                      "nrrd file \"", nrrd, "\" does not exist"
307                                    ]);
308                                  (stms, S.E_LoadImage(ty, nrrd, ImageInfo.mkInfo(dim, shape))))
309                            (* end case *))
310                      | _ => raise Fail "bogus type for E_LoadNrrd"                      | _ => raise Fail "bogus type for E_LoadNrrd"
311                    (* end case *))                    (* end case *))
312                  | AST.E_Coerce{dstTy, e=AST.E_Lit(Literal.Int n), ...} => (case cvtTy dstTy
313                       of SimpleTypes.T_Tensor[] => (stms, S.E_Lit(Literal.Real(RealLit.fromInt n)))
314                        | _ => raise Fail "impossible: bad coercion"
315                      (* end case *))
316                | AST.E_Coerce{srcTy, dstTy, e} => let                | AST.E_Coerce{srcTy, dstTy, e} => let
317                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)
318                    val dstTy = cvtTy dstTy                    val dstTy = cvtTy dstTy
# Line 395  Line 348 
348              f (exps, [], stms)              f (exps, [], stms)
349            end            end
350    
351      fun simplifyStrand (errStrm, AST.Strand{name, params, state, initM, updateM, stabilizeM}) = let    (* simplify a block and then prune unreachable and dead code *)
352        fun simplifyAndPruneBlock errStrm blk =
353              DeadCode.eliminate (simplifyBlock (errStrm, blk))
354    
355        fun simplifyStrand (errStrm, strand) = let
356              val AST.Strand{name, params, state, stateInit, initM, updateM, stabilizeM} = strand
357            val params' = cvtVars params            val params' = cvtVars params
358            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
359              | simplifyState ((x, optE) :: r, xs, stms) = let              | simplifyState ((x, optE) :: r, xs, stms) = let
# Line 406  Line 364 
364                      | SOME e => let                      | SOME e => let
365                          val (stms, e') = simplifyExp (errStrm, e, stms)                          val (stms, e') = simplifyExp (errStrm, e, stms)
366                          in                          in
367                            simplifyState (r, x'::xs, S.S_Var(x', SOME e') :: stms)                            simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
368                          end                          end
369                    (* end case *)                    (* end case *)
370                  end                  end
# Line 417  Line 375 
375                  params = params',                  params = params',
376                  state = xs,                  state = xs,
377                  stateInit = stm,                  stateInit = stm,
378                  initM = Option.map (simplifyBlock errStrm) initM,                  initM = Option.map (simplifyAndPruneBlock errStrm) initM,
379                  updateM = simplifyBlock errStrm updateM,                  updateM = simplifyAndPruneBlock errStrm updateM,
380                  stabilizeM = Option.map (simplifyBlock errStrm) stabilizeM                  stabilizeM = Option.map (simplifyAndPruneBlock errStrm) stabilizeM
381                }                }
382            end            end
383    
     fun simplifyCreate (errStrm, AST.C_Grid(dim, stm)) = S.C_Grid(dim, simplifyBlock errStrm stm)  
       | simplifyCreate (errStrm, AST.C_Collection stm) = S.C_Collection(simplifyBlock errStrm stm)  
   
384      fun transform (errStrm, prog) = let      fun transform (errStrm, prog) = let
385            val AST.Program{            val AST.Program{
386                    props, const_dcls, input_dcls, globals, init, strand, create, update                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
387                  } = prog                  } = prog
388            val consts' = ref[]            val consts' = ref[]
389            val constInit = ref[]            val constInit = ref[]
# Line 446  Line 401 
401            fun simplifyInputDcl ((x, NONE), desc) = let            fun simplifyInputDcl ((x, NONE), desc) = let
402                  val x' = cvtVar x                  val x' = cvtVar x
403                  val init = (case SimpleVar.typeOf x'                  val init = (case SimpleVar.typeOf x'
404                         of SimpleTypes.T_Image{dim, shape} => let                         of STy.T_Image{dim, shape} => let
405                              val info = ImageInfo.mkInfo(dim, shape)                              val info = ImageInfo.mkInfo(dim, shape)
406                              in                              in
407                                S.Image info                                S.Image info
# Line 455  Line 410 
410                        (* end case *))                        (* end case *))
411                  val inp = S.INP{                  val inp = S.INP{
412                          var = x',                          var = x',
413                            name = Var.nameOf x,
414                            ty =  apiTypeOf x',
415                          desc = desc,                          desc = desc,
416                          init = init                          init = init
417                        }                        }
# Line 463  Line 420 
420                  end                  end
421              | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let              | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let
422                  val x' = cvtVar x                  val x' = cvtVar x
               (* load the nrrd proxy here *)  
                 val info = NrrdInfo.getInfo (errStrm, nrrd)  
423                  val init = (case SimpleVar.typeOf x'                  val init = (case SimpleVar.typeOf x'
424                         of SimpleTypes.T_Sequence(_, NONE) => S.LoadSeq nrrd                         of SimpleTypes.T_Sequence(_, NONE) => S.LoadSeq nrrd
425                          | SimpleTypes.T_Image{dim, shape} => inputImage(errStrm, nrrd, dim, shape)                          | SimpleTypes.T_Image{dim, shape} => (
426                                case NrrdInfo.getInfo (errStrm, nrrd)
427                                 of SOME info => (case ImageInfo.fromNrrd(info, dim, shape)
428                                       of NONE => (
429                                            Error.error (errStrm, [
430                                                "proxy nrrd file \"", nrrd,
431                                                "\" does not have expected type"
432                                              ]);
433                                            S.Image(ImageInfo.mkInfo(dim, shape)))
434                                        | SOME info => S.Proxy(nrrd, info)
435                                      (* end case *))
436                                  | NONE => (
437                                      Error.warning (errStrm, [
438                                          "proxy nrrd file \"", nrrd, "\" does not exist"
439                                        ]);
440                                      S.Image(ImageInfo.mkInfo(dim, shape)))
441                                (* end case *))
442                          | _ => raise Fail "impossible"                          | _ => raise Fail "impossible"
443                        (* end case *))                        (* end case *))
444                  val inp = S.INP{                  val inp = S.INP{
445                          var = x',                          var = x',
446                            name = Var.nameOf x,
447                            ty = apiTypeOf x',
448                          desc = desc,                          desc = desc,
449                          init = init                          init = init
450                        }                        }
# Line 483  Line 456 
456                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (errStrm, e, [])
457                  val inp = S.INP{                  val inp = S.INP{
458                          var = x',                          var = x',
459                            name = Var.nameOf x,
460                            ty = apiTypeOf x',
461                          desc = desc,                          desc = desc,
462                          init = S.ConstExpr                          init = S.ConstExpr
463                        }                        }
# Line 506  Line 481 
481              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
482                  val f' = cvtVar f                  val f' = cvtVar f
483                  val params' = cvtVars params                  val params' = cvtVars params
484                  val body' = pruneUnreachableCode (simplifyBlock errStrm body)                  val body' = simplifyAndPruneBlock errStrm body
485                  in                  in
486                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs
487                  end                  end
488            in            val () = (
489              List.app simplifyConstDcl const_dcls;              List.app simplifyConstDcl const_dcls;
490              List.app simplifyInputDcl input_dcls;              List.app simplifyInputDcl input_dcls;
491              List.app simplifyGlobalDcl globals;                  List.app simplifyGlobalDcl globals)
492            (* make the global-initialization block *)
493              val globInit = (case globInit
494                     of SOME stm => mkBlock (simplifyStmt (errStrm, stm, !globalInit))
495                      | NONE => mkBlock (!globalInit)
496                    (* end case *))
497            (* if the globInit block is non-empty, record the fact in the property list *)
498              val props = (case globInit
499                     of S.Block{code=[], ...} => props
500                      | _ => Properties.GlobalInit :: props
501                    (* end case *))
502              in
503              S.Program{              S.Program{
504                  props = props,                  props = props,
505                  consts = List.rev(!consts'),                  consts = List.rev(!consts'),
506                  inputs = List.rev(!inputs'),                  inputs = List.rev(!inputs'),
507                  constInit = mkBlock (!constInit),                  constInit = mkBlock (!constInit),
508                  globals = List.rev(!globals'),                  globals = List.rev(!globals'),
509                  init = mkBlock (!globalInit),                  globInit = globInit,
510                  funcs = List.rev(!funcs),                  funcs = List.rev(!funcs),
511                  strand = simplifyStrand (errStrm, strand),                  strand = simplifyStrand (errStrm, strand),
512                  create = simplifyCreate (errStrm, create),                  create = Create.map (simplifyAndPruneBlock errStrm) create,
513                  update = Option.map (simplifyBlock errStrm) update                  init = Option.map (simplifyAndPruneBlock errStrm) init,
514                    update = Option.map (simplifyAndPruneBlock errStrm) update
515                }                }
516            end            end
517    

Legend:
Removed from v.3465  
changed lines
  Added in v.4130

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0