Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3431 - (view) (download)

1 : jhr 3396 (* check-expr.sml
2 :     *
3 :     * The typechecker for expressions.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure CheckExpr : sig
12 :    
13 : jhr 3407 (* type check an expression *)
14 :     val check : Env.t * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
15 : jhr 3396
16 : jhr 3410 (* type check a list of expressions *)
17 :     val checkList : Env.t * Env.context * ParseTree.expr list -> (AST.expr list * Types.ty list)
18 :    
19 : jhr 3424 (* type check an iteration expression (i.e., "x 'in' expr"), returning the iterator
20 :     * and the environment extended with a binding for x.
21 :     *)
22 :     val checkIter : Env.t * Env.context * ParseTree.iterator -> ((AST.var * AST.expr) * Env.t)
23 :    
24 : jhr 3410 (* type check a dimension that is given by a constant expression *)
25 :     val checkDim : Env.t * Env.context * ParseTree.expr -> IntLit.t option
26 :    
27 :     (* type check a tensor shape, where the dimensions are given by constant expressions *)
28 :     val checkShape : Env.t * Env.context * ParseTree.expr list -> Types.shape
29 :    
30 : jhr 3407 (* `resolveOverload (cxt, rator, tys, args, candidates)` resolves the application of
31 :     * the overloaded operator `rator` to `args`, where `tys` are the types of the arguments
32 :     * and `candidates` is the list of candidate definitions.
33 :     *)
34 :     val resolveOverload : Env.context * Atom.atom * Types.ty list * AST.expr list * Var.t list
35 :     -> (AST.expr * Types.ty)
36 :    
37 : jhr 3396 end = struct
38 :    
39 :     structure PT = ParseTree
40 :     structure L = Literal
41 :     structure E = Env
42 :     structure Ty = Types
43 :     structure BV = BasisVars
44 : jhr 3405 structure TU = TypeUtil
45 : jhr 3396
46 :     (* an expression to return when there is a type error *)
47 : jhr 3405 val bogusExp = AST.E_Lit(L.Int 0)
48 :     val bogusExpTy = (bogusExp, Ty.T_Error)
49 : jhr 3396
50 : jhr 3405 fun err arg = (TypeError.error arg; bogusExpTy)
51 : jhr 3396 val warn = TypeError.warning
52 :    
53 : jhr 3402 datatype token = datatype TypeError.token
54 : jhr 3396
55 : jhr 3407 (* mark a variable use with its location *)
56 : jhr 3413 fun useVar (cxt : Env.context, x) = (x, #2 cxt)
57 : jhr 3407
58 : jhr 3431 (* strip any marks that enclose an expression and return the span and the expression *)
59 :     fun stripMark (_, PT.E_Mark{span, tree}) = stripMark(span, tree)
60 :     | stripMark (span, e) = (span, e)
61 :    
62 : jhr 3407 (* resolve overloading: we use a simple scheme that selects the first operator in the
63 :     * list that matches the argument types.
64 :     *)
65 :     fun resolveOverload (_, rator, _, _, []) = raise Fail(concat[
66 :     "resolveOverload: \"", Atom.toString rator, "\" has no candidates"
67 :     ])
68 :     | resolveOverload (cxt, rator, argTys, args, candidates) = let
69 :     (* FIXME: we could be more efficient by just checking for coercion matchs the first pass
70 :     * and remembering those that are not pure EQ matches.
71 :     *)
72 :     (* try to match candidates while allowing type coercions *)
73 :     fun tryMatchCandidates [] = err(cxt, [
74 :     S "unable to resolve overloaded operator ", A rator, S "\n",
75 :     S " argument type is: ", TYS argTys, S "\n"
76 :     ])
77 :     | tryMatchCandidates (x::xs) = let
78 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
79 :     in
80 :     case Unify.tryMatchArgs (domTy, args, argTys)
81 :     of SOME args' => (AST.E_Prim(x, tyArgs, args', rngTy), rngTy)
82 :     | NONE => tryMatchCandidates xs
83 :     (* end case *)
84 :     end
85 :     fun tryCandidates [] = tryMatchCandidates candidates
86 :     | tryCandidates (x::xs) = let
87 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
88 :     in
89 :     if Unify.tryEqualTypes(domTy, argTys)
90 :     then (AST.E_Prim(x, tyArgs, args, rngTy), rngTy)
91 :     else tryCandidates xs
92 :     end
93 :     in
94 :     tryCandidates candidates
95 :     end
96 :    
97 : jhr 3396 (* check the type of a literal *)
98 :     fun checkLit lit = (case lit
99 :     of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)
100 :     | (L.Real _) => (AST.E_Lit lit, Ty.realTy)
101 :     | (L.String s) => (AST.E_Lit lit, Ty.T_String)
102 :     | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
103 :     (* end case *))
104 :    
105 : jhr 3405 (* type check a dot product, which has the constraint:
106 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
107 :     * and similarly for fields.
108 :     *)
109 :     fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
110 :     (* check the shape of the two arguments to verify that the inner constraint matches *)
111 :     fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
112 :     val (dd1, d1) = let
113 :     fun splitLast (prefix, [d]) = (List.rev prefix, d)
114 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
115 :     | splitLast (_, []) = raise Fail "impossible"
116 :     in
117 :     splitLast ([], dd1)
118 :     end
119 :     in
120 :     if Unify.equalDim(d1, d2)
121 :     then SOME(Ty.Shape(dd1@dd2))
122 :     else NONE
123 :     end
124 :     | chkShape _ = NONE
125 :     fun error () = err (cxt, [
126 :     S "type error for arguments of binary operator '•'\n",
127 :     S " found: ", TYS[ty1, ty2], S "\n"
128 :     ])
129 :     in
130 :     case (TU.prune ty1, TU.prune ty2)
131 :     (* tensor * tensor inner product *)
132 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
133 :     of SOME shp => let
134 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
135 :     val resTy = Ty.T_Tensor shp
136 :     in
137 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
138 : jhr 3407 then (AST.E_Prim(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
139 : jhr 3405 else error()
140 :     end
141 :     | NONE => error()
142 :     (* end case *))
143 :     (* tensor * field inner product *)
144 :     | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
145 :     of SOME shp => let
146 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
147 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
148 :     in
149 :     if Unify.equalTypes(domTy, [ty1, ty2])
150 :     andalso Unify.equalType(rngTy, resTy)
151 : jhr 3407 then (AST.E_Prim(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
152 : jhr 3405 else error()
153 :     end
154 :     | NONE => error()
155 :     (* end case *))
156 :     (* field * tensor inner product *)
157 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
158 :     of SOME shp => let
159 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
160 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
161 :     in
162 :     if Unify.equalTypes(domTy, [ty1, ty2])
163 :     andalso Unify.equalType(rngTy, resTy)
164 : jhr 3407 then (AST.E_Prim(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
165 : jhr 3405 else error()
166 :     end
167 :     | NONE => error()
168 :     (* end case *))
169 :     (* field * field inner product *)
170 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
171 :     case chkShape(s1, s2)
172 :     of SOME shp => let
173 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
174 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
175 :     in
176 :     (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
177 :     if Unify.equalDim(dim1, dim2)
178 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
179 :     andalso Unify.equalType(rngTy, resTy)
180 : jhr 3407 then (AST.E_Prim(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
181 : jhr 3405 else error()
182 :     end
183 :     | NONE => error()
184 :     (* end case *))
185 :     | (ty1, ty2) => error()
186 :     (* end case *)
187 :     end
188 :    
189 :     (* type check a colon product, which has the constraint:
190 :     * ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
191 :     * and similarly for fields.
192 :     *)
193 :     fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
194 :     (* check the shape of the two arguments to verify that the inner constraint matches *)
195 :     fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
196 :     val (dd1, d11, d12) = let
197 :     fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
198 :     | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
199 :     | splitLast2 (_, []) = raise Fail "impossible"
200 :     in
201 :     splitLast2 ([], dd1)
202 :     end
203 :     in
204 :     if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
205 :     then SOME(Ty.Shape(dd1@dd2))
206 :     else NONE
207 :     end
208 :     | chkShape _ = NONE
209 :     fun error () = err (cxt, [
210 :     S "type error for arguments of binary operator \":\"\n",
211 :     S " found: ", TYS[ty1, ty2], S "\n"
212 :     ])
213 :     in
214 :     case (TU.prune ty1, TU.prune ty2)
215 :     (* tensor * tensor colon product *)
216 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
217 :     of SOME shp => let
218 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
219 :     val resTy = Ty.T_Tensor shp
220 :     in
221 :     if Unify.equalTypes(domTy, [ty1, ty2])
222 :     andalso Unify.equalType(rngTy, resTy)
223 : jhr 3407 then (AST.E_Prim(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
224 : jhr 3405 else error()
225 :     end
226 :     | NONE => error()
227 :     (* end case *))
228 :     (* field * tensor colon product *)
229 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
230 :     of SOME shp => let
231 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
232 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
233 :     in
234 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
235 : jhr 3407 then (AST.E_Prim(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
236 : jhr 3405 else error()
237 :     end
238 :     | NONE => error()
239 :     (* end case *))
240 :     (* tensor * field colon product *)
241 :     | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
242 :     of SOME shp => let
243 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
244 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
245 :     in
246 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
247 : jhr 3407 then (AST.E_Prim(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
248 : jhr 3405 else error()
249 :     end
250 :     | NONE => error()
251 :     (* end case *))
252 :     (* field * field colon product *)
253 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
254 :     case chkShape(s1, s2)
255 :     of SOME shp => let
256 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
257 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
258 :     in
259 :     (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
260 :     if Unify.equalDim(dim1, dim2)
261 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
262 :     andalso Unify.equalType(rngTy, resTy)
263 : jhr 3407 then (AST.E_Prim(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
264 : jhr 3405 else error()
265 :     end
266 :     | NONE => error()
267 :     (* end case *))
268 :     | (ty1, ty2) => error()
269 :     (* end case *)
270 :     end
271 :    
272 : jhr 3396 (* check the type of an expression *)
273 :     fun check (env, cxt, e) = (case e
274 : jhr 3405 of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
275 : jhr 3396 | PT.E_Cond(e1, cond, e2) => let
276 :     val eTy1 = check (env, cxt, e1)
277 :     val eTy2 = check (env, cxt, e2)
278 :     in
279 : jhr 3431 case checkAndPrune(env, cxt, cond)
280 : jhr 3396 of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
281 : jhr 3405 of SOME(e1', e2', ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
282 : jhr 3396 | NONE => err (cxt, [
283 :     S "types do not match in conditional expression\n",
284 :     S " true branch: ", TY(#2 eTy1), S "\n",
285 :     S " false branch: ", TY(#2 eTy2)
286 :     ])
287 : jhr 3398 (* end case *))
288 : jhr 3431 | (_, Ty.T_Error) => bogusExpTy
289 : jhr 3396 | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
290 :     (* end case *)
291 :     end
292 :     | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
293 :     of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
294 : jhr 3398 val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
295 : jhr 3396 in
296 : jhr 3407 (AST.E_Prim(BV.range, [], [e1', e2'], resTy), resTy)
297 : jhr 3396 end
298 :     | ((_, Ty.T_Int), (_, ty2)) =>
299 :     err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
300 :     | ((_, ty1), (_, Ty.T_Int)) =>
301 :     err (cxt, [S "expected type 'int' on lhs of '..', but found ", TY ty1])
302 :     | ((_, ty1), (_, ty2)) => err (cxt, [
303 :     S "arguments of '..' must have type 'int', found ",
304 :     TY ty1, S " and ", TY ty2
305 :     ])
306 :     (* end case *))
307 :     | PT.E_OrElse(e1, e2) =>
308 :     checkCondOp (env, cxt, e1, "||", e2,
309 :     fn (e1', e2') => AST.E_Cond(e1', AST.E_Lit(L.Bool true), e2', Ty.T_Bool))
310 :     | PT.E_AndAlso(e1, e2) =>
311 :     checkCondOp (env, cxt, e1, "&&", e2,
312 :     fn (e1', e2') => AST.E_Cond(e1', e2', AST.E_Lit(L.Bool false), Ty.T_Bool))
313 :     | PT.E_BinOp(e1, rator, e2) => let
314 :     val (e1', ty1) = check (env, cxt, e1)
315 :     val (e2', ty2) = check (env, cxt, e2)
316 :     in
317 :     if Atom.same(rator, BasisNames.op_dot)
318 : jhr 3405 then chkInnerProduct (cxt, e1', ty1, e2', ty2)
319 : jhr 3396 else if Atom.same(rator, BasisNames.op_colon)
320 : jhr 3405 then chkColonProduct (cxt, e1', ty1, e2', ty2)
321 :     else (case Env.findFunc (env, rator)
322 : jhr 3396 of Env.PrimFun[rator] => let
323 : jhr 3405 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
324 : jhr 3396 in
325 : jhr 3402 case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
326 : jhr 3407 of SOME args => (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
327 : jhr 3396 | NONE => err (cxt, [
328 : jhr 3418 S "type error for binary operator ", V rator, S "\n",
329 : jhr 3396 S " expected: ", TYS domTy, S "\n",
330 :     S " but found: ", TYS[ty1, ty2]
331 :     ])
332 :     (* end case *)
333 :     end
334 :     | Env.PrimFun ovldList =>
335 :     resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
336 :     | _ => raise Fail "impossible"
337 :     (* end case *))
338 :     end
339 : jhr 3398 | PT.E_UnaryOp(rator, e) => let
340 : jhr 3405 val eTy = check(env, cxt, e)
341 : jhr 3398 in
342 : jhr 3405 case Env.findFunc (env, rator)
343 : jhr 3398 of Env.PrimFun[rator] => let
344 : jhr 3405 val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
345 : jhr 3398 in
346 : jhr 3405 case Util.coerceType (domTy, eTy)
347 : jhr 3410 of SOME e' => (AST.E_Prim(rator, tyArgs, [e'], rngTy), rngTy)
348 : jhr 3398 | NONE => err (cxt, [
349 : jhr 3418 S "type error for unary operator ", V rator, S "\n",
350 : jhr 3398 S " expected: ", TY domTy, S "\n",
351 : jhr 3405 S " but found: ", TY (#2 eTy)
352 : jhr 3398 ])
353 :     (* end case *)
354 :     end
355 : jhr 3405 | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
356 : jhr 3398 | _ => raise Fail "impossible"
357 :     (* end case *)
358 :     end
359 : jhr 3407 | PT.E_Apply(e, args) => let
360 :     val (args, tys) = checkList (env, cxt, args)
361 :     fun appTyError (f, paramTys, argTys) = err(cxt, [
362 :     S "type error in application of ", V f, S "\n",
363 :     S " expected: ", TYS paramTys, S "\n",
364 :     S " but found: ", TYS argTys
365 :     ])
366 :     fun checkPrimApp f = if Var.isPrim f
367 :     then (case TU.instantiate(Var.typeOf f)
368 :     of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
369 :     case Unify.matchArgs (domTy, args, tys)
370 :     of SOME args => (AST.E_Prim(f, tyArgs, args, rngTy), rngTy)
371 :     | NONE => appTyError (f, domTy, tys)
372 :     (* end case *))
373 :     | _ => err(cxt, [S "application of non-function/field ", V f])
374 :     (* end case *))
375 :     else raise Fail "unexpected user function"
376 :     (* check the application of a user-defined function *)
377 :     fun checkFunApp (cxt, f) = if Var.isPrim f
378 :     then raise Fail "unexpected primitive function"
379 :     else (case Var.monoTypeOf f
380 :     of Ty.T_Fun(domTy, rngTy) => (
381 :     case Unify.matchArgs (domTy, args, tys)
382 :     of SOME args => (AST.E_Apply(useVar(cxt, f), args, rngTy), rngTy)
383 :     | NONE => appTyError (f, domTy, tys)
384 :     (* end case *))
385 :     | _ => err(cxt, [S "application of non-function/field ", V f])
386 :     (* end case *))
387 :     fun checkFieldApp (e1', ty1) = (case (args, tys)
388 :     of ([e2'], [ty2]) => let
389 :     val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
390 :     TU.instantiate(Var.typeOf BV.op_probe)
391 :     fun tyError () = err (cxt, [
392 :     S "type error for field application\n",
393 :     S " expected: ", TYS[fldTy, domTy], S "\n",
394 :     S " but found: ", TYS[ty1, ty2]
395 :     ])
396 :     in
397 :     if Unify.equalType(fldTy, ty1)
398 :     then (case Util.coerceType(domTy, (e2', ty2))
399 : jhr 3410 of SOME e2' => (AST.E_Prim(BV.op_probe, tyArgs, [e1', e2'], rngTy), rngTy)
400 : jhr 3407 | NONE => tyError()
401 :     (* end case *))
402 :     else tyError()
403 :     end
404 :     | _ => err(cxt, [S "badly formed field application"])
405 :     (* end case *))
406 :     in
407 :     case stripMark(#2 cxt, e)
408 :     of (span, PT.E_Var f) => (case Env.findVar (env, f)
409 :     of SOME f' => checkFieldApp (
410 :     AST.E_Var(useVar((#1 cxt, span), f')),
411 :     Var.monoTypeOf f')
412 :     | NONE => (case Env.findFunc (env, f)
413 :     of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
414 :     | Env.PrimFun[f'] => checkPrimApp f'
415 :     | Env.PrimFun ovldList =>
416 :     resolveOverload ((#1 cxt, span), f, tys, args, ovldList)
417 :     | Env.UserFun f' => checkFunApp((#1 cxt, span), f')
418 :     (* end case *))
419 :     (* end case *))
420 :     | _ => checkFieldApp (check (env, cxt, e))
421 :     (* end case *)
422 :     end
423 : jhr 3424 | PT.E_Subscript(e, indices) => let
424 :     fun expectedTensor ty = err(cxt, [
425 :     S "expected tensor type for slicing, but found ", TY ty
426 : jhr 3398 ])
427 : jhr 3424 fun chkIndex e = let
428 :     val eTy as (_, ty) = check(env, cxt, e)
429 :     in
430 :     if Unify.equalType(ty, Ty.T_Int)
431 :     then eTy
432 :     else err (cxt, [
433 :     S "expected type 'int' for index, but found ", TY ty
434 :     ])
435 :     end
436 :     in
437 :     case (check(env, cxt, e), indices)
438 :     of ((e', Ty.T_Error), _) => (
439 :     List.app (ignore o Option.map chkIndex) indices;
440 :     bogusExpTy)
441 :     | ((e1', ty1 as Ty.T_Sequence(elemTy, optDim)), [SOME e2]) => let
442 :     val (e2', ty2) = chkIndex e2
443 :     val rator = if isSome optDim
444 :     then BV.subscript
445 :     else BV.dynSubscript
446 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
447 :     in
448 :     if Unify.equalTypes(domTy, [ty1, ty2])
449 :     then let
450 :     val exp = AST.E_Prim(rator, tyArgs, [e1', e2'], rngTy)
451 :     in
452 :     (exp, rngTy)
453 :     end
454 :     else raise Fail "unexpected unification failure"
455 :     end
456 :     | ((e', ty as Ty.T_Sequence _), [NONE]) => expectedTensor ty
457 :     | ((e', ty as Ty.T_Sequence _), _) => expectedTensor ty
458 : jhr 3427 | ((e', ty as Ty.T_Tensor shape), _) => let
459 :     val indices' = List.map (Option.map (#1 o chkIndex)) indices
460 :     val order = List.length indices'
461 :     val expectedTy = TU.mkTensorTy order
462 :     val resultTy = TU.slice(expectedTy, List.map Option.isSome indices')
463 :     in
464 :     if Unify.equalType(ty, expectedTy)
465 :     then (AST.E_Slice(e', indices', resultTy), resultTy)
466 :     else err (cxt, [
467 :     S "type error in slice operation\n",
468 :     S " expected: ", S(Int.toString order), S "-order tensor\n",
469 :     S " but found: ", TY ty
470 :     ])
471 :     end
472 : jhr 3424 | ((_, ty), _) => expectedTensor ty
473 :     (* end case *)
474 :     end
475 : jhr 3431 | PT.E_Select(e, field) => (case stripMark(#2 cxt, e)
476 :     of (_, PT.E_Var x) => (case E.findStrand (env, x)
477 :     of SOME _ => if E.inGlobalUpdate env
478 :     then (case E.findSetFn (env, field)
479 :     of SOME setFn => let
480 :     val (mvs, ty) = TU.instantiate (Var.typeOf setFn)
481 :     val resTy = Ty.T_Sequence(Ty.T_Named x, NONE)
482 :     in
483 :     E.recordProp (env, Properties.StrandSets);
484 :     if Unify.equalType(ty, Ty.T_Fun([], resTy))
485 :     then (AST.E_Prim(setFn, mvs, [], resTy), resTy)
486 :     else raise Fail "impossible"
487 :     end
488 :     | _ => err (cxt, [
489 :     S "unknown strand-set specifier ", A field
490 :     ])
491 :     (* end case *))
492 :     else err (cxt, [
493 :     S "illegal strand set specification in ",
494 :     S(E.scopeToString(E.currentScope env))
495 :     ])
496 :     | _ => checkSelect (env, cxt, e, field)
497 :     (* end case *))
498 :     | _ => checkSelect (env, cxt, e, field)
499 :     (* end case *))
500 :     | PT.E_Real e => (case checkAndPrune (env, cxt, e)
501 : jhr 3396 of (e', Ty.T_Int) =>
502 : jhr 3407 (AST.E_Prim(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
503 : jhr 3428 | (e', Ty.T_Error) => bogusExpTy
504 : jhr 3396 | (_, ty) => err(cxt, [
505 :     S "argument of 'real' must have type 'int', but found ",
506 :     TY ty
507 :     ])
508 :     (* end case *))
509 :     | PT.E_Load nrrd => let
510 : jhr 3418 val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load))
511 : jhr 3396 in
512 : jhr 3407 case chkStringConstExpr (env, cxt, nrrd)
513 :     of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
514 :     | NONE => (bogusExp, rngTy)
515 :     (* end case *)
516 : jhr 3396 end
517 :     | PT.E_Image nrrd => let
518 : jhr 3418 val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_image))
519 : jhr 3396 in
520 : jhr 3407 case chkStringConstExpr (env, cxt, nrrd)
521 :     of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
522 :     | NONE => (bogusExp, rngTy)
523 :     (* end case *)
524 : jhr 3396 end
525 : jhr 3405 | PT.E_Var x => (case E.findVar (env, x)
526 : jhr 3407 of SOME x' => (AST.E_Var(useVar(cxt, x')), Var.monoTypeOf x')
527 : jhr 3396 | NONE => err(cxt, [S "undeclared variable ", A x])
528 :     (* end case *))
529 : jhr 3421 | PT.E_Kernel(kern, dim) => (case E.findVar (env, kern)
530 :     of SOME kern' => (case Var.monoTypeOf kern'
531 :     of ty as Ty.T_Kernel(Ty.DiffConst k) => let
532 :     val k' = Int.fromLarge dim handle Overflow => 1073741823
533 :     val e = AST.E_Var(useVar(cxt, kern'))
534 :     in
535 :     if (k = k')
536 :     then (e, ty)
537 :     else let
538 :     val ty' = Ty.T_Kernel(Ty.DiffConst k')
539 :     in
540 :     (AST.E_Coerce{srcTy = ty, dstTy = ty', e = e}, ty')
541 :     end
542 :     end
543 :     | _ => err(cxt, [S "expected kernel, but found ", S(Var.kindToString kern')])
544 :     (* end case *))
545 :     | NONE => err(cxt, [S "unknown kernel ", A kern])
546 :     (* end case *))
547 : jhr 3396 | PT.E_Lit lit => checkLit lit
548 :     | PT.E_Id d => let
549 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
550 : jhr 3405 TU.instantiate(Var.typeOf(BV.identity))
551 : jhr 3396 in
552 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, [d, d])), rngTy)
553 :     then (AST.E_Prim(BV.identity, tyArgs, [], rngTy), rngTy)
554 : jhr 3396 else raise Fail "impossible"
555 :     end
556 :     | PT.E_Zero dd => let
557 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
558 : jhr 3405 TU.instantiate(Var.typeOf(BV.zero))
559 : jhr 3396 in
560 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
561 :     then (AST.E_Prim(BV.zero, tyArgs, [], rngTy), rngTy)
562 : jhr 3396 else raise Fail "impossible"
563 :     end
564 :     | PT.E_NaN dd => let
565 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
566 : jhr 3405 TU.instantiate(Var.typeOf(BV.nan))
567 : jhr 3396 in
568 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
569 :     then (AST.E_Prim(BV.nan, tyArgs, [], rngTy), rngTy)
570 : jhr 3396 else raise Fail "impossible"
571 :     end
572 : jhr 3422 | PT.E_Sequence exps => (case checkList (env, cxt, exps)
573 :     (* FIXME: need kind for concrete types here! *)
574 :     of ([], _) => let
575 :     val ty = Ty.T_Sequence(Ty.T_Var(MetaVar.newTyVar()), SOME(Ty.DimConst 0))
576 :     in
577 :     (AST.E_Seq([], ty), ty)
578 :     end
579 :     | (args, tys) => (case Util.coerceTypes(List.map TU.pruneHead tys)
580 :     of SOME ty => if TU.isValueType ty
581 :     then let
582 :     fun doExp eTy = valOf(Util.coerceType (ty, eTy))
583 :     val resTy = Ty.T_Sequence(ty, SOME(Ty.DimConst(List.length args)))
584 :     val args = ListPair.map doExp (args, tys)
585 :     in
586 :     (AST.E_Seq(args, resTy), resTy)
587 :     end
588 :     else err(cxt, [S "sequence expression of non-value argument type"])
589 :     | NONE => err(cxt, [S "arguments of sequence expression must have same type"])
590 :     (* end case *))
591 :     (* end case *))
592 : jhr 3428 | PT.E_SeqComp comp => chkComprehension (env, cxt, comp)
593 : jhr 3396 | PT.E_Cons args => let
594 :     (* Note that we are guaranteed that args is non-empty *)
595 :     val (args, tys) = checkList (env, cxt, args)
596 :     (* extract the first non-error type in tys *)
597 :     val ty = (case List.find (fn Ty.T_Error => false | _ => true) tys
598 :     of NONE => Ty.T_Error
599 :     | SOME ty => ty
600 :     (* end case *))
601 : jhr 3405 (* process the arguments checking that they all have the expected type *)
602 :     fun chkArgs (ty, shape) = let
603 :     val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
604 :     val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
605 :     fun chkArgs (arg::args, argTy::tys, args') = (
606 :     case Util.coerceType(ty, (arg, argTy))
607 : jhr 3410 of SOME arg' => chkArgs (args, tys, arg'::args')
608 : jhr 3405 | NONE => (
609 :     TypeError.error(cxt, [
610 :     S "arguments of tensor construction must have same type"
611 :     ]);
612 :     chkArgs (args, tys, bogusExp::args'))
613 :     (* end case *))
614 : jhr 3408 | chkArgs (_, _, args') = (AST.E_Tensor(List.rev args', resTy), resTy)
615 : jhr 3405 in
616 :     chkArgs (args, tys, [])
617 :     end
618 : jhr 3396 in
619 : jhr 3405 case TU.pruneHead ty
620 : jhr 3407 of Ty.T_Int => chkArgs(Ty.realTy, Ty.Shape[]) (* coerce integers to reals *)
621 : jhr 3405 | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
622 : jhr 3396 | _ => err(cxt, [S "Invalid argument type for tensor construction"])
623 :     (* end case *)
624 :     end
625 :     | PT.E_Deprecate(msg, e) => (
626 :     warn (cxt, [S msg]);
627 : jhr 3402 check (env, cxt, e))
628 : jhr 3396 (* end case *))
629 :    
630 : jhr 3431 (* typecheck and the prune the result *)
631 :     and checkAndPrune (env, cxt, e) = let
632 :     val (e, ty) = check (env, cxt, e)
633 :     in
634 :     (e, TU.prune ty)
635 :     end
636 :    
637 : jhr 3396 (* check a conditional operator (e.g., || or &&) *)
638 :     and checkCondOp (env, cxt, e1, rator, e2, mk) = (
639 :     case (check(env, cxt, e1), check(env, cxt, e2))
640 :     of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
641 :     | ((_, Ty.T_Bool), (_, ty2)) =>
642 : jhr 3405 err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
643 : jhr 3396 | ((_, ty1), (_, Ty.T_Bool)) =>
644 : jhr 3405 err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
645 : jhr 3396 | ((_, ty1), (_, ty2)) => err (cxt, [
646 : jhr 3405 S "arguments of '", S rator, S "' must have type 'bool', but found ",
647 : jhr 3396 TY ty1, S " and ", TY ty2
648 :     ])
649 :     (* end case *))
650 :    
651 : jhr 3431 (* check a field select that is _not_ a strand-set *)
652 :     and checkSelect (env, cxt, e, field) = (case checkAndPrune (env, cxt, e)
653 :     of (e', Ty.T_Named strand) => (case Env.findStrand(env, strand)
654 :     of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
655 :     of SOME x' => let
656 :     val ty = Var.monoTypeOf x'
657 :     in
658 :     (AST.E_Select(e', useVar(cxt, x')), ty)
659 :     end
660 :     | NONE => err(cxt, [
661 :     S "strand ", A strand,
662 :     S " does not have state variable ", A field
663 :     ])
664 :     (* end case *))
665 :     | NONE => err(cxt, [S "unknown strand ", A strand])
666 :     (* end case *))
667 :     | (_, Ty.T_Error) => bogusExpTy
668 :     | (_, ty) => err (cxt, [
669 :     S "expected strand type, but found ", TY ty,
670 :     S " in selection of ", A field
671 :     ])
672 :     (* end case *))
673 :    
674 : jhr 3424 and chkComprehension (env, cxt, PT.COMP_Mark m) =
675 :     chkComprehension(E.withEnvAndContext(env, cxt, m))
676 :     | chkComprehension (env, cxt, PT.COMP_Comprehension(e, [iter])) = let
677 :     val (iter', env') = checkIter (E.blockScope env, cxt, iter)
678 :     val (e', ty) = check (env', cxt, e)
679 :     val resTy = Ty.T_Sequence(ty, NONE)
680 :     in
681 :     (AST.E_Comprehension(e', iter', resTy), resTy)
682 :     end
683 :     | chkComprehension _ = raise Fail "impossible"
684 :    
685 :     and checkIter (env, cxt, PT.I_Mark m) = checkIter (E.withEnvAndContext (env, cxt, m))
686 :     | checkIter (env, cxt, PT.I_Iterator({span, tree=x}, e)) = (
687 : jhr 3431 case checkAndPrune (env, cxt, e)
688 : jhr 3424 of (e', ty as Ty.T_Sequence(elemTy, _)) => let
689 :     val x' = Var.new(x, Error.location(#1 cxt, span), Var.LocalVar, elemTy)
690 :     in
691 :     ((x', e'), E.insertLocal(env, cxt, x, x'))
692 :     end
693 :     | (e', ty) => let
694 :     val x' = Var.new(x, Error.UNKNOWN, Var.IterVar, Ty.T_Error)
695 :     in
696 : jhr 3431 if TU.isErrorType ty
697 :     then ()
698 :     else TypeError.error (cxt, [
699 :     S "expected sequence type in iteration, but found '", TY ty, S "'"
700 :     ]);
701 : jhr 3424 ((x', bogusExp), E.insertLocal(env, cxt, x, x'))
702 :     end
703 :     (* end case *))
704 :    
705 : jhr 3396 (* typecheck a list of expressions returning a list of AST expressions and a list
706 :     * of the types of the expressions.
707 :     *)
708 :     and checkList (env, cxt, exprs) = let
709 :     fun chk (e, (es, tys)) = let
710 : jhr 3431 val (e, ty) = checkAndPrune (env, cxt, e)
711 : jhr 3396 in
712 :     (e::es, ty::tys)
713 :     end
714 :     in
715 :     List.foldr chk ([], []) exprs
716 :     end
717 :    
718 : jhr 3407 (* check a string that is specified as a constant expression *)
719 :     and chkStringConstExpr (env, cxt, PT.E_Mark m) =
720 :     chkStringConstExpr (E.withEnvAndContext (env, cxt, m))
721 : jhr 3431 | chkStringConstExpr (env, cxt, e) = (case checkAndPrune (env, cxt, e)
722 : jhr 3407 of (e', Ty.T_String) => (case ConstExpr.eval (cxt, e')
723 :     of SOME(ConstExpr.String s) => SOME s
724 :     | SOME(ConstExpr.Expr e) => raise Fail "FIXME"
725 :     | NONE => NONE
726 :     | _ => raise Fail "impossible: wrong type for constant expr"
727 :     (* end case *))
728 : jhr 3431 | (_, Ty.T_Error) => NONE
729 : jhr 3407 | (_, ty) => (
730 :     TypeError.error (cxt, [
731 :     S "expected constant expression of type 'string', but found '",
732 :     TY ty, S "'"
733 :     ]);
734 :     NONE)
735 :     (* end case *))
736 :    
737 :     (* check a dimension that is given by a constant expression *)
738 : jhr 3431 and checkDim (env, cxt, dim) = (case checkAndPrune (env, cxt, dim)
739 : jhr 3407 of (e', Ty.T_Int) => (case ConstExpr.eval (cxt, e')
740 :     of SOME(ConstExpr.Int d) => SOME d
741 :     | SOME(ConstExpr.Expr e) => (
742 :     TypeError.error (cxt, [S "unable to evaluate constant dimension expression"]);
743 :     NONE)
744 :     | NONE => NONE
745 :     | _ => raise Fail "impossible: wrong type for constant expr"
746 :     (* end case *))
747 : jhr 3431 | (_, Ty.T_Error) => NONE
748 : jhr 3407 | (_, ty) => (
749 :     TypeError.error (cxt, [
750 :     S "expected constant expression of type 'int', but found '",
751 :     TY ty, S "'"
752 :     ]);
753 :     NONE)
754 :     (* end case *))
755 :    
756 :     (* check a tensor shape, where the dimensions are given by constant expressions *)
757 :     and checkShape (env, cxt, shape) = let
758 :     fun checkDim' e = (case checkDim (env, cxt, e)
759 :     of SOME d => (
760 :     if (d <= 1)
761 :     then TypeError.error (cxt, [
762 :     S "invalid tensor-shape dimension; must be > 1, but found ",
763 :     S (IntLit.toString d)
764 :     ])
765 :     else ();
766 :     Ty.DimConst(IntInf.toInt d))
767 :     | NONE => Ty.DimConst ~1
768 :     (* end case *))
769 :     in
770 :     Ty.Shape(List.map checkDim' shape)
771 :     end
772 :    
773 : jhr 3396 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0