Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/basis/basis.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3807 - (view) (download)

1 : jhr 3399 (* basis.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * Defining the Diderot Basis environment.
9 :     *)
10 :    
11 :     structure Basis : sig
12 :    
13 :     val env : unit -> GlobalEnv.t
14 :    
15 :     (* operations that are allowed in constant expressions *)
16 :     val allowedInConstExp : AST.var -> bool
17 :    
18 : jhr 3464 (* reduction operators *)
19 :     val isReductionOp : AST.var -> bool
20 :    
21 : jhr 3463 (* global sets of strands *)
22 :     val isStrandSet : AST.var -> bool
23 :    
24 : jhr 3399 end = struct
25 :    
26 :     structure N = BasisNames
27 :     structure BV = BasisVars
28 :     structure ATbl = AtomTable
29 :     structure GEnv = GlobalEnv
30 :    
31 :     (* non-overloaded operators, etc. *)
32 :     val basisFunctions = [
33 :     (* non-overloaded operators *)
34 :     BV.op_D,
35 :     BV.op_Dotimes,
36 :     BV.op_Ddot,
37 :     BV.op_not,
38 :     (* functions *)
39 : jhr 3519 BV.image_border,
40 : jhr 3399 BV.fn_inside,
41 :     BV.fn_length,
42 : jhr 3519 BV.image_mirror,
43 : jhr 3399 BV.fn_modulate,
44 : jhr 3628 (* unimplemented
45 : jhr 3399 BV.fn_principleEvec,
46 : jhr 3628 *)
47 : jhr 3519 BV.fn_size,
48 :     BV.image_wrap,
49 : jhr 3399 (* reductions *)
50 : jhr 3519 BV.red_all,
51 :     BV.red_exists,
52 :     BV.red_max,
53 :     BV.red_mean,
54 :     BV.red_min,
55 :     BV.red_product,
56 :     BV.red_sum,
57 :     BV.red_variance,
58 :     (* Math functions that have not yet been lifted to work on fields *)
59 :     BV.fn_atan2_rr,
60 :     BV.fn_ceil_r,
61 :     BV.fn_floor_r,
62 :     BV.fn_fmod_rr,
63 :     BV.fn_erf_r,
64 :     BV.fn_erfc_r,
65 :     BV.fn_log_r,
66 :     BV.fn_log10_r,
67 :     BV.fn_log2_r,
68 :     BV.fn_pow_rr
69 : jhr 3399 ]
70 :    
71 :     val basisVars = [
72 :     (* kernels *)
73 :     BV.kn_bspln3,
74 :     BV.kn_bspln5,
75 :     BV.kn_c4hexic,
76 :     BV.kn_ctmr,
77 :     BV.kn_tent
78 :     ]
79 :    
80 :     (* overloaded operators and functions *)
81 :     val overloads = [
82 : jhr 3519 (* overloaded operators *)
83 : jhr 3399 (N.op_at, [BV.at_Td, BV.at_dT, BV.at_dd]),
84 :     (N.op_lte, [BV.lte_ii, BV.lte_rr]),
85 :     (N.op_equ, [BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr]),
86 :     (N.op_neq, [BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr]),
87 :     (N.op_gte, [BV.gte_ii, BV.gte_rr]),
88 :     (N.op_gt, [BV.gt_ii, BV.gt_rr]),
89 :     (N.op_add, [BV.add_ii, BV.add_tt, BV.add_ff, BV.add_ft, BV.add_tf]),
90 :     (N.op_sub, [BV.sub_ii, BV.sub_tt, BV.sub_ff, BV.sub_ft, BV.sub_tf]),
91 :     (N.op_mul, [
92 : jhr 3519 BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr, BV.mul_rf, BV.mul_fr,
93 :     BV.mul_ss, BV.mul_sf, BV.mul_fs, BV.mul_st, BV.mul_ts
94 :     ]),
95 : jhr 3399 (N.op_div, [BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr, BV.div_fr, BV.div_ss, BV.div_fs]),
96 : jhr 3482 (N.op_pow, [BV.pow_ri, BV.pow_rr, BV.pow_si]),
97 : jhr 3399 (N.op_curl, [BV.curl2D, BV.curl3D]),
98 :     (N.op_convolve, [BV.convolve_vk, BV.convolve_kv]),
99 :     (N.op_lt, [BV.lt_ii, BV.lt_rr]),
100 :     (N.op_neg, [BV.neg_i, BV.neg_t, BV.neg_f]),
101 :     (N.op_cross, [BV.op_cross2_tt, BV.op_cross3_tt, BV.op_cross2_ff, BV.op_cross3_ff]),
102 :     (N.op_norm, [BV.op_norm_t, BV.op_norm_f]),
103 : jhr 3519 (* overloaded functions *)
104 :     (N.fn_abs, [BV.fn_abs_i, BV.fn_abs_r]),
105 : jhr 3482 (N.fn_acos, [BV.fn_acos_r, BV.fn_acos_s]),
106 :     (N.fn_asin, [BV.fn_asin_r, BV.fn_asin_s]),
107 :     (N.fn_atan, [BV.fn_atan_r, BV.fn_atan_s]),
108 : jhr 3399 (N.fn_clamp, [BV.clamp_rrr, BV.clamp_vvv, BV.image_clamp]),
109 : jhr 3482 (N.fn_cos, [BV.fn_cos_r, BV.fn_cos_s]),
110 : jhr 3399 (N.fn_det, [BV.fn_det2_t, BV.fn_det3_t, BV.fn_det2_f, BV.fn_det3_f]),
111 : jhr 3807 (N.fn_dist, [BV.dist2_t, BV.dist3_t]),
112 : jhr 3399 (N.fn_evals, [BV.evals2x2, BV.evals3x3]),
113 :     (N.fn_evecs, [BV.evecs2x2, BV.evecs3x3]),
114 : jhr 3482 (N.fn_exp, [BV.fn_exp_r, BV.fn_exp_s]),
115 : jhr 3399 (N.fn_lerp, [BV.lerp5, BV.lerp3]),
116 : jhr 3519 (N.fn_max, [BV.fn_max_i, BV.fn_max_r, BV.red_max]),
117 :     (N.fn_min, [BV.fn_min_i, BV.fn_min_r, BV.red_min]),
118 : jhr 3399 (N.fn_normalize, [BV.fn_normalize_t, BV.fn_normalize_f]),
119 : jhr 3482 (N.fn_sin, [BV.fn_sin_r, BV.fn_sin_s]),
120 : jhr 3807 (N.fn_sphere, [BV.fn_sphere_im, BV.fn_sphere1_r, BV.fn_sphere2_t, BV.fn_sphere3_t]),
121 : jhr 3482 (N.fn_sqrt, [BV.fn_sqrt_r, BV.fn_sqrt_s]),
122 :     (N.fn_tan, [BV.fn_tan_r, BV.fn_tan_s]),
123 : jhr 3399 (N.fn_trace, [BV.fn_trace_t, BV.fn_trace_f]),
124 :     (N.fn_transpose, [BV.fn_transpose_t, BV.fn_transpose_f]),
125 :     (* assignment operators are bound to the corresponding binary operator *)
126 :     (N.asgn_add, [BV.add_ii, BV.add_tt, BV.add_ff, BV.add_ft]),
127 :     (N.asgn_sub, [BV.sub_ii, BV.sub_tt, BV.sub_ff, BV.sub_ft]),
128 :     (N.asgn_mul, [BV.mul_ii, BV.mul_rr, BV.mul_tr, BV.mul_fr]),
129 :     (N.asgn_div, [BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr]),
130 : jhr 3519 (N.asgn_mod, [BV.op_mod])
131 : jhr 3399 ]
132 :    
133 :     (* seed the basis environment *)
134 :     fun env () = let
135 : jhr 3519 val gEnv = GEnv.new()
136 : jhr 3399 fun insF x = GEnv.insertFunc(gEnv, Atom.atom(Var.nameOf x), GEnv.PrimFun[x])
137 :     fun insV x = GEnv.insertVar(gEnv, Atom.atom(Var.nameOf x), x)
138 :     fun insOvld (f, fns) = GEnv.insertFunc(gEnv, f, GEnv.PrimFun fns)
139 :     in
140 :     List.app insF basisFunctions;
141 : jhr 3519 List.app insV basisVars;
142 :     List.app insOvld overloads;
143 :     gEnv
144 : jhr 3399 end
145 :    
146 :     (* operations that are allowed in constant expressions; we basically allow any operations
147 :     * on integers, booleans, or tensors. Operations on fields, images, sequences, or kernels
148 :     * are not allowed.
149 :     *)
150 :     local
151 :     val allowed = List.foldl Var.Set.add' Var.Set.empty [
152 : jhr 3519 BV.op_mod,
153 : jhr 3399 BV.op_cross2_tt, BV.op_cross3_tt,
154 :     BV.op_outer_tt,
155 :     BV.op_norm_t,
156 :     BV.op_not,
157 : jhr 3519 BV.fn_abs_i, BV.fn_abs_r,
158 : jhr 3482 BV.fn_max_i, BV.fn_max_r,
159 :     BV.fn_min_i, BV.fn_min_r,
160 : jhr 3399 BV.fn_modulate,
161 :     BV.fn_normalize_t,
162 : jhr 3641 (* unimplemented
163 : jhr 3399 BV.fn_principleEvec,
164 : jhr 3641 *)
165 : jhr 3399 BV.fn_trace_t,
166 :     BV.fn_transpose_t,
167 :     BV.lte_ii, BV.lte_rr,
168 :     BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr,
169 :     BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr,
170 :     BV.gte_ii, BV.gte_rr,
171 :     BV.lt_ii, BV.lt_rr,
172 :     BV.gt_ii, BV.gt_rr,
173 :     BV.add_ii, BV.add_tt,
174 :     BV.sub_ii, BV.sub_tt,
175 :     BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr,
176 :     BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr,
177 : jhr 3482 BV.pow_ri, BV.pow_rr,
178 : jhr 3399 BV.neg_i, BV.neg_t,
179 :     BV.clamp_rrr, BV.clamp_vvv,
180 :     BV.lerp5, BV.lerp3,
181 : jhr 3519 BV.fn_acos_r,
182 :     BV.fn_asin_r,
183 :     BV.fn_atan_r,
184 :     BV.fn_atan2_rr,
185 :     BV.fn_ceil_r,
186 :     BV.fn_cos_r,
187 :     BV.fn_erf_r,
188 :     BV.fn_erfc_r,
189 :     BV.fn_exp_r,
190 :     BV.fn_floor_r,
191 :     BV.fn_fmod_rr,
192 :     BV.fn_log_r,
193 :     BV.fn_log10_r,
194 :     BV.fn_log2_r,
195 :     BV.fn_sin_r,
196 :     BV.fn_sqrt_r,
197 :     BV.fn_tan_r
198 : jhr 3399 ]
199 :     in
200 :     fun allowedInConstExp x = Var.Set.member (allowed, x)
201 :     end (* local *)
202 :    
203 : jhr 3464 (* the reduction operators *)
204 :     local
205 :     val redOps = List.foldl Var.Set.add' Var.Set.empty [
206 : jhr 3519 BV.red_all,
207 :     BV.red_exists,
208 :     BV.red_max,
209 :     BV.red_mean,
210 :     BV.red_min,
211 :     BV.red_product,
212 :     BV.red_sum,
213 :     BV.red_variance
214 :     ]
215 : jhr 3464 in
216 :     fun isReductionOp x = Var.Set.member (redOps, x)
217 :     end (* local *)
218 :    
219 : jhr 3463 (* the sets of strands are only allowed in global initialization and update blocks *)
220 :     local
221 :     val strandSets = List.foldl Var.Set.add' Var.Set.empty [
222 : jhr 3519 BV.set_active,
223 :     BV.set_all,
224 :     BV.set_stable
225 :     ]
226 : jhr 3463 in
227 :     fun isStrandSet x = Var.Set.member (strandSets, x)
228 :     end (* end local *)
229 :    
230 : jhr 3399 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0