Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/basis/basis-vars.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1296 - (view) (download)
Original Path: trunk/src/compiler/basis/basis-vars.sml

1 : jhr 79 (* basis-vars.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 79 * All rights reserved.
5 :     *
6 :     * This module defines the AST variables for the built in operators and functions.
7 :     *)
8 :    
9 :     structure BasisVars =
10 :     struct
11 :     local
12 :     structure N = BasisNames
13 :     structure Ty = Types
14 :     structure MV = MetaVar
15 :    
16 : jhr 81 fun --> (tys1, ty) = Ty.T_Fun(tys1, ty)
17 : jhr 79 infix -->
18 :    
19 :     val N2 = Ty.DimConst 2
20 :     val N3 = Ty.DimConst 3
21 :    
22 :     (* short names for kinds *)
23 :     val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
24 : jhr 470 fun DK () : Ty.meta_var = Ty.DIFF(MV.newDiffVar 0)
25 : jhr 79 val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
26 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
27 :    
28 :     fun ty t = ([], t)
29 :     fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
30 :     val tvs = List.map (fn mk => mk()) kinds
31 :     in
32 :     (tvs, mkTy tvs)
33 :     end
34 :     fun allNK mkTy = let
35 :     val tv = MV.newDimVar()
36 :     in
37 :     ([Ty.DIM tv], mkTy tv)
38 :     end
39 :    
40 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
41 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
42 : jhr 1116 fun matrix d = tensor[d,d]
43 : jhr 79
44 :     fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
45 :     fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
46 :     in
47 :    
48 :     (* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
49 :     * two fields with different differentiation levels to be added.
50 :     *)
51 :    
52 :     (* overloaded operators; the naming convention is to use the operator name followed
53 :     * by the argument type signature, where
54 :     * i -- int
55 :     * b -- bool
56 :     * r -- real (tensor[])
57 :     * t -- tensor[shape]
58 : jhr 470 * f -- field#k(d)[shape]
59 : jhr 79 *)
60 :    
61 :     val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
62 :     val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
63 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
64 :     in
65 :     [t, t] --> t
66 :     end))
67 : jhr 470 val add_ff = polyVar(N.op_add, all([DK,NK,SK],
68 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
69 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
70 :     in
71 :     [t, t] --> t
72 :     end))
73 : jhr 79
74 :     val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
75 :     val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
76 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
77 :     in
78 :     [t, t] --> t
79 :     end))
80 : jhr 470 val sub_ff = polyVar(N.op_sub, all([DK,NK,SK],
81 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
82 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
83 :     in
84 :     [t, t] --> t
85 :     end))
86 : jhr 79
87 :     (* note that we assume that operators are tested in the order defined here, so that mul_rr
88 :     * takes precedence over mul_rt and mul_tr!
89 :     *)
90 :     val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
91 :     val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
92 :     val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
93 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
94 :     in
95 :     [Ty.realTy, t] --> t
96 :     end))
97 :     val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
98 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
99 :     in
100 :     [t, Ty.realTy] --> t
101 :     end))
102 : jhr 470 val mul_rf = polyVar(N.op_mul, all([DK,NK,SK],
103 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
104 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
105 :     in
106 :     [Ty.realTy, t] --> t
107 :     end))
108 :     val mul_fr = polyVar(N.op_mul, all([DK,NK,SK],
109 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
110 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
111 :     in
112 :     [t, Ty.realTy] --> t
113 :     end))
114 : jhr 79
115 :     val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
116 :     val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
117 :     val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
118 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
119 :     in
120 :     [t, Ty.realTy] --> t
121 :     end))
122 : jhr 470 val div_fr = polyVar(N.op_div, all([DK,NK,SK],
123 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
124 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
125 :     in
126 :     [t, Ty.realTy] --> t
127 :     end))
128 : jhr 79
129 : jhr 1116 (* exponentiation; we distinguish between integer and real exponents to allow x^2 to be compiled
130 :     * as x*x.
131 :     *)
132 :     val exp_ri = monoVar(N.op_exp, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
133 :     val exp_rr = monoVar(N.op_exp, [Ty.realTy, Ty.realTy] --> Ty.realTy)
134 :    
135 :     val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK],
136 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
137 :     val k = Ty.DiffVar(k, 0)
138 :     val d = Ty.DimVar d
139 :     val dd = Ty.ShapeVar dd
140 :     in
141 :     [Ty.T_Image{dim=d, shape=dd}, Ty.T_Kernel k]
142 :     --> field(k, d, dd)
143 :     end))
144 :     val convolve_kv = polyVar (N.op_convolve, all([DK, NK, SK],
145 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
146 :     val k = Ty.DiffVar(k, 0)
147 :     val d = Ty.DimVar d
148 :     val dd = Ty.ShapeVar dd
149 :     in
150 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
151 :     --> field(k, d, dd)
152 :     end))
153 :    
154 : jhr 79 val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
155 :     val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
156 :     val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
157 :     val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
158 :     val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
159 :     val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
160 :     val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
161 :     val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
162 :    
163 :     val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
164 :     val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
165 :     val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
166 :     val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
167 :     val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
168 :     val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
169 :     val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
170 :     val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
171 :    
172 :    
173 :     val neg_i = monoVar(N.op_neg, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
174 :     val neg_t = polyVar(N.op_neg, all([SK],
175 :     fn [Ty.SHAPE dd] => let
176 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
177 :     in
178 :     [t] --> t
179 :     end))
180 :     val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
181 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
182 :     val k = Ty.DiffVar(k, 0)
183 :     val d = Ty.DimVar d
184 :     val dd = Ty.ShapeVar dd
185 :     in
186 :     [field(k, d, dd)] --> field(k, d, dd)
187 :     end))
188 :    
189 : jhr 1295 (* clamp is overloaded at scalars and vectors *)
190 :     val clamp_rrr = monoVar(N.fn_clamp, [Ty.realTy, Ty.realTy, Ty.realTy] --> Ty.realTy)
191 :     val clamp_vvv = polyVar (N.fn_clamp, allNK(fn tv => let
192 :     val t = tensor[Ty.DimVar tv]
193 :     in
194 :     [t, t, t] --> t
195 :     end))
196 :    
197 : jhr 1116 val lerp3 = polyVar(N.fn_lerp, all([SK],
198 :     fn [Ty.SHAPE dd] => let
199 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
200 :     in
201 :     [t, t, Ty.realTy] --> t
202 :     end))
203 :     val lerp5 = polyVar(N.fn_lerp, all([SK],
204 :     fn [Ty.SHAPE dd] => let
205 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
206 :     in
207 :     [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
208 :     end))
209 : jhr 79
210 : jhr 1296
211 : jhr 79 (***** non-overloaded operators, etc. *****)
212 :    
213 :     val op_at = polyVar (N.op_at, all([DK, NK, SK],
214 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
215 :     val k = Ty.DiffVar(k, 0)
216 :     val d = Ty.DimVar d
217 :     val dd = Ty.ShapeVar dd
218 :     in
219 :     [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
220 :     end))
221 :    
222 :     val op_D = polyVar (N.op_D, all([DK, NK, SK],
223 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
224 :     val k0 = Ty.DiffVar(k, 0)
225 :     val km1 = Ty.DiffVar(k, ~1)
226 :     val d = Ty.DimVar d
227 :     val dd = Ty.ShapeVar dd
228 :     in
229 : jhr 85 [field(k0, d, dd)]
230 : jhr 79 --> field(km1, d, Ty.ShapeExt(dd, d))
231 :     end))
232 :    
233 :     val op_norm = polyVar (N.op_norm, all([SK],
234 :     fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
235 :    
236 :     val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
237 :    
238 :     (* functions *)
239 : jhr 1116 val fn_atan2 = monoVar (N.fn_atan2, [Ty.realTy, Ty.realTy] --> Ty.realTy)
240 :    
241 : jhr 185 val fn_CL = monoVar (N.fn_CL, [tensor[N3, N3]] --> Ty.realTy)
242 : jhr 79
243 : jhr 247 (* the following is depreciated in favor of the infix operator *)
244 : jhr 79 val fn_convolve = polyVar (N.fn_convolve, all([DK, NK, SK],
245 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
246 :     val k = Ty.DiffVar(k, 0)
247 :     val d = Ty.DimVar d
248 :     val dd = Ty.ShapeVar dd
249 :     in
250 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
251 :     --> field(k, d, dd)
252 :     end))
253 :    
254 : jhr 143 val fn_cos = monoVar (N.fn_cos, [Ty.realTy] --> Ty.realTy)
255 : jhr 79
256 : jhr 1116 local
257 :     val crossTy = let
258 :     val t = tensor[N3]
259 :     in
260 :     [t, t] --> t
261 :     end
262 :     in
263 :     val op_cross = monoVar (N.op_cross, crossTy)
264 :     val fn_cross = monoVar (N.fn_cross, crossTy)
265 :     end
266 :    
267 :     (* the depriciated 'dot' function *)
268 : jhr 88 val fn_dot = polyVar (N.fn_dot, allNK(fn tv => let
269 :     val t = tensor[Ty.DimVar tv]
270 :     in
271 : jhr 91 [t, t] --> Ty.realTy
272 : jhr 88 end))
273 : jhr 79
274 : jhr 1116 (* the inner product operator (including dot product) is treated as a special case in the
275 :     * typechecker. It is not included in the basis environment, but we define its type scheme
276 :     * here. There is an implicit constraint on its type to have the following scheme:
277 :     *
278 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
279 :     *)
280 :     val op_inner = polyVar (N.op_dot, all([SK, SK, SK],
281 :     fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
282 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
283 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
284 :    
285 :     (* Eigenvalues of a matrix *)
286 :     val fn_evals = polyVar (N.fn_trace, all([NK],
287 :     fn [Ty.DIM d] => let
288 :     val d = Ty.DimVar d
289 :     in
290 :     [matrix d] --> Ty.T_Sequence(Ty.realTy, d)
291 :     end))
292 :    
293 :     (* Eigenvectors of a matrix *)
294 :     val fn_evecs = polyVar (N.fn_trace, all([NK],
295 :     fn [Ty.DIM d] => let
296 :     val d = Ty.DimVar d
297 :     in
298 :     [matrix d] --> Ty.T_Sequence(tensor[d], d)
299 :     end))
300 :    
301 : jhr 143 val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
302 : jhr 79 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
303 :     val k = Ty.DiffVar(k, 0)
304 :     val d = Ty.DimVar d
305 :     val dd = Ty.ShapeVar dd
306 :     in
307 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
308 :     --> Ty.T_Bool
309 :     end))
310 :    
311 : jhr 143 val fn_load = polyVar (N.fn_load, all([NK, SK],
312 : jhr 79 fn [Ty.DIM d, Ty.SHAPE dd] => let
313 :     val d = Ty.DimVar d
314 :     val dd = Ty.ShapeVar dd
315 :     in
316 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
317 :     end))
318 :    
319 : jhr 143 val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
320 :     val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
321 :    
322 : jhr 83 val fn_modulate = polyVar (N.fn_modulate, all([NK],
323 :     fn [Ty.DIM d] => let
324 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
325 :     in
326 :     [t, t] --> t
327 :     end))
328 :    
329 : jhr 1116 val fn_normalize = polyVar (N.fn_normalize, all([NK],
330 :     fn [Ty.DIM d] => let
331 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
332 :     in
333 :     [t] --> t
334 :     end))
335 :    
336 :     (* outer product *)
337 :     local
338 :     fun mkOuter [Ty.DIM d1, Ty.DIM d2] = let
339 :     val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1])
340 :     val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d2])
341 :     val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2])
342 :     in
343 :     [vt1, vt2] --> mt
344 :     end
345 :     in
346 :     val fn_outer = polyVar (N.fn_outer, all([NK, NK], mkOuter))
347 :     val op_outer = polyVar (N.op_outer, all([NK, NK], mkOuter))
348 :     end
349 :    
350 : jhr 143 val fn_pow = monoVar (N.fn_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy)
351 : jhr 79
352 : jhr 91 val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK],
353 :     fn [Ty.DIM d] => let
354 :     val d = Ty.DimVar d
355 :     in
356 : jhr 1116 [matrix d] --> tensor[d]
357 : jhr 91 end))
358 : jhr 79
359 : jhr 143 val fn_sin = monoVar (N.fn_sin, [Ty.realTy] --> Ty.realTy)
360 : jhr 79
361 : jhr 1116 val fn_sqrt = monoVar (N.fn_sqrt, [Ty.realTy] --> Ty.realTy)
362 :    
363 :     val fn_tan = monoVar (N.fn_tan, [Ty.realTy] --> Ty.realTy)
364 :    
365 :     val fn_trace = polyVar (N.fn_trace, all([NK],
366 :     fn [Ty.DIM d] => let
367 :     val d = Ty.DimVar d
368 :     in
369 :     [matrix d] --> Ty.realTy
370 :     end))
371 :    
372 : jhr 79 (* kernels *)
373 : jhr 169 (* FIXME: we should really get the continuity info from the kernels themselves *)
374 : jhr 83 val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2))
375 : jhr 169 val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4))
376 : jhr 1116 val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 1))
377 : jhr 83 val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0))
378 : jhr 1116 (* kernels with false claims of differentiability, for pedagogy *)
379 :     val kn_c1tent = monoVar (N.kn_c1tent, Ty.T_Kernel(Ty.DiffConst 1))
380 :     val kn_c2ctmr = monoVar (N.kn_c2ctmr, Ty.T_Kernel(Ty.DiffConst 2))
381 : jhr 79
382 : jhr 1116 (***** internal variables *****)
383 : jhr 406
384 : jhr 1116 (* integer to real conversion *)
385 :     val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy)
386 :    
387 :     (* identity matrix *)
388 :     val identity = polyVar (Atom.atom "$id", allNK (fn dv => [] --> matrix(Ty.DimVar dv)))
389 :    
390 :     (* zero tensor *)
391 :     val zero = polyVar (Atom.atom "$zero", all ([SK],
392 :     fn [Ty.SHAPE dd] => [] --> Ty.T_Tensor(Ty.ShapeVar dd)))
393 :    
394 : jhr 79 end (* local *)
395 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0