Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4371, Sat Aug 6 11:48:16 2016 UTC revision 4441, Sun Aug 21 15:55:21 2016 UTC
# Line 34  Line 34 
34      structure II = ImageInfo      structure II = ImageInfo
35      structure BV = BasisVars      structure BV = BasisVars
36    
37      (* environment for mapping small global constants to AST expressions *)
38        type const_env = AST.expr VMap.map
39    
40    (* context for simplification *)    (* context for simplification *)
41      type context = {errStrm : Error.err_stream, gEnv : GlobalEnv.t}      datatype context = Cxt of {
42            errStrm : Error.err_stream,
43            gEnv : GlobalEnv.t,
44            cEnv : const_env
45          }
46    
47        fun getNrrdInfo (Cxt{errStrm, ...}, nrrd) = NrrdInfo.getInfo (errStrm, nrrd)
48    
49        fun findStrand (Cxt{gEnv, ...}, s) = GlobalEnv.findStrand(gEnv, s)
50    
51        fun insertConst (Cxt{errStrm, gEnv, cEnv}, x, e) = Cxt{
52                errStrm = errStrm, gEnv = gEnv, cEnv = VMap.insert(cEnv, x, e)
53              }
54    
55      fun error ({errStrm, gEnv}, msg) = Error.error (errStrm, msg)      fun findConst (Cxt{cEnv, ...}, x) = VMap.find(cEnv, x)
56      fun warning ({errStrm, gEnv}, msg) = Error.warning (errStrm, msg)  
57        fun error (Cxt{errStrm, ...}, msg) = Error.error (errStrm, msg)
58        fun warning (Cxt{errStrm, ...}, msg) = Error.warning (errStrm, msg)
59    
60      (* error message for when a nrrd image file is incompatible with the declared image type *)
61        fun badImageNrrd (cxt, nrrdFile, nrrdInfo, expectedDim, expectedShp) = let
62              val NrrdInfo.NrrdInfo{dim, nElems, ...} = nrrdInfo
63              val expectedNumElems = List.foldl (op * ) 1 expectedShp
64              val prefix = String.concat[
65                      "image file \"", nrrdFile, "\"  is incompatible with expected type image(",
66                      Int.toString expectedDim, ")[",
67                      String.concatWithMap "," Int.toString expectedShp, "]"
68                    ]
69              in
70                case (dim = expectedDim, nElems = expectedNumElems)
71                 of (false, true) => error (cxt, [
72                        prefix, "; its dimension is ", Int.toString dim
73                      ])
74                  | (true, false) => error (cxt, [
75                        prefix, "; it has ", Int.toString nElems, " sample per voxel"
76                      ])
77                  | _ =>  error (cxt, [
78                        prefix, ";  its dimension is ", Int.toString dim, " and it has ",
79                        Int.toString nElems, " sample per voxel"
80                      ])
81                (* end case *)
82              end
83    
84    (* convert a Types.ty to a SimpleTypes.ty *)    (* convert a Types.ty to a SimpleTypes.ty *)
85      fun cvtTy ty = (case ty      fun cvtTy ty = (case ty
# Line 111  Line 152 
152    (* make a block out of a list of statements that are in reverse order *)    (* make a block out of a list of statements that are in reverse order *)
153      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}      fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
154    
155      (* make a variable definition *)
156        fun mkDef (x, e) = S.S_Var(x, SOME e)
157    
158        fun mkRDiv (res, a, b) = mkDef (res, S.E_Prim(BV.div_rr, [], [a, b], STy.realTy))
159        fun mkToReal (res, a) =
160              mkDef (res, S.E_Coerce{srcTy = STy.T_Int, dstTy = STy.realTy, x = a})
161        fun mkLength (res, elemTy, xs) =
162              mkDef (res, S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int))
163    
164    (* simplify a statement into a single statement (i.e., a block if it expands    (* simplify a statement into a single statement (i.e., a block if it expands
165     * into more than one new statement).     * into more than one new statement).
166     *)     *)
# Line 128  Line 178 
178     * Note that error reporting is done in the typechecker, but it does not prune unreachable     * Note that error reporting is done in the typechecker, but it does not prune unreachable
179     * code.     * code.
180     *)     *)
181      and simplifyStmt (cxt, stm, stms) : S.stmt list = (case stm      and simplifyStmt (cxt : context, stm, stms) : S.stmt list = (case stm
182             of AST.S_Block body => let             of AST.S_Block body => let
183                  fun simplify ([], stms) = stms                  fun simplify ([], stms) = stms
184                    | simplify (stm::r, stms) = simplify (r, simplifyStmt (cxt, stm, stms))                    | simplify (stm::r, stms) = simplify (r, simplifyStmt (cxt, stm, stms))
# Line 220  Line 270 
270                      (* get the strand type for the query *)                      (* get the strand type for the query *)
271                        val tyArgs as [S.TY(STy.T_Strand strand)] = List.map cvtTyArg tyArgs                        val tyArgs as [S.TY(STy.T_Strand strand)] = List.map cvtTyArg tyArgs
272                      (* get the strand environment for the strand *)                      (* get the strand environment for the strand *)
273                        val SOME sEnv = GlobalEnv.findStrand(#gEnv cxt, strand)                        val SOME sEnv = findStrand(cxt, strand)
274                        fun result (query, pos) =                        fun result (query, pos) =
275                              (stms, S.E_Prim(query, tyArgs, cvtVar pos::xs, cvtTy ty))                              (stms, S.E_Prim(query, tyArgs, cvtVar pos::xs, cvtTy ty))
276                        in                        in
# Line 251  Line 301 
301                          in                          in
302                            (stm::stms, S.E_Var x')                            (stm::stms, S.E_Var x')
303                          end                          end
304                        | Var.ConstVar => (case findConst(cxt, x)
305                             of SOME e => let
306                                  val (stms, x') = simplifyExpToVar (cxt, e, stms)
307                                  in
308                                    (stms, S.E_Var x')
309                                  end
310                              | NONE => (stms, S.E_Var(cvtVar x))
311                            (* end case *))
312                      | _ => (stms, S.E_Var(cvtVar x))                      | _ => (stms, S.E_Var(cvtVar x))
313                    (* end case *))                    (* end case *))
314                | AST.E_Lit lit => (stms, S.E_Lit lit)                | AST.E_Lit lit => (stms, S.E_Lit lit)
# Line 271  Line 329 
329  (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)  (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)
330                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator                      | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
331                          then let                          then let
                           val {rator, init, mvs} = Util.reductionInfo rator  
332                            val (stms, xs) = simplifyExpToVar (cxt, e'', stms)                            val (stms, xs) = simplifyExpToVar (cxt, e'', stms)
333                            val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])                            val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e', [])
                           val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)  
334                            val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy                            val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
335                              fun mkReductionLoop (redOp, bodyStms, bodyResult, stms) = let
336                                    val {rator, init, mvs} = Util.reductionInfo redOp
337                                    val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)
338                            val initStm = S.S_Var(acc, SOME(S.E_Lit init))                            val initStm = S.S_Var(acc, SOME(S.E_Lit init))
339                            val updateStm = S.S_Assign(acc,                            val updateStm = S.S_Assign(acc,
340                                  S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))                                  S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))
341                            val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))                                  val foreachStm = S.S_Foreach(cvtVar x, xs,
342                                          mkBlock(updateStm :: bodyStms))
343                            in                            in
344                              (foreachStm :: initStm :: stms, S.E_Var acc)                              (foreachStm :: initStm :: stms, S.E_Var acc)
345                            end                            end
346                              in
347                                case Util.identifyReduction rator
348                                 of Util.MEAN => let
349                                      val (stms, S.E_Var resultV) = mkReductionLoop (
350                                            Reductions.SUM, bodyStms, bodyResult, stms)
351                                      val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
352                                      val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
353                                      val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
354                                      val stms =
355                                            mkRDiv (mean, resultV, rNum) ::
356                                            mkToReal (rNum, num) ::
357                                            mkLength (num, elemTy, xs) ::
358                                            stms
359                                      in
360                                        (stms, S.E_Var mean)
361                                      end
362                                  | Util.VARIANCE => raise Fail "FIXME: VARIANCE"
363                                  | Util.RED red => mkReductionLoop (red, bodyStms, bodyResult, stms)
364                                (* end case *)
365                              end
366                          else doPrimApply (rator, tyArgs, args, ty)                          else doPrimApply (rator, tyArgs, args, ty)
367                      | AST.E_ParallelMap(e', x, xs, _) =>                      | AST.E_ParallelMap(e', x, xs, _) =>
368                          if Basis.isReductionOp rator                          if Basis.isReductionOp rator
369                            then let                            then let
370                              val (result, stm) = simplifyReduction (cxt, rator, e', x, xs, ty)                              val (result, stms) = simplifyReduction (cxt, rator, e', x, xs, ty, stms)
371                              in                              in
372                                (stm :: stms, S.E_Var result)                                (stms, S.E_Var result)
373                              end                              end
374                            else raise Fail "unsupported operation on parallel map"                            else raise Fail "unsupported operation on parallel map"
375                      | _ => doPrimApply (rator, tyArgs, args, ty)                      | _ => doPrimApply (rator, tyArgs, args, ty)
# Line 364  Line 444 
444                          val dim = II.dim info                          val dim = II.dim info
445                          val shape = II.voxelShape info                          val shape = II.voxelShape info
446                          in                          in
447                            case NrrdInfo.getInfo (#errStrm cxt, nrrd)                            case getNrrdInfo (cxt, nrrd)
448                             of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)                             of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
449                                   of NONE => (                                   of NONE => (
450                                        error (cxt, [                                        badImageNrrd (cxt, nrrd, nrrdInfo, dim, shape);
                                           "nrrd file \"", nrrd, "\" does not have expected type"  
                                         ]);  
451                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))                                        (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
452                                    | SOME imgInfo =>                                    | SOME imgInfo =>
453                                        (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))                                        (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
454                                  (* end case *))                                  (* end case *))
455                              | NONE => (                              | NONE => (
456                                  warning (cxt, [                                  error (cxt, [
457                                      "nrrd file \"", nrrd, "\" does not exist"                                      "proxy-image file \"", nrrd, "\" does not exist"
458                                    ]);                                    ]);
459                                  (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))                                  (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
460                            (* end case *)                            (* end case *)
# Line 423  Line 501 
501            end            end
502    
503    (* simplify a parallel map-reduce *)    (* simplify a parallel map-reduce *)
504      and simplifyReduction (cxt, rator, e, x, xs, resTy) = let      and simplifyReduction (cxt, rator, e, x, xs, resTy, stms) = let
             val rator' = if Var.same(BV.red_all, rator) then Reductions.ALL  
                   else if Var.same(BV.red_exists, rator) then Reductions.EXISTS  
                   else if Var.same(BV.red_max, rator) then Reductions.MAX  
                   else if Var.same(BV.red_mean, rator) then raise Fail "FIXME: mean reduction"  
                   else if Var.same(BV.red_min, rator) then Reductions.MIN  
                   else if Var.same(BV.red_product, rator) then Reductions.PRODUCT  
                   else if Var.same(BV.red_sum, rator) then Reductions.SUM  
                   else if Var.same(BV.red_variance, rator) then raise Fail "FIXME: variance reduction"  
                     else raise Fail "impossible: not a reduction"  
             val x' = cvtVar x  
505              val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)              val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
506                val x' = cvtVar x
507              val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])              val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
508  (* FIXME: need to handle reductions over active/stable subsets of strands *)            (* convert the domain from a variable to a StrandSets.t value *)
509                val domain = if Var.same(BV.set_active, xs) then StrandSets.ACTIVE
510                      else if Var.same(BV.set_all, xs) then StrandSets.ALL
511                      else if Var.same(BV.set_stable, xs) then StrandSets.STABLE
512                        else raise Fail "impossible: not a strand set"
513              val (func, args) = Util.makeFunction(              val (func, args) = Util.makeFunction(
514                    Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),                    Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
515                    SimpleVar.typeOf bodyResult)                    SimpleVar.typeOf bodyResult)
516              val mapReduceStm = S.S_MapReduce{  
                     results = [result],  
                     reductions = [rator'],  
                     body = func,  
                     args = args,  
                     source = x'  
                   }  
517              in              in
518                (result, mapReduceStm)                case Util.identifyReduction rator
519                   of Util.MEAN => let
520                        val mapReduceStm = S.S_MapReduce[
521                                S.MapReduce{
522                                    result = result, reduction = Reductions.SUM, mapf = func,
523                                    args = args, source = x', domain = domain
524                                  }]
525                        val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
526                        val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
527                        val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
528                        val numStrandsOp = (case domain
529                               of StrandSets.ACTIVE => BV.numActive
530                                | StrandSets.ALL => BV.numStrands
531                                | StrandSets.STABLE => BV.numStable
532                              (* end case *))
533                        val stms =
534                              mkRDiv (mean, result, rNum) ::
535                              mkToReal (rNum, num) ::
536                              mkDef (num, S.E_Prim(numStrandsOp, [], [], STy.T_Int)) ::
537                              mapReduceStm ::
538                              stms
539                        in
540                          (mean, stms)
541                        end
542                    | Util.VARIANCE => raise Fail "FIXME: variance reduction"
543                    | Util.RED rator' => let
544                        val mapReduceStm = S.S_MapReduce[
545                                S.MapReduce{
546                                    result = result, reduction = rator', mapf = func, args = args,
547                                    source = x', domain = domain
548                                  }]
549                        in
550                          (result, mapReduceStm :: stms)
551                        end
552                  (* end case *)
553              end              end
554    
555    (* simplify a block and then prune unreachable and dead code *)    (* simplify a block and then prune unreachable and dead code *)
# Line 491  Line 592 
592            val AST.Program{            val AST.Program{
593                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update                    props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
594                  } = prog                  } = prog
           val cxt = {errStrm = errStrm, gEnv = gEnv}  
595            val consts' = ref[]            val consts' = ref[]
596            val constInit = ref[]            val constInit = ref[]
597            val inputs' = ref[]            val inputs' = ref[]
598            val globals' = ref[]            val globals' = ref[]
599            val globalInit = ref[]            val globalInit = ref[]
600            val funcs = ref[]            val funcs = ref[]
601            fun simplifyConstDcl (x, SOME e) = let          (* simplify the constant dcls: the small constants will be added to the context
602             * while the large constants will be added to the const' list.
603             *)
604              val cxt = let
605                    val cxt = Cxt{errStrm = errStrm, gEnv = gEnv, cEnv = VMap.empty}
606                    fun simplifyConstDcl ((x, SOME e), cxt) = if Util.isSmallExp e
607                          then insertConst (cxt, x, e)
608                          else let
609                  val (stms, e') = simplifyExp (cxt, e, [])                  val (stms, e') = simplifyExp (cxt, e, [])
610                  val x' = cvtVar x                  val x' = cvtVar x
611                  in                  in
612                    consts' := x' :: !consts';                    consts' := x' :: !consts';
613                    constInit := S.S_Assign(x', e') :: (stms @ !constInit)                            constInit := S.S_Assign(x', e') :: (stms @ !constInit);
614                              cxt
615                            end
616                      | simplifyConstDcl _ = raise Fail "impossble"
617                    in
618                      List.foldl simplifyConstDcl cxt const_dcls
619                  end                  end
620            fun simplifyInputDcl ((x, NONE), desc) = let            fun simplifyInputDcl ((x, NONE), desc) = let
621                  val x' = cvtVar x                  val x' = cvtVar x
622                  val init = (case SimpleVar.typeOf x'                  val init = (case SimpleVar.typeOf x'
623                         of STy.T_Image info => S.Image info                         of STy.T_Image info => (
624                                warning(cxt, [
625                                    "assuming a sample type of ", RawTypes.toString(II.sampleTy info),
626                                    " for '", SimpleVar.nameOf x',
627                                    "'; specify a proxy-image file to override the default sample type"
628                                  ]);
629                                S.Image info)
630                          | _ => S.NoDefault                          | _ => S.NoDefault
631                        (* end case *))                        (* end case *))
632                  val inp = S.INP{                  val inp = S.INP{
# Line 528  Line 646 
646                              val dim = TU.monoDim dim                              val dim = TU.monoDim dim
647                              val shape = TU.monoShape shape                              val shape = TU.monoShape shape
648                              in                              in
649                                case NrrdInfo.getInfo (#errStrm cxt, nrrd)                                case getNrrdInfo (cxt, nrrd)
650                                 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)                                 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
651                                       of NONE => (                                       of NONE => (
652                                            error (cxt, [                                            badImageNrrd (cxt, nrrd, nrrdInfo, dim, shape);
                                               "proxy nrrd file \"", nrrd,  
                                               "\" does not have expected type"  
                                             ]);  
653                                            (cvtVar x, S.Image(II.mkInfo(dim, shape))))                                            (cvtVar x, S.Image(II.mkInfo(dim, shape))))
654                                        | SOME info =>                                        | SOME info =>
655                                            (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))                                            (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
656                                      (* end case *))                                      (* end case *))
657                                  | NONE => (                                  | NONE => (
658                                      warning (cxt, [                                      error (cxt, [
659                                          "proxy nrrd file \"", nrrd, "\" does not exist"                                          "proxy-image file \"", nrrd, "\" does not exist"
660                                        ]);                                        ]);
661                                      (cvtVar x, S.Image(II.mkInfo(dim, shape))))                                      (cvtVar x, S.Image(II.mkInfo(dim, shape))))
662                                (* end case *)                                (* end case *)
# Line 588  Line 703 
703                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs                    funcs := S.Func{f=f', params=params', body=body'} :: !funcs
704                  end                  end
705            val () = (            val () = (
                 List.app simplifyConstDcl const_dcls;  
706                  List.app simplifyInputDcl input_dcls;                  List.app simplifyInputDcl input_dcls;
707                  List.app simplifyGlobalDcl globals)                  List.app simplifyGlobalDcl globals)
708          (* make the global-initialization block *)          (* make the global-initialization block *)

Legend:
Removed from v.4371  
changed lines
  Added in v.4441

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0