Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4349, Tue Aug 2 18:14:48 2016 UTC revision 4441, Sun Aug 21 15:55:21 2016 UTC
# Line 22  Line 22 
22    
23  structure Simplify : sig  structure Simplify : sig
24    
25      val transform : Error.err_stream * AST.program -> Simple.program      val transform : Error.err_stream * AST.program * GlobalEnv.t -> Simple.program
26    
27    end = struct    end = struct
28    
# Line 32  Line 32 
32      structure Ty = Types      structure Ty = Types
33      structure VMap = Var.Map      structure VMap = Var.Map
34      structure II = ImageInfo      structure II = ImageInfo
35        structure BV = BasisVars
36    
37      (* environment for mapping small global constants to AST expressions *)
38        type const_env = AST.expr VMap.map
39    
40      (* context for simplification *)
41        datatype context = Cxt of {
42            errStrm : Error.err_stream,
43            gEnv : GlobalEnv.t,
44            cEnv : const_env
45          }
46    
47        fun getNrrdInfo (Cxt{errStrm, ...}, nrrd) = NrrdInfo.getInfo (errStrm, nrrd)
48    
49        fun findStrand (Cxt{gEnv, ...}, s) = GlobalEnv.findStrand(gEnv, s)
50    
51        fun insertConst (Cxt{errStrm, gEnv, cEnv}, x, e) = Cxt{
52                errStrm = errStrm, gEnv = gEnv, cEnv = VMap.insert(cEnv, x, e)
53              }
54    
55        fun findConst (Cxt{cEnv, ...}, x) = VMap.find(cEnv, x)
56    
57        fun error (Cxt{errStrm, ...}, msg) = Error.error (errStrm, msg)
58        fun warning (Cxt{errStrm, ...}, msg) = Error.warning (errStrm, msg)
59    
60      (* error message for when a nrrd image file is incompatible with the declared image type *)
61        fun badImageNrrd (cxt, nrrdFile, nrrdInfo, expectedDim, expectedShp) = let
62              val NrrdInfo.NrrdInfo{dim, nElems, ...} = nrrdInfo
63              val expectedNumElems = List.foldl (op * ) 1 expectedShp
64              val prefix = String.concat[
65                      "image file \"", nrrdFile, "\"  is incompatible with expected type image(",
66                      Int.toString expectedDim, ")[",
67                      String.concatWithMap "," Int.toString expectedShp, "]"
68                    ]
69              in
70                case (dim = expectedDim, nElems = expectedNumElems)
71                 of (false, true) => error (cxt, [
72                        prefix, "; its dimension is ", Int.toString dim
73                      ])
74                  | (true, false) => error (cxt, [
75                        prefix, "; it has ", Int.toString nElems, " sample per voxel"
76                      ])
77                  | _ =>  error (cxt, [
78                        prefix, ";  its dimension is ", Int.toString dim, " and it has ",
79                        Int.toString nElems, " sample per voxel"
80                      ])
81                (* end case *)
82              end
83    
84    (* convert a Types.ty to a SimpleTypes.ty *)    (* convert a Types.ty to a SimpleTypes.ty *)
85      fun cvtTy ty = (case ty      fun cvtTy ty = (case ty
# Line 104  Line 152 
152    (* make a block out of a list of statements that are in reverse order *)    (* make a block out of a list of statements that are in reverse order *)
153      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
154    
155      (* make a variable definition *)
156        fun mkDef (x, e) = S.S_Var(x, SOME e)
157    
158        fun mkRDiv (res, a, b) = mkDef (res, S.E_Prim(BV.div_rr, [], [a, b], STy.realTy))
159        fun mkToReal (res, a) =
160              mkDef (res, S.E_Coerce{srcTy = STy.T_Int, dstTy = STy.realTy, x = a})
161        fun mkLength (res, elemTy, xs) =
162              mkDef (res, S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int))
163    
164    (* simplify a statement into a single statement (i.e., a block if it expands    (* simplify a statement into a single statement (i.e., a block if it expands
165     * into more than one new statement).     * into more than one new statement).
166     *)     *)
167      fun simplifyBlock (errStrm, stm) = mkBlock (simplifyStmt (errStrm, stm, []))      fun simplifyBlock (cxt, stm) = mkBlock (simplifyStmt (cxt, stm, []))
168    
169    (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,    (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,
170     * then we use the info from the proxy image to determine the type of the lhs     * then we use the info from the proxy image to determine the type of the lhs
# Line 121  Line 178 
178     * Note that error reporting is done in the typechecker, but it does not prune unreachable     * Note that error reporting is done in the typechecker, but it does not prune unreachable
179     * code.     * code.
180     *)     *)
181      and simplifyStmt (errStrm, stm, stms) : S.stmt list = (case stm      and simplifyStmt (cxt : context, stm, stms) : S.stmt list = (case stm
182             of AST.S_Block body => let             of AST.S_Block body => let
183                  fun simplify ([], stms) = stms                  fun simplify ([], stms) = stms
184                    | simplify (stm::r, stms) = simplify (r, simplifyStmt (errStrm, stm, stms))                    | simplify (stm::r, stms) = simplify (r, simplifyStmt (cxt, stm, stms))
185                  in                  in
186                    simplify (body, stms)                    simplify (body, stms)
187                  end                  end
# Line 134  Line 191 
191                    S.S_Var(x', NONE) :: stms                    S.S_Var(x', NONE) :: stms
192                  end                  end
193              | AST.S_Decl(x, SOME e) => let              | AST.S_Decl(x, SOME e) => let
194                  val (stms, e') = simplifyExp (errStrm, e, stms)                  val (stms, e') = simplifyExp (cxt, e, stms)
195                  val x' = cvtLHS (x, e')                  val x' = cvtLHS (x, e')
196                  in                  in
197                    S.S_Var(x', SOME e') :: stms                    S.S_Var(x', SOME e') :: stms
# Line 143  Line 200 
200   * handle both cases!   * handle both cases!
201   *)   *)
202              | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>              | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>
203                  simplifyStmt (errStrm, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)                  simplifyStmt (cxt, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)
204              | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>              | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>
205                  simplifyStmt (errStrm, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)                  simplifyStmt (cxt, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)
206              | AST.S_IfThenElse(e, s1, s2) => let              | AST.S_IfThenElse(e, s1, s2) => let
207                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (cxt, e, stms)
208                  val s1 = simplifyBlock (errStrm, s1)                  val s1 = simplifyBlock (cxt, s1)
209                  val s2 = simplifyBlock (errStrm, s2)                  val s2 = simplifyBlock (cxt, s2)
210                  in                  in
211                    S.S_IfThenElse(x, s1, s2) :: stms                    S.S_IfThenElse(x, s1, s2) :: stms
212                  end                  end
213              | AST.S_Foreach((x, e), body) => let              | AST.S_Foreach((x, e), body) => let
214                  val (stms, xs') = simplifyExpToVar (errStrm, e, stms)                  val (stms, xs') = simplifyExpToVar (cxt, e, stms)
215                  val body' = simplifyBlock (errStrm, body)                  val body' = simplifyBlock (cxt, body)
216                  in                  in
217                    S.S_Foreach(cvtVar x, xs', body') :: stms                    S.S_Foreach(cvtVar x, xs', body') :: stms
218                  end                  end
219              | AST.S_Assign((x, _), e) => let              | AST.S_Assign((x, _), e) => let
220                  val (stms, e') = simplifyExp (errStrm, e, stms)                  val (stms, e') = simplifyExp (cxt, e, stms)
221                  val x' = cvtLHS (x, e')                  val x' = cvtLHS (x, e')
222                  in                  in
223                    S.S_Assign(x', e') :: stms                    S.S_Assign(x', e') :: stms
224                  end                  end
225              | AST.S_New(name, args) => let              | AST.S_New(name, args) => let
226                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
227                  in                  in
228                    S.S_New(name, xs) :: stms                    S.S_New(name, xs) :: stms
229                  end                  end
# Line 174  Line 231 
231              | AST.S_Die => S.S_Die :: stms              | AST.S_Die => S.S_Die :: stms
232              | AST.S_Stabilize => S.S_Stabilize :: stms              | AST.S_Stabilize => S.S_Stabilize :: stms
233              | AST.S_Return e => let              | AST.S_Return e => let
234                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (cxt, e, stms)
235                  in                  in
236                    S.S_Return x :: stms                    S.S_Return x :: stms
237                  end                  end
238              | AST.S_Print args => let              | AST.S_Print args => let
239                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
240                  in                  in
241                    S.S_Print xs :: stms                    S.S_Print xs :: stms
242                  end                  end
243            (* end case *))            (* end case *))
244    
245      and simplifyExp (errStrm, exp, stms) = let      and simplifyExp (cxt, exp, stms) = let
246            fun doBorderCtl (f, args) = let            fun doBorderCtl (f, args) = let
247                  val (ctl, arg) = if Var.same(BasisVars.image_border, f)                  val (ctl, arg) = if Var.same(BV.image_border, f)
248                          then (BorderCtl.Default(hd args), hd(tl args))                          then (BorderCtl.Default(hd args), hd(tl args))
249                        else if Var.same(BasisVars.image_clamp, f)                        else if Var.same(BV.image_clamp, f)
250                          then (BorderCtl.Clamp, hd args)                          then (BorderCtl.Clamp, hd args)
251                        else if Var.same(BasisVars.image_mirror, f)                        else if Var.same(BV.image_mirror, f)
252                          then (BorderCtl.Mirror, hd args)                          then (BorderCtl.Mirror, hd args)
253                        else if Var.same(BasisVars.image_wrap, f)                        else if Var.same(BV.image_wrap, f)
254                          then (BorderCtl.Wrap, hd args)                          then (BorderCtl.Wrap, hd args)
255                          else raise Fail "impossible"                          else raise Fail "impossible"
256                  in                  in
257                    S.E_BorderCtl(ctl, arg)                    S.E_BorderCtl(ctl, arg)
258                  end                  end
259            fun doPrimApply (f, tyArgs, args, ty) = let            fun doPrimApply (f, tyArgs, args, ty) = let
260                  val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                  val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
261                    fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
262                      | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
263                      | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
264                      | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
265                  in                  in
266                    if Basis.isBorderCtl f                    if Basis.isBorderCtl f
267                      then (stms, doBorderCtl (f, xs))                      then (stms, doBorderCtl (f, xs))
268                    else if Var.same(f, BasisVars.fn_sphere_im)                    else if Var.same(f, BV.fn_sphere_im)
269                      then raise Fail "FIXME: implicit sphere query"                      then let
270                        (* get the strand type for the query *)
271                          val tyArgs as [S.TY(STy.T_Strand strand)] = List.map cvtTyArg tyArgs
272                        (* get the strand environment for the strand *)
273                          val SOME sEnv = findStrand(cxt, strand)
274                          fun result (query, pos) =
275                                (stms, S.E_Prim(query, tyArgs, cvtVar pos::xs, cvtTy ty))
276                          in
277                          (* extract the position variable and spatial dimension *)
278                            case (StrandEnv.findPosVar sEnv, StrandEnv.getSpaceDim sEnv)
279                             of (SOME pos, SOME 1) => result (BV.fn_sphere1_r, pos)
280                              | (SOME pos, SOME 2) => result (BV.fn_sphere2_t, pos)
281                              | (SOME pos, SOME 3) => result (BV.fn_sphere3_t, pos)
282                              | _ => raise Fail "impossible"
283                            (* end case *)
284                          end
285                      else (case Var.kindOf f                      else (case Var.kindOf f
286                         of Var.BasisVar => let                         of Var.BasisVar => let
                             fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))  
                               | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))  
                               | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))  
                               | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))  
287                              val tyArgs = List.map cvtTyArg tyArgs                              val tyArgs = List.map cvtTyArg tyArgs
288                              in                              in
289                                (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))                                (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
# Line 229  Line 301 
301                          in                          in
302                            (stm::stms, S.E_Var x')                            (stm::stms, S.E_Var x')
303                          end                          end
304                        | Var.ConstVar => (case findConst(cxt, x)
305                             of SOME e => let
306                                  val (stms, x') = simplifyExpToVar (cxt, e, stms)
307                                  in
308                                    (stms, S.E_Var x')
309                                  end
310                              | NONE => (stms, S.E_Var(cvtVar x))
311                            (* end case *))
312                      | _ => (stms, S.E_Var(cvtVar x))                      | _ => (stms, S.E_Var(cvtVar x))
313                    (* end case *))                    (* end case *))
314                | AST.E_Lit lit => (stms, S.E_Lit lit)                | AST.E_Lit lit => (stms, S.E_Lit lit)
315                | AST.E_Kernel h => (stms, S.E_Kernel h)                | AST.E_Kernel h => (stms, S.E_Kernel h)
316                | AST.E_Select(e, (fld, _)) => let                | AST.E_Select(e, (fld, _)) => let
317                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (cxt, e, stms)
318                    in                    in
319                      (stms, S.E_Select(x, cvtVar fld))                      (stms, S.E_Select(x, cvtVar fld))
320                    end                    end
321                | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e                | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e
322                     of AST.E_Lit(Literal.Int n) => if Var.same(BasisVars.neg_i, rator)                     of AST.E_Lit(Literal.Int n) => if Var.same(BV.neg_i, rator)
323                          then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)                          then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)
324                          else doPrimApply (rator, tyArgs, args, ty)                          else doPrimApply (rator, tyArgs, args, ty)
325                      | AST.E_Lit(Literal.Real f) =>                      | AST.E_Lit(Literal.Real f) =>
326                          if Var.same(BasisVars.neg_t, rator)                          if Var.same(BV.neg_t, rator)
327                            then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)                            then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)
328                            else doPrimApply (rator, tyArgs, args, ty)                            else doPrimApply (rator, tyArgs, args, ty)
329    (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)
330                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
331                          then let                          then let
332                            val {rator, init, mvs} = Util.reductionInfo rator                            val (stms, xs) = simplifyExpToVar (cxt, e'', stms)
333                            val (stms, xs) = simplifyExpToVar (errStrm, e'', stms)                            val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e', [])
                           val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])  
                           val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)  
334                            val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy                            val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
335                              fun mkReductionLoop (redOp, bodyStms, bodyResult, stms) = let
336                                    val {rator, init, mvs} = Util.reductionInfo redOp
337                                    val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)
338                            val initStm = S.S_Var(acc, SOME(S.E_Lit init))                            val initStm = S.S_Var(acc, SOME(S.E_Lit init))
339                            val updateStm = S.S_Assign(acc,                            val updateStm = S.S_Assign(acc,
340                                  S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))                                  S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))
341                            val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))                                  val foreachStm = S.S_Foreach(cvtVar x, xs,
342                                          mkBlock(updateStm :: bodyStms))
343                            in                            in
344                              (foreachStm :: initStm :: stms, S.E_Var acc)                              (foreachStm :: initStm :: stms, S.E_Var acc)
345                            end                            end
346                              in
347                                case Util.identifyReduction rator
348                                 of Util.MEAN => let
349                                      val (stms, S.E_Var resultV) = mkReductionLoop (
350                                            Reductions.SUM, bodyStms, bodyResult, stms)
351                                      val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
352                                      val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
353                                      val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
354                                      val stms =
355                                            mkRDiv (mean, resultV, rNum) ::
356                                            mkToReal (rNum, num) ::
357                                            mkLength (num, elemTy, xs) ::
358                                            stms
359                                      in
360                                        (stms, S.E_Var mean)
361                                      end
362                                  | Util.VARIANCE => raise Fail "FIXME: VARIANCE"
363                                  | Util.RED red => mkReductionLoop (red, bodyStms, bodyResult, stms)
364                                (* end case *)
365                              end
366                          else doPrimApply (rator, tyArgs, args, ty)                          else doPrimApply (rator, tyArgs, args, ty)
367                      | AST.E_ParallelMap(e', x, xs, _) =>                      | AST.E_ParallelMap(e', x, xs, _) =>
368                          if Basis.isReductionOp rator                          if Basis.isReductionOp rator
369                            then let                            then let
370                            (* parallel map-reduce *)                              val (result, stms) = simplifyReduction (cxt, rator, e', x, xs, ty, stms)
                             val x' = cvtVar x  
                             val result = SimpleVar.new ("res", Var.LocalVar, cvtTy ty)  
                             val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e', [])  
                             val (func, args) = Util.makeFunction(  
                                   Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),  
                                   SimpleVar.typeOf bodyResult)  
                             val mapReduceStm = S.S_MapReduce{  
                                     results = [result],  
                                     reductions = [rator],  
                                     body = func,  
                                     args = args,  
                                     source = [(x', xs)]  
                                   }  
371                              in                              in
372                                (mapReduceStm :: stms, S.E_Var result)                                (stms, S.E_Var result)
373                              end                              end
374                            else raise Fail "unsupported operation on parallel map"                            else raise Fail "unsupported operation on parallel map"
375                      | _ => doPrimApply (rator, tyArgs, args, ty)                      | _ => doPrimApply (rator, tyArgs, args, ty)
376                    (* end case *))                    (* end case *))
377                | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)                | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)
378                | AST.E_Apply((f, _), args, ty) => let                | AST.E_Apply((f, _), args, ty) => let
379                    val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)                    val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
380                    in                    in
381                      case Var.kindOf f                      case Var.kindOf f
382                       of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))                       of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))
# Line 295  Line 385 
385                    end                    end
386                | AST.E_Comprehension(e, (x, e'), seqTy) => let                | AST.E_Comprehension(e, (x, e'), seqTy) => let
387                  (* convert a comprehension to a foreach loop over the sequence defined by e' *)                  (* convert a comprehension to a foreach loop over the sequence defined by e' *)
388                    val (stms, xs) = simplifyExpToVar (errStrm, e', stms)                    val (stms, xs) = simplifyExpToVar (cxt, e', stms)
389                    val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])                    val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
390                    val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy                    val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
391                    val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')                    val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')
392                    val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))                    val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))
393                    val updateStm = S.S_Assign(acc,                    val updateStm = S.S_Assign(acc,
394                          S.E_Prim(BasisVars.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))                          S.E_Prim(BV.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))
395                    val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))                    val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
396                    in                    in
397                      (foreachStm :: initStm :: stms, S.E_Var acc)                      (foreachStm :: initStm :: stms, S.E_Var acc)
398                    end                    end
399                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"                | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"
400                | AST.E_Tensor(es, ty) => let                | AST.E_Tensor(es, ty) => let
401                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)                    val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
402                    in                    in
403                      (stms, S.E_Tensor(xs, cvtTy ty))                      (stms, S.E_Tensor(xs, cvtTy ty))
404                    end                    end
405                | AST.E_Seq(es, ty) => let                | AST.E_Seq(es, ty) => let
406                    val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)                    val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
407                    in                    in
408                      (stms, S.E_Seq(xs, cvtTy ty))                      (stms, S.E_Seq(xs, cvtTy ty))
409                    end                    end
410                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)                | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
411                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (cxt, e, stms)
412                    fun f NONE = NONE                    fun f NONE = NONE
413                      | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)                      | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
414                      | f _ = raise Fail "expected integer literal in slice"                      | f _ = raise Fail "expected integer literal in slice"
# Line 329  Line 419 
419                | AST.E_Cond(e1, e2, e3, ty) => let                | AST.E_Cond(e1, e2, e3, ty) => let
420                  (* a conditional expression gets turned into an if-then-else statememt *)                  (* a conditional expression gets turned into an if-then-else statememt *)
421                    val result = newTemp(cvtTy ty)                    val result = newTemp(cvtTy ty)
422                    val (stms, x) = simplifyExpToVar (errStrm, e1, S.S_Var(result, NONE) :: stms)                    val (stms, x) = simplifyExpToVar (cxt, e1, S.S_Var(result, NONE) :: stms)
423                    fun simplifyBranch e = let                    fun simplifyBranch e = let
424                          val (stms, e) = simplifyExp (errStrm, e, [])                          val (stms, e) = simplifyExp (cxt, e, [])
425                          in                          in
426                            mkBlock (S.S_Assign(result, e)::stms)                            mkBlock (S.S_Assign(result, e)::stms)
427                          end                          end
# Line 341  Line 431 
431                      (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)                      (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
432                    end                    end
433                | AST.E_Orelse(e1, e2) => simplifyExp (                | AST.E_Orelse(e1, e2) => simplifyExp (
434                    errStrm,                    cxt,
435                    AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),                    AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),
436                    stms)                    stms)
437                | AST.E_Andalso(e1, e2) => simplifyExp (                | AST.E_Andalso(e1, e2) => simplifyExp (
438                    errStrm,                    cxt,
439                    AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),                    AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),
440                    stms)                    stms)
441                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty                | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
# Line 354  Line 444 
444                          val dim = II.dim info                          val dim = II.dim info
445                          val shape = II.voxelShape info                          val shape = II.voxelShape info
446                          in                          in
447                            case NrrdInfo.getInfo (errStrm, nrrd)                            case getNrrdInfo (cxt, nrrd)
448                             of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)                             of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
449                                   of NONE => (                                   of NONE => (
450                                        Error.error (errStrm, [                                        badImageNrrd (cxt, nrrd, nrrdInfo, dim, shape);
                                           "nrrd file \"", nrrd, "\" does not have expected type"  
                                         ]);  
451                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
452                                    | SOME imgInfo =>                                    | SOME imgInfo =>
453                                        (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))                                        (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
454                                  (* end case *))                                  (* end case *))
455                              | NONE => (                              | NONE => (
456                                  Error.warning (errStrm, [                                  error (cxt, [
457                                      "nrrd file \"", nrrd, "\" does not exist"                                      "proxy-image file \"", nrrd, "\" does not exist"
458                                    ]);                                    ]);
459                                  (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))                                  (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
460                            (* end case *)                            (* end case *)
# Line 378  Line 466 
466                      | _ => raise Fail "impossible: bad coercion"                      | _ => raise Fail "impossible: bad coercion"
467                    (* end case *))                    (* end case *))
468                | AST.E_Coerce{srcTy, dstTy, e} => let                | AST.E_Coerce{srcTy, dstTy, e} => let
469                    val (stms, x) = simplifyExpToVar (errStrm, e, stms)                    val (stms, x) = simplifyExpToVar (cxt, e, stms)
470                    val dstTy = cvtTy dstTy                    val dstTy = cvtTy dstTy
471                    val result = newTemp dstTy                    val result = newTemp dstTy
472                    val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}                    val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
# Line 388  Line 476 
476              (* end case *)              (* end case *)
477            end            end
478    
479      and simplifyExpToVar (errStrm, exp, stms) = let      and simplifyExpToVar (cxt, exp, stms) = let
480            val (stms, e) = simplifyExp (errStrm, exp, stms)            val (stms, e) = simplifyExp (cxt, exp, stms)
481            in            in
482              case e              case e
483               of S.E_Var x => (stms, x)               of S.E_Var x => (stms, x)
# Line 401  Line 489 
489              (* end case *)              (* end case *)
490            end            end
491    
492      and simplifyExpsToVars (errStrm, exps, stms) = let      and simplifyExpsToVars (cxt, exps, stms) = let
493            fun f ([], xs, stms) = (stms, List.rev xs)            fun f ([], xs, stms) = (stms, List.rev xs)
494              | f (e::es, xs, stms) = let              | f (e::es, xs, stms) = let
495                  val (stms, x) = simplifyExpToVar (errStrm, e, stms)                  val (stms, x) = simplifyExpToVar (cxt, e, stms)
496                  in                  in
497                    f (es, x::xs, stms)                    f (es, x::xs, stms)
498                  end                  end
# Line 412  Line 500 
500              f (exps, [], stms)              f (exps, [], stms)
501            end            end
502    
503      (* simplify a parallel map-reduce *)
504        and simplifyReduction (cxt, rator, e, x, xs, resTy, stms) = let
505                val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
506                val x' = cvtVar x
507                val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
508              (* convert the domain from a variable to a StrandSets.t value *)
509                val domain = if Var.same(BV.set_active, xs) then StrandSets.ACTIVE
510                      else if Var.same(BV.set_all, xs) then StrandSets.ALL
511                      else if Var.same(BV.set_stable, xs) then StrandSets.STABLE
512                        else raise Fail "impossible: not a strand set"
513                val (func, args) = Util.makeFunction(
514                      Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
515                      SimpleVar.typeOf bodyResult)
516    
517                in
518                  case Util.identifyReduction rator
519                   of Util.MEAN => let
520                        val mapReduceStm = S.S_MapReduce[
521                                S.MapReduce{
522                                    result = result, reduction = Reductions.SUM, mapf = func,
523                                    args = args, source = x', domain = domain
524                                  }]
525                        val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
526                        val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
527                        val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
528                        val numStrandsOp = (case domain
529                               of StrandSets.ACTIVE => BV.numActive
530                                | StrandSets.ALL => BV.numStrands
531                                | StrandSets.STABLE => BV.numStable
532                              (* end case *))
533                        val stms =
534                              mkRDiv (mean, result, rNum) ::
535                              mkToReal (rNum, num) ::
536                              mkDef (num, S.E_Prim(numStrandsOp, [], [], STy.T_Int)) ::
537                              mapReduceStm ::
538                              stms
539                        in
540                          (mean, stms)
541                        end
542                    | Util.VARIANCE => raise Fail "FIXME: variance reduction"
543                    | Util.RED rator' => let
544                        val mapReduceStm = S.S_MapReduce[
545                                S.MapReduce{
546                                    result = result, reduction = rator', mapf = func, args = args,
547                                    source = x', domain = domain
548                                  }]
549                        in
550                          (result, mapReduceStm :: stms)
551                        end
552                  (* end case *)
553                end
554    
555    (* simplify a block and then prune unreachable and dead code *)    (* simplify a block and then prune unreachable and dead code *)
556      fun simplifyAndPruneBlock errStrm blk =      fun simplifyAndPruneBlock cxt blk =
557            DeadCode.eliminate (simplifyBlock (errStrm, blk))            DeadCode.eliminate (simplifyBlock (cxt, blk))
558    
559      fun simplifyStrand (errStrm, strand) = let      fun simplifyStrand (cxt, strand) = let
560            val AST.Strand{name, params, state, stateInit, initM, updateM, stabilizeM} = strand            val AST.Strand{
561                      name, params, spatialDim, state, stateInit, initM, updateM, stabilizeM
562                    } = strand
563            val params' = cvtVars params            val params' = cvtVars params
564            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)            fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
565              | simplifyState ((x, optE) :: r, xs, stms) = let              | simplifyState ((x, optE) :: r, xs, stms) = let
# Line 426  Line 568 
568                    case optE                    case optE
569                     of NONE => simplifyState (r, x'::xs, stms)                     of NONE => simplifyState (r, x'::xs, stms)
570                      | SOME e => let                      | SOME e => let
571                          val (stms, e') = simplifyExp (errStrm, e, stms)                          val (stms, e') = simplifyExp (cxt, e, stms)
572                          in                          in
573                            simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)                            simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
574                          end                          end
# Line 437  Line 579 
579              S.Strand{              S.Strand{
580                  name = name,                  name = name,
581                  params = params',                  params = params',
582                    spatialDim = spatialDim,
583                  state = xs,                  state = xs,
584                  stateInit = stm,                  stateInit = stm,
585                  initM = Option.map (simplifyAndPruneBlock errStrm) initM,                  initM = Option.map (simplifyAndPruneBlock cxt) initM,
586                  updateM = simplifyAndPruneBlock errStrm updateM,                  updateM = simplifyAndPruneBlock cxt updateM,
587                  stabilizeM = Option.map (simplifyAndPruneBlock errStrm) stabilizeM                  stabilizeM = Option.map (simplifyAndPruneBlock cxt) stabilizeM
588                }                }
589            end            end
590    
591      fun transform (errStrm, prog) = let      fun transform (errStrm, prog, gEnv) = let
592            val AST.Program{            val AST.Program{
593                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
594                  } = prog                  } = prog
# Line 455  Line 598 
598            val globals' = ref[]            val globals' = ref[]
599            val globalInit = ref[]            val globalInit = ref[]
600            val funcs = ref[]            val funcs = ref[]
601            fun simplifyConstDcl (x, SOME e) = let          (* simplify the constant dcls: the small constants will be added to the context
602                  val (stms, e') = simplifyExp (errStrm, e, [])           * while the large constants will be added to the const' list.
603             *)
604              val cxt = let
605                    val cxt = Cxt{errStrm = errStrm, gEnv = gEnv, cEnv = VMap.empty}
606                    fun simplifyConstDcl ((x, SOME e), cxt) = if Util.isSmallExp e
607                          then insertConst (cxt, x, e)
608                          else let
609                            val (stms, e') = simplifyExp (cxt, e, [])
610                  val x' = cvtVar x                  val x' = cvtVar x
611                  in                  in
612                    consts' := x' :: !consts';                    consts' := x' :: !consts';
613                    constInit := S.S_Assign(x', e') :: (stms @ !constInit)                            constInit := S.S_Assign(x', e') :: (stms @ !constInit);
614                              cxt
615                            end
616                      | simplifyConstDcl _ = raise Fail "impossble"
617                    in
618                      List.foldl simplifyConstDcl cxt const_dcls
619                  end                  end
620            fun simplifyInputDcl ((x, NONE), desc) = let            fun simplifyInputDcl ((x, NONE), desc) = let
621                  val x' = cvtVar x                  val x' = cvtVar x
622                  val init = (case SimpleVar.typeOf x'                  val init = (case SimpleVar.typeOf x'
623                         of STy.T_Image info => S.Image info                         of STy.T_Image info => (
624                                warning(cxt, [
625                                    "assuming a sample type of ", RawTypes.toString(II.sampleTy info),
626                                    " for '", SimpleVar.nameOf x',
627                                    "'; specify a proxy-image file to override the default sample type"
628                                  ]);
629                                S.Image info)
630                          | _ => S.NoDefault                          | _ => S.NoDefault
631                        (* end case *))                        (* end case *))
632                  val inp = S.INP{                  val inp = S.INP{
# Line 485  Line 646 
646                              val dim = TU.monoDim dim                              val dim = TU.monoDim dim
647                              val shape = TU.monoShape shape                              val shape = TU.monoShape shape
648                              in                              in
649                                case NrrdInfo.getInfo (errStrm, nrrd)                                case getNrrdInfo (cxt, nrrd)
650                                 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)                                 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
651                                       of NONE => (                                       of NONE => (
652                                            Error.error (errStrm, [                                            badImageNrrd (cxt, nrrd, nrrdInfo, dim, shape);
                                               "proxy nrrd file \"", nrrd,  
                                               "\" does not have expected type"  
                                             ]);  
653                                            (cvtVar x, S.Image(II.mkInfo(dim, shape))))                                            (cvtVar x, S.Image(II.mkInfo(dim, shape))))
654                                        | SOME info =>                                        | SOME info =>
655                                            (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))                                            (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
656                                      (* end case *))                                      (* end case *))
657                                  | NONE => (                                  | NONE => (
658                                      Error.warning (errStrm, [                                      error (cxt, [
659                                          "proxy nrrd file \"", nrrd, "\" does not exist"                                          "proxy-image file \"", nrrd, "\" does not exist"
660                                        ]);                                        ]);
661                                      (cvtVar x, S.Image(II.mkInfo(dim, shape))))                                      (cvtVar x, S.Image(II.mkInfo(dim, shape))))
662                                (* end case *)                                (* end case *)
# Line 517  Line 675 
675                  end                  end
676              | simplifyInputDcl ((x, SOME e), desc) = let              | simplifyInputDcl ((x, SOME e), desc) = let
677                  val x' = cvtVar x                  val x' = cvtVar x
678                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (cxt, e, [])
679                  val inp = S.INP{                  val inp = S.INP{
680                          var = x',                          var = x',
681                          name = Var.nameOf x,                          name = Var.nameOf x,
# Line 531  Line 689 
689                  end                  end
690            fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'            fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'
691              | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let              | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let
692                  val (stms, e') = simplifyExp (errStrm, e, [])                  val (stms, e') = simplifyExp (cxt, e, [])
693                  val x' = cvtLHS (x, e')                  val x' = cvtLHS (x, e')
694                  in                  in
695                    globals' := x' :: !globals';                    globals' := x' :: !globals';
# Line 540  Line 698 
698              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let              | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
699                  val f' = cvtFunc f                  val f' = cvtFunc f
700                  val params' = cvtVars params                  val params' = cvtVars params
701                  val body' = simplifyAndPruneBlock errStrm body                  val body' = simplifyAndPruneBlock cxt body
702                  in                  in
703                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs
704                  end                  end
705            val () = (            val () = (
                 List.app simplifyConstDcl const_dcls;  
706                  List.app simplifyInputDcl input_dcls;                  List.app simplifyInputDcl input_dcls;
707                  List.app simplifyGlobalDcl globals)                  List.app simplifyGlobalDcl globals)
708          (* make the global-initialization block *)          (* make the global-initialization block *)
709            val globInit = (case globInit            val globInit = (case globInit
710                   of SOME stm => mkBlock (simplifyStmt (errStrm, stm, !globalInit))                   of SOME stm => mkBlock (simplifyStmt (cxt, stm, !globalInit))
711                    | NONE => mkBlock (!globalInit)                    | NONE => mkBlock (!globalInit)
712                  (* end case *))                  (* end case *))
713          (* if the globInit block is non-empty, record the fact in the property list *)          (* if the globInit block is non-empty, record the fact in the property list *)
# Line 567  Line 724 
724                  globals = List.rev(!globals'),                  globals = List.rev(!globals'),
725                  globInit = globInit,                  globInit = globInit,
726                  funcs = List.rev(!funcs),                  funcs = List.rev(!funcs),
727                  strand = simplifyStrand (errStrm, strand),                  strand = simplifyStrand (cxt, strand),
728                  create = Create.map (simplifyAndPruneBlock errStrm) create,                  create = Create.map (simplifyAndPruneBlock cxt) create,
729                  init = Option.map (simplifyAndPruneBlock errStrm) init,                  init = Option.map (simplifyAndPruneBlock cxt) init,
730                  update = Option.map (simplifyAndPruneBlock errStrm) update                  update = Option.map (simplifyAndPruneBlock cxt) update
731                }                }
732            end            end
733    

Legend:
Removed from v.4349  
changed lines
  Added in v.4441

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0