SCM Repository
Annotation of /trunk/src/compiler/basis/basis-vars.sml
Parent Directory
|
Revision Log
Revision 185 - (view) (download)
1 : | jhr | 79 | (* basis-vars.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | * This module defines the AST variables for the built in operators and functions. | ||
7 : | *) | ||
8 : | |||
9 : | structure BasisVars = | ||
10 : | struct | ||
11 : | local | ||
12 : | structure N = BasisNames | ||
13 : | structure Ty = Types | ||
14 : | structure MV = MetaVar | ||
15 : | |||
16 : | jhr | 81 | fun --> (tys1, ty) = Ty.T_Fun(tys1, ty) |
17 : | jhr | 79 | infix --> |
18 : | |||
19 : | val N2 = Ty.DimConst 2 | ||
20 : | val N3 = Ty.DimConst 3 | ||
21 : | |||
22 : | (* short names for kinds *) | ||
23 : | val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar | ||
24 : | jhr | 81 | fun DK () = Ty.DIFF(MV.newDiffVar 0) |
25 : | jhr | 79 | val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar |
26 : | val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar | ||
27 : | |||
28 : | fun ty t = ([], t) | ||
29 : | fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let | ||
30 : | val tvs = List.map (fn mk => mk()) kinds | ||
31 : | in | ||
32 : | (tvs, mkTy tvs) | ||
33 : | end | ||
34 : | fun allNK mkTy = let | ||
35 : | val tv = MV.newDimVar() | ||
36 : | in | ||
37 : | ([Ty.DIM tv], mkTy tv) | ||
38 : | end | ||
39 : | |||
40 : | fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd} | ||
41 : | fun tensor ds = Ty.T_Tensor(Ty.Shape ds) | ||
42 : | |||
43 : | fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty) | ||
44 : | fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme) | ||
45 : | in | ||
46 : | |||
47 : | (* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow | ||
48 : | * two fields with different differentiation levels to be added. | ||
49 : | *) | ||
50 : | |||
51 : | (* overloaded operators; the naming convention is to use the operator name followed | ||
52 : | * by the argument type signature, where | ||
53 : | * i -- int | ||
54 : | * b -- bool | ||
55 : | * r -- real (tensor[]) | ||
56 : | * t -- tensor[shape] | ||
57 : | *) | ||
58 : | |||
59 : | val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int) | ||
60 : | val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let | ||
61 : | val t = Ty.T_Tensor(Ty.ShapeVar dd) | ||
62 : | in | ||
63 : | [t, t] --> t | ||
64 : | end)) | ||
65 : | |||
66 : | val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int) | ||
67 : | val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let | ||
68 : | val t = Ty.T_Tensor(Ty.ShapeVar dd) | ||
69 : | in | ||
70 : | [t, t] --> t | ||
71 : | end)) | ||
72 : | |||
73 : | (* note that we assume that operators are tested in the order defined here, so that mul_rr | ||
74 : | * takes precedence over mul_rt and mul_tr! | ||
75 : | *) | ||
76 : | val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int) | ||
77 : | val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy) | ||
78 : | val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let | ||
79 : | val t = Ty.T_Tensor(Ty.ShapeVar dd) | ||
80 : | in | ||
81 : | [Ty.realTy, t] --> t | ||
82 : | end)) | ||
83 : | val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let | ||
84 : | val t = Ty.T_Tensor(Ty.ShapeVar dd) | ||
85 : | in | ||
86 : | [t, Ty.realTy] --> t | ||
87 : | end)) | ||
88 : | |||
89 : | val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int) | ||
90 : | val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy) | ||
91 : | val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let | ||
92 : | val t = Ty.T_Tensor(Ty.ShapeVar dd) | ||
93 : | in | ||
94 : | [t, Ty.realTy] --> t | ||
95 : | end)) | ||
96 : | |||
97 : | val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
98 : | val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool) | ||
99 : | val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
100 : | val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool) | ||
101 : | val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
102 : | val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool) | ||
103 : | val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
104 : | val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool) | ||
105 : | |||
106 : | val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool) | ||
107 : | val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
108 : | val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool) | ||
109 : | val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool) | ||
110 : | val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool) | ||
111 : | val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
112 : | val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool) | ||
113 : | val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool) | ||
114 : | |||
115 : | |||
116 : | val neg_i = monoVar(N.op_neg, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool) | ||
117 : | val neg_t = polyVar(N.op_neg, all([SK], | ||
118 : | fn [Ty.SHAPE dd] => let | ||
119 : | val t = Ty.T_Tensor(Ty.ShapeVar dd) | ||
120 : | in | ||
121 : | [t] --> t | ||
122 : | end)) | ||
123 : | val neg_f = polyVar(N.op_neg, all([DK, NK, SK], | ||
124 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
125 : | val k = Ty.DiffVar(k, 0) | ||
126 : | val d = Ty.DimVar d | ||
127 : | val dd = Ty.ShapeVar dd | ||
128 : | in | ||
129 : | [field(k, d, dd)] --> field(k, d, dd) | ||
130 : | end)) | ||
131 : | |||
132 : | |||
133 : | (***** non-overloaded operators, etc. *****) | ||
134 : | |||
135 : | val op_at = polyVar (N.op_at, all([DK, NK, SK], | ||
136 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
137 : | val k = Ty.DiffVar(k, 0) | ||
138 : | val d = Ty.DimVar d | ||
139 : | val dd = Ty.ShapeVar dd | ||
140 : | in | ||
141 : | [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd | ||
142 : | end)) | ||
143 : | |||
144 : | val op_D = polyVar (N.op_D, all([DK, NK, SK], | ||
145 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
146 : | val k0 = Ty.DiffVar(k, 0) | ||
147 : | val km1 = Ty.DiffVar(k, ~1) | ||
148 : | val d = Ty.DimVar d | ||
149 : | val dd = Ty.ShapeVar dd | ||
150 : | in | ||
151 : | jhr | 85 | [field(k0, d, dd)] |
152 : | jhr | 79 | --> field(km1, d, Ty.ShapeExt(dd, d)) |
153 : | end)) | ||
154 : | |||
155 : | val op_norm = polyVar (N.op_norm, all([SK], | ||
156 : | fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)) | ||
157 : | |||
158 : | val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool) | ||
159 : | |||
160 : | jhr | 86 | val op_subscript = polyVar (N.op_subscript, all([SK, NK], |
161 : | fn [Ty.SHAPE dd, Ty.DIM d] => let | ||
162 : | val dd = Ty.ShapeVar dd | ||
163 : | val d = Ty.DimVar d | ||
164 : | in | ||
165 : | [Ty.T_Tensor(Ty.ShapeExt(dd, d)), Ty.T_Int] | ||
166 : | --> Ty.T_Tensor dd | ||
167 : | end)) | ||
168 : | jhr | 79 | |
169 : | (* functions *) | ||
170 : | jhr | 185 | val fn_CL = monoVar (N.fn_CL, [tensor[N3, N3]] --> Ty.realTy) |
171 : | jhr | 79 | |
172 : | val fn_convolve = polyVar (N.fn_convolve, all([DK, NK, SK], | ||
173 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
174 : | val k = Ty.DiffVar(k, 0) | ||
175 : | val d = Ty.DimVar d | ||
176 : | val dd = Ty.ShapeVar dd | ||
177 : | in | ||
178 : | [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}] | ||
179 : | --> field(k, d, dd) | ||
180 : | end)) | ||
181 : | |||
182 : | jhr | 143 | val fn_cos = monoVar (N.fn_cos, [Ty.realTy] --> Ty.realTy) |
183 : | jhr | 79 | |
184 : | jhr | 88 | val fn_dot = polyVar (N.fn_dot, allNK(fn tv => let |
185 : | val t = tensor[Ty.DimVar tv] | ||
186 : | in | ||
187 : | jhr | 91 | [t, t] --> Ty.realTy |
188 : | jhr | 88 | end)) |
189 : | jhr | 79 | |
190 : | jhr | 143 | val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK], |
191 : | jhr | 79 | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let |
192 : | val k = Ty.DiffVar(k, 0) | ||
193 : | val d = Ty.DimVar d | ||
194 : | val dd = Ty.ShapeVar dd | ||
195 : | in | ||
196 : | [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)] | ||
197 : | --> Ty.T_Bool | ||
198 : | end)) | ||
199 : | |||
200 : | jhr | 143 | val fn_load = polyVar (N.fn_load, all([NK, SK], |
201 : | jhr | 79 | fn [Ty.DIM d, Ty.SHAPE dd] => let |
202 : | val d = Ty.DimVar d | ||
203 : | val dd = Ty.ShapeVar dd | ||
204 : | in | ||
205 : | [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd} | ||
206 : | end)) | ||
207 : | |||
208 : | jhr | 143 | val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy) |
209 : | val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy) | ||
210 : | |||
211 : | jhr | 83 | val fn_modulate = polyVar (N.fn_modulate, all([NK], |
212 : | fn [Ty.DIM d] => let | ||
213 : | val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d]) | ||
214 : | in | ||
215 : | [t, t] --> t | ||
216 : | end)) | ||
217 : | |||
218 : | jhr | 143 | val fn_pow = monoVar (N.fn_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy) |
219 : | jhr | 79 | |
220 : | jhr | 91 | val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK], |
221 : | fn [Ty.DIM d] => let | ||
222 : | val d = Ty.DimVar d | ||
223 : | in | ||
224 : | [tensor[d,d]] --> tensor[d] | ||
225 : | end)) | ||
226 : | jhr | 79 | |
227 : | jhr | 143 | val fn_sin = monoVar (N.fn_sin, [Ty.realTy] --> Ty.realTy) |
228 : | jhr | 79 | |
229 : | (* kernels *) | ||
230 : | jhr | 169 | (* FIXME: we should really get the continuity info from the kernels themselves *) |
231 : | jhr | 83 | val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2)) |
232 : | jhr | 169 | val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4)) |
233 : | val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 2)) | ||
234 : | jhr | 83 | val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0)) |
235 : | jhr | 79 | |
236 : | jhr | 83 | (* internal variables *) |
237 : | val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy) (* integer to real conversion *) | ||
238 : | jhr | 179 | val input = polyVar (Atom.atom "$input", all([TK], |
239 : | jhr | 185 | fn [Ty.TYPE tv] => [Ty.T_String] --> Ty.T_Var tv)) |
240 : | val optInput = polyVar (Atom.atom "$optional-input", all([TK], | ||
241 : | jhr | 179 | fn [Ty.TYPE tv] => [Ty.T_String, Ty.T_Var tv] --> Ty.T_Var tv)) |
242 : | jhr | 79 | end (* local *) |
243 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |