Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/basis/basis-vars.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3055 - (view) (download)

1 : jhr 79 (* basis-vars.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 79 * All rights reserved.
5 :     *
6 :     * This module defines the AST variables for the built in operators and functions.
7 :     *)
8 :    
9 :     structure BasisVars =
10 :     struct
11 :     local
12 :     structure N = BasisNames
13 :     structure Ty = Types
14 :     structure MV = MetaVar
15 :    
16 : jhr 81 fun --> (tys1, ty) = Ty.T_Fun(tys1, ty)
17 : jhr 79 infix -->
18 :    
19 :     val N2 = Ty.DimConst 2
20 :     val N3 = Ty.DimConst 3
21 :    
22 :     (* short names for kinds *)
23 :     val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
24 : jhr 470 fun DK () : Ty.meta_var = Ty.DIFF(MV.newDiffVar 0)
25 : jhr 79 val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
26 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
27 :    
28 :     fun ty t = ([], t)
29 :     fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
30 :     val tvs = List.map (fn mk => mk()) kinds
31 :     in
32 :     (tvs, mkTy tvs)
33 :     end
34 :     fun allNK mkTy = let
35 :     val tv = MV.newDimVar()
36 :     in
37 :     ([Ty.DIM tv], mkTy tv)
38 :     end
39 :    
40 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
41 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
42 : jhr 1116 fun matrix d = tensor[d,d]
43 : jhr 79
44 :     fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
45 :     fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
46 :     in
47 :    
48 :     (* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
49 :     * two fields with different differentiation levels to be added.
50 :     *)
51 :    
52 :     (* overloaded operators; the naming convention is to use the operator name followed
53 :     * by the argument type signature, where
54 :     * i -- int
55 :     * b -- bool
56 :     * r -- real (tensor[])
57 :     * t -- tensor[shape]
58 : jhr 470 * f -- field#k(d)[shape]
59 : jhr 2926 * s -- field#k(d)[]
60 : jhr 79 *)
61 :    
62 :     val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
63 :     val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
64 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
65 :     in
66 :     [t, t] --> t
67 :     end))
68 : jhr 470 val add_ff = polyVar(N.op_add, all([DK,NK,SK],
69 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
70 : cchiw 2906 val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
71 : jhr 470 in
72 : cchiw 2906 [f, f] --> f
73 : jhr 470 end))
74 : jhr 79
75 : cchiw 2906 val add_ft = polyVar(N.op_add, all([DK,NK,SK], (* field + scalar *)
76 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
77 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
78 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
79 :     in
80 :     [f, t] --> f
81 :     end))
82 :    
83 :     val add_tf = polyVar(N.op_add, all([DK,NK,SK], (* scalar + field *)
84 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
85 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
86 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
87 :     in
88 :     [t, f] --> f
89 :     end))
90 :    
91 : jhr 79 val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
92 :     val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
93 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
94 :     in
95 :     [t, t] --> t
96 :     end))
97 : jhr 470 val sub_ff = polyVar(N.op_sub, all([DK,NK,SK],
98 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
99 : cchiw 2906 val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
100 : jhr 470 in
101 : cchiw 2906 [f, f] --> f
102 : jhr 470 end))
103 : cchiw 2906 val sub_ft = polyVar(N.op_add, all([DK,NK,SK], (* field - scalar *)
104 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
105 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
106 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
107 :     in
108 :     [f, t] --> f
109 :     end))
110 : jhr 79
111 : cchiw 2906 val sub_tf = polyVar(N.op_add, all([DK,NK,SK], (* scalar - field *)
112 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
113 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
114 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
115 :     in
116 :     [t, f] --> f
117 :     end))
118 :    
119 :    
120 : jhr 79 (* note that we assume that operators are tested in the order defined here, so that mul_rr
121 :     * takes precedence over mul_rt and mul_tr!
122 :     *)
123 :     val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
124 :     val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
125 :     val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
126 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
127 :     in
128 :     [Ty.realTy, t] --> t
129 :     end))
130 :     val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
131 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
132 :     in
133 :     [t, Ty.realTy] --> t
134 :     end))
135 : jhr 470 val mul_rf = polyVar(N.op_mul, all([DK,NK,SK],
136 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
137 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
138 :     in
139 :     [Ty.realTy, t] --> t
140 :     end))
141 :     val mul_fr = polyVar(N.op_mul, all([DK,NK,SK],
142 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
143 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
144 :     in
145 :     [t, Ty.realTy] --> t
146 :     end))
147 : cchiw 2576 val mul_ss = polyVar(N.op_mul, all([DK,NK],
148 : jhr 2926 fn [Ty.DIFF k, Ty.DIM d] => let
149 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
150 :     in
151 :     [t, t] --> t
152 :     end))
153 : cchiw 2576 val mul_sf = polyVar(N.op_mul, all([DK,NK,SK],
154 : jhr 2926 fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
155 :     val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
156 :     val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
157 :     in
158 :     [a,b] --> b
159 :     end))
160 : cchiw 2845 val mul_fs = polyVar(N.op_mul, all([DK,NK,SK],
161 : jhr 2926 fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
162 :     val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
163 :     val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
164 :     in
165 :     [b,a] --> b
166 :     end))
167 : cchiw 2576
168 : jhr 79 val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
169 :     val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
170 :     val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
171 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
172 :     in
173 :     [t, Ty.realTy] --> t
174 :     end))
175 : jhr 470 val div_fr = polyVar(N.op_div, all([DK,NK,SK],
176 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
177 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
178 :     in
179 :     [t, Ty.realTy] --> t
180 :     end))
181 : jhr 2926 val div_ss = polyVar(N.op_mul, all([DK,NK],
182 :     fn [Ty.DIFF k, Ty.DIM d] => let
183 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
184 :     in
185 :     [t, t] --> t
186 :     end))
187 : cchiw 2867 val div_fs = polyVar(N.op_div, all([DK,DK,NK,SK],
188 : jhr 2926 fn [Ty.DIFF k,Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd] => let
189 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
190 :     val s = Ty.T_Field{diff = Ty.DiffVar(k2, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
191 :     in
192 :     [f,s] --> f
193 :     end))
194 : cchiw 2867
195 : jhr 1116 (* exponentiation; we distinguish between integer and real exponents to allow x^2 to be compiled
196 :     * as x*x.
197 :     *)
198 :     val exp_ri = monoVar(N.op_exp, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
199 :     val exp_rr = monoVar(N.op_exp, [Ty.realTy, Ty.realTy] --> Ty.realTy)
200 :    
201 :     val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK],
202 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
203 :     val k = Ty.DiffVar(k, 0)
204 :     val d = Ty.DimVar d
205 :     val dd = Ty.ShapeVar dd
206 :     in
207 :     [Ty.T_Image{dim=d, shape=dd}, Ty.T_Kernel k]
208 :     --> field(k, d, dd)
209 :     end))
210 :     val convolve_kv = polyVar (N.op_convolve, all([DK, NK, SK],
211 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
212 :     val k = Ty.DiffVar(k, 0)
213 :     val d = Ty.DimVar d
214 :     val dd = Ty.ShapeVar dd
215 :     in
216 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
217 :     --> field(k, d, dd)
218 :     end))
219 :    
220 : jhr 2356 (* curl on 2d and 3d vector fields *)
221 :     local
222 :     val diff0 = Ty.DiffConst 0
223 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
224 :     in
225 :     (* FIXME: we want to be able to require that k > 0, but we don't have a way to do that! *)
226 :     val curl2D = polyVar (N.op_curl, all([DK],
227 : jhr 2948 fn [Ty.DIFF k] => let
228 :     val km1 = Ty.DiffVar(k, ~1)
229 :     in
230 : cchiw 2900 [field' (Ty.DiffVar(k, 0), 2, [2])] --> field' (km1, 2, [])
231 : jhr 2948 end))
232 : jhr 2356 val curl3D = polyVar (N.op_curl, all([DK],
233 : jhr 2948 fn [Ty.DIFF k] =>let
234 :     val km1 = Ty.DiffVar(k, ~1)
235 :     in
236 :     [field' (Ty.DiffVar(k, 0), 3, [3])] --> field' (km1, 3, [3])
237 :     end))
238 : jhr 2356 end (* local *)
239 :    
240 : jhr 79 val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
241 :     val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
242 :     val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
243 :     val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
244 :     val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
245 :     val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
246 :     val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
247 :     val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
248 :     val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
249 :     val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
250 :     val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
251 :     val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
252 :     val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
253 :     val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
254 :     val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
255 :     val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
256 : jhr 1640 val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)
257 : jhr 79 val neg_t = polyVar(N.op_neg, all([SK],
258 :     fn [Ty.SHAPE dd] => let
259 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
260 :     in
261 :     [t] --> t
262 :     end))
263 :     val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
264 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
265 :     val k = Ty.DiffVar(k, 0)
266 :     val d = Ty.DimVar d
267 :     val dd = Ty.ShapeVar dd
268 :     in
269 :     [field(k, d, dd)] --> field(k, d, dd)
270 :     end))
271 :    
272 : jhr 1295 (* clamp is overloaded at scalars and vectors *)
273 :     val clamp_rrr = monoVar(N.fn_clamp, [Ty.realTy, Ty.realTy, Ty.realTy] --> Ty.realTy)
274 :     val clamp_vvv = polyVar (N.fn_clamp, allNK(fn tv => let
275 :     val t = tensor[Ty.DimVar tv]
276 :     in
277 :     [t, t, t] --> t
278 :     end))
279 :    
280 : jhr 1116 val lerp3 = polyVar(N.fn_lerp, all([SK],
281 :     fn [Ty.SHAPE dd] => let
282 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
283 :     in
284 :     [t, t, Ty.realTy] --> t
285 :     end))
286 :     val lerp5 = polyVar(N.fn_lerp, all([SK],
287 :     fn [Ty.SHAPE dd] => let
288 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
289 :     in
290 :     [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
291 :     end))
292 : jhr 79
293 : jhr 1640 (* Eigenvalues/vectors of a matrix; we only support this operation on 2x2 and 3x3 matrices, so
294 :     * we overload the function.
295 :     *)
296 :     local
297 :     fun evals d = monoVar (N.fn_evals, [matrix d] --> Ty.T_Sequence(Ty.realTy, d))
298 :     fun evecs d = monoVar (N.fn_evecs, [matrix d] --> Ty.T_Sequence(tensor[d], d))
299 :     in
300 :     val evals2x2 = evals(Ty.DimConst 2)
301 :     val evecs2x2 = evecs(Ty.DimConst 2)
302 :     val evals3x3 = evals(Ty.DimConst 3)
303 :     val evecs3x3 = evecs(Ty.DimConst 3)
304 :     end
305 : jhr 1296
306 : jhr 79 (***** non-overloaded operators, etc. *****)
307 :    
308 : jhr 1923 (* C math functions *)
309 :     val mathFns : (MathFuns.name * Var.var) list = let
310 :     fun ty n = List.tabulate(MathFuns.arity n, fn _ => Ty.realTy) --> Ty.realTy
311 :     in
312 :     List.map (fn n => (n, monoVar(MathFuns.toAtom n, ty n))) MathFuns.allFuns
313 :     end
314 :    
315 :     (* pseudo-operator for probing a field *)
316 :     val op_probe = polyVar (N.op_at, all([DK, NK, SK],
317 : jhr 79 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
318 :     val k = Ty.DiffVar(k, 0)
319 :     val d = Ty.DimVar d
320 :     val dd = Ty.ShapeVar dd
321 :     in
322 :     [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
323 :     end))
324 :    
325 : jhr 1383 (* differentiation of scalar fields *)
326 :     val op_D = polyVar (N.op_D, all([DK, NK],
327 :     fn [Ty.DIFF k, Ty.DIM d] => let
328 : jhr 79 val k0 = Ty.DiffVar(k, 0)
329 :     val km1 = Ty.DiffVar(k, ~1)
330 :     val d = Ty.DimVar d
331 : jhr 1383 in
332 :     [field(k0, d, Ty.Shape[])]
333 :     --> field(km1, d, Ty.Shape[d])
334 :     end))
335 : cchiw 2845
336 : cchiw 2585 (* differentiation of higher-order tensor fields *)
337 : jhr 1383 val op_Dotimes = polyVar (N.op_Dotimes, all([DK, NK, SK, NK],
338 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
339 :     val k0 = Ty.DiffVar(k, 0)
340 :     val km1 = Ty.DiffVar(k, ~1)
341 :     val d = Ty.DimVar d
342 :     val d' = Ty.DimVar d'
343 : jhr 79 val dd = Ty.ShapeVar dd
344 :     in
345 : jhr 1383 [field(k0, d, Ty.ShapeExt(dd, d'))]
346 :     --> field(km1, d, Ty.ShapeExt(Ty.ShapeExt(dd, d'), d))
347 : jhr 79 end))
348 :    
349 : cchiw 2585 (* divergence differentiation of higher-order tensor fields *)
350 :     val op_Ddot = polyVar (N.op_Ddot, all([DK, NK, SK, NK],
351 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
352 :     val k0 = Ty.DiffVar(k, 0)
353 :     val km1 = Ty.DiffVar(k, ~1)
354 :     val d = Ty.DimVar d
355 :     val d' = Ty.DimVar d'
356 : jhr 2948 val dd' = Ty.ShapeVar dd
357 : cchiw 2585 in
358 : jhr 2948 [field(k0, d, Ty.ShapeExt(dd', d'))]
359 :     --> field(k0, d, dd')
360 : cchiw 2585 end))
361 : cchiw 2522
362 : jhr 2948 val op_norm_t = polyVar (N.op_norm, all([SK],
363 :     fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
364 :     val op_norm_f = polyVar (N.op_norm, all([DK, NK, SK],
365 :     fn [Ty.DIFF k,Ty.DIM d, Ty.SHAPE dd1] => let
366 :     val k = Ty.DiffVar(k, 0)
367 :     val d = Ty.DimVar d
368 :     val f1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
369 :     val f2 = Ty.T_Field{diff = k, dim = d, shape = Ty.Shape []}
370 :     in
371 :     [f1] --> f2
372 :     end))
373 : jhr 79
374 :     val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
375 :    
376 :     (* functions *)
377 : jhr 1116 local
378 :     val crossTy = let
379 : jhr 2948 val t = tensor[N3]
380 :     in
381 :     [t, t] --> t
382 :     end
383 :     val crossTy2 = let
384 :     val t = tensor[N2]
385 :     in
386 :     [t, t] --> Ty.realTy
387 :     end
388 : jhr 1116 in
389 : jhr 2948 val op_cross2_tt = monoVar (N.op_cross, crossTy2)
390 :     val op_cross3_tt = monoVar (N.op_cross, crossTy)
391 : jhr 1116 end
392 :    
393 : cchiw 2906 val op_cross2_ff = polyVar (N.op_cross, all([DK],
394 : jhr 2926 fn [Ty.DIFF k] => let
395 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
396 : jhr 2948 val k0 = Ty.DiffVar(k, 0)
397 : jhr 2926 val f = field' (k0, 2, [2])
398 :     val t1 = field' (k0, 2, [])
399 :     in
400 :     [f, f] --> t1
401 :     end))
402 : cchiw 2847
403 : cchiw 2906 val op_cross3_ff = polyVar (N.op_cross, all([DK],
404 : jhr 2926 fn [Ty.DIFF k] => let
405 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
406 :     val f = field' (Ty.DiffVar(k, 0), 3, [3])
407 :     in
408 :     [f, f] --> f
409 :     end))
410 : cchiw 2603
411 : jhr 1116 (* the inner product operator (including dot product) is treated as a special case in the
412 :     * typechecker. It is not included in the basis environment, but we define its type scheme
413 :     * here. There is an implicit constraint on its type to have the following scheme:
414 :     *
415 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
416 :     *)
417 : cchiw 2906 val op_inner_tt = polyVar (N.op_dot, all([SK, SK, SK],
418 : jhr 2926 fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
419 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
420 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
421 : jhr 1116
422 : jhr 2926 val op_inner_tf = polyVar (N.op_dot, all([DK ,NK, SK, SK, SK],
423 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
424 :     val k = Ty.DiffVar(k, 0)
425 :     val d = Ty.DimVar d
426 :     val t1 = Ty.T_Tensor(Ty.ShapeVar dd1)
427 :     val t2 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd2}
428 :     val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
429 :     in
430 :     [t1, t2] --> t3
431 :     end))
432 : cchiw 2584
433 : jhr 2926 val op_inner_ft = polyVar (N.op_dot, all([DK, NK, SK, SK, SK],
434 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
435 :     val k = Ty.DiffVar(k, 0)
436 :     val d = Ty.DimVar d
437 :     val t1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
438 :     val t2 = Ty.T_Tensor(Ty.ShapeVar dd2)
439 :     val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
440 :     in
441 :     [t1, t2] --> t3
442 :     end))
443 : cchiw 2584
444 : cchiw 2955 val op_inner_ff = polyVar (N.op_dot, all([DK,DK, NK, SK, SK, SK],
445 :     fn [Ty.DIFF k1,Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
446 :     val k1 = Ty.DiffVar(k1, 0)
447 :     val k2 = Ty.DiffVar(k2, 0)
448 : jhr 2926 val d = Ty.DimVar d
449 : cchiw 2955 val t1 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd1}
450 :     val t2 = Ty.T_Field{diff = k2, dim = d, shape = Ty.ShapeVar dd2}
451 :     val t3 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd3}
452 : jhr 2926 in
453 :     [t1, t2] --> t3
454 :     end))
455 : cchiw 2906
456 : jhr 2356 (* the colon (or double-dot) product operator is treated as a special case in the
457 :     * typechecker. It is not included in the basis environment, but we define its type
458 :     * schemehere. There is an implicit constraint on its type to have the following scheme:
459 :     *
460 :     * ALL[sigma1, d1, d2, sigma2] .
461 :     * tensor[sigma1, d1, d2] * tensor[d1, d2, sigma2] -> tensor[sigma1, sigma2]
462 :     *)
463 : cchiw 2906 val op_colon_tt = polyVar (N.op_colon, all([SK, SK, SK],
464 : jhr 2948 fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
465 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
466 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
467 : cchiw 2906 val op_colon_ff = polyVar (N.op_colon, all([DK, SK,NK,SK,SK],
468 : jhr 2948 fn [Ty.DIFF k,Ty.SHAPE dd1, Ty.DIM d, Ty.SHAPE dd2,Ty.SHAPE dd3] =>let
469 :     val k0 = Ty.DiffVar(k, 0)
470 :     val d' = Ty.DimVar d
471 :     val t1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
472 :     val t2 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd2}
473 :     val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
474 :     in
475 :     [t1,t2] --> t3
476 :     end))
477 : cchiw 2611
478 : jhr 2492 (* load image from nrrd *)
479 :     val fn_image = polyVar (N.fn_image, all([NK, SK],
480 :     fn [Ty.DIM d, Ty.SHAPE dd] => let
481 :     val d = Ty.DimVar d
482 :     val dd = Ty.ShapeVar dd
483 :     in
484 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
485 :     end))
486 :    
487 : jhr 143 val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
488 : jhr 79 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
489 :     val k = Ty.DiffVar(k, 0)
490 :     val d = Ty.DimVar d
491 :     val dd = Ty.ShapeVar dd
492 :     in
493 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
494 :     --> Ty.T_Bool
495 :     end))
496 :    
497 : jhr 143 val fn_load = polyVar (N.fn_load, all([NK, SK],
498 : jhr 79 fn [Ty.DIM d, Ty.SHAPE dd] => let
499 :     val d = Ty.DimVar d
500 :     val dd = Ty.ShapeVar dd
501 :     in
502 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
503 :     end))
504 :    
505 : jhr 143 val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
506 :     val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
507 :    
508 : jhr 83 val fn_modulate = polyVar (N.fn_modulate, all([NK],
509 : jhr 2948 fn [Ty.DIM d] => let
510 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
511 :     in
512 :     [t, t] --> t
513 :     end))
514 : jhr 83
515 : jhr 2948 val fn_normalize_t = polyVar (N.fn_normalize, all([NK],
516 :     fn [Ty.DIM d] => let
517 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
518 :     in
519 :     [t] --> t
520 :     end))
521 :     val fn_normalize_f = polyVar (N.fn_normalize, all([DK,NK, SK],
522 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1] => let
523 :     val k0 = Ty.DiffVar(k, 0)
524 :     val d' = Ty.DimVar d
525 :     val f1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
526 :     in
527 : cchiw 2870 [f1] --> f1
528 : jhr 2948 end))
529 : cchiw 2870
530 : cchiw 2585 (* outer products*)
531 : jhr 1116 local
532 :     fun mkOuter [Ty.DIM d1, Ty.DIM d2] = let
533 :     val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1])
534 :     val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d2])
535 :     val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2])
536 :     in
537 :     [vt1, vt2] --> mt
538 :     end
539 :     in
540 : cchiw 2906 val op_outer_tt = polyVar (N.op_outer, all([NK, NK], mkOuter))
541 : jhr 1116 end
542 :    
543 : cchiw 2584 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
544 :    
545 : cchiw 2585 local
546 : cchiw 3017 fun mkOuterField [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] = let
547 : jhr 2948 val k0 = Ty.DiffVar(k, 0)
548 :     val d' = Ty.DimVar d
549 : cchiw 3017
550 :     val a' = Ty.DimVar a
551 :     val b' = Ty.DimVar b
552 :     val f = field(k0, d', Ty.Shape[a'])
553 :     val g = field(k0, d', Ty.Shape[b'])
554 :     val h = field(k0, d', Ty.Shape[a',b'])
555 : jhr 2948 in
556 : cchiw 3017 [f, g] --> h
557 : jhr 2948 end
558 : cchiw 2585 in
559 : cchiw 3017 val op_outer_ff = polyVar (N.op_outer, all([DK,NK,NK,NK], mkOuterField))
560 : cchiw 2585 end
561 : jhr 2948
562 : jhr 91 val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK],
563 :     fn [Ty.DIM d] => let
564 :     val d = Ty.DimVar d
565 :     in
566 : jhr 1116 [matrix d] --> tensor[d]
567 : jhr 91 end))
568 : jhr 79
569 : cchiw 2906 val fn_trace_t = polyVar (N.fn_trace, all([NK],
570 : jhr 2948 fn [Ty.DIM d] => [matrix(Ty.DimVar d)] --> Ty.realTy))
571 :     val fn_trace_f = polyVar (N.fn_trace, all([DK,NK,SK],
572 : jhr 2949 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1] => let
573 : jhr 2948 val k' = Ty.DiffVar(k, 0)
574 :     val d' = Ty.DimVar d
575 :     val d1 = Ty.ShapeVar dd1
576 :     val f = field(k', d', Ty.ShapeExt(Ty.ShapeExt(d1, d'), d'))
577 :     val h = field(k', d', d1)
578 :     in
579 :     [f] --> h
580 :     end))
581 : jhr 2949
582 : cchiw 2906 val fn_transpose_t = polyVar (N.fn_transpose, all([NK, NK],
583 : jhr 2949 fn [Ty.DIM d1, Ty.DIM d2] =>
584 : jhr 2356 [tensor[Ty.DimVar d1, Ty.DimVar d2]] --> tensor[Ty.DimVar d2, Ty.DimVar d1]))
585 : jhr 2949 val fn_transpose_f = polyVar (N.fn_transpose, all([DK,NK,NK,NK],
586 :     fn [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] => let
587 :     val k0 = Ty.DiffVar(k, 0)
588 :     val d' = Ty.DimVar d
589 :     val a' = Ty.DimVar a
590 :     val b' = Ty.DimVar b
591 :     val f = field(k0, d', Ty.Shape[a',b'])
592 :     val h = field(k0, d', Ty.Shape[b',a'])
593 :     in
594 :     [f] --> h
595 :     end))
596 : jhr 2356
597 : cchiw 3054 (*restrict to 2x2 and 3x3*)
598 :     local
599 :     (*)val crossTy = let
600 :     val t = tensor[N3]
601 :     in
602 :     [t, t] --> t
603 : cchiw 3055 end
604 :     *)
605 :     val detT2 = let
606 : cchiw 3054 val t = matrix N2
607 :     in
608 :     [t] --> Ty.realTy
609 :     end
610 : cchiw 3055
611 :    
612 : cchiw 3054 in
613 :    
614 : cchiw 3055 val fn_det_t2 = monoVar (N.fn_det, detT2)
615 : cchiw 3054 end
616 : cchiw 3055 (*
617 :     val op_det2_f = polyVar (N.fn_det, all([DK],
618 :     fn [Ty.DIFF k] => let
619 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
620 :     val k0 = Ty.DiffVar(k, 0)
621 :     val f = field' (k0, 2, [2])
622 :     val t1 = field' (k0, 2, [])
623 :     in
624 :     [f, f] --> t1
625 :     end))
626 : cchiw 3054
627 :    
628 : cchiw 3055 *)
629 : cchiw 3054
630 : jhr 79 (* kernels *)
631 : jhr 169 (* FIXME: we should really get the continuity info from the kernels themselves *)
632 : jhr 83 val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2))
633 : jhr 169 val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4))
634 : jhr 2356 val kn_c4hexic = monoVar (N.kn_c4hexic, Ty.T_Kernel(Ty.DiffConst 4))
635 : jhr 1116 val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 1))
636 : jhr 83 val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0))
637 : jhr 1116 (* kernels with false claims of differentiability, for pedagogy *)
638 :     val kn_c1tent = monoVar (N.kn_c1tent, Ty.T_Kernel(Ty.DiffConst 1))
639 :     val kn_c2ctmr = monoVar (N.kn_c2ctmr, Ty.T_Kernel(Ty.DiffConst 2))
640 : jhr 79
641 : jhr 1116 (***** internal variables *****)
642 : jhr 406
643 : jhr 1116 (* integer to real conversion *)
644 :     val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy)
645 :    
646 :     (* identity matrix *)
647 :     val identity = polyVar (Atom.atom "$id", allNK (fn dv => [] --> matrix(Ty.DimVar dv)))
648 :    
649 :     (* zero tensor *)
650 :     val zero = polyVar (Atom.atom "$zero", all ([SK],
651 :     fn [Ty.SHAPE dd] => [] --> Ty.T_Tensor(Ty.ShapeVar dd)))
652 :    
653 : jhr 1640 (* sequence subscript *)
654 :     val subscript = polyVar (Atom.atom "$sub", all ([TK, NK],
655 :     fn [Ty.TYPE tv, Ty.DIM d] =>
656 :     [Ty.T_Sequence(Ty.T_Var tv, Ty.DimVar d), Ty.T_Int] --> Ty.T_Var tv))
657 : jhr 79 end (* local *)
658 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0