Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4441 - (view) (download)

1 : jhr 3464 (* util.sml
2 :     *
3 :     * Utility code for Simplification.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure Util : sig
12 :    
13 : jhr 4393 (* the standard reductions extended with pseudo reductions *)
14 :     datatype reduction = MEAN | VARIANCE | RED of Reductions.t
15 :    
16 :     (* identify a basis variable that specifies a reduction *)
17 :     val identifyReduction : Var.t -> reduction
18 :    
19 :     (* return information about how to compute a reduction operator *)
20 :     val reductionInfo : Reductions.t -> {
21 : jhr 4317 rator : Var.t, (* primitive operator *)
22 :     init : Literal.t, (* identity element to use for initialization *)
23 :     mvs : SimpleTypes.meta_arg list (* meta-variable arguments for primitive application *)
24 :     }
25 : jhr 3464
26 : jhr 3465 (* convert a block into a function by closing over its free variables *)
27 : jhr 4163 val makeFunction : string * Simple.block * SimpleTypes.ty -> Simple.func_def * Simple.var list
28 : jhr 3465
29 : jhr 4441 (* return true if an AST constant expression is "small" *)
30 :     val isSmallExp : AST.expr -> bool
31 :    
32 : jhr 3464 end = struct
33 :    
34 : jhr 3465 structure S = Simple
35 : jhr 3464 structure BV = BasisVars
36 :     structure L = Literal
37 :     structure R = RealLit
38 : jhr 3465 structure VMap = SimpleVar.Map
39 : jhr 3464
40 : jhr 4393 datatype reduction = MEAN | VARIANCE | RED of Reductions.t
41 :    
42 :     fun identifyReduction rator =
43 : jhr 4317 if Var.same(BV.red_all, rator)
44 : jhr 4393 then RED(Reductions.ALL)
45 : jhr 4317 else if Var.same(BV.red_exists, rator)
46 : jhr 4393 then RED(Reductions.EXISTS)
47 : jhr 4317 else if Var.same(BV.red_max, rator)
48 : jhr 4393 then RED(Reductions.MAX)
49 : jhr 4317 else if Var.same(BV.red_mean, rator)
50 : jhr 4393 then MEAN
51 : jhr 4317 else if Var.same(BV.red_min, rator)
52 : jhr 4393 then RED(Reductions.MIN)
53 : jhr 4317 else if Var.same(BV.red_product, rator)
54 : jhr 4393 then RED(Reductions.PRODUCT)
55 : jhr 4317 else if Var.same(BV.red_sum, rator)
56 : jhr 4393 then RED(Reductions.SUM)
57 : jhr 4317 else if Var.same(BV.red_variance, rator)
58 : jhr 4393 then VARIANCE
59 : jhr 4317 else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator")
60 : jhr 3464
61 : jhr 4393 fun reductionInfo redOp = (case redOp
62 :     of Reductions.ALL => {rator = BV.op_and, init = L.Bool true, mvs = []}
63 :     | Reductions.EXISTS => {rator = BV.op_or, init = L.Bool false, mvs = []}
64 :     | Reductions.MAX => {rator = BV.fn_max_r, init = L.Real R.negInf, mvs = []}
65 :     | Reductions.MIN => {rator = BV.fn_min_r, init = L.Real R.posInf, mvs = []}
66 :     | Reductions.PRODUCT => {rator = BV.mul_rr, init = L.Real R.one, mvs = []}
67 :     | Reductions.SUM => {rator = BV.add_tt, init = L.Real R.one, mvs = [SimpleTypes.SHAPE[]]}
68 :     (* end case *))
69 :    
70 : jhr 3465 local
71 :     val n = ref 0
72 :     fun mkFuncId (name, ty) = let val id = !n
73 : jhr 4317 in
74 :     n := id + 1;
75 :     SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty)
76 :     end
77 : jhr 3465 in
78 :     fun makeFunction (name, blk, resTy) = let
79 : jhr 4317 val freeVars = ref []
80 :     fun cvtVar (env, x) = (case VMap.find(env, x)
81 :     of SOME x' => (env, x')
82 :     | NONE => let
83 :     val x' = SimpleVar.copy(x, SimpleVar.FunParam)
84 :     in
85 :     freeVars := (x, x') :: !freeVars;
86 :     (VMap.insert(env, x, x'), x')
87 :     end
88 :     (* end case *))
89 :     fun cvtVars (env, xs) = let
90 :     fun cvt (x, (env, xs')) = let
91 :     val (env, x') = cvtVar (env, x)
92 :     in
93 :     (env, x'::xs')
94 :     end
95 :     in
96 :     List.foldr cvt (env, []) xs
97 :     end
98 :     fun newVar (env, x) = let
99 :     val x' = SimpleVar.copy(x, SimpleVar.LocalVar)
100 :     in
101 :     (VMap.insert(env, x, x'), x')
102 :     end
103 :     fun cvtBlock (env, S.Block{props, code}) = let
104 :     fun cvtStms (env, [], stms') = (env, S.Block{props = props, code = List.rev stms'})
105 :     | cvtStms (env, stm::stms, stms') = let
106 :     val (env, stm') = cvtStm (env, stm)
107 :     in
108 :     cvtStms (env, stms, stm'::stms')
109 :     end
110 :     in
111 :     cvtStms (env, code, [])
112 :     end
113 :     and cvtStm (env, stm) = (case stm
114 :     of S.S_Var(x, NONE) => let
115 :     val (env, x') = newVar (env, x)
116 :     in
117 :     (env, S.S_Var(x', NONE))
118 :     end
119 :     | S.S_Var(x, SOME e) => let
120 :     val (env, e') = cvtExp (env, e)
121 :     val (env, x') = newVar (env, x)
122 :     in
123 :     (env, S.S_Var(x', SOME e'))
124 :     end
125 :     | S.S_Assign(x, e) => let
126 :     val (env, e') = cvtExp (env, e)
127 :     val (env, x') = cvtVar (env, x)
128 :     in
129 :     (env, S.S_Assign(x', e'))
130 :     end
131 :     | S.S_IfThenElse(x, b1, b2) => let
132 :     val (env, x') = cvtVar (env, x)
133 :     val (env, b1') = cvtBlock (env, b1)
134 :     val (env, b2') = cvtBlock (env, b2)
135 :     in
136 :     (env, S.S_IfThenElse(x', b1', b2'))
137 :     end
138 :     | S.S_Foreach(x, xs, b) => let
139 :     val (env, x') = cvtVar (env, x)
140 :     val (env, xs') = cvtVar (env, xs)
141 :     val (env, b') = cvtBlock (env, b)
142 :     in
143 :     (env, S.S_Foreach(x', xs', b'))
144 :     end
145 :     | S.S_New(name, args) => let
146 :     val (env, args') = cvtVars (env, args)
147 :     in
148 :     (env, S.S_New(name, args'))
149 :     end
150 :     | S.S_Continue => (env, stm)
151 :     | S.S_Die => (env, stm)
152 :     | S.S_Stabilize => (env, stm)
153 :     | S.S_Return x => let
154 :     val (env, x') = cvtVar (env, x)
155 :     in
156 :     (env, S.S_Return x')
157 :     end
158 :     | S.S_Print xs => let
159 :     val (env, xs') = cvtVars (env, xs)
160 :     in
161 :     (env, S.S_Print xs')
162 :     end
163 :     | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce"
164 :     (* end case *))
165 :     and cvtExp (env, exp) = (case exp
166 :     of S.E_Var x => let
167 :     val (env, x') = cvtVar (env, x)
168 :     in
169 :     (env, S.E_Var x')
170 :     end
171 :     | S.E_Lit _ => (env, exp)
172 :     | S.E_Kernel _ => (env, exp)
173 :     | S.E_Select(x, fld) => let
174 :     val (env, x') = cvtVar (env, x)
175 :     in
176 :     (env, S.E_Select(x', fld))
177 :     end
178 :     | S.E_Apply(f, args) => let
179 :     val (env, args') = cvtVars (env, args)
180 :     in
181 :     (env, S.E_Apply(f, args'))
182 :     end
183 :     | S.E_Prim(f, mvs, args, ty) => let
184 :     val (env, args') = cvtVars (env, args)
185 :     in
186 :     (env, S.E_Prim(f, mvs, args', ty))
187 :     end
188 :     | S.E_Tensor(args, ty) => let
189 :     val (env, args') = cvtVars (env, args)
190 :     in
191 :     (env, S.E_Tensor(args', ty))
192 :     end
193 :     | S.E_Seq(args, ty) => let
194 :     val (env, args') = cvtVars (env, args)
195 :     in
196 :     (env, S.E_Seq(args', ty))
197 :     end
198 :     | S.E_Tuple xs => let
199 :     val (env, xs') = cvtVars (env, xs)
200 :     in
201 :     (env, S.E_Tuple xs')
202 :     end
203 :     | S.E_Project(x, i) => let
204 :     val (env, x') = cvtVar (env, x)
205 :     in
206 :     (env, S.E_Project(x', i))
207 :     end
208 :     | S.E_Slice(x, indices, ty) => let
209 :     fun cvt (NONE, (env, idxs)) = (env, NONE::idxs)
210 :     | cvt (SOME x, (env, idxs)) = let
211 :     val (env, x') = cvtVar (env, x)
212 :     in
213 :     (env, SOME x' :: idxs)
214 :     end
215 :     val (env, x') = cvtVar (env, x)
216 :     in
217 :     (env, S.E_Slice(x', indices, ty))
218 :     end
219 :     | S.E_Coerce{srcTy, dstTy, x} => let
220 :     val (env, x') = cvtVar (env, x)
221 :     in
222 :     (env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'})
223 :     end
224 :     | S.E_BorderCtl(BorderCtl.Default x, y) => let
225 :     val (env, x') = cvtVar (env, x)
226 :     val (env, y') = cvtVar (env, y)
227 :     in
228 :     (env, S.E_BorderCtl(BorderCtl.Default x', y'))
229 :     end
230 :     | S.E_BorderCtl(ctl, x) => let
231 :     val (env, x') = cvtVar (env, x)
232 :     in
233 :     (env, S.E_BorderCtl(ctl, x'))
234 :     end
235 :     | S.E_LoadSeq _ => (env, exp)
236 :     | S.E_LoadImage _ => (env, exp)
237 :     | S.E_InsideImage(pos, img, s) => let
238 :     val (env, pos') = cvtVar (env, pos)
239 :     val (env, img') = cvtVar (env, img)
240 :     in
241 :     (env, S.E_InsideImage(pos', img', s))
242 :     end
243 :     (* end case *))
244 :     val (env, blk) = cvtBlock (VMap.empty, blk)
245 :     val (args, params) = ListPair.unzip (List.rev (! freeVars))
246 :     val f = SimpleFunc.new (name, resTy, List.map SimpleVar.typeOf params)
247 :     in
248 :     (S.Func{f=f, params=params, body=blk}, args)
249 :     end
250 : jhr 3465 end (* local *)
251 :    
252 : jhr 4441 fun isSmallExp (AST.E_Lit _) = true
253 :     | isSmallExp (AST.E_Tensor(exps, _)) = (List.length exps <= 4)
254 :     | isSmallExp _ = false
255 :    
256 : jhr 3464 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0