Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/basis/basis-vars.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5558 - (view) (download)

1 : jhr 3391 (* basis-vars.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 : jhr 5307 * COPYRIGHT (c) 2017 The University of Chicago
6 : jhr 3391 * All rights reserved.
7 :     *
8 :     * This module defines the AST variables for the built in operators and functions.
9 :     *)
10 :    
11 :     structure BasisVars =
12 :     struct
13 :     local
14 :     structure N = BasisNames
15 :     structure Ty = Types
16 :     structure MV = MetaVar
17 :    
18 :     fun --> (tys, ty) = Ty.T_Fun(tys, ty)
19 :     infix -->
20 :    
21 :     val N2 = Ty.DimConst 2
22 :     val N3 = Ty.DimConst 3
23 :    
24 :     (* short names for kinds *)
25 :     val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
26 :     fun DK () : Ty.meta_var = Ty.DIFF(MV.newDiffVar 0)
27 :     val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
28 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
29 :    
30 :     fun ty t = ([], t)
31 :     fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
32 :     val tvs = List.map (fn mk => mk()) kinds
33 :     in
34 :     (tvs, mkTy tvs)
35 :     end
36 :     fun allNK mkTy = let
37 :     val tv = MV.newDimVar()
38 :     in
39 :     ([Ty.DIM tv], mkTy tv)
40 :     end
41 :    
42 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
43 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
44 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
45 :     fun matrix d = tensor[d,d]
46 : jhr 3398 fun dynSeq ty = Ty.T_Sequence(ty, NONE)
47 : jhr 3391
48 : jhr 4060 fun monoVar (name, ty) = Var.newBasis (name, ([], ty))
49 : jhr 3407 fun polyVar arg = Var.newBasis arg
50 : jhr 3391 in
51 :    
52 :     (* overloaded operators; the naming convention is to use the operator name followed
53 :     * by the argument type signature, where
54 :     * i -- int
55 :     * b -- bool
56 :     * r -- real (tensor[])
57 :     * t -- tensor[shape]
58 : jhr 3392 * I -- image(d)[shape]
59 : jhr 3391 * f -- field#k(d)[shape]
60 :     * s -- field#k(d)[]
61 :     * d -- ty{}
62 :     * T -- ty
63 :     *)
64 :    
65 :     (* concatenation of sequences *)
66 :     val at_dT = polyVar (N.op_at, all([TK],
67 :     fn [Ty.TYPE tv] => let
68 : jhr 3398 val seqTyc = dynSeq(Ty.T_Var tv)
69 : jhr 3391 in
70 :     [seqTyc, Ty.T_Var tv] --> seqTyc
71 :     end))
72 :     val at_Td = polyVar (N.op_at, all([TK],
73 :     fn [Ty.TYPE tv] => let
74 : jhr 3398 val seqTyc = dynSeq(Ty.T_Var tv)
75 : jhr 3391 in
76 :     [Ty.T_Var tv, seqTyc] --> seqTyc
77 :     end))
78 :     val at_dd = polyVar (N.op_at, all([TK],
79 :     fn [Ty.TYPE tv] => let
80 : jhr 3398 val seqTyc = dynSeq(Ty.T_Var tv)
81 : jhr 3391 in
82 :     [seqTyc, seqTyc] --> seqTyc
83 :     end))
84 :    
85 :     val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
86 :     val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
87 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
88 :     in
89 :     [t, t] --> t
90 :     end))
91 :     val add_ff = polyVar(N.op_add, all([DK,NK,SK],
92 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
93 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
94 :     in
95 :     [f, f] --> f
96 :     end))
97 :     val add_ft = polyVar(N.op_add, all([DK,NK,SK], (* field + scalar *)
98 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
99 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
100 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
101 :     in
102 :     [f, t] --> f
103 :     end))
104 :     val add_tf = polyVar(N.op_add, all([DK,NK,SK], (* scalar + field *)
105 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
106 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
107 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
108 :     in
109 :     [t, f] --> f
110 :     end))
111 :    
112 :     val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
113 :     val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
114 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
115 :     in
116 :     [t, t] --> t
117 :     end))
118 :     val sub_ff = polyVar(N.op_sub, all([DK,NK,SK],
119 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
120 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
121 :     in
122 :     [f, f] --> f
123 :     end))
124 :     val sub_ft = polyVar(N.op_sub, all([DK,NK,SK], (* field - scalar *)
125 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
126 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
127 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
128 :     in
129 :     [f, t] --> f
130 :     end))
131 :     val sub_tf = polyVar(N.op_sub, all([DK,NK,SK], (* scalar - field *)
132 :     fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
133 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
134 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
135 :     in
136 :     [t, f] --> f
137 :     end))
138 :    
139 :     (* note that we assume that operators are tested in the order defined here, so that mul_rr
140 :     * takes precedence over mul_rt and mul_tr!
141 :     *)
142 :     val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
143 :     val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
144 :     val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
145 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
146 :     in
147 :     [Ty.realTy, t] --> t
148 :     end))
149 :     val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
150 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
151 :     in
152 :     [t, Ty.realTy] --> t
153 :     end))
154 :     val mul_rf = polyVar(N.op_mul, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
155 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
156 :     in
157 :     [Ty.realTy, t] --> t
158 :     end))
159 :     val mul_fr = polyVar(N.op_mul, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
160 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
161 :     in
162 :     [t, Ty.realTy] --> t
163 :     end))
164 :     val mul_st = polyVar(N.op_mul, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
165 : jhr 3392 val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape[]}
166 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
167 :     val g = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
168 :     in
169 :     [f, t] --> g
170 :     end))
171 : jhr 3391 val mul_ts = polyVar(N.op_mul, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
172 : jhr 3392 val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape[]}
173 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
174 :     val g = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
175 :     in
176 :     [t, f] --> g
177 :     end))
178 : jhr 3391 val mul_ss = polyVar(N.op_mul, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let
179 : jhr 3392 val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
180 :     in
181 :     [t, t] --> t
182 :     end))
183 : jhr 3391 val mul_sf = polyVar(N.op_mul, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
184 : jhr 3392 val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
185 :     val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
186 :     in
187 :     [a, b] --> b
188 :     end))
189 : jhr 3391 val mul_fs = polyVar(N.op_mul, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
190 : jhr 3392 val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
191 :     val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
192 :     in
193 :     [b, a] --> b
194 :     end))
195 : jhr 3391
196 :     val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
197 :     val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
198 :     val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
199 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
200 :     in
201 :     [t, Ty.realTy] --> t
202 :     end))
203 :     val div_fr = polyVar(N.op_div, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
204 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
205 :     in
206 :     [t, Ty.realTy] --> t
207 :     end))
208 :     val div_ss = polyVar(N.op_mul, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let
209 :     val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
210 :     in
211 : jhr 3392 [t, t] --> t
212 : jhr 3391 end))
213 :     val div_fs = polyVar(N.op_div, all([DK,DK,NK,SK], fn [Ty.DIFF k, Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd] => let
214 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
215 :     val s = Ty.T_Field{diff = Ty.DiffVar(k2, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
216 :     in
217 : jhr 3392 [f, s] --> f
218 : jhr 3391 end))
219 : jhr 4002 val div_ts = polyVar(N.op_div, all([DK,NK,SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
220 : cchiw 4000 val t = Ty.T_Tensor(Ty.ShapeVar dd)
221 :     val s = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
222 :     val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
223 :     in
224 :     [t, s] --> f
225 :     end))
226 : jhr 3391
227 : jhr 3482 (* power; we distinguish between integer and real exponents to allow x^2 to be compiled
228 :     * as x*x. The power operation of fields is restricted by the typechecker to constant
229 :     * integer arguments.
230 : jhr 3391 *)
231 : jhr 5230 (* FIXME: add pow_ii *)
232 : jhr 3482 val pow_ri = monoVar(N.op_pow, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
233 :     val pow_rr = monoVar(N.op_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy)
234 :     val pow_si = polyVar (N.op_pow, all([DK, NK], fn [Ty.DIFF k, Ty.DIM d] => let
235 : jhr 3519 val k = Ty.DiffVar(k, 0)
236 :     val d = Ty.DimVar d
237 :     val fld = field(k, d, Ty.Shape[])
238 :     in
239 :     [fld, Ty.T_Int] --> fld
240 :     end))
241 : jhr 5230
242 : jhr 3391 val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
243 : jhr 3519 val k = Ty.DiffVar(k, 0)
244 :     val d = Ty.DimVar d
245 :     val dd = Ty.ShapeVar dd
246 :     in
247 :     [Ty.T_Image{dim=d, shape=dd}, Ty.T_Kernel k] --> field(k, d, dd)
248 :     end))
249 : jhr 3391 val convolve_kv = polyVar (N.op_convolve, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
250 : jhr 3519 val k = Ty.DiffVar(k, 0)
251 :     val d = Ty.DimVar d
252 :     val dd = Ty.ShapeVar dd
253 :     in
254 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}] --> field(k, d, dd)
255 :     end))
256 : jhr 3391
257 :     (* curl on 2d and 3d vector fields *)
258 :     local
259 :     fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
260 :     in
261 :     val curl2D = polyVar (N.op_curl, all([DK], fn [Ty.DIFF k] => let
262 : jhr 3519 val km1 = Ty.DiffVar(k, ~1)
263 : jhr 3807 in
264 : jhr 3519 [field' (Ty.DiffVar(k, 0), 2, [2])] --> field' (km1, 2, [])
265 :     end))
266 : jhr 3391 val curl3D = polyVar (N.op_curl, all([DK], fn [Ty.DIFF k] =>let
267 : jhr 3519 val km1 = Ty.DiffVar(k, ~1)
268 : jhr 3807 in
269 : jhr 3519 [field' (Ty.DiffVar(k, 0), 3, [3])] --> field' (km1, 3, [3])
270 :     end))
271 : jhr 3391 end (* local *)
272 :    
273 :     val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
274 :     val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
275 :     val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
276 :     val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
277 :     val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
278 :     val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
279 :     val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
280 :     val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
281 :    
282 :     val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
283 :     val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
284 :     val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
285 :     val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
286 :     val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
287 :     val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
288 :     val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
289 :     val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
290 :    
291 :     val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)
292 : jhr 3482 val neg_t = polyVar(N.op_neg, all([SK], fn [Ty.SHAPE dd] => let
293 : jhr 3519 val t = Ty.T_Tensor(Ty.ShapeVar dd)
294 :     in
295 :     [t] --> t
296 :     end))
297 : jhr 3482 val neg_f = polyVar(N.op_neg, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
298 : jhr 3519 val k = Ty.DiffVar(k, 0)
299 :     val d = Ty.DimVar d
300 :     val dd = Ty.ShapeVar dd
301 :     in
302 :     [field(k, d, dd)] --> field(k, d, dd)
303 :     end))
304 : jhr 3391
305 : jhr 4128 (* clamp works on tensors, but there is also the boarder-control clamp function on images;
306 :     * the arguments are (lo, hi, value), which is different than found in OpenCL and OpenGL.
307 :     *)
308 :     val clamp_rrt = polyVar (N.fn_clamp, all([SK,NK], fn [Ty.SHAPE dd, Ty.DIM d] => let
309 : jhr 3830 val t = Ty.T_Tensor(Ty.ShapeExt(Ty.ShapeVar dd, Ty.DimVar d))
310 : jhr 3391 in
311 : jhr 4128 [Ty.realTy, Ty.realTy, t] --> t
312 : jhr 3830 end))
313 :     val clamp_ttt = polyVar (N.fn_clamp, all([SK], fn [Ty.SHAPE dd] => let
314 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
315 :     in
316 : jhr 3391 [t, t, t] --> t
317 :     end))
318 :    
319 : jhr 3482 val lerp3 = polyVar(N.fn_lerp, all([SK], fn [Ty.SHAPE dd] => let
320 : jhr 3519 val t = Ty.T_Tensor(Ty.ShapeVar dd)
321 :     in
322 :     [t, t, Ty.realTy] --> t
323 :     end))
324 : jhr 3482 val lerp5 = polyVar(N.fn_lerp, all([SK], fn [Ty.SHAPE dd] => let
325 : jhr 3519 val t = Ty.T_Tensor(Ty.ShapeVar dd)
326 :     in
327 :     [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
328 :     end))
329 : jhr 3391
330 : cchiw 5242 val clerp3 = polyVar(N.fn_clerp, all([SK], fn [Ty.SHAPE dd] => let
331 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
332 :     in
333 :     [t, t, Ty.realTy] --> t
334 :     end))
335 :     val clerp5 = polyVar(N.fn_clerp, all([SK], fn [Ty.SHAPE dd] => let
336 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
337 :     in
338 :     [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
339 :     end))
340 : jhr 4128
341 : jhr 3391 (* Eigenvalues/vectors of a matrix; we only support this operation on 2x2 and 3x3 matrices, so
342 :     * we overload the function.
343 :     *)
344 :     local
345 : jhr 3398 fun evals d = monoVar (N.fn_evals, [matrix d] --> Ty.T_Sequence(Ty.realTy, SOME d))
346 :     fun evecs d = monoVar (N.fn_evecs, [matrix d] --> Ty.T_Sequence(tensor[d], SOME d))
347 : jhr 3391 in
348 :     val evals2x2 = evals(Ty.DimConst 2)
349 :     val evecs2x2 = evecs(Ty.DimConst 2)
350 :     val evals3x3 = evals(Ty.DimConst 3)
351 :     val evecs3x3 = evecs(Ty.DimConst 3)
352 : jhr 3398 end (* local *)
353 : jhr 3391
354 :     (***** non-overloaded operators, etc. *****)
355 :    
356 :     (* integer modulo *)
357 :     val op_mod = monoVar(N.op_mod, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
358 :    
359 :     (* pseudo-operator for probing a field *)
360 :     val op_probe = polyVar (N.op_at, all([DK, NK, SK],
361 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
362 : jhr 3392 val k = Ty.DiffVar(k, 0)
363 :     val d = Ty.DimVar d
364 :     val dd = Ty.ShapeVar dd
365 :     in
366 :     [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
367 :     end))
368 : jhr 3391
369 :     (* differentiation of scalar fields *)
370 :     val op_D = polyVar (N.op_D, all([DK, NK],
371 :     fn [Ty.DIFF k, Ty.DIM d] => let
372 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
373 :     val km1 = Ty.DiffVar(k, ~1)
374 :     val d = Ty.DimVar d
375 :     in
376 :     [field(k0, d, Ty.Shape[])] --> field(km1, d, Ty.Shape[d])
377 :     end))
378 : jhr 3391
379 :     (* differentiation of higher-order tensor fields *)
380 :     val op_Dotimes = polyVar (N.op_Dotimes, all([DK, NK, SK, NK],
381 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
382 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
383 :     val km1 = Ty.DiffVar(k, ~1)
384 :     val d = Ty.DimVar d
385 :     val d' = Ty.DimVar d'
386 :     val dd = Ty.ShapeVar dd
387 :     in
388 :     [field(k0, d, Ty.ShapeExt(dd, d'))]
389 :     --> field(km1, d, Ty.ShapeExt(Ty.ShapeExt(dd, d'), d))
390 :     end))
391 : jhr 3391
392 :     (* divergence differentiation of higher-order tensor fields *)
393 :     val op_Ddot = polyVar (N.op_Ddot, all([DK, NK, SK, NK],
394 : jhr 3392 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
395 :     val k0 = Ty.DiffVar(k, 0)
396 :     val km1 = Ty.DiffVar(k, ~1)
397 :     val d = Ty.DimVar d
398 :     val d' = Ty.DimVar d'
399 :     val dd' = Ty.ShapeVar dd
400 :     in
401 :     [field(k0, d, Ty.ShapeExt(dd', d'))] --> field(k0, d, dd')
402 :     end))
403 : jhr 3391
404 : jhr 5307 (* vector-norm and absolute value *)
405 :     val op_norm_i = monoVar (N.op_norm, [Ty.T_Int] --> Ty.T_Int)
406 : jhr 3391 val op_norm_t = polyVar (N.op_norm, all([SK],
407 :     fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
408 :     val op_norm_f = polyVar (N.op_norm, all([DK, NK, SK], fn [Ty.DIFF k,Ty.DIM d, Ty.SHAPE dd1] => let
409 : jhr 3392 val k = Ty.DiffVar(k, 0)
410 :     val d = Ty.DimVar d
411 :     val f1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
412 :     val f2 = Ty.T_Field{diff = k, dim = d, shape = Ty.Shape []}
413 :     in
414 :     [f1] --> f2
415 :     end))
416 : jhr 3391
417 : jhr 3464 (* boolean operators; 'and' and 'or' are used to implement reductions *)
418 :     val op_and = monoVar (Atom.atom "$and", [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
419 :     val op_or = monoVar (Atom.atom "$or", [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
420 : jhr 3391 val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
421 :    
422 :     (* cross product *)
423 :     local
424 :     val crossTy = let
425 :     val t = tensor[N3]
426 :     in
427 :     [t, t] --> t
428 :     end
429 :     val crossTy2 = let
430 :     val t = tensor[N2]
431 :     in
432 :     [t, t] --> Ty.realTy
433 :     end
434 :     in
435 :     val op_cross2_tt = monoVar (N.op_cross, crossTy2)
436 : jhr 3807 val op_cross3_tt = monoVar (N.op_cross, crossTy)
437 : jhr 3398 end (* local *)
438 : jhr 3391
439 :     val op_cross2_ff = polyVar (N.op_cross, all([DK], fn [Ty.DIFF k] => let
440 : jhr 3392 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
441 :     val k0 = Ty.DiffVar(k, 0)
442 :     val f = field' (k0, 2, [2])
443 :     val t1 = field' (k0, 2, [])
444 :     in
445 :     [f, f] --> t1
446 :     end))
447 : jhr 3391
448 :     val op_cross3_ff = polyVar (N.op_cross, all([DK], fn [Ty.DIFF k] => let
449 : jhr 3392 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
450 :     val f = field' (Ty.DiffVar(k, 0), 3, [3])
451 :     in
452 :     [f, f] --> f
453 :     end))
454 : jhr 3391
455 : jhr 4002 val op_cross2_ft = polyVar (N.op_cross, all([DK], fn [Ty.DIFF k] => let
456 : cchiw 4000 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
457 :     val k0 = Ty.DiffVar(k, 0)
458 :     val f = field' (k0, 2, [2])
459 :     val t = tensor[N2]
460 :     val t1 = field' (k0, 2, [])
461 :     in
462 :     [f, t] --> t1
463 :     end))
464 :    
465 : jhr 4002 val op_cross2_tf = polyVar (N.op_cross, all([DK], fn [Ty.DIFF k] => let
466 : cchiw 4000 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
467 :     val k0 = Ty.DiffVar(k, 0)
468 :     val f = field' (k0, 2, [2])
469 :     val t = tensor[N2]
470 :     val t1 = field' (k0, 2, [])
471 :     in
472 :     [t, f] --> t1
473 :     end))
474 :    
475 : jhr 4002 val op_cross3_ft = polyVar (N.op_cross, all([DK], fn [Ty.DIFF k] => let
476 : cchiw 4000 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
477 :     val f = field' (Ty.DiffVar(k, 0), 3, [3])
478 :     val t = tensor[N3]
479 :     in
480 :     [f, t] --> f
481 :     end))
482 :    
483 : jhr 4002 val op_cross3_tf = polyVar (N.op_cross, all([DK], fn [Ty.DIFF k] => let
484 : cchiw 4000 fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
485 :     val f = field' (Ty.DiffVar(k, 0), 3, [3])
486 :     val t = tensor[N3]
487 :     in
488 :     [t, f] --> f
489 :     end))
490 :    
491 :    
492 : jhr 3391 (* the inner product operator (including dot product) is treated as a special case in the
493 :     * typechecker. It is not included in the basis environment, but we define its type scheme
494 :     * here. There is an implicit constraint on its type to have the following scheme:
495 :     *
496 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
497 :     *)
498 :     val op_inner_tt = polyVar (N.op_dot, all([SK,SK,SK],
499 : jhr 3392 fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
500 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
501 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
502 : jhr 3518 val op_inner_tf = polyVar (N.op_dot, all([DK, NK, SK, SK, SK],
503 : jhr 3391 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
504 :     val k = Ty.DiffVar(k, 0)
505 :     val d = Ty.DimVar d
506 :     val t1 = Ty.T_Tensor(Ty.ShapeVar dd1)
507 :     val t2 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd2}
508 : jhr 3807 val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
509 : jhr 3391 in
510 :     [t1, t2] --> t3
511 :     end))
512 : jhr 3518 val op_inner_ft = polyVar (N.op_dot, all([DK, NK, SK, SK, SK],
513 : jhr 3391 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
514 :     val k = Ty.DiffVar(k, 0)
515 :     val d = Ty.DimVar d
516 :     val t1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
517 :     val t2 = Ty.T_Tensor(Ty.ShapeVar dd2)
518 : jhr 3807 val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
519 : jhr 3391 in
520 :     [t1, t2] --> t3
521 :     end))
522 : jhr 3518 val op_inner_ff = polyVar (N.op_dot, all([DK, DK, NK, SK, SK, SK],
523 : jhr 3391 fn [Ty.DIFF k1,Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
524 :     val k1 = Ty.DiffVar(k1, 0)
525 :     val k2 = Ty.DiffVar(k2, 0)
526 :     val d = Ty.DimVar d
527 :     val t1 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd1}
528 :     val t2 = Ty.T_Field{diff = k2, dim = d, shape = Ty.ShapeVar dd2}
529 :     val t3 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd3}
530 :     in
531 :     [t1, t2] --> t3
532 :     end))
533 :    
534 : jhr 3807 (* the outer product operator is treated as a special case in the typechecker. It is not
535 :     * included in the basis environment, but we define its type scheme here. There is an
536 :     * implicit constraint on its type to have the following scheme:
537 :     *
538 :     * ALL[sigma1, sigma2] . tensor[sigma1] * tensor[sigma2] -> tensor[sigma1, sigma2]
539 :     *)
540 : jhr 3518 val op_outer_tt = polyVar (N.op_outer, all([SK, SK, SK],
541 : jhr 3519 fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
542 : jhr 3518 [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
543 : jhr 3585 --> Ty.T_Tensor(Ty.ShapeVar s3)))
544 : cchiw 3514 val op_outer_tf = polyVar (N.op_outer, all([DK,NK,SK,SK,SK],
545 : jhr 3519 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
546 : cchiw 3514 val k = Ty.DiffVar(k, 0)
547 :     val d = Ty.DimVar d
548 :     val t1 = Ty.T_Tensor(Ty.ShapeVar dd1)
549 :     val t2 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd2}
550 :     val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
551 :     in
552 : jhr 3519 [t1, t2] --> t3
553 : cchiw 3514 end))
554 : jhr 3518 val op_outer_ft = polyVar (N.op_outer, all([DK, NK, SK, SK, SK],
555 : jhr 3519 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
556 :     val k = Ty.DiffVar(k, 0)
557 :     val d = Ty.DimVar d
558 :     val t1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
559 :     val t2 = Ty.T_Tensor(Ty.ShapeVar dd2)
560 :     val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
561 :     in
562 :     [t1, t2] --> t3
563 :     end))
564 : jhr 3518 val op_outer_ff = polyVar (N.op_outer, all([DK, DK, NK, SK, SK, SK],
565 : jhr 3519 fn [Ty.DIFF k1,Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
566 :     val k1 = Ty.DiffVar(k1, 0)
567 :     val k2 = Ty.DiffVar(k2, 0)
568 :     val d = Ty.DimVar d
569 :     val t1 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd1}
570 :     val t2 = Ty.T_Field{diff = k2, dim = d, shape = Ty.ShapeVar dd2}
571 :     val t3 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd3}
572 :     in
573 :     [t1, t2] --> t3
574 :     end))
575 : cchiw 3514
576 : jhr 3391 (* the colon (or double-dot) product operator is treated as a special case in the
577 :     * typechecker. It is not included in the basis environment, but we define its type
578 :     * scheme here. There is an implicit constraint on its type to have the following scheme:
579 :     *
580 :     * ALL[sigma1, d1, d2, sigma2] .
581 :     * tensor[sigma1, d1, d2] * tensor[d1, d2, sigma2] -> tensor[sigma1, sigma2]
582 :     *)
583 :     val op_colon_tt = polyVar (N.op_colon, all([SK,SK,SK],
584 : jhr 3392 fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
585 :     [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
586 :     --> Ty.T_Tensor(Ty.ShapeVar s3)))
587 : jhr 3391 val op_colon_ff = polyVar (N.op_colon, all([DK,SK,NK,SK,SK],
588 : jhr 3398 fn [Ty.DIFF k,Ty.SHAPE dd1, Ty.DIM d, Ty.SHAPE dd2,Ty.SHAPE dd3] => let
589 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
590 :     val d' = Ty.DimVar d
591 :     val t1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
592 :     val t2 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd2}
593 :     val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
594 :     in
595 :     [t1,t2] --> t3
596 :     end))
597 : jhr 3391 val op_colon_ft = polyVar (N.op_colon, all([DK,SK,NK,SK,SK],
598 : jhr 3398 fn [Ty.DIFF k,Ty.SHAPE dd1, Ty.DIM d, Ty.SHAPE s2,Ty.SHAPE dd3] => let
599 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
600 :     val d' = Ty.DimVar d
601 :     val t1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
602 :     val t2 = Ty.T_Tensor(Ty.ShapeVar s2)
603 :     val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
604 :     in
605 :     [t1, t2] --> t3
606 :     end))
607 : jhr 3391 val op_colon_tf = polyVar (N.op_colon, all([DK,SK,NK,SK,SK],
608 : jhr 3398 fn [Ty.DIFF k,Ty.SHAPE s1, Ty.DIM d, Ty.SHAPE dd2,Ty.SHAPE dd3] => let
609 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
610 :     val d' = Ty.DimVar d
611 :     val t1 = Ty.T_Tensor(Ty.ShapeVar s1)
612 :     val t2 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd2}
613 :     val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
614 :     in
615 :     [t1,t2] --> t3
616 :     end))
617 : jhr 3391
618 :     (* image size operation *)
619 :     val fn_size = polyVar (N.fn_size, all([NK, SK],
620 : jhr 3392 fn [Ty.DIM d, Ty.SHAPE dd] => let
621 : jhr 3391 val d = Ty.DimVar d
622 :     val dd = Ty.ShapeVar dd
623 :     in
624 : jhr 3398 [Ty.T_Image{dim=d, shape=dd}] --> Ty.T_Sequence(Ty.T_Int, SOME d)
625 : jhr 3391 end))
626 :    
627 :     (* functions that handle the boundary behavior of an image *)
628 :     local
629 :     fun img2img f = polyVar (f, all([NK, SK],
630 : jhr 3392 fn [Ty.DIM d, Ty.SHAPE dd] => let
631 :     val imgTy = Ty.T_Image{dim=Ty.DimVar d, shape=Ty.ShapeVar dd}
632 : jhr 3391 in
633 :     [imgTy] --> imgTy
634 :     end))
635 :     in
636 :     val image_border = polyVar (N.fn_border, all([NK, SK],
637 : jhr 3392 fn [Ty.DIM d, Ty.SHAPE dd] => let
638 : jhr 3391 val d = Ty.DimVar d
639 :     val dd = Ty.ShapeVar dd
640 :     in
641 :     [Ty.T_Image{dim=d, shape=dd}, Ty.T_Tensor dd]
642 : jhr 3392 --> Ty.T_Image{dim=d, shape=dd}
643 : jhr 3391 end))
644 :     val image_clamp = img2img N.fn_clamp
645 :     val image_mirror = img2img N.fn_mirror
646 :     val image_wrap = img2img N.fn_wrap
647 :     end (* local *)
648 :    
649 :     (* is a point inside the domain of a field? *)
650 :     val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
651 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
652 :     val k = Ty.DiffVar(k, 0)
653 :     val d = Ty.DimVar d
654 :     val dd = Ty.ShapeVar dd
655 :     in
656 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
657 :     --> Ty.T_Bool
658 :     end))
659 :    
660 :     val fn_length = polyVar (N.fn_length, all([TK],
661 : jhr 3398 fn [Ty.TYPE tv] => [dynSeq(Ty.T_Var tv)] --> Ty.T_Int))
662 : jhr 3391
663 : jhr 3482 val fn_max_i = monoVar (N.fn_max, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
664 :     val fn_max_r = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
665 : jhr 5307
666 : jhr 3482 val fn_min_i = monoVar (N.fn_min, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
667 :     val fn_min_r = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
668 : jhr 3391
669 : cchiw 4700 val fn_modulate_tt = polyVar (N.fn_modulate, all([NK],
670 :     fn [Ty.DIM dd] => let
671 :     (*val t = Ty.T_Tensor(Ty.ShapeVar dd)*)
672 :     val t = tensor [Ty.DimVar dd]
673 : jhr 4286 in
674 :     [t, t] --> t
675 :     end))
676 : jhr 3391
677 : cchiw 4700 val fn_modulate_tf = polyVar (N.fn_modulate, all([DK ,NK, NK],
678 :     fn [Ty.DIFF k, Ty.DIM d, Ty.DIM dd] => let
679 : jhr 4286 val k = Ty.DiffVar(k, 0)
680 :     val d = Ty.DimVar d
681 : cchiw 4700 val t1 = tensor [Ty.DimVar dd]
682 :     val t2 = Ty.T_Field{diff = k, dim = d, shape = Ty.Shape([Ty.DimVar dd])}
683 : jhr 4286 in
684 :     [t1, t2] --> t2
685 :     end))
686 : cchiw 4281
687 : cchiw 4700 val fn_modulate_ft = polyVar (N.fn_modulate, all([DK, NK, NK],
688 :     fn [Ty.DIFF k, Ty.DIM d, Ty.DIM dd] => let
689 : jhr 4286 val k = Ty.DiffVar(k, 0)
690 :     val d = Ty.DimVar d
691 : cchiw 4700 val t1 = Ty.T_Field{diff = k, dim = d, shape = Ty.Shape([Ty.DimVar dd])}
692 :     val t2 = tensor [Ty.DimVar dd]
693 : jhr 4286 in
694 :     [t1, t2] --> t1
695 :     end))
696 : cchiw 4281
697 : cchiw 4700 val fn_modulate_ff = polyVar (N.fn_modulate, all([DK, NK, NK],
698 :     fn [Ty.DIFF k, Ty.DIM d, Ty.DIM dd] => let
699 : jhr 4286 val k0 = Ty.DiffVar(k, 0)
700 :     val d' = Ty.DimVar d
701 : cchiw 4700 val f1 = Ty.T_Field{diff = k0, dim = d', shape =Ty.Shape([Ty.DimVar dd])}
702 : jhr 4286 in
703 :     [f1,f1] --> f1
704 :     end))
705 : cchiw 4281
706 : jhr 3478 val fn_normalize_t = polyVar (N.fn_normalize, all([SK],
707 :     fn [Ty.SHAPE dd] => let
708 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
709 : jhr 3392 in
710 :     [t] --> t
711 :     end))
712 : jhr 3391 val fn_normalize_f = polyVar (N.fn_normalize, all([DK,NK,SK],
713 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1] => let
714 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
715 :     val d' = Ty.DimVar d
716 :     val f1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
717 :     in
718 :     [f1] --> f1
719 :     end))
720 : jhr 3391
721 :     val fn_trace_t = polyVar (N.fn_trace, all([NK],
722 :     fn [Ty.DIM d] => [matrix(Ty.DimVar d)] --> Ty.realTy))
723 :    
724 : jhr 4286 (* generalized dimension *)
725 :     val fn_trace_f = polyVar (N.fn_trace, all([DK, NK, NK, SK],
726 :     fn [Ty.DIFF k, Ty.DIM d1, Ty.DIM d2, Ty.SHAPE dd1] => let
727 :     val k' = Ty.DiffVar(k, 0)
728 :     val dim1 = Ty.DimVar d1
729 :     val dim2 = Ty.DimVar d2
730 :     val dshape = Ty.ShapeVar dd1
731 :     val f = field(k', dim1, Ty.ShapeExt(Ty.ShapeExt(dshape, dim2), dim2))
732 :     val h = field(k', dim1, dshape)
733 :     in
734 :     [f] --> h
735 :     end))
736 : cchiw 4273
737 : jhr 3391 val fn_transpose_t = polyVar (N.fn_transpose, all([NK, NK],
738 :     fn [Ty.DIM d1, Ty.DIM d2] =>
739 :     [tensor[Ty.DimVar d1, Ty.DimVar d2]] --> tensor[Ty.DimVar d2, Ty.DimVar d1]))
740 : jhr 4286
741 : jhr 3807 val fn_transpose_f = polyVar (N.fn_transpose, all([DK,NK,NK,NK],
742 : jhr 3391 fn [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] => let
743 :     val k0 = Ty.DiffVar(k, 0)
744 :     val d' = Ty.DimVar d
745 :     val a' = Ty.DimVar a
746 :     val b' = Ty.DimVar b
747 :     val f = field(k0, d', Ty.Shape[a',b'])
748 :     val h = field(k0, d', Ty.Shape[b',a'])
749 :     in
750 :     [f] --> h
751 :     end))
752 :    
753 :     (* determinant: restrict to 2x2 and 3x3*)
754 :     val fn_det2_t = monoVar (N.fn_det, [matrix N2] --> Ty.realTy)
755 : jhr 4286
756 : jhr 3391 val fn_det3_t = monoVar (N.fn_det, [matrix N3] --> Ty.realTy)
757 : jhr 4286
758 :     val fn_det2_f = polyVar (N.fn_det, all([DK, NK], fn [Ty.DIFF k, Ty.DIM d] => let
759 :     fun field' (k, dd) = field(k, Ty.DimVar d, Ty.Shape(List.map Ty.DimConst dd))
760 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
761 : jhr 4286 val f = field' (k0, [2, 2])
762 :     val s = field' (k0, [])
763 : jhr 3392 in
764 :     [f] --> s
765 :     end))
766 : cchiw 4274
767 : jhr 4286 val fn_det3_f = polyVar (N.fn_det, all([DK, NK], fn [Ty.DIFF k, Ty.DIM d] => let
768 :     fun field' (k, dd) = field(k, Ty.DimVar d, Ty.Shape(List.map Ty.DimConst dd))
769 : jhr 3392 val k0 = Ty.DiffVar(k, 0)
770 : jhr 4286 val f = field' (k0, [3,3])
771 :     val s = field' (k0, [])
772 : jhr 3392 in
773 :     [f] --> s
774 :     end))
775 : jhr 3391
776 : jhr 5421 (* inverse: restrict to 1x1, 2x2, and 3x3 shapes *)
777 :     val fn_inv1_t = monoVar (N.fn_inv, [Ty.realTy] --> Ty.realTy)
778 :    
779 : jhr 4414 val fn_inv2_t = let
780 :     val t = matrix N2
781 :     in
782 :     monoVar (N.fn_inv, [t] --> t)
783 :     end
784 :    
785 : cchiw 5407 val fn_inv3_t = let
786 : jhr 5421 val t = matrix N3
787 :     in
788 :     monoVar (N.fn_inv, [t] --> t)
789 :     end
790 : cchiw 5407
791 :     val fn_inv1_f = polyVar (N.fn_inv, all([DK, NK], fn [Ty.DIFF k, Ty.DIM dim] => let
792 : jhr 5421 fun field' (k, d) = field(k, Ty.DimVar d, Ty.Shape([]))
793 :     val k0 = Ty.DiffVar(k, 0)
794 :     val f = field' (k0, dim)
795 :     in
796 :     [f] --> f
797 :     end))
798 : cchiw 5407
799 : jhr 4414 val fn_inv2_f = polyVar (N.fn_inv, all([DK, NK], fn [Ty.DIFF k, Ty.DIM dim] => let
800 : jhr 5421 fun field' (k, d, dd) = field(k, Ty.DimVar d, Ty.Shape(List.map Ty.DimConst dd))
801 :     val k0 = Ty.DiffVar(k, 0)
802 :     val f = field' (k0, dim, [2, 2])
803 :     in
804 :     [f] --> f
805 :     end))
806 : cchiw 4409
807 : jhr 4414 val fn_inv3_f = polyVar (N.fn_inv, all([DK, NK], fn [Ty.DIFF k, Ty.DIM dim] => let
808 : jhr 5421 fun field' (k, d, dd) = field(k, Ty.DimVar d, Ty.Shape(List.map Ty.DimConst dd))
809 :     val k0 = Ty.DiffVar(k, 0)
810 :     val f = field' (k0, dim, [3, 3])
811 :     in
812 :     [f] --> f
813 :     end))
814 : cchiw 4409
815 :    
816 : jhr 3482 (* lifted unary math functions; these have both real and scalar-field forms *)
817 : jhr 3478 local
818 : jhr 3482 fun fn_r name = monoVar (name, [Ty.realTy] --> Ty.realTy)
819 :     fun fn_s name = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let
820 : jhr 3519 val k' = Ty.DiffVar(k, 0)
821 :     val d' = Ty.DimVar d
822 :     val f = field(k', d', Ty.Shape[])
823 :     in
824 :     [f] --> f
825 :     end))
826 : jhr 3478 in
827 : jhr 3482 val fn_sqrt_r = fn_r N.fn_sqrt
828 :     val fn_sqrt_s = fn_s N.fn_sqrt
829 :     val fn_cos_r = fn_r N.fn_cos
830 :     val fn_cos_s = fn_s N.fn_cos
831 :     val fn_acos_r = fn_r N.fn_acos
832 :     val fn_acos_s = fn_s N.fn_acos
833 :     val fn_sin_r = fn_r N.fn_sin
834 :     val fn_sin_s = fn_s N.fn_sin
835 :     val fn_asin_r = fn_r N.fn_asin
836 :     val fn_asin_s = fn_s N.fn_asin
837 :     val fn_tan_r = fn_r N.fn_tan
838 :     val fn_tan_s = fn_s N.fn_tan
839 :     val fn_atan_r = fn_r N.fn_atan
840 :     val fn_atan_s = fn_s N.fn_atan
841 :     val fn_exp_r = fn_r N.fn_exp
842 :     val fn_exp_s = fn_s N.fn_exp
843 : jhr 3478 end (* local *)
844 : jhr 3391
845 : jhr 3398 (* Math functions that have not yet been lifted to work on fields *)
846 :     local
847 :     fun mk (name, n) =
848 : jhr 3519 monoVar(name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)
849 : jhr 3398 in
850 : jhr 3482 val fn_atan2_rr = mk(N.fn_atan2, 2)
851 :     val fn_ceil_r = mk(N.fn_ceil, 1)
852 : jhr 4298 val fn_erf_r = mk(N.fn_erf, 1)
853 :     val fn_erfc_r = mk(N.fn_erfc, 1)
854 : jhr 3482 val fn_floor_r = mk(N.fn_floor, 1)
855 :     val fn_fmod_rr = mk(N.fn_fmod, 2)
856 :     val fn_log_r = mk(N.fn_log, 1)
857 :     val fn_log10_r = mk(N.fn_log10, 1)
858 :     val fn_log2_r = mk(N.fn_log2, 1)
859 : jhr 3511 val fn_pow_rr = mk(N.fn_pow, 2)
860 : jhr 4298 val fn_round_r = mk(N.fn_round, 1)
861 :     val fn_trunc_r = mk(N.fn_trunc, 1)
862 : jhr 3398 end (* local *)
863 :    
864 : jhr 3807 (* Query functions *)
865 : jhr 3392 local
866 : jhr 3398 val implicit = fn [Ty.TYPE tv] => [Ty.realTy] --> dynSeq(Ty.T_Var tv)
867 :     val realTy = fn [Ty.TYPE tv] => [Ty.realTy, Ty.realTy] --> dynSeq(Ty.T_Var tv)
868 : jhr 3392 val vec2Ty = let
869 :     val t = tensor[N2]
870 :     in
871 : jhr 3398 fn [Ty.TYPE tv] => [t, Ty.realTy] --> dynSeq(Ty.T_Var tv)
872 : jhr 3392 end
873 :     val vec3Ty = let
874 :     val t = tensor[N3]
875 :     in
876 : jhr 3398 fn [Ty.TYPE tv] => [t, Ty.realTy] --> dynSeq(Ty.T_Var tv)
877 : jhr 3392 end
878 :     in
879 : jhr 4349 val fn_sphere_im = polyVar (N.fn_sphere, all([TK], fn [Ty.TYPE tv] =>
880 : jhr 4414 [Ty.realTy] --> dynSeq(Ty.T_Var tv)))
881 : jhr 4349 (* queries with an explicit position *)
882 :     val fn_sphere1_r = polyVar (N.fn_sphere, all([TK], fn [Ty.TYPE tv] =>
883 : jhr 4414 [Ty.realTy, Ty.realTy] --> dynSeq(Ty.T_Var tv)))
884 : jhr 4359 val fn_sphere2_t = polyVar (N.fn_sphere, all([TK], fn [Ty.TYPE tv] =>
885 : jhr 4414 [tensor[N2], Ty.realTy] --> dynSeq(Ty.T_Var tv)))
886 : jhr 4349 val fn_sphere3_t = polyVar (N.fn_sphere, all([TK], fn [Ty.TYPE tv] =>
887 : jhr 4414 [tensor[N3], Ty.realTy] --> dynSeq(Ty.T_Var tv)))
888 : jhr 3392 end (* local *)
889 :    
890 : jhr 3431 (* Sets of strands *)
891 :     local
892 :     fun mkSetFn name = polyVar (name, all([TK], fn [Ty.TYPE tv] => [] --> dynSeq(Ty.T_Var tv)))
893 :     in
894 :     val set_active = mkSetFn N.set_active
895 : jhr 3463 val set_all = mkSetFn N.set_all
896 : jhr 3431 val set_stable = mkSetFn N.set_stable
897 :     end
898 :    
899 : jhr 4588 (* functions for getting the number of strands in a set *)
900 :     local
901 :     fun mkNumberOf name = monoVar (name, [] --> Ty.T_Int)
902 :     in
903 :     val fn_numActive = mkNumberOf N.fn_numActive
904 :     val fn_numStable = mkNumberOf N.fn_numStable
905 :     val fn_numStrands = mkNumberOf N.fn_numStrands
906 :     end (* local *)
907 :    
908 : jhr 3392 (* reduction operators *)
909 :     local
910 :     fun reduction (name, elemTy) =
911 : jhr 3519 monoVar (name, [dynSeq elemTy] --> elemTy)
912 : jhr 3392 in
913 : jhr 3519 val red_all = reduction (N.fn_all, Ty.T_Bool)
914 :     val red_exists = reduction (N.fn_exists, Ty.T_Bool)
915 : jhr 4589 val red_max_i = reduction (N.fn_max, Ty.T_Int)
916 :     val red_max_r = reduction (N.fn_max, Ty.realTy)
917 : jhr 3519 val red_mean = reduction (N.fn_mean, Ty.realTy)
918 : jhr 4589 val red_min_i = reduction (N.fn_min, Ty.T_Int)
919 :     val red_min_r = reduction (N.fn_min, Ty.realTy)
920 : jhr 4588 val red_product_i = reduction (N.fn_product, Ty.T_Int)
921 :     val red_product_r = reduction (N.fn_product, Ty.realTy)
922 :     (* FIXME: allow sum on tensor types *)
923 :     val red_sum_i = reduction (N.fn_sum, Ty.T_Int)
924 :     val red_sum_r = reduction (N.fn_sum, Ty.realTy)
925 : jhr 3519 val red_variance = reduction (N.fn_variance, Ty.realTy)
926 : jhr 3392 end (* local *)
927 :    
928 : jhr 3391 (***** internal variables *****)
929 :    
930 : jhr 4496 (* load image from nrrd *)
931 :     val fn_load_image = polyVar (Atom.atom "$load_image", all([NK, SK],
932 :     fn [Ty.DIM d, Ty.SHAPE dd] => let
933 :     val d = Ty.DimVar d
934 :     val dd = Ty.ShapeVar dd
935 :     in
936 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
937 :     end))
938 :    
939 :     (* load dynamic sequence from nrrd *)
940 :     val fn_load_sequence = polyVar (Atom.atom "$load_seqeunce", all([TK],
941 :     fn [Ty.TYPE tv] => [Ty.T_String] --> dynSeq(Ty.T_Var tv)))
942 :    
943 : jhr 3391 (* integer to real conversion *)
944 :     val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy)
945 :    
946 :     (* identity matrix *)
947 :     val identity = polyVar (Atom.atom "$id", allNK (fn dv => [] --> matrix(Ty.DimVar dv)))
948 :    
949 :     (* zero tensor *)
950 :     val zero = polyVar (Atom.atom "$zero", all ([SK],
951 : jhr 3392 fn [Ty.SHAPE dd] => [] --> Ty.T_Tensor(Ty.ShapeVar dd)))
952 : jhr 3391
953 :     (* NaN tensor *)
954 :     val nan = polyVar (Atom.atom "$nan", all ([SK],
955 : jhr 3392 fn [Ty.SHAPE dd] => [] --> Ty.T_Tensor(Ty.ShapeVar dd)))
956 : jhr 3391
957 :     (* sequence subscript *)
958 :     val subscript = polyVar (Atom.atom "$sub", all ([TK, NK],
959 : jhr 3392 fn [Ty.TYPE tv, Ty.DIM d] =>
960 : jhr 3398 [Ty.T_Sequence(Ty.T_Var tv, SOME(Ty.DimVar d)), Ty.T_Int] --> Ty.T_Var tv))
961 : jhr 3391
962 :     val dynSubscript = polyVar (Atom.atom "$dynsub", all ([TK],
963 : jhr 3398 fn [Ty.TYPE tv] => [dynSeq(Ty.T_Var tv), Ty.T_Int] --> Ty.T_Var tv))
964 : jhr 3391
965 : jhr 4043 (* range expressions *)
966 : jhr 3398 val range = monoVar (Atom.atom "$range", [Ty.T_Int, Ty.T_Int] --> dynSeq Ty.T_Int)
967 : jhr 3391
968 : jhr 4043 (* boolean and *)
969 :     val and_b = monoVar (Atom.atom "$and", [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
970 :    
971 : jhr 3391 end (* local *)
972 : jhr 4359
973 : jhr 3391 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0