SCM Repository
Annotation of /trunk/src/basis/basis.sml
Parent Directory
|
Revision Log
Revision 68 - (view) (download)
1 : | jhr | 47 | (* basis.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | * Type definitions for Basis functions. | ||
7 : | *) | ||
8 : | |||
9 : | structure Basis = | ||
10 : | struct | ||
11 : | jhr | 63 | local |
12 : | structure N = BasisNames | ||
13 : | structure Ty = Types | ||
14 : | jhr | 65 | structure TV = TypeVar |
15 : | jhr | 47 | |
16 : | jhr | 63 | fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty]) |
17 : | infix --> | ||
18 : | |||
19 : | val N2 = Ty.NatConst 2 | ||
20 : | val N3 = Ty.NatConst 3 | ||
21 : | |||
22 : | (* short names for kinds *) | ||
23 : | val NK = Ty.TK_NAT | ||
24 : | val SK = Ty.TK_SHAPE | ||
25 : | val TK = Ty.TK_TYPE | ||
26 : | |||
27 : | fun ty t = ([], t) | ||
28 : | fun all (kinds, mkTy) = let | ||
29 : | val tvs = List.map (fn k => TV.new k) kinds | ||
30 : | in | ||
31 : | (tvs, mkTy tvs) | ||
32 : | end | ||
33 : | fun allNK mkTy = let | ||
34 : | val tv = TV.new NK | ||
35 : | in | ||
36 : | ([tv], mkTy tv) | ||
37 : | end | ||
38 : | |||
39 : | fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd} | ||
40 : | fun tensor ds = Ty.T_Tensor(Ty.Shape ds) | ||
41 : | |||
42 : | in | ||
43 : | |||
44 : | jhr | 68 | (* overloaded operators *) |
45 : | val overloads = [ | ||
46 : | (* | ||
47 : | jhr | 63 | val op_add = Atom.atom "+" |
48 : | val op_sub = Atom.atom "-" | ||
49 : | val op_mul = Atom.atom "*" | ||
50 : | val op_div = Atom.atom "/" | ||
51 : | val op_lt = Atom.atom "<" | ||
52 : | val op_lte = Atom.atom "<=" | ||
53 : | val op_eql = Atom.atom "==" | ||
54 : | val op_neq = Atom.atom "!=" | ||
55 : | val op_gte = Atom.atom ">=" | ||
56 : | val op_gt = Atom.atom ">" | ||
57 : | *) | ||
58 : | jhr | 68 | ] |
59 : | |||
60 : | (* non-overloaded operators, etc. *) | ||
61 : | val basis = [ | ||
62 : | (* operators *) | ||
63 : | jhr | 63 | (N.op_at, all([NK, NK, SK], |
64 : | fn [k, d, dd] => let | ||
65 : | val k = Ty.NatVar k | ||
66 : | val d = Ty.NatVar d | ||
67 : | val dd = Ty.ShapeVar dd | ||
68 : | in | ||
69 : | jhr | 68 | [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd |
70 : | end)), | ||
71 : | (N.op_at, all([NK, NK, SK], | ||
72 : | fn [k, d, dd] => let | ||
73 : | val k = Ty.NatVar k | ||
74 : | val d = Ty.NatVar d | ||
75 : | val dd = Ty.ShapeVar dd | ||
76 : | in | ||
77 : | jhr | 63 | [field(k, d, dd), tensor[d]] |
78 : | --> field(Ty.NatExp(k, ~1), d, Ty.ShapeExt(dd, d)) | ||
79 : | end)), | ||
80 : | (N.op_norm, all([SK], | ||
81 : | fn [dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)), | ||
82 : | (* functions *) | ||
83 : | (N.fn_CL, ty([tensor[N3, N3]] --> Ty.vec3Ty)), | ||
84 : | (N.fn_convolve, all([NK, NK, SK], | ||
85 : | fn [k, d, dd] => let | ||
86 : | val k = Ty.NatVar k | ||
87 : | val d = Ty.NatVar d | ||
88 : | val dd = Ty.ShapeVar dd | ||
89 : | in | ||
90 : | [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}] | ||
91 : | --> field(k, d, dd) | ||
92 : | end)), | ||
93 : | (N.fn_cos, ty([Ty.realTy] --> Ty.realTy)), | ||
94 : | (N.fn_dot, allNK(fn tv => [tensor[Ty.NatVar tv]] | ||
95 : | --> tensor[Ty.NatVar tv])), | ||
96 : | jhr | 68 | (N.fn_inside, all([NK, NK, SK], |
97 : | fn [k, d, dd] => let | ||
98 : | val k = Ty.NatVar k | ||
99 : | val d = Ty.NatVar d | ||
100 : | val dd = Ty.ShapeVar dd | ||
101 : | in | ||
102 : | [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)] | ||
103 : | --> Ty.T_Bool | ||
104 : | end)), | ||
105 : | jhr | 63 | (N.fn_load, all([NK, SK], |
106 : | fn [d, dd] => let | ||
107 : | val d = Ty.NatVar d | ||
108 : | val dd = Ty.ShapeVar dd | ||
109 : | in | ||
110 : | [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd} | ||
111 : | end)), | ||
112 : | (N.fn_pow, ty([Ty.realTy, Ty.realTy] --> Ty.realTy)), | ||
113 : | (* | ||
114 : | val fn_principleEvec = Atom.atom "principleEvec" | ||
115 : | *) | ||
116 : | (N.fn_sin, ty([Ty.realTy] --> Ty.realTy)), | ||
117 : | (* kernels *) | ||
118 : | (N.kn_bspln3, ty(Ty.T_Kernel(Ty.NatConst 2))), | ||
119 : | (N.kn_tent, ty(Ty.T_Kernel(Ty.NatConst 0))) | ||
120 : | ] | ||
121 : | |||
122 : | end (* local *) | ||
123 : | jhr | 47 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |