Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4317, Sat Jul 30 14:12:14 2016 UTC revision 4426, Wed Aug 17 20:00:23 2016 UTC
# Line 22  Line 22 
22    
23  structure Simplify : sig  structure Simplify : sig
24    
25      val transform : Error.err_stream * AST.program -> Simple.program      val transform : Error.err_stream * AST.program * GlobalEnv.t -> Simple.program
26    
27    end = struct    end = struct
28    
# Line 32  Line 32 
32      structure Ty = Types      structure Ty = Types
33      structure VMap = Var.Map      structure VMap = Var.Map
34      structure II = ImageInfo      structure II = ImageInfo
35        structure BV = BasisVars
36    
37      (* context for simplification *)
38        type context = {errStrm : Error.err_stream, gEnv : GlobalEnv.t}
39    
40        fun error ({errStrm, gEnv}, msg) = Error.error (errStrm, msg)
41        fun warning ({errStrm, gEnv}, msg) = Error.warning (errStrm, msg)
42    
43    (* convert a Types.ty to a SimpleTypes.ty *)    (* convert a Types.ty to a SimpleTypes.ty *)
44      fun cvtTy ty = (case ty      fun cvtTy ty = (case ty
# Line 104  Line 111 
111    (* make a block out of a list of statements that are in reverse order *)    (* make a block out of a list of statements that are in reverse order *)
112      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
113    
114      (* make a variable definition *)
115        fun mkDef (x, e) = S.S_Var(x, SOME e)
116    
117        fun mkRDiv (res, a, b) = mkDef (res, S.E_Prim(BV.div_rr, [], [a, b], STy.realTy))
118        fun mkToReal (res, a) =
119              mkDef (res, S.E_Coerce{srcTy = STy.T_Int, dstTy = STy.realTy, x = a})
120        fun mkLength (res, elemTy, xs) =
121              mkDef (res, S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int))
122    
123    (* simplify a statement into a single statement (i.e., a block if it expands    (* simplify a statement into a single statement (i.e., a block if it expands
124     * into more than one new statement).     * into more than one new statement).
125     *)     *)
126      fun simplifyBlock (errStrm, stm) = mkBlock (simplifyStmt (errStrm, stm, []))      fun simplifyBlock (cxt, stm) = mkBlock (simplifyStmt (cxt, stm, []))
127    
128    (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,    (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,
129     * then we use the info from the proxy image to determine the type of the lhs     * then we use the info from the proxy image to determine the type of the lhs
# Line 121  Line 137 
137     * Note that error reporting is done in the typechecker, but it does not prune unreachable     * Note that error reporting is done in the typechecker, but it does not prune unreachable
138     * code.     * code.
139     *)     *)
140      and simplifyStmt (errStrm, stm, stms) : S.stmt list = (case stm      and simplifyStmt (cxt, stm, stms) : S.stmt list = (case stm
141             of AST.S_Block body => let             of AST.S_Block body => let
142                  fun simplify ([], stms) = stms                  fun simplify ([], stms) = stms
143                    | simplify (stm::r, stms) = simplify (r, simplifyStmt (errStrm, stm, stms))                    | simplify (stm::r, stms) = simplify (r, simplifyStmt (cxt, stm, stms))
144                  in                  in
145                    simplify (body, stms)                    simplify (body, stms)
146                  end                  end
# Line 134  Line 150 
150                    S.S_Var(x', NONE) :: stms                    S.S_Var(x', NONE) :: stms
151                  end                  end
152              | AST.S_Decl(x, SOME e) => let              | AST.S_Decl(x, SOME e) => let
153                  val (stms, e') = simplifyExp (errStrm, e, stms)                  val (stms, e') = simplifyExp (cxt, e, stms)
154                  val x' = cvtLHS (x, e')                  val x' = cvtLHS (x, e')
155                  in                  in
156                    S.S_Var(x', SOME e') :: stms                    S.S_Var(x', SOME e') :: stms
# Line 143  Line 159 
159   * handle both cases!   * handle both cases!
160   *)   *)
161              | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>              | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>
162                  simplifyStmt (errStrm, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)                  simplifyStmt (cxt, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)
163              | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>              | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>
164                  simplifyStmt (errStrm, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)                  simplifyStmt (cxt, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)
165              | AST.S_IfThenElse(e, s1, s2) => let              | AST.S_IfThenElse(e, s1, s2) => let
166                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (cxt, e, stms)
167                  val s1 = simplifyBlock (errStrm, s1)                  val s1 = simplifyBlock (cxt, s1)
168                  val s2 = simplifyBlock (errStrm, s2)                  val s2 = simplifyBlock (cxt, s2)
169                  in                  in
170                    S.S_IfThenElse(x, s1, s2) :: stms                    S.S_IfThenElse(x, s1, s2) :: stms
171                  end                  end
172              | AST.S_Foreach((x, e), body) => let              | AST.S_Foreach((x, e), body) => let
173                  val (stms, xs') = simplifyExpToVar (errStrm, e, stms)                  val (stms, xs') = simplifyExpToVar (cxt, e, stms)
174                  val body' = simplifyBlock (errStrm, body)                  val body' = simplifyBlock (cxt, body)
175                  in                  in
176                    S.S_Foreach(cvtVar x, xs', body') :: stms                    S.S_Foreach(cvtVar x, xs', body') :: stms
177                  end                  end
178              | AST.S_Assign((x, _), e) => let              | AST.S_Assign((x, _), e) => let
179                  val (stms, e') = simplifyExp (errStrm, e, stms)                  val (stms, e') = simplifyExp (cxt, e, stms)
180                  val x' = cvtLHS (x, e')                  val x' = cvtLHS (x, e')
181                  in                  in
182                    S.S_Assign(x', e') :: stms                    S.S_Assign(x', e') :: stms
183                  end                  end
184              | AST.S_New(name, args) => let              | AST.S_New(name, args) => let
185                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
186                  in                  in
187                    S.S_New(name, xs) :: stms                    S.S_New(name, xs) :: stms
188                  end                  end
# Line 174  Line 190 
190              | AST.S_Die => S.S_Die :: stms              | AST.S_Die => S.S_Die :: stms
191              | AST.S_Stabilize => S.S_Stabilize :: stms              | AST.S_Stabilize => S.S_Stabilize :: stms
192              | AST.S_Return e => let              | AST.S_Return e => let
193                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (cxt, e, stms)
194                  in                  in
195                    S.S_Return x :: stms                    S.S_Return x :: stms
196                  end                  end
197              | AST.S_Print args => let              | AST.S_Print args => let
198                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
199                  in                  in
200                    S.S_Print xs :: stms                    S.S_Print xs :: stms
201                  end                  end
202            (* end case *))            (* end case *))
203    
204      and simplifyExp (errStrm, exp, stms) = let      and simplifyExp (cxt, exp, stms) = let
205            fun doBorderCtl (f, args) = let            fun doBorderCtl (f, args) = let
206                  val (ctl, arg) = if Var.same(BasisVars.image_border, f)                  val (ctl, arg) = if Var.same(BV.image_border, f)
207                          then (BorderCtl.Default(hd args), hd(tl args))                          then (BorderCtl.Default(hd args), hd(tl args))
208                        else if Var.same(BasisVars.image_clamp, f)                        else if Var.same(BV.image_clamp, f)
209                          then (BorderCtl.Clamp, hd args)                          then (BorderCtl.Clamp, hd args)
210                        else if Var.same(BasisVars.image_mirror, f)                        else if Var.same(BV.image_mirror, f)
211                          then (BorderCtl.Mirror, hd args)                          then (BorderCtl.Mirror, hd args)
212                        else if Var.same(BasisVars.image_wrap, f)                        else if Var.same(BV.image_wrap, f)
213                          then (BorderCtl.Wrap, hd args)                          then (BorderCtl.Wrap, hd args)
214                          else raise Fail "impossible"                          else raise Fail "impossible"
215                  in                  in
216                    S.E_BorderCtl(ctl, arg)                    S.E_BorderCtl(ctl, arg)
217                  end                  end
218            fun doPrimApply (f, tyArgs, args, ty) = let            fun doPrimApply (f, tyArgs, args, ty) = let
219                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
220                    fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
221                      | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
222                      | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
223                      | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
224                  in                  in
225                    if Basis.isBorderCtl f                    if Basis.isBorderCtl f
226                      then (stms, doBorderCtl (f, xs))                      then (stms, doBorderCtl (f, xs))
227                      else if Var.same(f, BV.fn_sphere_im)
228                        then let
229                        (* get the strand type for the query *)
230                          val tyArgs as [S.TY(STy.T_Strand strand)] = List.map cvtTyArg tyArgs
231                        (* get the strand environment for the strand *)
232                          val SOME sEnv = GlobalEnv.findStrand(#gEnv cxt, strand)
233                          fun result (query, pos) =
234                                (stms, S.E_Prim(query, tyArgs, cvtVar pos::xs, cvtTy ty))
235                          in
236                          (* extract the position variable and spatial dimension *)
237                            case (StrandEnv.findPosVar sEnv, StrandEnv.getSpaceDim sEnv)
238                             of (SOME pos, SOME 1) => result (BV.fn_sphere1_r, pos)
239                              | (SOME pos, SOME 2) => result (BV.fn_sphere2_t, pos)
240                              | (SOME pos, SOME 3) => result (BV.fn_sphere3_t, pos)
241                              | _ => raise Fail "impossible"
242                            (* end case *)
243                          end
244                      else (case Var.kindOf f                      else (case Var.kindOf f
245                         of Var.BasisVar => let                         of Var.BasisVar => let
                             fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))  
                               | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))  
                               | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))  
                               | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))  
246                              val tyArgs = List.map cvtTyArg tyArgs                              val tyArgs = List.map cvtTyArg tyArgs
247                              in                              in
248                                (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))                                (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
# Line 232  Line 265 
265                | AST.E_Lit lit => (stms, S.E_Lit lit)                | AST.E_Lit lit => (stms, S.E_Lit lit)
266                | AST.E_Kernel h => (stms, S.E_Kernel h)                | AST.E_Kernel h => (stms, S.E_Kernel h)
267                | AST.E_Select(e, (fld, _)) => let                | AST.E_Select(e, (fld, _)) => let
268                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (cxt, e, stms)
269                    in                    in
270                      (stms, S.E_Select(x, cvtVar fld))                      (stms, S.E_Select(x, cvtVar fld))
271                    end                    end
272                | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e                | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e
273                     of AST.E_Lit(Literal.Int n) => if Var.same(BasisVars.neg_i, rator)                     of AST.E_Lit(Literal.Int n) => if Var.same(BV.neg_i, rator)
274                          then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)                          then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)
275                          else doPrimApply (rator, tyArgs, args, ty)                          else doPrimApply (rator, tyArgs, args, ty)
276                      | AST.E_Lit(Literal.Real f) =>                      | AST.E_Lit(Literal.Real f) =>
277                          if Var.same(BasisVars.neg_t, rator)                          if Var.same(BV.neg_t, rator)
278                            then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)                            then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)
279                            else doPrimApply (rator, tyArgs, args, ty)                            else doPrimApply (rator, tyArgs, args, ty)
280    (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)
281                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
282                          then let                          then let
283                            val {rator, init, mvs} = Util.reductionInfo rator                            val (stms, xs) = simplifyExpToVar (cxt, e'', stms)
284                            val (stms, xs) = simplifyExpToVar (errStrm, e'', stms)                            val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e', [])
                           val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])  
                           val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)  
285                            val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy                            val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
286                              fun mkReductionLoop (redOp, bodyStms, bodyResult, stms) = let
287                                    val {rator, init, mvs} = Util.reductionInfo redOp
288                                    val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)
289                            val initStm = S.S_Var(acc, SOME(S.E_Lit init))                            val initStm = S.S_Var(acc, SOME(S.E_Lit init))
290                            val updateStm = S.S_Assign(acc,                            val updateStm = S.S_Assign(acc,
291                                  S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))                                  S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))
292                            val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))                                  val foreachStm = S.S_Foreach(cvtVar x, xs,
293                                          mkBlock(updateStm :: bodyStms))
294                            in                            in
295                              (foreachStm :: initStm :: stms, S.E_Var acc)                              (foreachStm :: initStm :: stms, S.E_Var acc)
296                            end                            end
297                              in
298                                case Util.identifyReduction rator
299                                 of Util.MEAN => let
300                                      val (stms, S.E_Var resultV) = mkReductionLoop (
301                                            Reductions.SUM, bodyStms, bodyResult, stms)
302                                      val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
303                                      val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
304                                      val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
305                                      val stms =
306                                            mkRDiv (mean, resultV, rNum) ::
307                                            mkToReal (rNum, num) ::
308                                            mkLength (num, elemTy, xs) ::
309                                            stms
310                                      in
311                                        (stms, S.E_Var mean)
312                                      end
313                                  | Util.VARIANCE => raise Fail "FIXME: VARIANCE"
314                                  | Util.RED red => mkReductionLoop (red, bodyStms, bodyResult, stms)
315                                (* end case *)
316                              end
317                          else doPrimApply (rator, tyArgs, args, ty)                          else doPrimApply (rator, tyArgs, args, ty)
318                      | AST.E_ParallelMap(e', x, xs, _) =>                      | AST.E_ParallelMap(e', x, xs, _) =>
319                          if Basis.isReductionOp rator                          if Basis.isReductionOp rator
320                            then let                            then let
321                            (* parallel map-reduce *)                              val (result, stms) = simplifyReduction (cxt, rator, e', x, xs, ty, stms)
                             val result = SimpleVar.new ("res", Var.LocalVar, cvtTy ty)  
                             val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e', [])  
                             val (func, args) = Util.makeFunction(  
                                   Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),  
                                   SimpleVar.typeOf bodyResult)  
                             val mapReduceStm = S.S_MapReduce{  
                                     results = [result],  
                                     reductions = [rator],  
                                     body = func,  
                                     args = args,  
                                     source = xs  
                                   }  
322                              in                              in
323                                (mapReduceStm :: stms, S.E_Var result)                                (stms, S.E_Var result)
324                              end                              end
325                            else raise Fail "unsupported operation on parallel map"                            else raise Fail "unsupported operation on parallel map"
326                      | _ => doPrimApply (rator, tyArgs, args, ty)                      | _ => doPrimApply (rator, tyArgs, args, ty)
327                    (* end case *))                    (* end case *))
328                | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)                | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)
329                | AST.E_Apply((f, _), args, ty) => let                | AST.E_Apply((f, _), args, ty) => let
330                    val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                    val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
331                    in                    in
332                      case Var.kindOf f                      case Var.kindOf f
333                       of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))                       of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))
# Line 292  Line 336 
336                    end                    end
337                | AST.E_Comprehension(e, (x, e'), seqTy) => let                | AST.E_Comprehension(e, (x, e'), seqTy) => let
338                  (* convert a comprehension to a foreach loop over the sequence defined by e' *)                  (* convert a comprehension to a foreach loop over the sequence defined by e' *)
339                    val (stms, xs) = simplifyExpToVar (errStrm, e', stms)                    val (stms, xs) = simplifyExpToVar (cxt, e', stms)
340                    val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])                    val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
341                    val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy                    val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
342                    val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')                    val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')
343                    val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))                    val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))
344                    val updateStm = S.S_Assign(acc,                    val updateStm = S.S_Assign(acc,
345                          S.E_Prim(BasisVars.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))                          S.E_Prim(BV.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))
346                    val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))                    val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
347                    in                    in
348                      (foreachStm :: initStm :: stms, S.E_Var acc)                      (foreachStm :: initStm :: stms, S.E_Var acc)
349                    end                    end
350                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"
351                | AST.E_Tensor(es, ty) => let                | AST.E_Tensor(es, ty) => let
352                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)                    val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
353                    in                    in
354                      (stms, S.E_Tensor(xs, cvtTy ty))                      (stms, S.E_Tensor(xs, cvtTy ty))
355                    end                    end
356                | AST.E_Seq(es, ty) => let                | AST.E_Seq(es, ty) => let
357                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)                    val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
358                    in                    in
359                      (stms, S.E_Seq(xs, cvtTy ty))                      (stms, S.E_Seq(xs, cvtTy ty))
360                    end                    end
361                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
362                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (cxt, e, stms)
363                    fun f NONE = NONE                    fun f NONE = NONE
364                      | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)                      | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
365                      | f _ = raise Fail "expected integer literal in slice"                      | f _ = raise Fail "expected integer literal in slice"
# Line 326  Line 370 
370                | AST.E_Cond(e1, e2, e3, ty) => let                | AST.E_Cond(e1, e2, e3, ty) => let
371                  (* a conditional expression gets turned into an if-then-else statememt *)                  (* a conditional expression gets turned into an if-then-else statememt *)
372                    val result = newTemp(cvtTy ty)                    val result = newTemp(cvtTy ty)
373                    val (stms, x) = simplifyExpToVar (errStrm, e1, S.S_Var(result, NONE) :: stms)                    val (stms, x) = simplifyExpToVar (cxt, e1, S.S_Var(result, NONE) :: stms)
374                    fun simplifyBranch e = let                    fun simplifyBranch e = let
375                          val (stms, e) = simplifyExp (errStrm, e, [])                          val (stms, e) = simplifyExp (cxt, e, [])
376                          in                          in
377                            mkBlock (S.S_Assign(result, e)::stms)                            mkBlock (S.S_Assign(result, e)::stms)
378                          end                          end
# Line 338  Line 382 
382                      (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)                      (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
383                    end                    end
384                | AST.E_Orelse(e1, e2) => simplifyExp (                | AST.E_Orelse(e1, e2) => simplifyExp (
385                    errStrm,                    cxt,
386                    AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),                    AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),
387                    stms)                    stms)
388                | AST.E_Andalso(e1, e2) => simplifyExp (                | AST.E_Andalso(e1, e2) => simplifyExp (
389                    errStrm,                    cxt,
390                    AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),                    AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),
391                    stms)                    stms)
392                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
# Line 351  Line 395 
395                          val dim = II.dim info                          val dim = II.dim info
396                          val shape = II.voxelShape info                          val shape = II.voxelShape info
397                          in                          in
398                            case NrrdInfo.getInfo (errStrm, nrrd)                            case NrrdInfo.getInfo (#errStrm cxt, nrrd)
399                             of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)                             of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
400                                   of NONE => (                                   of NONE => (
401                                        Error.error (errStrm, [                                        error (cxt, [
402                                            "nrrd file \"", nrrd, "\" does not have expected type"                                            "nrrd file \"", nrrd, "\" does not have expected type"
403                                          ]);                                          ]);
404                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
# Line 362  Line 406 
406                                        (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))                                        (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
407                                  (* end case *))                                  (* end case *))
408                              | NONE => (                              | NONE => (
409                                  Error.warning (errStrm, [                                  warning (cxt, [
410                                      "nrrd file \"", nrrd, "\" does not exist"                                      "nrrd file \"", nrrd, "\" does not exist"
411                                    ]);                                    ]);
412                                  (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))                                  (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
# Line 375  Line 419 
419                      | _ => raise Fail "impossible: bad coercion"                      | _ => raise Fail "impossible: bad coercion"
420                    (* end case *))                    (* end case *))
421                | AST.E_Coerce{srcTy, dstTy, e} => let                | AST.E_Coerce{srcTy, dstTy, e} => let
422                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (cxt, e, stms)
423                    val dstTy = cvtTy dstTy                    val dstTy = cvtTy dstTy
424                    val result = newTemp dstTy                    val result = newTemp dstTy
425                    val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}                    val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
# Line 385  Line 429 
429              (* end case *)              (* end case *)
430            end            end
431    
432      and simplifyExpToVar (errStrm, exp, stms) = let      and simplifyExpToVar (cxt, exp, stms) = let
433            val (stms, e) = simplifyExp (errStrm, exp, stms)            val (stms, e) = simplifyExp (cxt, exp, stms)
434            in            in
435              case e              case e
436               of S.E_Var x => (stms, x)               of S.E_Var x => (stms, x)
# Line 398  Line 442 
442              (* end case *)              (* end case *)
443            end            end
444    
445      and simplifyExpsToVars (errStrm, exps, stms) = let      and simplifyExpsToVars (cxt, exps, stms) = let
446            fun f ([], xs, stms) = (stms, List.rev xs)            fun f ([], xs, stms) = (stms, List.rev xs)
447              | f (e::es, xs, stms) = let              | f (e::es, xs, stms) = let
448                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (cxt, e, stms)
449                  in                  in
450                    f (es, x::xs, stms)                    f (es, x::xs, stms)
451                  end                  end
# Line 409  Line 453 
453              f (exps, [], stms)              f (exps, [], stms)
454            end            end
455    
456      (* simplify a parallel map-reduce *)
457        and simplifyReduction (cxt, rator, e, x, xs, resTy, stms) = let
458                val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
459                val x' = cvtVar x
460                val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
461              (* convert the domain from a variable to a StrandSets.t value *)
462                val domain = if Var.same(BV.set_active, xs) then StrandSets.ACTIVE
463                      else if Var.same(BV.set_all, xs) then StrandSets.ALL
464                      else if Var.same(BV.set_stable, xs) then StrandSets.STABLE
465                        else raise Fail "impossible: not a strand set"
466                val (func, args) = Util.makeFunction(
467                      Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
468                      SimpleVar.typeOf bodyResult)
469    
470                in
471                  case Util.identifyReduction rator
472                   of Util.MEAN => let
473                        val mapReduceStm = S.S_MapReduce[
474                                S.MapReduce{
475                                    result = result, reduction = Reductions.SUM, mapf = func,
476                                    args = args, source = x', domain = domain
477                                  }]
478                        val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
479                        val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
480                        val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
481                        val numStrandsOp = (case domain
482                               of StrandSets.ACTIVE => BV.numActive
483                                | StrandSets.ALL => BV.numStrands
484                                | StrandSets.STABLE => BV.numStable
485                              (* end case *))
486                        val stms =
487                              mkRDiv (mean, result, rNum) ::
488                              mkToReal (rNum, num) ::
489                              mkDef (num, S.E_Prim(numStrandsOp, [], [], STy.T_Int)) ::
490                              mapReduceStm ::
491                              stms
492                        in
493                          (mean, stms)
494                        end
495                    | Util.VARIANCE => raise Fail "FIXME: variance reduction"
496                    | Util.RED rator' => let
497                        val mapReduceStm = S.S_MapReduce[
498                                S.MapReduce{
499                                    result = result, reduction = rator', mapf = func, args = args,
500                                    source = x', domain = domain
501                                  }]
502                        in
503                          (result, mapReduceStm :: stms)
504                        end
505                  (* end case *)
506                end
507    
508    (* simplify a block and then prune unreachable and dead code *)    (* simplify a block and then prune unreachable and dead code *)
509      fun simplifyAndPruneBlock errStrm blk =      fun simplifyAndPruneBlock cxt blk =
510            DeadCode.eliminate (simplifyBlock (errStrm, blk))            DeadCode.eliminate (simplifyBlock (cxt, blk))
511    
512      fun simplifyStrand (errStrm, strand) = let      fun simplifyStrand (cxt, strand) = let
513            val AST.Strand{name, params, state, stateInit, initM, updateM, stabilizeM} = strand            val AST.Strand{
514                      name, params, spatialDim, state, stateInit, initM, updateM, stabilizeM
515                    } = strand
516            val params' = cvtVars params            val params' = cvtVars params
517            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
518              | simplifyState ((x, optE) :: r, xs, stms) = let              | simplifyState ((x, optE) :: r, xs, stms) = let
# Line 423  Line 521 
521                    case optE                    case optE
522                     of NONE => simplifyState (r, x'::xs, stms)                     of NONE => simplifyState (r, x'::xs, stms)
523                      | SOME e => let                      | SOME e => let
524                          val (stms, e') = simplifyExp (errStrm, e, stms)                          val (stms, e') = simplifyExp (cxt, e, stms)
525                          in                          in
526                            simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)                            simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
527                          end                          end
# Line 434  Line 532 
532              S.Strand{              S.Strand{
533                  name = name,                  name = name,
534                  params = params',                  params = params',
535                    spatialDim = spatialDim,
536                  state = xs,                  state = xs,
537                  stateInit = stm,                  stateInit = stm,
538                  initM = Option.map (simplifyAndPruneBlock errStrm) initM,                  initM = Option.map (simplifyAndPruneBlock cxt) initM,
539                  updateM = simplifyAndPruneBlock errStrm updateM,                  updateM = simplifyAndPruneBlock cxt updateM,
540                  stabilizeM = Option.map (simplifyAndPruneBlock errStrm) stabilizeM                  stabilizeM = Option.map (simplifyAndPruneBlock cxt) stabilizeM
541                }                }
542            end            end
543    
544      fun transform (errStrm, prog) = let      fun transform (errStrm, prog, gEnv) = let
545            val AST.Program{            val AST.Program{
546                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
547                  } = prog                  } = prog
548              val cxt = {errStrm = errStrm, gEnv = gEnv}
549            val consts' = ref[]            val consts' = ref[]
550            val constInit = ref[]            val constInit = ref[]
551            val inputs' = ref[]            val inputs' = ref[]
# Line 453  Line 553 
553            val globalInit = ref[]            val globalInit = ref[]
554            val funcs = ref[]            val funcs = ref[]
555            fun simplifyConstDcl (x, SOME e) = let            fun simplifyConstDcl (x, SOME e) = let
556                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (cxt, e, [])
557                  val x' = cvtVar x                  val x' = cvtVar x
558                  in                  in
559                    consts' := x' :: !consts';                    consts' := x' :: !consts';
# Line 482  Line 582 
582                              val dim = TU.monoDim dim                              val dim = TU.monoDim dim
583                              val shape = TU.monoShape shape                              val shape = TU.monoShape shape
584                              in                              in
585                                case NrrdInfo.getInfo (errStrm, nrrd)                                case NrrdInfo.getInfo (#errStrm cxt, nrrd)
586                                 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)                                 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
587                                       of NONE => (                                       of NONE => (
588                                            Error.error (errStrm, [                                            error (cxt, [
589                                                "proxy nrrd file \"", nrrd,                                                "proxy input file \"", nrrd,
590                                                "\" does not have expected type"                                                "\" does not have expected type"
591                                              ]);                                              ]);
592                                            (cvtVar x, S.Image(II.mkInfo(dim, shape))))                                            (cvtVar x, S.Image(II.mkInfo(dim, shape))))
# Line 494  Line 594 
594                                            (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))                                            (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
595                                      (* end case *))                                      (* end case *))
596                                  | NONE => (                                  | NONE => (
597                                      Error.warning (errStrm, [                                      warning (cxt, [
598                                          "proxy nrrd file \"", nrrd, "\" does not exist"                                          "proxy input file \"", nrrd, "\" does not exist"
599                                        ]);                                        ]);
600                                      (cvtVar x, S.Image(II.mkInfo(dim, shape))))                                      (cvtVar x, S.Image(II.mkInfo(dim, shape))))
601                                (* end case *)                                (* end case *)
# Line 514  Line 614 
614                  end                  end
615              | simplifyInputDcl ((x, SOME e), desc) = let              | simplifyInputDcl ((x, SOME e), desc) = let
616                  val x' = cvtVar x                  val x' = cvtVar x
617                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (cxt, e, [])
618                  val inp = S.INP{                  val inp = S.INP{
619                          var = x',                          var = x',
620                          name = Var.nameOf x,                          name = Var.nameOf x,
# Line 528  Line 628 
628                  end                  end
629            fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'            fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'
630              | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let              | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let
631                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (cxt, e, [])
632                  val x' = cvtLHS (x, e')                  val x' = cvtLHS (x, e')
633                  in                  in
634                    globals' := x' :: !globals';                    globals' := x' :: !globals';
# Line 537  Line 637 
637              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
638                  val f' = cvtFunc f                  val f' = cvtFunc f
639                  val params' = cvtVars params                  val params' = cvtVars params
640                  val body' = simplifyAndPruneBlock errStrm body                  val body' = simplifyAndPruneBlock cxt body
641                  in                  in
642                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs
643                  end                  end
# Line 547  Line 647 
647                  List.app simplifyGlobalDcl globals)                  List.app simplifyGlobalDcl globals)
648          (* make the global-initialization block *)          (* make the global-initialization block *)
649            val globInit = (case globInit            val globInit = (case globInit
650                   of SOME stm => mkBlock (simplifyStmt (errStrm, stm, !globalInit))                   of SOME stm => mkBlock (simplifyStmt (cxt, stm, !globalInit))
651                    | NONE => mkBlock (!globalInit)                    | NONE => mkBlock (!globalInit)
652                  (* end case *))                  (* end case *))
653          (* if the globInit block is non-empty, record the fact in the property list *)          (* if the globInit block is non-empty, record the fact in the property list *)
# Line 564  Line 664 
664                  globals = List.rev(!globals'),                  globals = List.rev(!globals'),
665                  globInit = globInit,                  globInit = globInit,
666                  funcs = List.rev(!funcs),                  funcs = List.rev(!funcs),
667                  strand = simplifyStrand (errStrm, strand),                  strand = simplifyStrand (cxt, strand),
668                  create = Create.map (simplifyAndPruneBlock errStrm) create,                  create = Create.map (simplifyAndPruneBlock cxt) create,
669                  init = Option.map (simplifyAndPruneBlock errStrm) init,                  init = Option.map (simplifyAndPruneBlock cxt) init,
670                  update = Option.map (simplifyAndPruneBlock errStrm) update                  update = Option.map (simplifyAndPruneBlock cxt) update
671                }                }
672            end            end
673    

Legend:
Removed from v.4317  
changed lines
  Added in v.4426

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0