Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/lamont/src/compiler/basis/basis-vars.sml
ViewVC logotype

Annotation of /branches/lamont/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 247 - (view) (download)
Original Path: trunk/src/compiler/basis/basis-vars.sml

1 : jhr 79 (* basis-vars.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * This module defines the AST variables for the built in operators and functions.
7 :     *)
8 :    
9 :     structure BasisVars =
10 :     struct
11 :     local
12 :     structure N = BasisNames
13 :     structure Ty = Types
14 :     structure MV = MetaVar
15 :    
16 : jhr 81 fun --> (tys1, ty) = Ty.T_Fun(tys1, ty)
17 : jhr 79 infix -->
18 :    
19 :     val N2 = Ty.DimConst 2
20 :     val N3 = Ty.DimConst 3
21 :    
22 :     (* short names for kinds *)
23 :     val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
24 : jhr 81 fun DK () = Ty.DIFF(MV.newDiffVar 0)
25 : jhr 79 val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
26 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
27 :    
28 :     fun ty t = ([], t)
29 :     fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
30 :     val tvs = List.map (fn mk => mk()) kinds
31 :     in
32 :     (tvs, mkTy tvs)
33 :     end
34 :     fun allNK mkTy = let
35 :     val tv = MV.newDimVar()
36 :     in
37 :     ([Ty.DIM tv], mkTy tv)
38 :     end
39 :    
40 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
41 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
42 :    
43 :     fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
44 :     fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
45 :     in
46 :    
47 :     (* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
48 :     * two fields with different differentiation levels to be added.
49 :     *)
50 :    
51 :     (* overloaded operators; the naming convention is to use the operator name followed
52 :     * by the argument type signature, where
53 :     * i -- int
54 :     * b -- bool
55 :     * r -- real (tensor[])
56 :     * t -- tensor[shape]
57 :     *)
58 :    
59 :     val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
60 :     val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
61 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
62 :     in
63 :     [t, t] --> t
64 :     end))
65 :    
66 :     val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
67 :     val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
68 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
69 :     in
70 :     [t, t] --> t
71 :     end))
72 :    
73 :     (* note that we assume that operators are tested in the order defined here, so that mul_rr
74 :     * takes precedence over mul_rt and mul_tr!
75 :     *)
76 :     val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
77 :     val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
78 :     val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
79 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
80 :     in
81 :     [Ty.realTy, t] --> t
82 :     end))
83 :     val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
84 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
85 :     in
86 :     [t, Ty.realTy] --> t
87 :     end))
88 :    
89 :     val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
90 :     val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
91 :     val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
92 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
93 :     in
94 :     [t, Ty.realTy] --> t
95 :     end))
96 :    
97 :     val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
98 :     val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
99 :     val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
100 :     val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
101 :     val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
102 :     val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
103 :     val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
104 :     val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
105 :    
106 :     val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
107 :     val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
108 :     val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
109 :     val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
110 :     val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
111 :     val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
112 :     val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
113 :     val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
114 :    
115 :    
116 :     val neg_i = monoVar(N.op_neg, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
117 :     val neg_t = polyVar(N.op_neg, all([SK],
118 :     fn [Ty.SHAPE dd] => let
119 :     val t = Ty.T_Tensor(Ty.ShapeVar dd)
120 :     in
121 :     [t] --> t
122 :     end))
123 :     val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
124 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
125 :     val k = Ty.DiffVar(k, 0)
126 :     val d = Ty.DimVar d
127 :     val dd = Ty.ShapeVar dd
128 :     in
129 :     [field(k, d, dd)] --> field(k, d, dd)
130 :     end))
131 :    
132 :    
133 :     (***** non-overloaded operators, etc. *****)
134 :    
135 :     val op_at = polyVar (N.op_at, all([DK, NK, SK],
136 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
137 :     val k = Ty.DiffVar(k, 0)
138 :     val d = Ty.DimVar d
139 :     val dd = Ty.ShapeVar dd
140 :     in
141 :     [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
142 :     end))
143 :    
144 : jhr 247 (* NOTE: this should be overloaded to allow both v*h and h*v orders *)
145 :     val op_convolve = polyVar (N.op_convolve, all([DK, NK, SK],
146 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
147 :     val k = Ty.DiffVar(k, 0)
148 :     val d = Ty.DimVar d
149 :     val dd = Ty.ShapeVar dd
150 :     in
151 :     [Ty.T_Image{dim=d, shape=dd}, Ty.T_Kernel k]
152 :     --> field(k, d, dd)
153 :     end))
154 :    
155 : jhr 79 val op_D = polyVar (N.op_D, all([DK, NK, SK],
156 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
157 :     val k0 = Ty.DiffVar(k, 0)
158 :     val km1 = Ty.DiffVar(k, ~1)
159 :     val d = Ty.DimVar d
160 :     val dd = Ty.ShapeVar dd
161 :     in
162 : jhr 85 [field(k0, d, dd)]
163 : jhr 79 --> field(km1, d, Ty.ShapeExt(dd, d))
164 :     end))
165 :    
166 :     val op_norm = polyVar (N.op_norm, all([SK],
167 :     fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
168 :    
169 :     val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
170 :    
171 : jhr 86 val op_subscript = polyVar (N.op_subscript, all([SK, NK],
172 :     fn [Ty.SHAPE dd, Ty.DIM d] => let
173 :     val dd = Ty.ShapeVar dd
174 :     val d = Ty.DimVar d
175 :     in
176 :     [Ty.T_Tensor(Ty.ShapeExt(dd, d)), Ty.T_Int]
177 :     --> Ty.T_Tensor dd
178 :     end))
179 : jhr 79
180 :     (* functions *)
181 : jhr 185 val fn_CL = monoVar (N.fn_CL, [tensor[N3, N3]] --> Ty.realTy)
182 : jhr 79
183 : jhr 247 (* the following is depreciated in favor of the infix operator *)
184 : jhr 79 val fn_convolve = polyVar (N.fn_convolve, all([DK, NK, SK],
185 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
186 :     val k = Ty.DiffVar(k, 0)
187 :     val d = Ty.DimVar d
188 :     val dd = Ty.ShapeVar dd
189 :     in
190 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
191 :     --> field(k, d, dd)
192 :     end))
193 :    
194 : jhr 143 val fn_cos = monoVar (N.fn_cos, [Ty.realTy] --> Ty.realTy)
195 : jhr 79
196 : jhr 88 val fn_dot = polyVar (N.fn_dot, allNK(fn tv => let
197 :     val t = tensor[Ty.DimVar tv]
198 :     in
199 : jhr 91 [t, t] --> Ty.realTy
200 : jhr 88 end))
201 : jhr 79
202 : jhr 143 val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
203 : jhr 79 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
204 :     val k = Ty.DiffVar(k, 0)
205 :     val d = Ty.DimVar d
206 :     val dd = Ty.ShapeVar dd
207 :     in
208 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
209 :     --> Ty.T_Bool
210 :     end))
211 :    
212 : jhr 143 val fn_load = polyVar (N.fn_load, all([NK, SK],
213 : jhr 79 fn [Ty.DIM d, Ty.SHAPE dd] => let
214 :     val d = Ty.DimVar d
215 :     val dd = Ty.ShapeVar dd
216 :     in
217 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
218 :     end))
219 :    
220 : jhr 143 val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
221 :     val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
222 :    
223 : jhr 83 val fn_modulate = polyVar (N.fn_modulate, all([NK],
224 :     fn [Ty.DIM d] => let
225 :     val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
226 :     in
227 :     [t, t] --> t
228 :     end))
229 :    
230 : jhr 143 val fn_pow = monoVar (N.fn_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy)
231 : jhr 79
232 : jhr 91 val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK],
233 :     fn [Ty.DIM d] => let
234 :     val d = Ty.DimVar d
235 :     in
236 :     [tensor[d,d]] --> tensor[d]
237 :     end))
238 : jhr 79
239 : jhr 143 val fn_sin = monoVar (N.fn_sin, [Ty.realTy] --> Ty.realTy)
240 : jhr 79
241 :     (* kernels *)
242 : jhr 169 (* FIXME: we should really get the continuity info from the kernels themselves *)
243 : jhr 83 val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2))
244 : jhr 169 val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4))
245 :     val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 2))
246 : jhr 83 val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0))
247 : jhr 79
248 : jhr 83 (* internal variables *)
249 :     val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy) (* integer to real conversion *)
250 : jhr 179 val input = polyVar (Atom.atom "$input", all([TK],
251 : jhr 185 fn [Ty.TYPE tv] => [Ty.T_String] --> Ty.T_Var tv))
252 :     val optInput = polyVar (Atom.atom "$optional-input", all([TK],
253 : jhr 179 fn [Ty.TYPE tv] => [Ty.T_String, Ty.T_Var tv] --> Ty.T_Var tv))
254 : jhr 79 end (* local *)
255 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0