Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3465, Sun Nov 29 20:04:16 2015 UTC revision 4359, Thu Aug 4 01:30:23 2016 UTC
# Line 5  Line 5 
5   * COPYRIGHT (c) 2015 The University of Chicago   * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *   *
8   * Simplify the AST representation.   * Simplify the AST representation.  This phase involves the following transformations:
9     *
10     *      - types are simplified by removing meta variables (which will have been resolved)
11     *
12     *      - expressions are simplified to involve a single operation on variables
13     *
14     *      - global reductions are converted to MapReduce statements
15     *
16     *      - other comprehensions and reductions are converted to foreach loops
17     *
18     *      - unreachable code is pruned
19     *
20     *      - negation of literal integers and reals are constant folded
21   *)   *)
22    
23  structure Simplify : sig  structure Simplify : sig
# Line 19  Line 31 
31      structure STy = SimpleTypes      structure STy = SimpleTypes
32      structure Ty = Types      structure Ty = Types
33      structure VMap = Var.Map      structure VMap = Var.Map
34        structure II = ImageInfo
35        structure BV = BasisVars
36    
37    (* convert a Types.ty to a SimpleTypes.ty *)    (* convert a Types.ty to a SimpleTypes.ty *)
38      fun cvtTy ty = (case ty      fun cvtTy ty = (case ty
# Line 31  Line 45 
45              | Ty.T_String => STy.T_String              | Ty.T_String => STy.T_String
46              | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)              | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)
47              | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))              | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))
48              | Ty.T_Named id => STy.T_Named id              | Ty.T_Strand id => STy.T_Strand id
49              | Ty.T_Kernel n => STy.T_Kernel(TU.monoDiff n)              | Ty.T_Kernel _ => STy.T_Kernel
50              | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)              | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)
51              | Ty.T_Image{dim, shape} => STy.T_Image{              | Ty.T_Image{dim, shape} =>
52                    dim = TU.monoDim dim,                  STy.T_Image(II.mkInfo(TU.monoDim dim, TU.monoShape shape))
                   shape = TU.monoShape shape  
                 }  
53              | Ty.T_Field{diff, dim, shape} => STy.T_Field{              | Ty.T_Field{diff, dim, shape} => STy.T_Field{
54                    diff = TU.monoDiff diff,                    diff = TU.monoDiff diff,
55                    dim = TU.monoDim dim,                    dim = TU.monoDim dim,
56                    shape = TU.monoShape shape                    shape = TU.monoShape shape
57                  }                  }
58              | Ty.T_Fun(tys1, ty2) => STy.T_Fun(List.map cvtTy tys1, cvtTy ty2)              | Ty.T_Fun(tys1, ty2) => raise Fail "unexpected T_Fun in Simplify"
59              | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"              | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"
60            (* end case *))            (* end case *))
61    
62      fun newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)      fun apiTypeOf x = let
63              fun cvtTy STy.T_Bool = APITypes.BoolTy
64                | cvtTy STy.T_Int = APITypes.IntTy
65                | cvtTy STy.T_String = APITypes.StringTy
66                | cvtTy (STy.T_Sequence(ty, len)) = APITypes.SeqTy(cvtTy ty, len)
67                | cvtTy (STy.T_Tensor shape) = APITypes.TensorTy shape
68                | cvtTy (STy.T_Image info) =
69                    APITypes.ImageTy(II.dim info, II.voxelShape info)
70                | cvtTy ty = raise Fail "bogus API type"
71              in
72                cvtTy (SimpleVar.typeOf x)
73              end
74    
75        fun newTemp (ty as STy.T_Image _) = SimpleVar.new ("img", SimpleVar.LocalVar, ty)
76          | newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)
77    
78      (* a property to map AST function variables to SimpleAST functions *)
79        local
80          fun cvt x = let
81                val Ty.T_Fun(paramTys, resTy) = Var.monoTypeOf x
82                in
83                  SimpleFunc.new (Var.nameOf x, cvtTy resTy, List.map cvtTy paramTys)
84                end
85        in
86        val {getFn = cvtFunc, ...} = Var.newProp cvt
87        end
88    
89    (* a property to map AST variables to SimpleAST variables *)    (* a property to map AST variables to SimpleAST variables *)
90      local      local
91        fun cvt x = SimpleVar.new (Var.nameOf x, Var.kindOf x, cvtTy(Var.monoTypeOf x))        fun cvt x = SimpleVar.new (Var.nameOf x, Var.kindOf x, cvtTy(Var.monoTypeOf x))
92          val {getFn, setFn, ...} = Var.newProp cvt
93      in      in
94      val {getFn = cvtVar, ...} = Var.newProp cvt      val cvtVar = getFn
95        fun newVarWithType (x, ty) = let
96              val x' = SimpleVar.new (Var.nameOf x, Var.kindOf x, ty)
97              in
98                setFn (x, x');
99                x'
100              end
101      end      end
102    
103      fun cvtVars xs = List.map cvtVar xs      fun cvtVars xs = List.map cvtVar xs
104    
105    (* make a block out of a list of statements that are in reverse order *)    (* make a block out of a list of statements that are in reverse order *)
106      fun mkBlock stms = S.Block(List.rev stms)      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
   
     fun inputImage (errStrm, nrrd, dim, shape) = (  
           case ImageInfo.fromNrrd(NrrdInfo.getInfo(errStrm, nrrd), dim, shape)  
            of NONE => raise Fail(concat["nrrd file \"", nrrd, "\" does not have expected type"])  
             | SOME info => S.Proxy(nrrd, info)  
           (* end case *))  
   
     datatype 'a ctl_flow_info  
       = EXIT                    (* stm sequence always exits; no pruning so far *)  
       | PRUNE of 'a             (* stm sequence always exits at last stm in argument, which  
                                  * is either a block or stm list *)  
       | CONT                    (* stm sequence falls through *)  
       | EDIT of 'a              (* pruned code that has non-exiting paths *)  
   
     fun pruneUnreachableCode (blk as S.Block stms) = let  
           fun isExit S.S_Die = true  
             | isExit S.S_Stabilize = true  
             | isExit (S.S_Return _) = true  
             | isExit _ = false  
           fun pruneStms [] = CONT  
             | pruneStms [S.S_IfThenElse(x, blk1, blk2)] = (  
                 case pruneIf(x, blk1, blk2)  
                  of EXIT => EXIT  
                   | PRUNE stm => PRUNE[stm]  
                   | CONT => CONT  
                   | EDIT stm => EDIT[stm]  
                 (* end case *))  
             | pruneStms [stm] = if isExit stm then EXIT else CONT  
             | pruneStms ((stm as S.S_IfThenElse(x, blk1, blk2))::stms) = (  
                 case pruneIf(x, blk1, blk2)  
                  of EXIT => PRUNE[stm]  
                   | PRUNE stm => PRUNE[stm]  
                   | CONT => (case pruneStms stms  
                        of PRUNE stms => PRUNE(stm::stms)  
                         | EDIT stms => EDIT(stm::stms)  
                         | EXIT => EXIT (* different instances of ctl_flow_info *)  
                         | CONT => CONT  
                       (* end case *))  
                   | EDIT stm => (case pruneStms stms  
                        of PRUNE stms => PRUNE(stm::stms)  
                         | EDIT stms => EDIT(stm::stms)  
                         | _ => EDIT(stm::stms)  
                       (* end case *))  
                 (* end case *))  
             | pruneStms (stm::stms) = if isExit stm  
                 then PRUNE[stm]  
                 else (case pruneStms stms  
                    of PRUNE stms => PRUNE(stm::stms)  
                     | EDIT stms => EDIT(stm::stms)  
                     | info => info  
                   (* end case *))  
           and pruneIf (x, blk1, blk2) = (case (pruneBlk blk1, pruneBlk blk2)  
                  of (EXIT,       EXIT      ) => EXIT  
                   | (CONT,       CONT      ) => CONT  
                   | (CONT,       EXIT      ) => CONT  
                   | (EXIT,       CONT      ) => CONT  
                   | (CONT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (CONT,       PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, CONT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EXIT,       EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  EXIT      ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EDIT blk1,  PRUNE blk2) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, EDIT blk2 ) => EDIT(S.S_IfThenElse(x, blk1, blk2))  
                   | (EXIT,       PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, EXIT      ) => PRUNE(S.S_IfThenElse(x, blk1, blk2))  
                   | (PRUNE blk1, PRUNE blk2) => PRUNE(S.S_IfThenElse(x, blk1, blk2))  
                 (* end case *))  
           and pruneBlk (S.Block stms) = (case pruneStms stms  
                  of PRUNE stms => PRUNE(S.Block stms)  
                   | EDIT stms => EDIT(S.Block stms)  
                   | EXIT => EXIT (* different instances of ctl_flow_info *)  
                   | CONT => CONT  
                 (* end case *))  
           in  
             case pruneBlk blk  
              of PRUNE blk => blk  
               | EDIT blk => blk  
               | _=> blk  
             (* end case *)  
           end  
107    
108    (* simplify a statement into a single statement (i.e., a block if it expands    (* simplify a statement into a single statement (i.e., a block if it expands
109     * into more than one new statement).     * into more than one new statement).
110     *)     *)
111      fun simplifyBlock errStrm stm = mkBlock (simplifyStmt (errStrm, stm, []))      fun simplifyBlock (errStrm, stm) = mkBlock (simplifyStmt (errStrm, stm, []))
112    
113      (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,
114       * then we use the info from the proxy image to determine the type of the lhs
115       * variable.
116       *)
117        and cvtLHS (lhs, S.E_LoadImage(_, _, info)) = newVarWithType(lhs, STy.T_Image info)
118          | cvtLHS (lhs, _) = cvtVar lhs
119    
120    (* simplify the statement stm where stms is a reverse-order list of preceeding simplified    (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
121     * statements.  This function returns a reverse-order list of simplified statements.     * statements.  This function returns a reverse-order list of simplified statements.
# Line 167  Line 136 
136                  end                  end
137              | AST.S_Decl(x, SOME e) => let              | AST.S_Decl(x, SOME e) => let
138                  val (stms, e') = simplifyExp (errStrm, e, stms)                  val (stms, e') = simplifyExp (errStrm, e, stms)
139                  val x' = cvtVar x                  val x' = cvtLHS (x, e')
140                  in                  in
141                    S.S_Var(x', SOME e') :: stms                    S.S_Var(x', SOME e') :: stms
142                  end                  end
143    (* FIXME: we should also define a "boolean negate" operation on AST expressions so that we can
144     * handle both cases!
145     *)
146                | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>
147                    simplifyStmt (errStrm, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)
148                | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>
149                    simplifyStmt (errStrm, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)
150              | AST.S_IfThenElse(e, s1, s2) => let              | AST.S_IfThenElse(e, s1, s2) => let
151                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)
152                  val s1 = simplifyBlock errStrm s1                  val s1 = simplifyBlock (errStrm, s1)
153                  val s2 = simplifyBlock errStrm s2                  val s2 = simplifyBlock (errStrm, s2)
154                  in                  in
155                    S.S_IfThenElse(x, s1, s2) :: stms                    S.S_IfThenElse(x, s1, s2) :: stms
156                  end                  end
157              | AST.S_Foreach((x, e), body) => let              | AST.S_Foreach((x, e), body) => let
158                  val (stms, xs') = simplifyExpToVar (errStrm, e, stms)                  val (stms, xs') = simplifyExpToVar (errStrm, e, stms)
159                  val body' = simplifyBlock errStrm body                  val body' = simplifyBlock (errStrm, body)
160                  in                  in
161                    S.S_Foreach(cvtVar x, xs', body') :: stms                    S.S_Foreach(cvtVar x, xs', body') :: stms
162                  end                  end
163              | AST.S_Assign((x, _), e) => let              | AST.S_Assign((x, _), e) => let
164                  val (stms, e') = simplifyExp (errStrm, e, stms)                  val (stms, e') = simplifyExp (errStrm, e, stms)
165                    val x' = cvtLHS (x, e')
166                  in                  in
167                    S.S_Assign(cvtVar x, e') :: stms                    S.S_Assign(x', e') :: stms
168                  end                  end
169              | AST.S_New(name, args) => let              | AST.S_New(name, args) => let
170                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
# Line 210  Line 187 
187            (* end case *))            (* end case *))
188    
189      and simplifyExp (errStrm, exp, stms) = let      and simplifyExp (errStrm, exp, stms) = let
190              fun doBorderCtl (f, args) = let
191                    val (ctl, arg) = if Var.same(BV.image_border, f)
192                            then (BorderCtl.Default(hd args), hd(tl args))
193                          else if Var.same(BV.image_clamp, f)
194                            then (BorderCtl.Clamp, hd args)
195                          else if Var.same(BV.image_mirror, f)
196                            then (BorderCtl.Mirror, hd args)
197                          else if Var.same(BV.image_wrap, f)
198                            then (BorderCtl.Wrap, hd args)
199                            else raise Fail "impossible"
200                    in
201                      S.E_BorderCtl(ctl, arg)
202                    end
203            fun doPrimApply (f, tyArgs, args, ty) = let            fun doPrimApply (f, tyArgs, args, ty) = let
204                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
205                  in                  in
206                    case Var.kindOf f                    if Basis.isBorderCtl f
207                        then (stms, doBorderCtl (f, xs))
208                      else if Var.same(f, BV.fn_sphere_im)
209                        then raise Fail "FIXME: implicit sphere query"
210                        else (case Var.kindOf f
211                     of Var.BasisVar => let                     of Var.BasisVar => let
212                          fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))                          fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
213                            | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))                            | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
# Line 224  Line 218 
218                            (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))                            (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
219                          end                          end
220                      | _ => raise Fail "bogus prim application"                      | _ => raise Fail "bogus prim application"
221                    (* end case *)                        (* end case *))
222                  end                  end
223            in            in
224              case exp              case exp
# Line 239  Line 233 
233                      | _ => (stms, S.E_Var(cvtVar x))                      | _ => (stms, S.E_Var(cvtVar x))
234                    (* end case *))                    (* end case *))
235                | AST.E_Lit lit => (stms, S.E_Lit lit)                | AST.E_Lit lit => (stms, S.E_Lit lit)
236                  | AST.E_Kernel h => (stms, S.E_Kernel h)
237                | AST.E_Select(e, (fld, _)) => let                | AST.E_Select(e, (fld, _)) => let
238                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)
239                    in                    in
240                      (stms, S.E_Select(x, cvtVar fld))                      (stms, S.E_Select(x, cvtVar fld))
241                    end                    end
242                | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e                | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e
243                     of AST.E_Lit(Literal.Int n) => if Var.same(BasisVars.neg_i, rator)                     of AST.E_Lit(Literal.Int n) => if Var.same(BV.neg_i, rator)
244                          then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)                          then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)
245                          else doPrimApply (rator, tyArgs, args, ty)                          else doPrimApply (rator, tyArgs, args, ty)
246                      | AST.E_Lit(Literal.Real f) =>                      | AST.E_Lit(Literal.Real f) =>
247                          if Var.same(BasisVars.neg_t, rator)                          if Var.same(BV.neg_t, rator)
248                            then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)                            then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)
249                            else doPrimApply (rator, tyArgs, args, ty)                            else doPrimApply (rator, tyArgs, args, ty)
250                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
# Line 270  Line 265 
265                      | AST.E_ParallelMap(e', x, xs, _) =>                      | AST.E_ParallelMap(e', x, xs, _) =>
266                          if Basis.isReductionOp rator                          if Basis.isReductionOp rator
267                            then let                            then let
268                            (* parallel map-reduce *)                              val (result, stm) = simplifyReduction (errStrm, rator, e', x, xs, ty)
                             val result = SimpleVar.new ("res", Var.LocalVar, cvtTy ty)  
                             val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e', [])  
                             val (func, args) = Util.makeFunction(  
                                   Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),  
                                   SimpleVar.typeOf bodyResult)  
                             val mapReduceStm = S.S_MapReduce{  
                                     results = [result],  
                                     reductions = [rator],  
                                     body = func,  
                                     args = args,  
                                     source = xs  
                                   }  
269                              in                              in
270                                (mapReduceStm :: stms, S.E_Var result)                                (stm :: stms, S.E_Var result)
271                              end                              end
272                            else raise Fail "unsupported operation on parallel map"                            else raise Fail "unsupported operation on parallel map"
273                      | _ => doPrimApply (rator, tyArgs, args, ty)                      | _ => doPrimApply (rator, tyArgs, args, ty)
# Line 294  Line 277 
277                    val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                    val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
278                    in                    in
279                      case Var.kindOf f                      case Var.kindOf f
280                       of Var.FunVar => (stms, S.E_Apply(cvtVar f, xs, cvtTy ty))                       of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))
281                        | _ => raise Fail "bogus application"                        | _ => raise Fail "bogus application"
282                      (* end case *)                      (* end case *)
283                    end                    end
# Line 306  Line 289 
289                    val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')                    val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')
290                    val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))                    val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))
291                    val updateStm = S.S_Assign(acc,                    val updateStm = S.S_Assign(acc,
292                          S.E_Prim(BasisVars.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))                          S.E_Prim(BV.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))
293                    val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))                    val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
294                    in                    in
295                      (foreachStm :: initStm :: stms, S.E_Var acc)                      (foreachStm :: initStm :: stms, S.E_Var acc)
296                    end                    end
297                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME"                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"
298                | AST.E_Tensor(es, ty) => let                | AST.E_Tensor(es, ty) => let
299                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)
300                    in                    in
# Line 324  Line 307 
307                    end                    end
308                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
309                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)
310                    fun f ([], ys, stms) = (stms, List.rev ys)                    fun f NONE = NONE
311                      | f (NONE::es, ys, stms) = f (es, NONE::ys, stms)                      | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
312                      | f (SOME e::es, ys, stms) = let                      | f _ = raise Fail "expected integer literal in slice"
313                          val (stms, y) = simplifyExpToVar (errStrm, e, stms)                    val indices = List.map f indices
                         in  
                           f (es, SOME y::ys, stms)  
                         end  
                   val (stms, indices) = f (indices, [], stms)  
314                    in                    in
315                      (stms, S.E_Slice(x, indices, cvtTy ty))                      (stms, S.E_Slice(x, indices, cvtTy ty))
316                    end                    end
# Line 349  Line 328 
328                    in                    in
329                      (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)                      (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
330                    end                    end
331                  | AST.E_Orelse(e1, e2) => simplifyExp (
332                      errStrm,
333                      AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),
334                      stms)
335                  | AST.E_Andalso(e1, e2) => simplifyExp (
336                      errStrm,
337                      AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),
338                      stms)
339                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
340                     of ty as SimpleTypes.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))                     of ty as STy.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))
341                      | ty as SimpleTypes.T_Image{dim, shape} => (                      | ty as STy.T_Image info => let
342                          case ImageInfo.fromNrrd(NrrdInfo.getInfo(errStrm, nrrd), dim, shape)                          val dim = II.dim info
343                           of NONE => raise Fail(concat[                          val shape = II.voxelShape info
344                            in
345                              case NrrdInfo.getInfo (errStrm, nrrd)
346                               of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
347                                     of NONE => (
348                                          Error.error (errStrm, [
349                                  "nrrd file \"", nrrd, "\" does not have expected type"                                  "nrrd file \"", nrrd, "\" does not have expected type"
350                                ])                                          ]);
351                            | SOME info => (stms, S.E_LoadImage(ty, nrrd, info))                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
352                          (* end case *))                                    | SOME imgInfo =>
353                                          (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
354                                    (* end case *))
355                                | NONE => (
356                                    Error.warning (errStrm, [
357                                        "nrrd file \"", nrrd, "\" does not exist"
358                                      ]);
359                                    (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
360                              (* end case *)
361                            end
362                      | _ => raise Fail "bogus type for E_LoadNrrd"                      | _ => raise Fail "bogus type for E_LoadNrrd"
363                    (* end case *))                    (* end case *))
364                  | AST.E_Coerce{dstTy, e=AST.E_Lit(Literal.Int n), ...} => (case cvtTy dstTy
365                       of SimpleTypes.T_Tensor[] => (stms, S.E_Lit(Literal.Real(RealLit.fromInt n)))
366                        | _ => raise Fail "impossible: bad coercion"
367                      (* end case *))
368                | AST.E_Coerce{srcTy, dstTy, e} => let                | AST.E_Coerce{srcTy, dstTy, e} => let
369                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)
370                    val dstTy = cvtTy dstTy                    val dstTy = cvtTy dstTy
# Line 395  Line 400 
400              f (exps, [], stms)              f (exps, [], stms)
401            end            end
402    
403      fun simplifyStrand (errStrm, AST.Strand{name, params, state, initM, updateM, stabilizeM}) = let    (* simplify a parallel map-reduce *)
404        and simplifyReduction (errStrm, rator, e, x, xs, resTy) = let
405                val rator' = if Var.same(BV.red_all, rator) then Reductions.ALL
406                      else if Var.same(BV.red_exists, rator) then Reductions.EXISTS
407                      else if Var.same(BV.red_max, rator) then Reductions.MAX
408                      else if Var.same(BV.red_mean, rator) then raise Fail "FIXME: mean reduction"
409                      else if Var.same(BV.red_min, rator) then Reductions.MIN
410                      else if Var.same(BV.red_product, rator) then Reductions.PRODUCT
411                      else if Var.same(BV.red_sum, rator) then Reductions.SUM
412                      else if Var.same(BV.red_variance, rator) then raise Fail "FIXME: variance reduction"
413                        else raise Fail "impossible: not a reduction"
414                val x' = cvtVar x
415                val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
416                val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])
417    (* FIXME: need to handle reductions over active/stable subsets of strands *)
418                val (func, args) = Util.makeFunction(
419                      Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
420                      SimpleVar.typeOf bodyResult)
421                val mapReduceStm = S.S_MapReduce{
422                        results = [result],
423                        reductions = [rator'],
424                        body = func,
425                        args = args,
426                        source = x'
427                      }
428                in
429                  (result, mapReduceStm)
430                end
431    
432    
433      (* simplify a block and then prune unreachable and dead code *)
434        fun simplifyAndPruneBlock errStrm blk =
435              DeadCode.eliminate (simplifyBlock (errStrm, blk))
436    
437        fun simplifyStrand (errStrm, strand) = let
438              val AST.Strand{name, params, state, stateInit, initM, updateM, stabilizeM} = strand
439            val params' = cvtVars params            val params' = cvtVars params
440            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
441              | simplifyState ((x, optE) :: r, xs, stms) = let              | simplifyState ((x, optE) :: r, xs, stms) = let
# Line 406  Line 446 
446                      | SOME e => let                      | SOME e => let
447                          val (stms, e') = simplifyExp (errStrm, e, stms)                          val (stms, e') = simplifyExp (errStrm, e, stms)
448                          in                          in
449                            simplifyState (r, x'::xs, S.S_Var(x', SOME e') :: stms)                            simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
450                          end                          end
451                    (* end case *)                    (* end case *)
452                  end                  end
# Line 417  Line 457 
457                  params = params',                  params = params',
458                  state = xs,                  state = xs,
459                  stateInit = stm,                  stateInit = stm,
460                  initM = Option.map (simplifyBlock errStrm) initM,                  initM = Option.map (simplifyAndPruneBlock errStrm) initM,
461                  updateM = simplifyBlock errStrm updateM,                  updateM = simplifyAndPruneBlock errStrm updateM,
462                  stabilizeM = Option.map (simplifyBlock errStrm) stabilizeM                  stabilizeM = Option.map (simplifyAndPruneBlock errStrm) stabilizeM
463                }                }
464            end            end
465    
     fun simplifyCreate (errStrm, AST.C_Grid(dim, stm)) = S.C_Grid(dim, simplifyBlock errStrm stm)  
       | simplifyCreate (errStrm, AST.C_Collection stm) = S.C_Collection(simplifyBlock errStrm stm)  
   
466      fun transform (errStrm, prog) = let      fun transform (errStrm, prog) = let
467            val AST.Program{            val AST.Program{
468                    props, const_dcls, input_dcls, globals, init, strand, create, update                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
469                  } = prog                  } = prog
470            val consts' = ref[]            val consts' = ref[]
471            val constInit = ref[]            val constInit = ref[]
# Line 446  Line 483 
483            fun simplifyInputDcl ((x, NONE), desc) = let            fun simplifyInputDcl ((x, NONE), desc) = let
484                  val x' = cvtVar x                  val x' = cvtVar x
485                  val init = (case SimpleVar.typeOf x'                  val init = (case SimpleVar.typeOf x'
486                         of SimpleTypes.T_Image{dim, shape} => let                         of STy.T_Image info => S.Image info
                             val info = ImageInfo.mkInfo(dim, shape)  
                             in  
                               S.Image info  
                             end  
487                          | _ => S.NoDefault                          | _ => S.NoDefault
488                        (* end case *))                        (* end case *))
489                  val inp = S.INP{                  val inp = S.INP{
490                          var = x',                          var = x',
491                            name = Var.nameOf x,
492                            ty =  apiTypeOf x',
493                          desc = desc,                          desc = desc,
494                          init = init                          init = init
495                        }                        }
# Line 462  Line 497 
497                    inputs' := inp :: !inputs'                    inputs' := inp :: !inputs'
498                  end                  end
499              | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let              | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let
500                  val x' = cvtVar x                  val (x', init) = (case Var.monoTypeOf x
501                (* load the nrrd proxy here *)                         of Ty.T_Sequence(_, NONE) => (cvtVar x, S.LoadSeq nrrd)
502                  val info = NrrdInfo.getInfo (errStrm, nrrd)                          | Ty.T_Image{dim, shape} => let
503                  val init = (case SimpleVar.typeOf x'                              val dim = TU.monoDim dim
504                         of SimpleTypes.T_Sequence(_, NONE) => S.LoadSeq nrrd                              val shape = TU.monoShape shape
505                          | SimpleTypes.T_Image{dim, shape} => inputImage(errStrm, nrrd, dim, shape)                              in
506                                  case NrrdInfo.getInfo (errStrm, nrrd)
507                                   of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
508                                         of NONE => (
509                                              Error.error (errStrm, [
510                                                  "proxy nrrd file \"", nrrd,
511                                                  "\" does not have expected type"
512                                                ]);
513                                              (cvtVar x, S.Image(II.mkInfo(dim, shape))))
514                                          | SOME info =>
515                                              (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
516                                        (* end case *))
517                                    | NONE => (
518                                        Error.warning (errStrm, [
519                                            "proxy nrrd file \"", nrrd, "\" does not exist"
520                                          ]);
521                                        (cvtVar x, S.Image(II.mkInfo(dim, shape))))
522                                  (* end case *)
523                                end
524                          | _ => raise Fail "impossible"                          | _ => raise Fail "impossible"
525                        (* end case *))                        (* end case *))
526                  val inp = S.INP{                  val inp = S.INP{
527                          var = x',                          var = x',
528                            name = Var.nameOf x,
529                            ty = apiTypeOf x',
530                          desc = desc,                          desc = desc,
531                          init = init                          init = init
532                        }                        }
# Line 483  Line 538 
538                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (errStrm, e, [])
539                  val inp = S.INP{                  val inp = S.INP{
540                          var = x',                          var = x',
541                            name = Var.nameOf x,
542                            ty = apiTypeOf x',
543                          desc = desc,                          desc = desc,
544                          init = S.ConstExpr                          init = S.ConstExpr
545                        }                        }
# Line 490  Line 547 
547                    inputs' := inp :: !inputs';                    inputs' := inp :: !inputs';
548                    constInit := S.S_Assign(x', e') :: (stms @ !constInit)                    constInit := S.S_Assign(x', e') :: (stms @ !constInit)
549                  end                  end
550            fun simplifyGlobalDcl (AST.D_Var(x, optE)) = let            fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'
551                  val x' = cvtVar x              | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let
                 in  
                   case optE  
                     of NONE => globals' := x' :: !globals'  
                      | SOME e => let  
552                           val (stms, e') = simplifyExp (errStrm, e, [])                           val (stms, e') = simplifyExp (errStrm, e, [])
553                    val x' = cvtLHS (x, e')
554                           in                           in
555                             globals' := x' :: !globals';                             globals' := x' :: !globals';
556                             globalInit := S.S_Assign(x', e') :: (stms @ !globalInit)                             globalInit := S.S_Assign(x', e') :: (stms @ !globalInit)
557                           end                           end
                   (* end case *)  
                 end  
558              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
559                  val f' = cvtVar f                  val f' = cvtFunc f
560                  val params' = cvtVars params                  val params' = cvtVars params
561                  val body' = pruneUnreachableCode (simplifyBlock errStrm body)                  val body' = simplifyAndPruneBlock errStrm body
562                  in                  in
563                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs
564                  end                  end
565            in            val () = (
566              List.app simplifyConstDcl const_dcls;              List.app simplifyConstDcl const_dcls;
567              List.app simplifyInputDcl input_dcls;              List.app simplifyInputDcl input_dcls;
568              List.app simplifyGlobalDcl globals;                  List.app simplifyGlobalDcl globals)
569            (* make the global-initialization block *)
570              val globInit = (case globInit
571                     of SOME stm => mkBlock (simplifyStmt (errStrm, stm, !globalInit))
572                      | NONE => mkBlock (!globalInit)
573                    (* end case *))
574            (* if the globInit block is non-empty, record the fact in the property list *)
575              val props = (case globInit
576                     of S.Block{code=[], ...} => props
577                      | _ => Properties.GlobalInit :: props
578                    (* end case *))
579              in
580              S.Program{              S.Program{
581                  props = props,                  props = props,
582                  consts = List.rev(!consts'),                  consts = List.rev(!consts'),
583                  inputs = List.rev(!inputs'),                  inputs = List.rev(!inputs'),
584                  constInit = mkBlock (!constInit),                  constInit = mkBlock (!constInit),
585                  globals = List.rev(!globals'),                  globals = List.rev(!globals'),
586                  init = mkBlock (!globalInit),                  globInit = globInit,
587                  funcs = List.rev(!funcs),                  funcs = List.rev(!funcs),
588                  strand = simplifyStrand (errStrm, strand),                  strand = simplifyStrand (errStrm, strand),
589                  create = simplifyCreate (errStrm, create),                  create = Create.map (simplifyAndPruneBlock errStrm) create,
590                  update = Option.map (simplifyBlock errStrm) update                  init = Option.map (simplifyAndPruneBlock errStrm) init,
591                    update = Option.map (simplifyAndPruneBlock errStrm) update
592                }                }
593            end            end
594    

Legend:
Removed from v.3465  
changed lines
  Added in v.4359

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0