Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4588 - (view) (download)

1 : jhr 3464 (* util.sml
2 :     *
3 :     * Utility code for Simplification.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure Util : sig
12 :    
13 : jhr 4393 (* the standard reductions extended with pseudo reductions *)
14 :     datatype reduction = MEAN | VARIANCE | RED of Reductions.t
15 :    
16 :     (* identify a basis variable that specifies a reduction *)
17 :     val identifyReduction : Var.t -> reduction
18 :    
19 :     (* return information about how to compute a reduction operator *)
20 :     val reductionInfo : Reductions.t -> {
21 : jhr 4317 rator : Var.t, (* primitive operator *)
22 :     init : Literal.t, (* identity element to use for initialization *)
23 :     mvs : SimpleTypes.meta_arg list (* meta-variable arguments for primitive application *)
24 :     }
25 : jhr 3464
26 : jhr 3465 (* convert a block into a function by closing over its free variables *)
27 : jhr 4163 val makeFunction : string * Simple.block * SimpleTypes.ty -> Simple.func_def * Simple.var list
28 : jhr 3465
29 : jhr 4441 (* return true if an AST constant expression is "small" *)
30 :     val isSmallExp : AST.expr -> bool
31 :    
32 : jhr 3464 end = struct
33 :    
34 : jhr 3465 structure S = Simple
35 : jhr 3464 structure BV = BasisVars
36 :     structure L = Literal
37 :     structure R = RealLit
38 : jhr 3465 structure VMap = SimpleVar.Map
39 : jhr 3464
40 : jhr 4393 datatype reduction = MEAN | VARIANCE | RED of Reductions.t
41 :    
42 :     fun identifyReduction rator =
43 : jhr 4317 if Var.same(BV.red_all, rator)
44 : jhr 4393 then RED(Reductions.ALL)
45 : jhr 4317 else if Var.same(BV.red_exists, rator)
46 : jhr 4393 then RED(Reductions.EXISTS)
47 : jhr 4588 else if Var.same(BV.red_max_i, rator)
48 :     then RED(Reductions.IMAX)
49 :     else if Var.same(BV.red_max_r, rator)
50 :     then RED(Reductions.RMAX)
51 : jhr 4317 else if Var.same(BV.red_mean, rator)
52 : jhr 4393 then MEAN
53 : jhr 4588 else if Var.same(BV.red_min_i, rator)
54 :     then RED(Reductions.IMIN)
55 :     else if Var.same(BV.red_min_r, rator)
56 :     then RED(Reductions.RMIN)
57 :     else if Var.same(BV.red_product_i, rator)
58 :     then RED(Reductions.IPRODUCT)
59 :     else if Var.same(BV.red_product_r, rator)
60 :     then RED(Reductions.RPRODUCT)
61 :     else if Var.same(BV.red_sum_i, rator)
62 :     then RED(Reductions.ISUM)
63 :     else if Var.same(BV.red_sum_r, rator)
64 :     then RED(Reductions.RSUM)
65 :     (* FIXME: variance not supported yet
66 : jhr 4317 else if Var.same(BV.red_variance, rator)
67 : jhr 4393 then VARIANCE
68 : jhr 4588 *)
69 : jhr 4317 else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator")
70 : jhr 3464
71 : jhr 4588 val mvsReal = [SimpleTypes.SHAPE[]]
72 : jhr 4393
73 : jhr 4588 fun reductionInfo redOp = let
74 :     val id = Reductions.identity redOp
75 :     in
76 :     case redOp
77 :     of Reductions.ALL => {rator = BV.op_and, init = id, mvs = []}
78 :     | Reductions.EXISTS => {rator = BV.op_or, init = id, mvs = []}
79 :     | Reductions.IMAX => {rator = BV.fn_max_i, init = id, mvs = []}
80 :     | Reductions.RMAX => {rator = BV.fn_max_r, init = id, mvs = []}
81 :     | Reductions.IMIN => {rator = BV.fn_min_i, init = id, mvs = []}
82 :     | Reductions.RMIN => {rator = BV.fn_min_r, init = id, mvs = []}
83 :     | Reductions.IPRODUCT => {rator = BV.mul_ii, init = id, mvs = []}
84 :     | Reductions.RPRODUCT => {rator = BV.mul_rr, init = id, mvs = []}
85 :     | Reductions.ISUM => {rator = BV.add_ii, init = id, mvs = mvsReal}
86 :     | Reductions.RSUM => {rator = BV.add_tt, init = id, mvs = mvsReal}
87 :     (* end case *)
88 :     end
89 :    
90 : jhr 3465 local
91 :     val n = ref 0
92 :     fun mkFuncId (name, ty) = let val id = !n
93 : jhr 4317 in
94 :     n := id + 1;
95 :     SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty)
96 :     end
97 : jhr 3465 in
98 :     fun makeFunction (name, blk, resTy) = let
99 : jhr 4317 val freeVars = ref []
100 :     fun cvtVar (env, x) = (case VMap.find(env, x)
101 :     of SOME x' => (env, x')
102 :     | NONE => let
103 :     val x' = SimpleVar.copy(x, SimpleVar.FunParam)
104 :     in
105 :     freeVars := (x, x') :: !freeVars;
106 :     (VMap.insert(env, x, x'), x')
107 :     end
108 :     (* end case *))
109 :     fun cvtVars (env, xs) = let
110 :     fun cvt (x, (env, xs')) = let
111 :     val (env, x') = cvtVar (env, x)
112 :     in
113 :     (env, x'::xs')
114 :     end
115 :     in
116 :     List.foldr cvt (env, []) xs
117 :     end
118 :     fun newVar (env, x) = let
119 :     val x' = SimpleVar.copy(x, SimpleVar.LocalVar)
120 :     in
121 :     (VMap.insert(env, x, x'), x')
122 :     end
123 :     fun cvtBlock (env, S.Block{props, code}) = let
124 :     fun cvtStms (env, [], stms') = (env, S.Block{props = props, code = List.rev stms'})
125 :     | cvtStms (env, stm::stms, stms') = let
126 :     val (env, stm') = cvtStm (env, stm)
127 :     in
128 :     cvtStms (env, stms, stm'::stms')
129 :     end
130 :     in
131 :     cvtStms (env, code, [])
132 :     end
133 :     and cvtStm (env, stm) = (case stm
134 :     of S.S_Var(x, NONE) => let
135 :     val (env, x') = newVar (env, x)
136 :     in
137 :     (env, S.S_Var(x', NONE))
138 :     end
139 :     | S.S_Var(x, SOME e) => let
140 :     val (env, e') = cvtExp (env, e)
141 :     val (env, x') = newVar (env, x)
142 :     in
143 :     (env, S.S_Var(x', SOME e'))
144 :     end
145 :     | S.S_Assign(x, e) => let
146 :     val (env, e') = cvtExp (env, e)
147 :     val (env, x') = cvtVar (env, x)
148 :     in
149 :     (env, S.S_Assign(x', e'))
150 :     end
151 :     | S.S_IfThenElse(x, b1, b2) => let
152 :     val (env, x') = cvtVar (env, x)
153 :     val (env, b1') = cvtBlock (env, b1)
154 :     val (env, b2') = cvtBlock (env, b2)
155 :     in
156 :     (env, S.S_IfThenElse(x', b1', b2'))
157 :     end
158 :     | S.S_Foreach(x, xs, b) => let
159 :     val (env, x') = cvtVar (env, x)
160 :     val (env, xs') = cvtVar (env, xs)
161 :     val (env, b') = cvtBlock (env, b)
162 :     in
163 :     (env, S.S_Foreach(x', xs', b'))
164 :     end
165 :     | S.S_New(name, args) => let
166 :     val (env, args') = cvtVars (env, args)
167 :     in
168 :     (env, S.S_New(name, args'))
169 :     end
170 : jhr 4480 | S.S_StabilizeAll => (env, stm)
171 : jhr 4317 | S.S_Continue => (env, stm)
172 :     | S.S_Die => (env, stm)
173 :     | S.S_Stabilize => (env, stm)
174 :     | S.S_Return x => let
175 :     val (env, x') = cvtVar (env, x)
176 :     in
177 :     (env, S.S_Return x')
178 :     end
179 :     | S.S_Print xs => let
180 :     val (env, xs') = cvtVars (env, xs)
181 :     in
182 :     (env, S.S_Print xs')
183 :     end
184 :     | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce"
185 :     (* end case *))
186 :     and cvtExp (env, exp) = (case exp
187 :     of S.E_Var x => let
188 :     val (env, x') = cvtVar (env, x)
189 :     in
190 :     (env, S.E_Var x')
191 :     end
192 :     | S.E_Lit _ => (env, exp)
193 :     | S.E_Kernel _ => (env, exp)
194 :     | S.E_Select(x, fld) => let
195 :     val (env, x') = cvtVar (env, x)
196 :     in
197 :     (env, S.E_Select(x', fld))
198 :     end
199 :     | S.E_Apply(f, args) => let
200 :     val (env, args') = cvtVars (env, args)
201 :     in
202 :     (env, S.E_Apply(f, args'))
203 :     end
204 :     | S.E_Prim(f, mvs, args, ty) => let
205 :     val (env, args') = cvtVars (env, args)
206 :     in
207 :     (env, S.E_Prim(f, mvs, args', ty))
208 :     end
209 :     | S.E_Tensor(args, ty) => let
210 :     val (env, args') = cvtVars (env, args)
211 :     in
212 :     (env, S.E_Tensor(args', ty))
213 :     end
214 :     | S.E_Seq(args, ty) => let
215 :     val (env, args') = cvtVars (env, args)
216 :     in
217 :     (env, S.E_Seq(args', ty))
218 :     end
219 :     | S.E_Tuple xs => let
220 :     val (env, xs') = cvtVars (env, xs)
221 :     in
222 :     (env, S.E_Tuple xs')
223 :     end
224 :     | S.E_Project(x, i) => let
225 :     val (env, x') = cvtVar (env, x)
226 :     in
227 :     (env, S.E_Project(x', i))
228 :     end
229 :     | S.E_Slice(x, indices, ty) => let
230 :     fun cvt (NONE, (env, idxs)) = (env, NONE::idxs)
231 :     | cvt (SOME x, (env, idxs)) = let
232 :     val (env, x') = cvtVar (env, x)
233 :     in
234 :     (env, SOME x' :: idxs)
235 :     end
236 :     val (env, x') = cvtVar (env, x)
237 :     in
238 :     (env, S.E_Slice(x', indices, ty))
239 :     end
240 :     | S.E_Coerce{srcTy, dstTy, x} => let
241 :     val (env, x') = cvtVar (env, x)
242 :     in
243 :     (env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'})
244 :     end
245 :     | S.E_BorderCtl(BorderCtl.Default x, y) => let
246 :     val (env, x') = cvtVar (env, x)
247 :     val (env, y') = cvtVar (env, y)
248 :     in
249 :     (env, S.E_BorderCtl(BorderCtl.Default x', y'))
250 :     end
251 :     | S.E_BorderCtl(ctl, x) => let
252 :     val (env, x') = cvtVar (env, x)
253 :     in
254 :     (env, S.E_BorderCtl(ctl, x'))
255 :     end
256 :     | S.E_LoadSeq _ => (env, exp)
257 :     | S.E_LoadImage _ => (env, exp)
258 :     | S.E_InsideImage(pos, img, s) => let
259 :     val (env, pos') = cvtVar (env, pos)
260 :     val (env, img') = cvtVar (env, img)
261 :     in
262 :     (env, S.E_InsideImage(pos', img', s))
263 :     end
264 :     (* end case *))
265 :     val (env, blk) = cvtBlock (VMap.empty, blk)
266 :     val (args, params) = ListPair.unzip (List.rev (! freeVars))
267 :     val f = SimpleFunc.new (name, resTy, List.map SimpleVar.typeOf params)
268 :     in
269 :     (S.Func{f=f, params=params, body=blk}, args)
270 :     end
271 : jhr 3465 end (* local *)
272 :    
273 : jhr 4441 fun isSmallExp (AST.E_Lit _) = true
274 :     | isSmallExp (AST.E_Tensor(exps, _)) = (List.length exps <= 4)
275 :     | isSmallExp _ = false
276 :    
277 : jhr 3464 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0