1 : |
jhr |
79 |
(* basis-vars.sml
|
2 : |
|
|
*
|
3 : |
|
|
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
|
4 : |
|
|
* All rights reserved.
|
5 : |
|
|
*
|
6 : |
|
|
* This module defines the AST variables for the built in operators and functions.
|
7 : |
|
|
*)
|
8 : |
|
|
|
9 : |
|
|
structure BasisVars =
|
10 : |
|
|
struct
|
11 : |
|
|
local
|
12 : |
|
|
structure N = BasisNames
|
13 : |
|
|
structure Ty = Types
|
14 : |
|
|
structure MV = MetaVar
|
15 : |
|
|
|
16 : |
|
|
fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])
|
17 : |
|
|
infix -->
|
18 : |
|
|
|
19 : |
|
|
val N2 = Ty.DimConst 2
|
20 : |
|
|
val N3 = Ty.DimConst 3
|
21 : |
|
|
|
22 : |
|
|
(* short names for kinds *)
|
23 : |
|
|
val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
|
24 : |
|
|
val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar
|
25 : |
|
|
val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
|
26 : |
|
|
val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
|
27 : |
|
|
|
28 : |
|
|
fun ty t = ([], t)
|
29 : |
|
|
fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
|
30 : |
|
|
val tvs = List.map (fn mk => mk()) kinds
|
31 : |
|
|
in
|
32 : |
|
|
(tvs, mkTy tvs)
|
33 : |
|
|
end
|
34 : |
|
|
fun allNK mkTy = let
|
35 : |
|
|
val tv = MV.newDimVar()
|
36 : |
|
|
in
|
37 : |
|
|
([Ty.DIM tv], mkTy tv)
|
38 : |
|
|
end
|
39 : |
|
|
|
40 : |
|
|
fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
|
41 : |
|
|
fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
|
42 : |
|
|
|
43 : |
|
|
fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
|
44 : |
|
|
fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
|
45 : |
|
|
in
|
46 : |
|
|
|
47 : |
|
|
(* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
|
48 : |
|
|
* two fields with different differentiation levels to be added.
|
49 : |
|
|
*)
|
50 : |
|
|
|
51 : |
|
|
(* overloaded operators; the naming convention is to use the operator name followed
|
52 : |
|
|
* by the argument type signature, where
|
53 : |
|
|
* i -- int
|
54 : |
|
|
* b -- bool
|
55 : |
|
|
* r -- real (tensor[])
|
56 : |
|
|
* t -- tensor[shape]
|
57 : |
|
|
*)
|
58 : |
|
|
|
59 : |
|
|
val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
|
60 : |
|
|
val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
|
61 : |
|
|
val t = Ty.T_Tensor(Ty.ShapeVar dd)
|
62 : |
|
|
in
|
63 : |
|
|
[t, t] --> t
|
64 : |
|
|
end))
|
65 : |
|
|
|
66 : |
|
|
val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
|
67 : |
|
|
val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
|
68 : |
|
|
val t = Ty.T_Tensor(Ty.ShapeVar dd)
|
69 : |
|
|
in
|
70 : |
|
|
[t, t] --> t
|
71 : |
|
|
end))
|
72 : |
|
|
|
73 : |
|
|
(* note that we assume that operators are tested in the order defined here, so that mul_rr
|
74 : |
|
|
* takes precedence over mul_rt and mul_tr!
|
75 : |
|
|
*)
|
76 : |
|
|
val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
|
77 : |
|
|
val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
|
78 : |
|
|
val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
|
79 : |
|
|
val t = Ty.T_Tensor(Ty.ShapeVar dd)
|
80 : |
|
|
in
|
81 : |
|
|
[Ty.realTy, t] --> t
|
82 : |
|
|
end))
|
83 : |
|
|
val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
|
84 : |
|
|
val t = Ty.T_Tensor(Ty.ShapeVar dd)
|
85 : |
|
|
in
|
86 : |
|
|
[t, Ty.realTy] --> t
|
87 : |
|
|
end))
|
88 : |
|
|
|
89 : |
|
|
val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
|
90 : |
|
|
val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
|
91 : |
|
|
val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
|
92 : |
|
|
val t = Ty.T_Tensor(Ty.ShapeVar dd)
|
93 : |
|
|
in
|
94 : |
|
|
[t, Ty.realTy] --> t
|
95 : |
|
|
end))
|
96 : |
|
|
|
97 : |
|
|
val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
98 : |
|
|
val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
|
99 : |
|
|
val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
100 : |
|
|
val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
|
101 : |
|
|
val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
102 : |
|
|
val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
|
103 : |
|
|
val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
104 : |
|
|
val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
|
105 : |
|
|
|
106 : |
|
|
val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
|
107 : |
|
|
val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
108 : |
|
|
val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
|
109 : |
|
|
val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
|
110 : |
|
|
val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
|
111 : |
|
|
val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
112 : |
|
|
val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
|
113 : |
|
|
val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
|
114 : |
|
|
|
115 : |
|
|
|
116 : |
|
|
val neg_i = monoVar(N.op_neg, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
|
117 : |
|
|
val neg_t = polyVar(N.op_neg, all([SK],
|
118 : |
|
|
fn [Ty.SHAPE dd] => let
|
119 : |
|
|
val t = Ty.T_Tensor(Ty.ShapeVar dd)
|
120 : |
|
|
in
|
121 : |
|
|
[t] --> t
|
122 : |
|
|
end))
|
123 : |
|
|
val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
|
124 : |
|
|
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
|
125 : |
|
|
val k = Ty.DiffVar(k, 0)
|
126 : |
|
|
val d = Ty.DimVar d
|
127 : |
|
|
val dd = Ty.ShapeVar dd
|
128 : |
|
|
in
|
129 : |
|
|
[field(k, d, dd)] --> field(k, d, dd)
|
130 : |
|
|
end))
|
131 : |
|
|
|
132 : |
|
|
|
133 : |
|
|
(***** non-overloaded operators, etc. *****)
|
134 : |
|
|
|
135 : |
|
|
val op_at = polyVar (N.op_at, all([DK, NK, SK],
|
136 : |
|
|
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
|
137 : |
|
|
val k = Ty.DiffVar(k, 0)
|
138 : |
|
|
val d = Ty.DimVar d
|
139 : |
|
|
val dd = Ty.ShapeVar dd
|
140 : |
|
|
in
|
141 : |
|
|
[field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
|
142 : |
|
|
end))
|
143 : |
|
|
|
144 : |
|
|
val op_D = polyVar (N.op_D, all([DK, NK, SK],
|
145 : |
|
|
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
|
146 : |
|
|
val k0 = Ty.DiffVar(k, 0)
|
147 : |
|
|
val km1 = Ty.DiffVar(k, ~1)
|
148 : |
|
|
val d = Ty.DimVar d
|
149 : |
|
|
val dd = Ty.ShapeVar dd
|
150 : |
|
|
in
|
151 : |
|
|
[field(k0, d, dd), tensor[d]]
|
152 : |
|
|
--> field(km1, d, Ty.ShapeExt(dd, d))
|
153 : |
|
|
end))
|
154 : |
|
|
|
155 : |
|
|
val op_norm = polyVar (N.op_norm, all([SK],
|
156 : |
|
|
fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
|
157 : |
|
|
|
158 : |
|
|
val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
|
159 : |
|
|
|
160 : |
|
|
|
161 : |
|
|
(* functions *)
|
162 : |
|
|
val fn_CL = polyVar (N.fn_CL, ty([tensor[N3, N3]] --> Ty.vec3Ty))
|
163 : |
|
|
|
164 : |
|
|
val fn_convolve = polyVar (N.fn_convolve, all([DK, NK, SK],
|
165 : |
|
|
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
|
166 : |
|
|
val k = Ty.DiffVar(k, 0)
|
167 : |
|
|
val d = Ty.DimVar d
|
168 : |
|
|
val dd = Ty.ShapeVar dd
|
169 : |
|
|
in
|
170 : |
|
|
[Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
|
171 : |
|
|
--> field(k, d, dd)
|
172 : |
|
|
end))
|
173 : |
|
|
|
174 : |
|
|
val fn_cos = polyVar (N.fn_cos, ty([Ty.realTy] --> Ty.realTy))
|
175 : |
|
|
|
176 : |
|
|
val fn_dot = polyVar (N.fn_dot, allNK(fn tv => [tensor[Ty.DimVar tv]] --> tensor[Ty.DimVar tv]))
|
177 : |
|
|
|
178 : |
|
|
val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
|
179 : |
|
|
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
|
180 : |
|
|
val k = Ty.DiffVar(k, 0)
|
181 : |
|
|
val d = Ty.DimVar d
|
182 : |
|
|
val dd = Ty.ShapeVar dd
|
183 : |
|
|
in
|
184 : |
|
|
[Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
|
185 : |
|
|
--> Ty.T_Bool
|
186 : |
|
|
end))
|
187 : |
|
|
|
188 : |
|
|
val fn_load = polyVar (N.fn_load, all([NK, SK],
|
189 : |
|
|
fn [Ty.DIM d, Ty.SHAPE dd] => let
|
190 : |
|
|
val d = Ty.DimVar d
|
191 : |
|
|
val dd = Ty.ShapeVar dd
|
192 : |
|
|
in
|
193 : |
|
|
[Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
|
194 : |
|
|
end))
|
195 : |
|
|
|
196 : |
|
|
val fn_pow = polyVar (N.fn_pow, ty([Ty.realTy, Ty.realTy] --> Ty.realTy))
|
197 : |
|
|
|
198 : |
|
|
(*
|
199 : |
|
|
val fn_principleEvec = Atom.atom "principleEvec"
|
200 : |
|
|
*)
|
201 : |
|
|
|
202 : |
|
|
val fn_sin = polyVar (N.fn_sin, ty([Ty.realTy] --> Ty.realTy))
|
203 : |
|
|
|
204 : |
|
|
(* kernels *)
|
205 : |
|
|
val kn_bspln3 = polyVar (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2)))
|
206 : |
|
|
val kn_tent = polyVar (N.kn_tent, ty(Ty.T_Kernel(Ty.DiffConst 0)))
|
207 : |
|
|
|
208 : |
|
|
end (* local *)
|
209 : |
|
|
end
|