Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3408 - (view) (download)

1 : jhr 3396 (* check-expr.sml
2 :     *
3 :     * The typechecker for expressions.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure CheckExpr : sig
12 :    
13 : jhr 3407 (* type check an expression *)
14 :     val check : Env.t * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
15 : jhr 3396
16 : jhr 3407 (* `resolveOverload (cxt, rator, tys, args, candidates)` resolves the application of
17 :     * the overloaded operator `rator` to `args`, where `tys` are the types of the arguments
18 :     * and `candidates` is the list of candidate definitions.
19 :     *)
20 :     val resolveOverload : Env.context * Atom.atom * Types.ty list * AST.expr list * Var.t list
21 :     -> (AST.expr * Types.ty)
22 :    
23 :     (* check a dimension that is given by a constant expression *)
24 :     val checkDim : Env.t * Env.context * ParseTree.expr -> IntLit.t option
25 :    
26 :     (* check a tensor shape, where the dimensions are given by constant expressions *)
27 :     val checkShape : Env.t * Env.context * ParseTree.expr list -> Types.shape
28 :    
29 : jhr 3396 end = struct
30 :    
31 :     structure PT = ParseTree
32 :     structure L = Literal
33 :     structure E = Env
34 :     structure Ty = Types
35 :     structure BV = BasisVars
36 : jhr 3405 structure TU = TypeUtil
37 : jhr 3396
38 :     (* an expression to return when there is a type error *)
39 : jhr 3405 val bogusExp = AST.E_Lit(L.Int 0)
40 :     val bogusExpTy = (bogusExp, Ty.T_Error)
41 : jhr 3396
42 : jhr 3405 fun err arg = (TypeError.error arg; bogusExpTy)
43 : jhr 3396 val warn = TypeError.warning
44 :    
45 : jhr 3402 datatype token = datatype TypeError.token
46 : jhr 3396
47 : jhr 3407 (* mark a variable use with its location *)
48 :     fun useVar (cxt, x) = (x, Error.location cxt)
49 :    
50 :     (* resolve overloading: we use a simple scheme that selects the first operator in the
51 :     * list that matches the argument types.
52 :     *)
53 :     fun resolveOverload (_, rator, _, _, []) = raise Fail(concat[
54 :     "resolveOverload: \"", Atom.toString rator, "\" has no candidates"
55 :     ])
56 :     | resolveOverload (cxt, rator, argTys, args, candidates) = let
57 :     (* FIXME: we could be more efficient by just checking for coercion matchs the first pass
58 :     * and remembering those that are not pure EQ matches.
59 :     *)
60 :     (* try to match candidates while allowing type coercions *)
61 :     fun tryMatchCandidates [] = err(cxt, [
62 :     S "unable to resolve overloaded operator ", A rator, S "\n",
63 :     S " argument type is: ", TYS argTys, S "\n"
64 :     ])
65 :     | tryMatchCandidates (x::xs) = let
66 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
67 :     in
68 :     case Unify.tryMatchArgs (domTy, args, argTys)
69 :     of SOME args' => (AST.E_Prim(x, tyArgs, args', rngTy), rngTy)
70 :     | NONE => tryMatchCandidates xs
71 :     (* end case *)
72 :     end
73 :     fun tryCandidates [] = tryMatchCandidates candidates
74 :     | tryCandidates (x::xs) = let
75 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
76 :     in
77 :     if Unify.tryEqualTypes(domTy, argTys)
78 :     then (AST.E_Prim(x, tyArgs, args, rngTy), rngTy)
79 :     else tryCandidates xs
80 :     end
81 :     in
82 :     tryCandidates candidates
83 :     end
84 :    
85 : jhr 3396 (* check the type of a literal *)
86 :     fun checkLit lit = (case lit
87 :     of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)
88 :     | (L.Real _) => (AST.E_Lit lit, Ty.realTy)
89 :     | (L.String s) => (AST.E_Lit lit, Ty.T_String)
90 :     | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
91 :     (* end case *))
92 :    
93 : jhr 3405 (* type check a dot product, which has the constraint:
94 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
95 :     * and similarly for fields.
96 :     *)
97 :     fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
98 :     (* check the shape of the two arguments to verify that the inner constraint matches *)
99 :     fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
100 :     val (dd1, d1) = let
101 :     fun splitLast (prefix, [d]) = (List.rev prefix, d)
102 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
103 :     | splitLast (_, []) = raise Fail "impossible"
104 :     in
105 :     splitLast ([], dd1)
106 :     end
107 :     in
108 :     if Unify.equalDim(d1, d2)
109 :     then SOME(Ty.Shape(dd1@dd2))
110 :     else NONE
111 :     end
112 :     | chkShape _ = NONE
113 :     fun error () = err (cxt, [
114 :     S "type error for arguments of binary operator '•'\n",
115 :     S " found: ", TYS[ty1, ty2], S "\n"
116 :     ])
117 :     in
118 :     case (TU.prune ty1, TU.prune ty2)
119 :     (* tensor * tensor inner product *)
120 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
121 :     of SOME shp => let
122 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
123 :     val resTy = Ty.T_Tensor shp
124 :     in
125 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
126 : jhr 3407 then (AST.E_Prim(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
127 : jhr 3405 else error()
128 :     end
129 :     | NONE => error()
130 :     (* end case *))
131 :     (* tensor * field inner product *)
132 :     | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
133 :     of SOME shp => let
134 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
135 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
136 :     in
137 :     if Unify.equalTypes(domTy, [ty1, ty2])
138 :     andalso Unify.equalType(rngTy, resTy)
139 : jhr 3407 then (AST.E_Prim(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
140 : jhr 3405 else error()
141 :     end
142 :     | NONE => error()
143 :     (* end case *))
144 :     (* field * tensor inner product *)
145 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
146 :     of SOME shp => let
147 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
148 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
149 :     in
150 :     if Unify.equalTypes(domTy, [ty1, ty2])
151 :     andalso Unify.equalType(rngTy, resTy)
152 : jhr 3407 then (AST.E_Prim(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
153 : jhr 3405 else error()
154 :     end
155 :     | NONE => error()
156 :     (* end case *))
157 :     (* field * field inner product *)
158 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
159 :     case chkShape(s1, s2)
160 :     of SOME shp => let
161 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
162 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
163 :     in
164 :     (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
165 :     if Unify.equalDim(dim1, dim2)
166 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
167 :     andalso Unify.equalType(rngTy, resTy)
168 : jhr 3407 then (AST.E_Prim(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
169 : jhr 3405 else error()
170 :     end
171 :     | NONE => error()
172 :     (* end case *))
173 :     | (ty1, ty2) => error()
174 :     (* end case *)
175 :     end
176 :    
177 :     (* type check a colon product, which has the constraint:
178 :     * ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
179 :     * and similarly for fields.
180 :     *)
181 :     fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
182 :     (* check the shape of the two arguments to verify that the inner constraint matches *)
183 :     fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
184 :     val (dd1, d11, d12) = let
185 :     fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
186 :     | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
187 :     | splitLast2 (_, []) = raise Fail "impossible"
188 :     in
189 :     splitLast2 ([], dd1)
190 :     end
191 :     in
192 :     if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
193 :     then SOME(Ty.Shape(dd1@dd2))
194 :     else NONE
195 :     end
196 :     | chkShape _ = NONE
197 :     fun error () = err (cxt, [
198 :     S "type error for arguments of binary operator \":\"\n",
199 :     S " found: ", TYS[ty1, ty2], S "\n"
200 :     ])
201 :     in
202 :     case (TU.prune ty1, TU.prune ty2)
203 :     (* tensor * tensor colon product *)
204 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
205 :     of SOME shp => let
206 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
207 :     val resTy = Ty.T_Tensor shp
208 :     in
209 :     if Unify.equalTypes(domTy, [ty1, ty2])
210 :     andalso Unify.equalType(rngTy, resTy)
211 : jhr 3407 then (AST.E_Prim(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
212 : jhr 3405 else error()
213 :     end
214 :     | NONE => error()
215 :     (* end case *))
216 :     (* field * tensor colon product *)
217 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
218 :     of SOME shp => let
219 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
220 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
221 :     in
222 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
223 : jhr 3407 then (AST.E_Prim(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
224 : jhr 3405 else error()
225 :     end
226 :     | NONE => error()
227 :     (* end case *))
228 :     (* tensor * field colon product *)
229 :     | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
230 :     of SOME shp => let
231 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
232 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
233 :     in
234 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
235 : jhr 3407 then (AST.E_Prim(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
236 : jhr 3405 else error()
237 :     end
238 :     | NONE => error()
239 :     (* end case *))
240 :     (* field * field colon product *)
241 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
242 :     case chkShape(s1, s2)
243 :     of SOME shp => let
244 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
245 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
246 :     in
247 :     (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
248 :     if Unify.equalDim(dim1, dim2)
249 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
250 :     andalso Unify.equalType(rngTy, resTy)
251 : jhr 3407 then (AST.E_Prim(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
252 : jhr 3405 else error()
253 :     end
254 :     | NONE => error()
255 :     (* end case *))
256 :     | (ty1, ty2) => error()
257 :     (* end case *)
258 :     end
259 :    
260 : jhr 3396 (* check the type of an expression *)
261 :     fun check (env, cxt, e) = (case e
262 : jhr 3405 of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
263 : jhr 3396 | PT.E_Cond(e1, cond, e2) => let
264 :     val eTy1 = check (env, cxt, e1)
265 :     val eTy2 = check (env, cxt, e2)
266 :     in
267 : jhr 3402 case check(env, cxt, cond)
268 : jhr 3396 of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
269 : jhr 3405 of SOME(e1', e2', ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
270 : jhr 3396 | NONE => err (cxt, [
271 :     S "types do not match in conditional expression\n",
272 :     S " true branch: ", TY(#2 eTy1), S "\n",
273 :     S " false branch: ", TY(#2 eTy2)
274 :     ])
275 : jhr 3398 (* end case *))
276 : jhr 3396 | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
277 :     (* end case *)
278 :     end
279 :     | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
280 :     of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
281 : jhr 3398 val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
282 : jhr 3396 in
283 : jhr 3407 (AST.E_Prim(BV.range, [], [e1', e2'], resTy), resTy)
284 : jhr 3396 end
285 :     | ((_, Ty.T_Int), (_, ty2)) =>
286 :     err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
287 :     | ((_, ty1), (_, Ty.T_Int)) =>
288 :     err (cxt, [S "expected type 'int' on lhs of '..', but found ", TY ty1])
289 :     | ((_, ty1), (_, ty2)) => err (cxt, [
290 :     S "arguments of '..' must have type 'int', found ",
291 :     TY ty1, S " and ", TY ty2
292 :     ])
293 :     (* end case *))
294 :     | PT.E_OrElse(e1, e2) =>
295 :     checkCondOp (env, cxt, e1, "||", e2,
296 :     fn (e1', e2') => AST.E_Cond(e1', AST.E_Lit(L.Bool true), e2', Ty.T_Bool))
297 :     | PT.E_AndAlso(e1, e2) =>
298 :     checkCondOp (env, cxt, e1, "&&", e2,
299 :     fn (e1', e2') => AST.E_Cond(e1', e2', AST.E_Lit(L.Bool false), Ty.T_Bool))
300 :     | PT.E_BinOp(e1, rator, e2) => let
301 :     val (e1', ty1) = check (env, cxt, e1)
302 :     val (e2', ty2) = check (env, cxt, e2)
303 :     in
304 :     if Atom.same(rator, BasisNames.op_dot)
305 : jhr 3405 then chkInnerProduct (cxt, e1', ty1, e2', ty2)
306 : jhr 3396 else if Atom.same(rator, BasisNames.op_colon)
307 : jhr 3405 then chkColonProduct (cxt, e1', ty1, e2', ty2)
308 :     else (case Env.findFunc (env, rator)
309 : jhr 3396 of Env.PrimFun[rator] => let
310 : jhr 3405 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
311 : jhr 3396 in
312 : jhr 3402 case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
313 : jhr 3407 of SOME args => (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
314 : jhr 3396 | NONE => err (cxt, [
315 :     S "type error for binary operator '", V rator, S "'\n",
316 :     S " expected: ", TYS domTy, S "\n",
317 :     S " but found: ", TYS[ty1, ty2]
318 :     ])
319 :     (* end case *)
320 :     end
321 :     | Env.PrimFun ovldList =>
322 :     resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
323 :     | _ => raise Fail "impossible"
324 :     (* end case *))
325 :     end
326 : jhr 3398 | PT.E_UnaryOp(rator, e) => let
327 : jhr 3405 val eTy = check(env, cxt, e)
328 : jhr 3398 in
329 : jhr 3405 case Env.findFunc (env, rator)
330 : jhr 3398 of Env.PrimFun[rator] => let
331 : jhr 3405 val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
332 : jhr 3398 in
333 : jhr 3405 case Util.coerceType (domTy, eTy)
334 : jhr 3407 of SOME(e', ty) => (AST.E_Prim(rator, tyArgs, [e'], rngTy), rngTy)
335 : jhr 3398 | NONE => err (cxt, [
336 :     S "type error for unary operator \"", V rator, S "\"\n",
337 :     S " expected: ", TY domTy, S "\n",
338 : jhr 3405 S " but found: ", TY (#2 eTy)
339 : jhr 3398 ])
340 :     (* end case *)
341 :     end
342 : jhr 3405 | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
343 : jhr 3398 | _ => raise Fail "impossible"
344 :     (* end case *)
345 :     end
346 : jhr 3407 | PT.E_Apply(e, args) => let
347 :     fun stripMark (_, PT.E_Mark{span, tree}) = stripMark(span, tree)
348 :     | stripMark (span, e) = (span, e)
349 :     val (args, tys) = checkList (env, cxt, args)
350 :     fun appTyError (f, paramTys, argTys) = err(cxt, [
351 :     S "type error in application of ", V f, S "\n",
352 :     S " expected: ", TYS paramTys, S "\n",
353 :     S " but found: ", TYS argTys
354 :     ])
355 :     fun checkPrimApp f = if Var.isPrim f
356 :     then (case TU.instantiate(Var.typeOf f)
357 :     of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
358 :     case Unify.matchArgs (domTy, args, tys)
359 :     of SOME args => (AST.E_Prim(f, tyArgs, args, rngTy), rngTy)
360 :     | NONE => appTyError (f, domTy, tys)
361 :     (* end case *))
362 :     | _ => err(cxt, [S "application of non-function/field ", V f])
363 :     (* end case *))
364 :     else raise Fail "unexpected user function"
365 :     (* check the application of a user-defined function *)
366 :     fun checkFunApp (cxt, f) = if Var.isPrim f
367 :     then raise Fail "unexpected primitive function"
368 :     else (case Var.monoTypeOf f
369 :     of Ty.T_Fun(domTy, rngTy) => (
370 :     case Unify.matchArgs (domTy, args, tys)
371 :     of SOME args => (AST.E_Apply(useVar(cxt, f), args, rngTy), rngTy)
372 :     | NONE => appTyError (f, domTy, tys)
373 :     (* end case *))
374 :     | _ => err(cxt, [S "application of non-function/field ", V f])
375 :     (* end case *))
376 :     fun checkFieldApp (e1', ty1) = (case (args, tys)
377 :     of ([e2'], [ty2]) => let
378 :     val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
379 :     TU.instantiate(Var.typeOf BV.op_probe)
380 :     fun tyError () = err (cxt, [
381 :     S "type error for field application\n",
382 :     S " expected: ", TYS[fldTy, domTy], S "\n",
383 :     S " but found: ", TYS[ty1, ty2]
384 :     ])
385 :     in
386 :     if Unify.equalType(fldTy, ty1)
387 :     then (case Util.coerceType(domTy, (e2', ty2))
388 :     of SOME(e2', _) => (AST.E_Prim(BV.op_probe, tyArgs, [e1', e2'], rngTy), rngTy)
389 :     | NONE => tyError()
390 :     (* end case *))
391 :     else tyError()
392 :     end
393 :     | _ => err(cxt, [S "badly formed field application"])
394 :     (* end case *))
395 :     in
396 :     case stripMark(#2 cxt, e)
397 :     of (span, PT.E_Var f) => (case Env.findVar (env, f)
398 :     of SOME f' => checkFieldApp (
399 :     AST.E_Var(useVar((#1 cxt, span), f')),
400 :     Var.monoTypeOf f')
401 :     | NONE => (case Env.findFunc (env, f)
402 :     of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
403 :     | Env.PrimFun[f'] => checkPrimApp f'
404 :     | Env.PrimFun ovldList =>
405 :     resolveOverload ((#1 cxt, span), f, tys, args, ovldList)
406 :     | Env.UserFun f' => checkFunApp((#1 cxt, span), f')
407 :     (* end case *))
408 :     (* end case *))
409 :     | _ => checkFieldApp (check (env, cxt, e))
410 :     (* end case *)
411 :     end
412 : jhr 3398 | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)
413 :     of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"
414 :     | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"
415 :     | ((_, ty), _) => err(cxt, [
416 :     S "expected sequence or tensor type for object of subscripting, but found",
417 :     TY ty
418 :     ])
419 :     (* end case *))
420 :     | PT.E_Select(e, field) => (case check(env, cxt, e)
421 : jhr 3405 of (e', Ty.T_Named strand) => (case Env.findStrand(env, strand)
422 :     of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
423 :     of SOME x' => let
424 :     val ty = Var.monoTypeOf x'
425 :     in
426 : jhr 3407 (AST.E_Select(e', useVar(cxt, x')), ty)
427 : jhr 3405 end
428 :     | NONE => err(cxt, [
429 :     S "strand '", A strand,
430 :     S "' does not have state variable '", A field, S "'"
431 :     ])
432 :     (* end case *))
433 :     | NONE => err(cxt, [S "unknown strand '", A strand, S "'"])
434 : jhr 3398 (* end case *))
435 :     | (_, ty) => err (cxt, [
436 :     S "expected strand type, but found ", TY ty,
437 : jhr 3405 S " in selection of '", A field, S "'"
438 : jhr 3398 ])
439 :     (* end case *))
440 : jhr 3396 | PT.E_Real e => (case check (env, cxt, e)
441 :     of (e', Ty.T_Int) =>
442 : jhr 3407 (AST.E_Prim(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
443 : jhr 3396 | (_, ty) => err(cxt, [
444 :     S "argument of 'real' must have type 'int', but found ",
445 :     TY ty
446 :     ])
447 :     (* end case *))
448 :     | PT.E_Load nrrd => let
449 : jhr 3405 val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_image))
450 : jhr 3396 in
451 : jhr 3407 case chkStringConstExpr (env, cxt, nrrd)
452 :     of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
453 :     | NONE => (bogusExp, rngTy)
454 :     (* end case *)
455 : jhr 3396 end
456 :     | PT.E_Image nrrd => let
457 : jhr 3405 val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load))
458 : jhr 3396 in
459 : jhr 3407 case chkStringConstExpr (env, cxt, nrrd)
460 :     of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
461 :     | NONE => (bogusExp, rngTy)
462 :     (* end case *)
463 : jhr 3396 end
464 : jhr 3405 | PT.E_Var x => (case E.findVar (env, x)
465 : jhr 3407 of SOME x' => (AST.E_Var(useVar(cxt, x')), Var.monoTypeOf x')
466 : jhr 3396 | NONE => err(cxt, [S "undeclared variable ", A x])
467 :     (* end case *))
468 : jhr 3398 | PT.E_Kernel(kern, dim) => raise Fail "FIXME"
469 : jhr 3396 | PT.E_Lit lit => checkLit lit
470 :     | PT.E_Id d => let
471 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
472 : jhr 3405 TU.instantiate(Var.typeOf(BV.identity))
473 : jhr 3396 in
474 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, [d, d])), rngTy)
475 :     then (AST.E_Prim(BV.identity, tyArgs, [], rngTy), rngTy)
476 : jhr 3396 else raise Fail "impossible"
477 :     end
478 :     | PT.E_Zero dd => let
479 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
480 : jhr 3405 TU.instantiate(Var.typeOf(BV.zero))
481 : jhr 3396 in
482 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
483 :     then (AST.E_Prim(BV.zero, tyArgs, [], rngTy), rngTy)
484 : jhr 3396 else raise Fail "impossible"
485 :     end
486 :     | PT.E_NaN dd => let
487 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
488 : jhr 3405 TU.instantiate(Var.typeOf(BV.nan))
489 : jhr 3396 in
490 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
491 :     then (AST.E_Prim(BV.nan, tyArgs, [], rngTy), rngTy)
492 : jhr 3396 else raise Fail "impossible"
493 :     end
494 : jhr 3398 | PT.E_Sequence exps => raise Fail "FIXME"
495 :     | PT.E_SeqComp comp => raise Fail "FIXME"
496 : jhr 3396 | PT.E_Cons args => let
497 :     (* Note that we are guaranteed that args is non-empty *)
498 :     val (args, tys) = checkList (env, cxt, args)
499 :     (* extract the first non-error type in tys *)
500 :     val ty = (case List.find (fn Ty.T_Error => false | _ => true) tys
501 :     of NONE => Ty.T_Error
502 :     | SOME ty => ty
503 :     (* end case *))
504 : jhr 3405 (* process the arguments checking that they all have the expected type *)
505 :     fun chkArgs (ty, shape) = let
506 :     val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
507 :     val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
508 :     fun chkArgs (arg::args, argTy::tys, args') = (
509 :     case Util.coerceType(ty, (arg, argTy))
510 : jhr 3407 of SOME(arg', _) => chkArgs (args, tys, arg'::args')
511 : jhr 3405 | NONE => (
512 :     TypeError.error(cxt, [
513 :     S "arguments of tensor construction must have same type"
514 :     ]);
515 :     chkArgs (args, tys, bogusExp::args'))
516 :     (* end case *))
517 : jhr 3408 | chkArgs (_, _, args') = (AST.E_Tensor(List.rev args', resTy), resTy)
518 : jhr 3405 in
519 :     chkArgs (args, tys, [])
520 :     end
521 : jhr 3396 in
522 : jhr 3405 case TU.pruneHead ty
523 : jhr 3407 of Ty.T_Int => chkArgs(Ty.realTy, Ty.Shape[]) (* coerce integers to reals *)
524 : jhr 3405 | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
525 : jhr 3396 | _ => err(cxt, [S "Invalid argument type for tensor construction"])
526 :     (* end case *)
527 :     end
528 :     | PT.E_Deprecate(msg, e) => (
529 :     warn (cxt, [S msg]);
530 : jhr 3402 check (env, cxt, e))
531 : jhr 3396 (* end case *))
532 :    
533 :     (* check a conditional operator (e.g., || or &&) *)
534 :     and checkCondOp (env, cxt, e1, rator, e2, mk) = (
535 :     case (check(env, cxt, e1), check(env, cxt, e2))
536 :     of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
537 :     | ((_, Ty.T_Bool), (_, ty2)) =>
538 : jhr 3405 err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
539 : jhr 3396 | ((_, ty1), (_, Ty.T_Bool)) =>
540 : jhr 3405 err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
541 : jhr 3396 | ((_, ty1), (_, ty2)) => err (cxt, [
542 : jhr 3405 S "arguments of '", S rator, S "' must have type 'bool', but found ",
543 : jhr 3396 TY ty1, S " and ", TY ty2
544 :     ])
545 :     (* end case *))
546 :    
547 :     (* typecheck a list of expressions returning a list of AST expressions and a list
548 :     * of the types of the expressions.
549 :     *)
550 :     and checkList (env, cxt, exprs) = let
551 :     fun chk (e, (es, tys)) = let
552 : jhr 3402 val (e, ty) = check (env, cxt, e)
553 : jhr 3396 in
554 :     (e::es, ty::tys)
555 :     end
556 :     in
557 :     List.foldr chk ([], []) exprs
558 :     end
559 :    
560 : jhr 3407 (* check a string that is specified as a constant expression *)
561 :     and chkStringConstExpr (env, cxt, PT.E_Mark m) =
562 :     chkStringConstExpr (E.withEnvAndContext (env, cxt, m))
563 :     | chkStringConstExpr (env, cxt, e) = (case check (env, cxt, e)
564 :     of (e', Ty.T_String) => (case ConstExpr.eval (cxt, e')
565 :     of SOME(ConstExpr.String s) => SOME s
566 :     | SOME(ConstExpr.Expr e) => raise Fail "FIXME"
567 :     | NONE => NONE
568 :     | _ => raise Fail "impossible: wrong type for constant expr"
569 :     (* end case *))
570 :     | (_, ty) => (
571 :     TypeError.error (cxt, [
572 :     S "expected constant expression of type 'string', but found '",
573 :     TY ty, S "'"
574 :     ]);
575 :     NONE)
576 :     (* end case *))
577 :    
578 :     (* check a dimension that is given by a constant expression *)
579 :     and checkDim (env, cxt, dim) = (case check (env, cxt, dim)
580 :     of (e', Ty.T_Int) => (case ConstExpr.eval (cxt, e')
581 :     of SOME(ConstExpr.Int d) => SOME d
582 :     | SOME(ConstExpr.Expr e) => (
583 :     TypeError.error (cxt, [S "unable to evaluate constant dimension expression"]);
584 :     NONE)
585 :     | NONE => NONE
586 :     | _ => raise Fail "impossible: wrong type for constant expr"
587 :     (* end case *))
588 :     | (_, ty) => (
589 :     TypeError.error (cxt, [
590 :     S "expected constant expression of type 'int', but found '",
591 :     TY ty, S "'"
592 :     ]);
593 :     NONE)
594 :     (* end case *))
595 :    
596 :     (* check a tensor shape, where the dimensions are given by constant expressions *)
597 :     and checkShape (env, cxt, shape) = let
598 :     fun checkDim' e = (case checkDim (env, cxt, e)
599 :     of SOME d => (
600 :     if (d <= 1)
601 :     then TypeError.error (cxt, [
602 :     S "invalid tensor-shape dimension; must be > 1, but found ",
603 :     S (IntLit.toString d)
604 :     ])
605 :     else ();
606 :     Ty.DimConst(IntInf.toInt d))
607 :     | NONE => Ty.DimConst ~1
608 :     (* end case *))
609 :     in
610 :     Ty.Shape(List.map checkDim' shape)
611 :     end
612 :    
613 : jhr 3396 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0