Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/basis/basis-vars.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2608 - (view) (download)

1 : jhr 79 (* basis-vars.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 79 * All rights reserved.
5 :     *
6 :     * This module defines the AST variables for the built in operators and functions.
7 :     *)
8 :    
9 :     structure BasisVars =
10 :     struct
11 :     local
12 :     structure N = BasisNames
13 :     structure Ty = Types
14 :     structure MV = MetaVar
15 :    
16 : jhr 81 fun --> (tys1, ty) = Ty.T_Fun(tys1, ty)
17 : jhr 79 infix -->
18 :    
19 :     val N2 = Ty.DimConst 2
20 :     val N3 = Ty.DimConst 3
21 :    
22 :     (* short names for kinds *)
23 :     val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
24 : jhr 470 fun DK () : Ty.meta_var = Ty.DIFF(MV.newDiffVar 0)
25 : jhr 79 val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
26 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
27 :    
28 :     fun ty t = ([], t)
29 :     fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
30 :     val tvs = List.map (fn mk => mk()) kinds
31 :     in
32 :     (tvs, mkTy tvs)
33 :     end
34 :     fun allNK mkTy = let
35 :     val tv = MV.newDimVar()
36 :     in
37 :     ([Ty.DIM tv], mkTy tv)
38 :     end
39 :    
40 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
41 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
42 : jhr 1116 fun matrix d = tensor[d,d]
43 : jhr 79
44 :     fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
45 :     fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
46 :     in
47 :    
48 :     (* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
49 :     * two fields with different differentiation levels to be added.
50 :     *)
51 :    
52 :     (* overloaded operators; the naming convention is to use the operator name followed
53 :     * by the argument type signature, where
54 :     * i -- int
55 :     * b -- bool
56 :     * r -- real (tensor[])
57 :     * t -- tensor[shape]
58 : jhr 470 * f -- field#k(d)[shape]
59 : jhr 79 *)
60 :    
61 :     val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
62 :     val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
63 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
64 :     in
65 :     [t, t] --> t
66 :     end))
67 : jhr 470 val add_ff = polyVar(N.op_add, all([DK,NK,SK],
68 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
69 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
70 :     in
71 :     [t, t] --> t
72 :     end))
73 : jhr 2356 val add_fr = polyVar(N.op_add, all([DK,NK], (* field + scalar *)
74 :     fn [Ty.DIFF k, Ty.DIM d] => let
75 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape[]}
76 :     in
77 :     [t, Ty.realTy] --> t
78 :     end))
79 :     val add_rf = polyVar(N.op_add, all([DK,NK], (* scalar + field *)
80 :     fn [Ty.DIFF k, Ty.DIM d] => let
81 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape[]}
82 :     in
83 :     [Ty.realTy, t] --> t
84 :     end))
85 : jhr 79
86 :     val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
87 :     val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
88 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
89 :     in
90 :     [t, t] --> t
91 :     end))
92 : jhr 470 val sub_ff = polyVar(N.op_sub, all([DK,NK,SK],
93 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
94 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
95 :     in
96 :     [t, t] --> t
97 :     end))
98 : jhr 2356 val sub_fr = polyVar(N.op_sub, all([DK,NK], (* field - scalar *)
99 :     fn [Ty.DIFF k, Ty.DIM d] => let
100 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape[]}
101 :     in
102 :     [t, Ty.realTy] --> t
103 :     end))
104 :     val sub_rf = polyVar(N.op_sub, all([DK,NK], (* scalar - field *)
105 :     fn [Ty.DIFF k, Ty.DIM d] => let
106 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape[]}
107 :     in
108 :     [Ty.realTy, t] --> t
109 :     end))
110 : jhr 79
111 :     (* note that we assume that operators are tested in the order defined here, so that mul_rr
112 :     * takes precedence over mul_rt and mul_tr!
113 :     *)
114 :     val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
115 :     val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
116 :     val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
117 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
118 :     in
119 :     [Ty.realTy, t] --> t
120 :     end))
121 :     val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
122 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
123 :     in
124 :     [t, Ty.realTy] --> t
125 :     end))
126 : jhr 470 val mul_rf = polyVar(N.op_mul, all([DK,NK,SK],
127 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
128 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
129 :     in
130 :     [Ty.realTy, t] --> t
131 :     end))
132 :     val mul_fr = polyVar(N.op_mul, all([DK,NK,SK],
133 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
134 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
135 :     in
136 :     [t, Ty.realTy] --> t
137 :     end))
138 : jhr 79
139 : cchiw 2576
140 :     (****MARK *)
141 :     val mul_ss = polyVar(N.op_mul, all([DK,NK],
142 :     fn [Ty.DIFF k, Ty.DIM d] => let
143 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
144 :     in
145 :     [t, t] --> t
146 :     end))
147 :     (*Right now assumes same level of k*)
148 :     val mul_sf = polyVar(N.op_mul, all([DK,NK,SK],
149 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
150 :     val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
151 :     val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
152 :     in
153 :     [a,b] --> b
154 :     end))
155 :    
156 :     val mul_fs = polyVar(N.op_mul, all([DK,NK,SK],
157 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
158 :     val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
159 :     val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
160 :     in
161 :     [b,a] --> b
162 :     end))
163 :    
164 :    
165 : jhr 79 val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
166 :     val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
167 :     val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
168 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
169 :     in
170 :     [t, Ty.realTy] --> t
171 :     end))
172 : jhr 470 val div_fr = polyVar(N.op_div, all([DK,NK,SK],
173 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
174 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
175 :     in
176 :     [t, Ty.realTy] --> t
177 :     end))
178 : jhr 79
179 : cchiw 2608 val div_ss = polyVar(N.op_mul, all([DK,NK],
180 :     fn [Ty.DIFF k, Ty.DIM d] => let
181 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
182 :     in
183 :     [t, t] --> t
184 :     end))
185 :    
186 :    
187 : jhr 1116 (* exponentiation; we distinguish between integer and real exponents to allow x^2 to be compiled
188 :     * as x*x.
189 :     *)
190 :     val exp_ri = monoVar(N.op_exp, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
191 :     val exp_rr = monoVar(N.op_exp, [Ty.realTy, Ty.realTy] --> Ty.realTy)
192 :    
193 :     val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK],
194 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
195 :     val k = Ty.DiffVar(k, 0)
196 :     val d = Ty.DimVar d
197 :     val dd = Ty.ShapeVar dd
198 :     in
199 :     [Ty.T_Image{dim=d, shape=dd}, Ty.T_Kernel k]
200 :     --> field(k, d, dd)
201 :     end))
202 :     val convolve_kv = polyVar (N.op_convolve, all([DK, NK, SK],
203 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
204 :     val k = Ty.DiffVar(k, 0)
205 :     val d = Ty.DimVar d
206 :     val dd = Ty.ShapeVar dd
207 :     in
208 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
209 :     --> field(k, d, dd)
210 :     end))
211 :    
212 : jhr 2356 (* curl on 2d and 3d vector fields *)
213 :     local
214 :     val diff0 = Ty.DiffConst 0
215 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
216 :     in
217 :     (* FIXME: we want to be able to require that k > 0, but we don't have a way to do that! *)
218 :     val curl2D = polyVar (N.op_curl, all([DK],
219 :     fn [Ty.DIFF k] =>
220 :     [field' (Ty.DiffVar(k, 0), 2, [2])] --> field' (diff0, 2, [])))
221 :     val curl3D = polyVar (N.op_curl, all([DK],
222 :     fn [Ty.DIFF k] =>
223 :     [field' (Ty.DiffVar(k, 0), 3, [3])] --> field' (diff0, 2, [3])))
224 :     end (* local *)
225 :    
226 : jhr 79 val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
227 :     val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
228 :     val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
229 :     val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
230 :     val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
231 :     val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
232 :     val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
233 :     val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
234 :    
235 :     val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
236 :     val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
237 :     val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
238 :     val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
239 :     val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
240 :     val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
241 :     val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
242 :     val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
243 :    
244 :    
245 : jhr 1640 val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)
246 : jhr 79 val neg_t = polyVar(N.op_neg, all([SK],
247 :     fn [Ty.SHAPE dd] => let
248 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
249 :     in
250 :     [t] --> t
251 :     end))
252 :     val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
253 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
254 :     val k = Ty.DiffVar(k, 0)
255 :     val d = Ty.DimVar d
256 :     val dd = Ty.ShapeVar dd
257 :     in
258 :     [field(k, d, dd)] --> field(k, d, dd)
259 :     end))
260 :    
261 : jhr 1295 (* clamp is overloaded at scalars and vectors *)
262 :     val clamp_rrr = monoVar(N.fn_clamp, [Ty.realTy, Ty.realTy, Ty.realTy] --> Ty.realTy)
263 :     val clamp_vvv = polyVar (N.fn_clamp, allNK(fn tv => let
264 :     val t = tensor[Ty.DimVar tv]
265 :     in
266 :     [t, t, t] --> t
267 :     end))
268 :    
269 : jhr 1116 val lerp3 = polyVar(N.fn_lerp, all([SK],
270 :     fn [Ty.SHAPE dd] => let
271 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
272 :     in
273 :     [t, t, Ty.realTy] --> t
274 :     end))
275 :     val lerp5 = polyVar(N.fn_lerp, all([SK],
276 :     fn [Ty.SHAPE dd] => let
277 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
278 :     in
279 :     [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
280 :     end))
281 : jhr 79
282 : jhr 1640 (* Eigenvalues/vectors of a matrix; we only support this operation on 2x2 and 3x3 matrices, so
283 :     * we overload the function.
284 :     *)
285 :     local
286 :     fun evals d = monoVar (N.fn_evals, [matrix d] --> Ty.T_Sequence(Ty.realTy, d))
287 :     fun evecs d = monoVar (N.fn_evecs, [matrix d] --> Ty.T_Sequence(tensor[d], d))
288 :     in
289 :     val evals2x2 = evals(Ty.DimConst 2)
290 :     val evecs2x2 = evecs(Ty.DimConst 2)
291 :     val evals3x3 = evals(Ty.DimConst 3)
292 :     val evecs3x3 = evecs(Ty.DimConst 3)
293 :     end
294 : jhr 1296
295 : jhr 1640
296 : jhr 79 (***** non-overloaded operators, etc. *****)
297 :    
298 : jhr 1923 (* C math functions *)
299 :     val mathFns : (MathFuns.name * Var.var) list = let
300 :     fun ty n = List.tabulate(MathFuns.arity n, fn _ => Ty.realTy) --> Ty.realTy
301 :     in
302 :     List.map (fn n => (n, monoVar(MathFuns.toAtom n, ty n))) MathFuns.allFuns
303 :     end
304 :    
305 :     (* pseudo-operator for probing a field *)
306 :     val op_probe = polyVar (N.op_at, all([DK, NK, SK],
307 : jhr 79 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
308 :     val k = Ty.DiffVar(k, 0)
309 :     val d = Ty.DimVar d
310 :     val dd = Ty.ShapeVar dd
311 :     in
312 :     [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
313 :     end))
314 :    
315 : jhr 1383 (* differentiation of scalar fields *)
316 :     val op_D = polyVar (N.op_D, all([DK, NK],
317 :     fn [Ty.DIFF k, Ty.DIM d] => let
318 : jhr 79 val k0 = Ty.DiffVar(k, 0)
319 :     val km1 = Ty.DiffVar(k, ~1)
320 :     val d = Ty.DimVar d
321 : jhr 1383 in
322 :     [field(k0, d, Ty.Shape[])]
323 :     --> field(km1, d, Ty.Shape[d])
324 :     end))
325 : cchiw 2585 (* differentiation of higher-order tensor fields *)
326 : jhr 1383 val op_Dotimes = polyVar (N.op_Dotimes, all([DK, NK, SK, NK],
327 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
328 :     val k0 = Ty.DiffVar(k, 0)
329 :     val km1 = Ty.DiffVar(k, ~1)
330 :     val d = Ty.DimVar d
331 :     val d' = Ty.DimVar d'
332 : jhr 79 val dd = Ty.ShapeVar dd
333 :     in
334 : jhr 1383 [field(k0, d, Ty.ShapeExt(dd, d'))]
335 :     --> field(km1, d, Ty.ShapeExt(Ty.ShapeExt(dd, d'), d))
336 : jhr 79 end))
337 :    
338 : cchiw 2515
339 :    
340 : cchiw 2585 (* divergence differentiation of higher-order tensor fields *)
341 :     val op_Ddot = polyVar (N.op_Ddot, all([DK, NK, SK, NK],
342 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
343 :     val k0 = Ty.DiffVar(k, 0)
344 :     val km1 = Ty.DiffVar(k, ~1)
345 :     val d = Ty.DimVar d
346 :     val d' = Ty.DimVar d'
347 :     val ddK = Ty.ShapeVar dd
348 :     in
349 :     [field(k0, d, Ty.ShapeExt(ddK, d'))]
350 :     --> field(k0, d, Ty.ShapeVar dd)
351 :     end))
352 : cchiw 2522
353 : jhr 79 val op_norm = polyVar (N.op_norm, all([SK],
354 :     fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
355 :    
356 :     val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
357 :    
358 :     (* functions *)
359 : jhr 1116 local
360 :     val crossTy = let
361 :     val t = tensor[N3]
362 :     in
363 :     [t, t] --> t
364 :     end
365 :     in
366 :     val op_cross = monoVar (N.op_cross, crossTy)
367 :     end
368 :    
369 : cchiw 2603
370 :    
371 :     val op_crossField = polyVar (N.op_cross, all([DK],
372 :     fn [Ty.DIFF k] => let
373 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
374 :     (*Hard-coded 2, should be 3 *)
375 :     val f = field' (Ty.DiffVar(k, 0), 2, [2])
376 :     in
377 :     [f, f] --> f
378 :     end))
379 :    
380 : jhr 1116 (* the inner product operator (including dot product) is treated as a special case in the
381 :     * typechecker. It is not included in the basis environment, but we define its type scheme
382 :     * here. There is an implicit constraint on its type to have the following scheme:
383 :     *
384 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
385 :     *)
386 :     val op_inner = polyVar (N.op_dot, all([SK, SK, SK],
387 :     fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
388 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
389 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
390 :    
391 : cchiw 2584
392 : cchiw 2585 val op_innerField = polyVar (N.op_dot, all([DK, SK,NK, SK,SK],
393 :     fn [Ty.DIFF k,Ty.SHAPE dd1, Ty.DIM d, Ty.SHAPE dd2,Ty.SHAPE dd3] => let
394 : cchiw 2606
395 : cchiw 2585 val k0=Ty.DiffVar(k, 0)
396 :     val d' = Ty.DimVar d
397 : cchiw 2606
398 :    
399 : cchiw 2585 val _ = print(String.concat["\n Basis Var Inner product Field. "])
400 : cchiw 2584
401 : cchiw 2606 val t1 = Ty.T_Field{diff = k0, dim = Ty.DimVar d, shape = Ty.ShapeVar dd1}
402 :     val t2 = Ty.T_Field{diff = k0, dim = Ty.DimVar d, shape = Ty.ShapeVar dd2}
403 :     val t3 = Ty.T_Field{diff = k0, dim = Ty.DimVar d, shape = Ty.ShapeVar dd3}
404 :    
405 : cchiw 2585 (*
406 :     val f = field(k0, d', Ty.Shape[ Ty.DimConst 2])
407 : cchiw 2584
408 : cchiw 2606 val h = field(Ty.DiffConst(1), Ty.DimConst 2 , Ty.Shape[])
409 :    
410 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}*)
411 : cchiw 2585 in
412 : cchiw 2584
413 : cchiw 2606 [t1,t2] --> t3
414 : cchiw 2585 end))
415 : cchiw 2584
416 :    
417 : cchiw 2585
418 :    
419 :    
420 : jhr 2356 (* the colon (or double-dot) product operator is treated as a special case in the
421 :     * typechecker. It is not included in the basis environment, but we define its type
422 :     * schemehere. There is an implicit constraint on its type to have the following scheme:
423 :     *
424 :     * ALL[sigma1, d1, d2, sigma2] .
425 :     * tensor[sigma1, d1, d2] * tensor[d1, d2, sigma2] -> tensor[sigma1, sigma2]
426 :     *)
427 :     val op_colon = polyVar (N.op_colon, all([SK, SK, SK],
428 :     fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
429 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
430 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
431 :    
432 : jhr 2492 (* load image from nrrd *)
433 :     val fn_image = polyVar (N.fn_image, all([NK, SK],
434 :     fn [Ty.DIM d, Ty.SHAPE dd] => let
435 :     val d = Ty.DimVar d
436 :     val dd = Ty.ShapeVar dd
437 :     in
438 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
439 :     end))
440 :    
441 : jhr 143 val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
442 : jhr 79 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
443 :     val k = Ty.DiffVar(k, 0)
444 :     val d = Ty.DimVar d
445 :     val dd = Ty.ShapeVar dd
446 :     in
447 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
448 :     --> Ty.T_Bool
449 :     end))
450 :    
451 : jhr 143 val fn_load = polyVar (N.fn_load, all([NK, SK],
452 : jhr 79 fn [Ty.DIM d, Ty.SHAPE dd] => let
453 :     val d = Ty.DimVar d
454 :     val dd = Ty.ShapeVar dd
455 :     in
456 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
457 :     end))
458 :    
459 : jhr 143 val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
460 :     val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
461 :    
462 : jhr 83 val fn_modulate = polyVar (N.fn_modulate, all([NK],
463 :     fn [Ty.DIM d] => let
464 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
465 :     in
466 :     [t, t] --> t
467 :     end))
468 :    
469 : jhr 1116 val fn_normalize = polyVar (N.fn_normalize, all([NK],
470 :     fn [Ty.DIM d] => let
471 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
472 :     in
473 :     [t] --> t
474 :     end))
475 :    
476 : cchiw 2585 (* outer products*)
477 :    
478 : jhr 1116 local
479 :     fun mkOuter [Ty.DIM d1, Ty.DIM d2] = let
480 :     val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1])
481 :     val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d2])
482 :     val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2])
483 : cchiw 2585
484 :    
485 :    
486 :    
487 : jhr 1116 in
488 :     [vt1, vt2] --> mt
489 :     end
490 :     in
491 :     val op_outer = polyVar (N.op_outer, all([NK, NK], mkOuter))
492 :     end
493 :    
494 : cchiw 2584
495 : cchiw 2585
496 :    
497 :    
498 : cchiw 2584 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
499 :    
500 : cchiw 2585 local
501 :     fun mkOuterField [Ty.DIFF k, Ty.DIM d] = let
502 :     val k0=Ty.DiffVar(k, 0)
503 :     val d' = Ty.DimVar d
504 :     val f = field(k0, d', Ty.Shape[d'])
505 :     (* val h = field(k0, d', Ty.Shape[d',d'])*)
506 : cchiw 2603 val h = field(k0, d', Ty.Shape[ d', d'])
507 : cchiw 2585 in
508 :     [f, f] --> h
509 :     end
510 :     in
511 :     val op_outerField = polyVar (N.op_outer, all([DK,NK], mkOuterField))
512 :     end
513 :     (*
514 :     local
515 :     fun mkOuterFieldTen [Ty.DIFF k, Ty.DIM d] = let
516 :     val k0=Ty.DiffVar(k, 0)
517 :     val d' = Ty.DimVar d
518 : cchiw 2584
519 : cchiw 2585 val t = Ty.T_Tensor(Ty.Shape[d'])
520 :     val f = field(k0, d', Ty.Shape[d'])
521 :     val h = field(k0, d', Ty.Shape[d',d'])
522 : cchiw 2584
523 : cchiw 2585 in
524 :     [f, t] --> h
525 : cchiw 2584 end
526 :     in
527 : cchiw 2585 val op_outerFieldTen = polyVar (N.op_outer, all([DK,NK], mkOuterFieldTen))
528 : cchiw 2584 end
529 :     *)
530 :    
531 : cchiw 2585
532 :    
533 :     (* Fields of any shape, and dimension
534 :     fun mkOuterField [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2] = let
535 :     val d1 = Ty.ShapeVar dd1
536 :     val d2 = Ty.ShapeVar dd2
537 :     val k0=Ty.DiffVar(k, 0)
538 :     val d' = Ty.DimVar d
539 :    
540 :     val f = field(k0, d', d1)
541 :     val g = field(k0, d', d2)
542 :     val h = field'(k0, 2, [2,2])
543 :    
544 :    
545 :     in
546 :     [f, g] --> h
547 :     end
548 :     in
549 :     val op_outerField = polyVar (N.op_outer, all([DK,NK, SK,SK], mkOuterField))
550 :     end
551 :    
552 :    
553 :     *)
554 :    
555 : jhr 91 val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK],
556 :     fn [Ty.DIM d] => let
557 :     val d = Ty.DimVar d
558 :     in
559 : jhr 1116 [matrix d] --> tensor[d]
560 : jhr 91 end))
561 : jhr 79
562 : jhr 1116 val fn_trace = polyVar (N.fn_trace, all([NK],
563 : jhr 1640 fn [Ty.DIM d] => [matrix(Ty.DimVar d)] --> Ty.realTy))
564 : jhr 1116
565 : cchiw 2603 local
566 :     fun mktraceField [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd1] = let
567 :     val k0=Ty.DiffVar(k, 0)
568 :     val d' = Ty.DimVar d
569 :     val d1 = Ty.ShapeVar dd1
570 :     val f = field(k0, d', Ty.ShapeExt(Ty.ShapeExt(d1,d'),d'))
571 :     val h = field(k0, d', d1)
572 :     in
573 :     [f] --> h
574 :     end
575 :     in
576 :     val fn_traceField = polyVar (N.fn_trace, all([DK,NK,SK], mktraceField))
577 :     end
578 :    
579 :    
580 :    
581 : jhr 2356 val fn_transpose = polyVar (N.fn_transpose, all([NK, NK],
582 :     fn [Ty.DIM d1, Ty.DIM d2] =>
583 :     [tensor[Ty.DimVar d1, Ty.DimVar d2]] --> tensor[Ty.DimVar d2, Ty.DimVar d1]))
584 :    
585 : cchiw 2603
586 :     local
587 :     fun mktransposeField [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] = let
588 :     val k0=Ty.DiffVar(k, 0)
589 :     val d' = Ty.DimVar d
590 :     val a' = Ty.DimVar a
591 :     val b' = Ty.DimVar b
592 :     val f = field(k0, d', Ty.Shape[a',b'])
593 :     val h = field(k0, d', Ty.Shape[b',a'])
594 :     in
595 :     [f] --> h
596 :     end
597 :     in
598 :     val fn_transposeField = polyVar (N.fn_transpose, all([DK,NK,NK,NK], mktransposeField))
599 :     end
600 :    
601 :    
602 : jhr 79 (* kernels *)
603 : jhr 169 (* FIXME: we should really get the continuity info from the kernels themselves *)
604 : jhr 83 val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2))
605 : jhr 169 val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4))
606 : jhr 2356 val kn_c4hexic = monoVar (N.kn_c4hexic, Ty.T_Kernel(Ty.DiffConst 4))
607 : jhr 1116 val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 1))
608 : jhr 83 val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0))
609 : jhr 1116 (* kernels with false claims of differentiability, for pedagogy *)
610 :     val kn_c1tent = monoVar (N.kn_c1tent, Ty.T_Kernel(Ty.DiffConst 1))
611 :     val kn_c2ctmr = monoVar (N.kn_c2ctmr, Ty.T_Kernel(Ty.DiffConst 2))
612 : jhr 79
613 : jhr 1116 (***** internal variables *****)
614 : jhr 406
615 : jhr 1116 (* integer to real conversion *)
616 :     val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy)
617 :    
618 :     (* identity matrix *)
619 :     val identity = polyVar (Atom.atom "$id", allNK (fn dv => [] --> matrix(Ty.DimVar dv)))
620 :    
621 :     (* zero tensor *)
622 :     val zero = polyVar (Atom.atom "$zero", all ([SK],
623 :     fn [Ty.SHAPE dd] => [] --> Ty.T_Tensor(Ty.ShapeVar dd)))
624 :    
625 : jhr 1640 (* sequence subscript *)
626 :     val subscript = polyVar (Atom.atom "$sub", all ([TK, NK],
627 :     fn [Ty.TYPE tv, Ty.DIM d] =>
628 :     [Ty.T_Sequence(Ty.T_Var tv, Ty.DimVar d), Ty.T_Int] --> Ty.T_Var tv))
629 : jhr 79 end (* local *)
630 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0