Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5564 - (view) (download)

1 : jhr 3464 (* util.sml
2 :     *
3 :     * Utility code for Simplification.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure Util : sig
12 :    
13 : jhr 4393 (* the standard reductions extended with pseudo reductions *)
14 :     datatype reduction = MEAN | VARIANCE | RED of Reductions.t
15 :    
16 :     (* identify a basis variable that specifies a reduction *)
17 :     val identifyReduction : Var.t -> reduction
18 :    
19 :     (* return information about how to compute a reduction operator *)
20 :     val reductionInfo : Reductions.t -> {
21 : jhr 4317 rator : Var.t, (* primitive operator *)
22 :     init : Literal.t, (* identity element to use for initialization *)
23 :     mvs : SimpleTypes.meta_arg list (* meta-variable arguments for primitive application *)
24 :     }
25 : jhr 3464
26 : jhr 4591 (* convert a block, which is the map part of a map-reduce, into a function by closing over
27 :     * its free variables. The simple variable argument is the strand object, which should always
28 :     * be the first parameter (even if it is not referenced in the block).
29 :     *)
30 :     val makeFunction : string * Simple.var * Simple.block * SimpleTypes.ty
31 : jhr 4628 -> Simple.func_def * Simple.var list
32 : jhr 3465
33 : jhr 4441 (* return true if an AST constant expression is "small" *)
34 :     val isSmallExp : AST.expr -> bool
35 :    
36 : jhr 3464 end = struct
37 :    
38 : jhr 3465 structure S = Simple
39 : jhr 3464 structure BV = BasisVars
40 :     structure L = Literal
41 :     structure R = RealLit
42 : jhr 3465 structure VMap = SimpleVar.Map
43 : jhr 3464
44 : jhr 4393 datatype reduction = MEAN | VARIANCE | RED of Reductions.t
45 :    
46 :     fun identifyReduction rator =
47 : jhr 4317 if Var.same(BV.red_all, rator)
48 : jhr 4393 then RED(Reductions.ALL)
49 : jhr 4317 else if Var.same(BV.red_exists, rator)
50 : jhr 4393 then RED(Reductions.EXISTS)
51 : jhr 4588 else if Var.same(BV.red_max_i, rator)
52 :     then RED(Reductions.IMAX)
53 :     else if Var.same(BV.red_max_r, rator)
54 :     then RED(Reductions.RMAX)
55 : jhr 4317 else if Var.same(BV.red_mean, rator)
56 : jhr 4393 then MEAN
57 : jhr 4588 else if Var.same(BV.red_min_i, rator)
58 :     then RED(Reductions.IMIN)
59 :     else if Var.same(BV.red_min_r, rator)
60 :     then RED(Reductions.RMIN)
61 :     else if Var.same(BV.red_product_i, rator)
62 :     then RED(Reductions.IPRODUCT)
63 :     else if Var.same(BV.red_product_r, rator)
64 :     then RED(Reductions.RPRODUCT)
65 :     else if Var.same(BV.red_sum_i, rator)
66 :     then RED(Reductions.ISUM)
67 :     else if Var.same(BV.red_sum_r, rator)
68 :     then RED(Reductions.RSUM)
69 :     (* FIXME: variance not supported yet
70 : jhr 4317 else if Var.same(BV.red_variance, rator)
71 : jhr 4393 then VARIANCE
72 : jhr 4588 *)
73 : jhr 4317 else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator")
74 : jhr 3464
75 : jhr 4588 val mvsReal = [SimpleTypes.SHAPE[]]
76 : jhr 4393
77 : jhr 4588 fun reductionInfo redOp = let
78 : jhr 4589 val id = Reductions.identity redOp
79 :     in
80 :     case redOp
81 :     of Reductions.ALL => {rator = BV.op_and, init = id, mvs = []}
82 :     | Reductions.EXISTS => {rator = BV.op_or, init = id, mvs = []}
83 :     | Reductions.IMAX => {rator = BV.fn_max_i, init = id, mvs = []}
84 :     | Reductions.RMAX => {rator = BV.fn_max_r, init = id, mvs = []}
85 :     | Reductions.IMIN => {rator = BV.fn_min_i, init = id, mvs = []}
86 :     | Reductions.RMIN => {rator = BV.fn_min_r, init = id, mvs = []}
87 :     | Reductions.IPRODUCT => {rator = BV.mul_ii, init = id, mvs = []}
88 : jhr 5219 | Reductions.RPRODUCT => {rator = BV.mul_rr, init = id, mvs = mvsReal}
89 :     | Reductions.ISUM => {rator = BV.add_ii, init = id, mvs = []}
90 : jhr 4589 | Reductions.RSUM => {rator = BV.add_tt, init = id, mvs = mvsReal}
91 :     (* end case *)
92 :     end
93 : jhr 4588
94 : jhr 3465 local
95 :     val n = ref 0
96 :     fun mkFuncId (name, ty) = let val id = !n
97 : jhr 4317 in
98 :     n := id + 1;
99 :     SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty)
100 :     end
101 : jhr 3465 in
102 : jhr 4591 fun makeFunction (name, strand, blk, resTy) = let
103 : jhr 4317 val freeVars = ref []
104 :     fun cvtVar (env, x) = (case VMap.find(env, x)
105 :     of SOME x' => (env, x')
106 :     | NONE => let
107 :     val x' = SimpleVar.copy(x, SimpleVar.FunParam)
108 :     in
109 :     freeVars := (x, x') :: !freeVars;
110 :     (VMap.insert(env, x, x'), x')
111 :     end
112 :     (* end case *))
113 :     fun cvtVars (env, xs) = let
114 :     fun cvt (x, (env, xs')) = let
115 :     val (env, x') = cvtVar (env, x)
116 :     in
117 :     (env, x'::xs')
118 :     end
119 :     in
120 :     List.foldr cvt (env, []) xs
121 :     end
122 :     fun newVar (env, x) = let
123 :     val x' = SimpleVar.copy(x, SimpleVar.LocalVar)
124 :     in
125 :     (VMap.insert(env, x, x'), x')
126 :     end
127 :     fun cvtBlock (env, S.Block{props, code}) = let
128 :     fun cvtStms (env, [], stms') = (env, S.Block{props = props, code = List.rev stms'})
129 :     | cvtStms (env, stm::stms, stms') = let
130 :     val (env, stm') = cvtStm (env, stm)
131 :     in
132 :     cvtStms (env, stms, stm'::stms')
133 :     end
134 :     in
135 :     cvtStms (env, code, [])
136 :     end
137 :     and cvtStm (env, stm) = (case stm
138 :     of S.S_Var(x, NONE) => let
139 :     val (env, x') = newVar (env, x)
140 :     in
141 :     (env, S.S_Var(x', NONE))
142 :     end
143 :     | S.S_Var(x, SOME e) => let
144 :     val (env, e') = cvtExp (env, e)
145 :     val (env, x') = newVar (env, x)
146 :     in
147 :     (env, S.S_Var(x', SOME e'))
148 :     end
149 :     | S.S_Assign(x, e) => let
150 :     val (env, e') = cvtExp (env, e)
151 :     val (env, x') = cvtVar (env, x)
152 :     in
153 :     (env, S.S_Assign(x', e'))
154 :     end
155 :     | S.S_IfThenElse(x, b1, b2) => let
156 :     val (env, x') = cvtVar (env, x)
157 :     val (env, b1') = cvtBlock (env, b1)
158 :     val (env, b2') = cvtBlock (env, b2)
159 :     in
160 :     (env, S.S_IfThenElse(x', b1', b2'))
161 :     end
162 :     | S.S_Foreach(x, xs, b) => let
163 :     val (env, x') = cvtVar (env, x)
164 :     val (env, xs') = cvtVar (env, xs)
165 :     val (env, b') = cvtBlock (env, b)
166 :     in
167 :     (env, S.S_Foreach(x', xs', b'))
168 :     end
169 :     | S.S_New(name, args) => let
170 :     val (env, args') = cvtVars (env, args)
171 :     in
172 :     (env, S.S_New(name, args'))
173 :     end
174 : jhr 4628 | S.S_KillAll => (env, stm)
175 : jhr 4480 | S.S_StabilizeAll => (env, stm)
176 : jhr 4317 | S.S_Continue => (env, stm)
177 :     | S.S_Die => (env, stm)
178 :     | S.S_Stabilize => (env, stm)
179 :     | S.S_Return x => let
180 :     val (env, x') = cvtVar (env, x)
181 :     in
182 :     (env, S.S_Return x')
183 :     end
184 :     | S.S_Print xs => let
185 :     val (env, xs') = cvtVars (env, xs)
186 :     in
187 :     (env, S.S_Print xs')
188 :     end
189 :     | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce"
190 :     (* end case *))
191 :     and cvtExp (env, exp) = (case exp
192 :     of S.E_Var x => let
193 :     val (env, x') = cvtVar (env, x)
194 :     in
195 :     (env, S.E_Var x')
196 :     end
197 :     | S.E_Lit _ => (env, exp)
198 :     | S.E_Kernel _ => (env, exp)
199 :     | S.E_Select(x, fld) => let
200 :     val (env, x') = cvtVar (env, x)
201 :     in
202 :     (env, S.E_Select(x', fld))
203 :     end
204 :     | S.E_Apply(f, args) => let
205 :     val (env, args') = cvtVars (env, args)
206 :     in
207 :     (env, S.E_Apply(f, args'))
208 :     end
209 :     | S.E_Prim(f, mvs, args, ty) => let
210 :     val (env, args') = cvtVars (env, args)
211 :     in
212 :     (env, S.E_Prim(f, mvs, args', ty))
213 :     end
214 :     | S.E_Tensor(args, ty) => let
215 :     val (env, args') = cvtVars (env, args)
216 :     in
217 :     (env, S.E_Tensor(args', ty))
218 :     end
219 :     | S.E_Seq(args, ty) => let
220 :     val (env, args') = cvtVars (env, args)
221 :     in
222 :     (env, S.E_Seq(args', ty))
223 :     end
224 :     | S.E_Tuple xs => let
225 :     val (env, xs') = cvtVars (env, xs)
226 :     in
227 :     (env, S.E_Tuple xs')
228 :     end
229 :     | S.E_Project(x, i) => let
230 :     val (env, x') = cvtVar (env, x)
231 :     in
232 :     (env, S.E_Project(x', i))
233 :     end
234 :     | S.E_Slice(x, indices, ty) => let
235 :     fun cvt (NONE, (env, idxs)) = (env, NONE::idxs)
236 :     | cvt (SOME x, (env, idxs)) = let
237 :     val (env, x') = cvtVar (env, x)
238 :     in
239 :     (env, SOME x' :: idxs)
240 :     end
241 :     val (env, x') = cvtVar (env, x)
242 :     in
243 :     (env, S.E_Slice(x', indices, ty))
244 :     end
245 :     | S.E_Coerce{srcTy, dstTy, x} => let
246 :     val (env, x') = cvtVar (env, x)
247 :     in
248 :     (env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'})
249 :     end
250 :     | S.E_BorderCtl(BorderCtl.Default x, y) => let
251 :     val (env, x') = cvtVar (env, x)
252 :     val (env, y') = cvtVar (env, y)
253 :     in
254 :     (env, S.E_BorderCtl(BorderCtl.Default x', y'))
255 :     end
256 :     | S.E_BorderCtl(ctl, x) => let
257 :     val (env, x') = cvtVar (env, x)
258 :     in
259 :     (env, S.E_BorderCtl(ctl, x'))
260 :     end
261 :     | S.E_LoadSeq _ => (env, exp)
262 :     | S.E_LoadImage _ => (env, exp)
263 :     | S.E_InsideImage(pos, img, s) => let
264 :     val (env, pos') = cvtVar (env, pos)
265 :     val (env, img') = cvtVar (env, img)
266 :     in
267 :     (env, S.E_InsideImage(pos', img', s))
268 :     end
269 : jhr 5564 | S.E_FieldFn _ => (env, exp)
270 : jhr 4317 (* end case *))
271 : jhr 4628 (* the initial environment always includes the strand variable *)
272 :     val (env, _) = cvtVar (VMap.empty, strand)
273 : jhr 4591 val (env, blk) = cvtBlock (env, blk)
274 : jhr 4317 val (args, params) = ListPair.unzip (List.rev (! freeVars))
275 :     val f = SimpleFunc.new (name, resTy, List.map SimpleVar.typeOf params)
276 :     in
277 :     (S.Func{f=f, params=params, body=blk}, args)
278 :     end
279 : jhr 3465 end (* local *)
280 :    
281 : jhr 4441 fun isSmallExp (AST.E_Lit _) = true
282 :     | isSmallExp (AST.E_Tensor(exps, _)) = (List.length exps <= 4)
283 :     | isSmallExp _ = false
284 :    
285 : jhr 3464 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0