Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/chiw17/src/compiler/basis/basis.sml
ViewVC logotype

Annotation of /branches/chiw17/src/compiler/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5030 - (view) (download)

1 : jhr 3399 (* basis.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * Defining the Diderot Basis environment.
9 :     *)
10 :    
11 :     structure Basis : sig
12 :    
13 :     val env : unit -> GlobalEnv.t
14 :    
15 :     (* operations that are allowed in constant expressions *)
16 :     val allowedInConstExp : AST.var -> bool
17 :    
18 : jhr 4349 (* spatial queries *)
19 :     val isSpatialQueryOp : AST.var -> bool
20 :    
21 : jhr 3464 (* reduction operators *)
22 :     val isReductionOp : AST.var -> bool
23 :    
24 : jhr 3463 (* global sets of strands *)
25 :     val isStrandSet : AST.var -> bool
26 :    
27 : jhr 4228 (* border-control operations *)
28 :     val isBorderCtl : AST.var -> bool
29 :    
30 : jhr 3399 end = struct
31 :    
32 :     structure N = BasisNames
33 :     structure BV = BasisVars
34 :     structure ATbl = AtomTable
35 :     structure GEnv = GlobalEnv
36 :    
37 :     (* non-overloaded operators, etc. *)
38 :     val basisFunctions = [
39 :     (* non-overloaded operators *)
40 : jhr 4504 BV.op_mod,
41 : jhr 3399 BV.op_D,
42 : cchiw 5010 (* BV.op_Dotimes,*)
43 : jhr 3399 BV.op_Ddot,
44 :     BV.op_not,
45 :     (* functions *)
46 : jhr 3519 BV.image_border,
47 : jhr 3399 BV.fn_inside,
48 :     BV.fn_length,
49 : jhr 3519 BV.image_mirror,
50 : jhr 4589 BV.fn_numActive,
51 :     BV.fn_numStable,
52 :     BV.fn_numStrands,
53 : jhr 3519 BV.fn_size,
54 :     BV.image_wrap,
55 : jhr 4588 (* non-overloaded reductions *)
56 : jhr 3519 BV.red_all,
57 :     BV.red_exists,
58 :     BV.red_mean,
59 : jhr 4588 (* FIXME: variance not yet supported
60 : jhr 3519 BV.red_variance,
61 : jhr 4588 *)
62 : jhr 3519 (* Math functions that have not yet been lifted to work on fields *)
63 :     BV.fn_atan2_rr,
64 :     BV.fn_ceil_r,
65 :     BV.fn_floor_r,
66 :     BV.fn_fmod_rr,
67 :     BV.fn_erf_r,
68 :     BV.fn_erfc_r,
69 :     BV.fn_log_r,
70 :     BV.fn_log10_r,
71 :     BV.fn_log2_r,
72 : jhr 4586 BV.fn_pow_rr,
73 :     BV.fn_round_r
74 : jhr 3399 ]
75 :    
76 : jhr 4043 (* kernels *)
77 :     val basisKernels = [
78 :     Kernel.bspln3,
79 :     Kernel.bspln5,
80 :     Kernel.c4hexic,
81 :     Kernel.ctmr,
82 : glk 4559 Kernel.tent,
83 :     Kernel.c1tent, (* for backwards compatibility to vis12 *)
84 :     Kernel.c2tent, (* for backwards compatibility to vis12 *)
85 :     Kernel.c2ctmr (* for backwards compatibility to vis12 *)
86 : jhr 3399 ]
87 :    
88 :     (* overloaded operators and functions *)
89 :     val overloads = [
90 : jhr 3519 (* overloaded operators *)
91 : jhr 3399 (N.op_at, [BV.at_Td, BV.at_dT, BV.at_dd]),
92 :     (N.op_lte, [BV.lte_ii, BV.lte_rr]),
93 :     (N.op_equ, [BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr]),
94 :     (N.op_neq, [BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr]),
95 :     (N.op_gte, [BV.gte_ii, BV.gte_rr]),
96 :     (N.op_gt, [BV.gt_ii, BV.gt_rr]),
97 :     (N.op_add, [BV.add_ii, BV.add_tt, BV.add_ff, BV.add_ft, BV.add_tf]),
98 :     (N.op_sub, [BV.sub_ii, BV.sub_tt, BV.sub_ff, BV.sub_ft, BV.sub_tf]),
99 :     (N.op_mul, [
100 : jhr 3519 BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr, BV.mul_rf, BV.mul_fr,
101 :     BV.mul_ss, BV.mul_sf, BV.mul_fs, BV.mul_st, BV.mul_ts
102 :     ]),
103 : cchiw 4000 (N.op_div, [BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr, BV.div_fr, BV.div_ss, BV.div_fs, BV.div_ts]),
104 : cchiw 3979 (N.op_pow, [BV.pow_ri, BV.pow_rr, BV.pow_si]),
105 : jhr 3399 (N.op_curl, [BV.curl2D, BV.curl3D]),
106 :     (N.op_convolve, [BV.convolve_vk, BV.convolve_kv]),
107 :     (N.op_lt, [BV.lt_ii, BV.lt_rr]),
108 :     (N.op_neg, [BV.neg_i, BV.neg_t, BV.neg_f]),
109 : jhr 4002 (N.op_cross, [
110 : jhr 4317 BV.op_cross2_tt, BV.op_cross3_tt, BV.op_cross2_ff, BV.op_cross3_ff,
111 :     BV.op_cross2_tf, BV.op_cross3_tf, BV.op_cross2_ft, BV.op_cross3_ft
112 :     ]),
113 : jhr 3399 (N.op_norm, [BV.op_norm_t, BV.op_norm_f]),
114 : jhr 3519 (* overloaded functions *)
115 :     (N.fn_abs, [BV.fn_abs_i, BV.fn_abs_r]),
116 : jhr 3482 (N.fn_acos, [BV.fn_acos_r, BV.fn_acos_s]),
117 :     (N.fn_asin, [BV.fn_asin_r, BV.fn_asin_s]),
118 :     (N.fn_atan, [BV.fn_atan_r, BV.fn_atan_s]),
119 : jhr 4128 (N.fn_clamp, [BV.clamp_rrt, BV.clamp_ttt, BV.image_clamp]),
120 : jhr 3482 (N.fn_cos, [BV.fn_cos_r, BV.fn_cos_s]),
121 : jhr 3399 (N.fn_det, [BV.fn_det2_t, BV.fn_det3_t, BV.fn_det2_f, BV.fn_det3_f]),
122 : jhr 3807 (N.fn_dist, [BV.dist2_t, BV.dist3_t]),
123 : jhr 3399 (N.fn_evals, [BV.evals2x2, BV.evals3x3]),
124 :     (N.fn_evecs, [BV.evecs2x2, BV.evecs3x3]),
125 : jhr 3482 (N.fn_exp, [BV.fn_exp_r, BV.fn_exp_s]),
126 : jhr 4414 (N.fn_inv, [BV.fn_inv2_t, BV.fn_inv3_t, BV.fn_inv2_f, BV.fn_inv3_f]),
127 : jhr 3399 (N.fn_lerp, [BV.lerp5, BV.lerp3]),
128 : jhr 4588 (N.fn_max, [BV.fn_max_i, BV.fn_max_r, BV.red_max_i, BV.red_max_r]),
129 :     (N.fn_min, [BV.fn_min_i, BV.fn_min_r, BV.red_min_i, BV.red_min_r]),
130 : jhr 3399 (N.fn_normalize, [BV.fn_normalize_t, BV.fn_normalize_f]),
131 : cchiw 4281 (N.fn_modulate, [BV.fn_modulate_tt, BV.fn_modulate_ff, BV.fn_modulate_tf, BV.fn_modulate_ft]),
132 : jhr 4588 (N.fn_product, [BV.red_product_i, BV.red_product_r]),
133 : jhr 3482 (N.fn_sin, [BV.fn_sin_r, BV.fn_sin_s]),
134 : jhr 3807 (N.fn_sphere, [BV.fn_sphere_im, BV.fn_sphere1_r, BV.fn_sphere2_t, BV.fn_sphere3_t]),
135 : jhr 3482 (N.fn_sqrt, [BV.fn_sqrt_r, BV.fn_sqrt_s]),
136 : jhr 4588 (N.fn_sum, [BV.red_sum_i, BV.red_sum_r]),
137 : jhr 3482 (N.fn_tan, [BV.fn_tan_r, BV.fn_tan_s]),
138 : jhr 3399 (N.fn_trace, [BV.fn_trace_t, BV.fn_trace_f]),
139 :     (N.fn_transpose, [BV.fn_transpose_t, BV.fn_transpose_f]),
140 :     (* assignment operators are bound to the corresponding binary operator *)
141 :     (N.asgn_add, [BV.add_ii, BV.add_tt, BV.add_ff, BV.add_ft]),
142 :     (N.asgn_sub, [BV.sub_ii, BV.sub_tt, BV.sub_ff, BV.sub_ft]),
143 :     (N.asgn_mul, [BV.mul_ii, BV.mul_rr, BV.mul_tr, BV.mul_fr]),
144 :     (N.asgn_div, [BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr]),
145 : cchiw 4936 (N.asgn_mod, [BV.op_mod]),
146 : cchiw 4953 (N.fn_concat, [BV.fn_concat_fv2, BV.fn_concat_fm2, BV.fn_concat_ft2, BV.fn_concat_fs3, BV.fn_concat_fm3, BV.fn_concat_ft3]),
147 : cchiw 5000 (N.fn_comp, [BV.fn_comp]),
148 :     (N.op_comp, [BV.comp]),
149 : cchiw 4999 (N.fn_poly, [BV.fn_poly]),
150 : cchiw 5010 (N.fn_inst, [BV.fn_inst]),
151 :     (N.op_D, [BV.op_D, BV.op_DPoly]),
152 : cchiw 5030 (N.op_Dotimes, [BV.op_Dotimes, BV.op_DotimesPoly]),
153 :     (N.fn_fem, [BV.fn_fem])
154 : cchiw 4936 ]
155 : jhr 3399
156 :     (* seed the basis environment *)
157 :     fun env () = let
158 : jhr 3519 val gEnv = GEnv.new()
159 : jhr 3399 fun insF x = GEnv.insertFunc(gEnv, Atom.atom(Var.nameOf x), GEnv.PrimFun[x])
160 : jhr 4043 fun insK k = GEnv.insertKernel(gEnv, Atom.atom(Kernel.name k), k)
161 : jhr 3399 fun insOvld (f, fns) = GEnv.insertFunc(gEnv, f, GEnv.PrimFun fns)
162 :     in
163 :     List.app insF basisFunctions;
164 : jhr 4043 List.app insK basisKernels;
165 : jhr 3519 List.app insOvld overloads;
166 :     gEnv
167 : jhr 3399 end
168 :    
169 :     (* operations that are allowed in constant expressions; we basically allow any operations
170 :     * on integers, booleans, or tensors. Operations on fields, images, sequences, or kernels
171 :     * are not allowed.
172 :     *)
173 :     local
174 :     val allowed = List.foldl Var.Set.add' Var.Set.empty [
175 : jhr 3519 BV.op_mod,
176 : jhr 3399 BV.op_cross2_tt, BV.op_cross3_tt,
177 :     BV.op_outer_tt,
178 :     BV.op_norm_t,
179 :     BV.op_not,
180 : jhr 3519 BV.fn_abs_i, BV.fn_abs_r,
181 : jhr 3482 BV.fn_max_i, BV.fn_max_r,
182 :     BV.fn_min_i, BV.fn_min_r,
183 : jhr 4286 BV.fn_modulate_tt,
184 : jhr 3399 BV.fn_normalize_t,
185 :     BV.fn_trace_t,
186 :     BV.fn_transpose_t,
187 :     BV.lte_ii, BV.lte_rr,
188 :     BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr,
189 :     BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr,
190 :     BV.gte_ii, BV.gte_rr,
191 :     BV.lt_ii, BV.lt_rr,
192 :     BV.gt_ii, BV.gt_rr,
193 :     BV.add_ii, BV.add_tt,
194 :     BV.sub_ii, BV.sub_tt,
195 :     BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr,
196 :     BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr,
197 : jhr 3482 BV.pow_ri, BV.pow_rr,
198 : jhr 3399 BV.neg_i, BV.neg_t,
199 : jhr 4128 BV.clamp_rrt, BV.clamp_ttt,
200 : jhr 3399 BV.lerp5, BV.lerp3,
201 : jhr 3519 BV.fn_acos_r,
202 :     BV.fn_asin_r,
203 :     BV.fn_atan_r,
204 :     BV.fn_atan2_rr,
205 :     BV.fn_ceil_r,
206 :     BV.fn_cos_r,
207 :     BV.fn_erf_r,
208 :     BV.fn_erfc_r,
209 :     BV.fn_exp_r,
210 :     BV.fn_floor_r,
211 :     BV.fn_fmod_rr,
212 :     BV.fn_log_r,
213 :     BV.fn_log10_r,
214 :     BV.fn_log2_r,
215 : jhr 4317 BV.fn_round_r,
216 : jhr 3519 BV.fn_sin_r,
217 :     BV.fn_sqrt_r,
218 : jhr 4298 BV.fn_tan_r,
219 : jhr 4317 BV.fn_trunc_r
220 : jhr 3399 ]
221 :     in
222 :     fun allowedInConstExp x = Var.Set.member (allowed, x)
223 :     end (* local *)
224 :    
225 : jhr 4349 (* spatial queries *)
226 :     local
227 :     val qOps = List.foldl Var.Set.add' Var.Set.empty [
228 :     BV.fn_sphere_im,
229 :     BV.fn_sphere1_r,
230 :     BV.fn_sphere2_t,
231 :     BV.fn_sphere3_t
232 :     ]
233 :     in
234 :     fun isSpatialQueryOp x = Var.Set.member (qOps, x)
235 :     end
236 :    
237 : jhr 3464 (* the reduction operators *)
238 :     local
239 :     val redOps = List.foldl Var.Set.add' Var.Set.empty [
240 : jhr 3519 BV.red_all,
241 :     BV.red_exists,
242 : jhr 4588 BV.red_max_i,
243 :     BV.red_max_r,
244 : jhr 3519 BV.red_mean,
245 : jhr 4588 BV.red_min_i,
246 :     BV.red_min_r,
247 :     BV.red_product_i,
248 :     BV.red_product_r,
249 :     BV.red_sum_i,
250 :     BV.red_sum_r
251 :     (* FIXME: variance not supported yet
252 : jhr 3519 BV.red_variance
253 : jhr 4588 *)
254 : jhr 3519 ]
255 : jhr 3464 in
256 :     fun isReductionOp x = Var.Set.member (redOps, x)
257 :     end (* local *)
258 :    
259 : jhr 3463 (* the sets of strands are only allowed in global initialization and update blocks *)
260 :     local
261 :     val strandSets = List.foldl Var.Set.add' Var.Set.empty [
262 : jhr 3519 BV.set_active,
263 :     BV.set_all,
264 :     BV.set_stable
265 :     ]
266 : jhr 3463 in
267 :     fun isStrandSet x = Var.Set.member (strandSets, x)
268 :     end (* end local *)
269 :    
270 : jhr 4228 (* border-control operations *)
271 :     local
272 :     val borderCtl = List.foldl Var.Set.add' Var.Set.empty [
273 : jhr 4317 BV.image_border, BV.image_clamp, BV.image_mirror, BV.image_wrap
274 :     ]
275 : jhr 4228 in
276 :     fun isBorderCtl x = Var.Set.member (borderCtl, x)
277 :     end
278 :    
279 : jhr 3399 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0