Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5574 - (view) (download)

1 : jhr 3396 (* check-expr.sml
2 :     *
3 :     * The typechecker for expressions.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure CheckExpr : sig
12 :    
13 : jhr 3407 (* type check an expression *)
14 :     val check : Env.t * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
15 : jhr 3396
16 : jhr 3410 (* type check a list of expressions *)
17 :     val checkList : Env.t * Env.context * ParseTree.expr list -> (AST.expr list * Types.ty list)
18 :    
19 : jhr 3424 (* type check an iteration expression (i.e., "x 'in' expr"), returning the iterator
20 :     * and the environment extended with a binding for x.
21 :     *)
22 :     val checkIter : Env.t * Env.context * ParseTree.iterator -> ((AST.var * AST.expr) * Env.t)
23 :    
24 : jhr 3410 (* type check a dimension that is given by a constant expression *)
25 :     val checkDim : Env.t * Env.context * ParseTree.expr -> IntLit.t option
26 :    
27 :     (* type check a tensor shape, where the dimensions are given by constant expressions *)
28 :     val checkShape : Env.t * Env.context * ParseTree.expr list -> Types.shape
29 :    
30 : jhr 3407 (* `resolveOverload (cxt, rator, tys, args, candidates)` resolves the application of
31 :     * the overloaded operator `rator` to `args`, where `tys` are the types of the arguments
32 :     * and `candidates` is the list of candidate definitions.
33 :     *)
34 :     val resolveOverload : Env.context * Atom.atom * Types.ty list * AST.expr list * Var.t list
35 : jhr 4317 -> (AST.expr * Types.ty)
36 : jhr 3407
37 : jhr 3396 end = struct
38 :    
39 :     structure PT = ParseTree
40 :     structure L = Literal
41 :     structure E = Env
42 :     structure Ty = Types
43 :     structure BV = BasisVars
44 : jhr 3405 structure TU = TypeUtil
45 : jhr 3396
46 :     (* an expression to return when there is a type error *)
47 : jhr 3405 val bogusExp = AST.E_Lit(L.Int 0)
48 :     val bogusExpTy = (bogusExp, Ty.T_Error)
49 : jhr 3396
50 : jhr 3405 fun err arg = (TypeError.error arg; bogusExpTy)
51 : jhr 3396 val warn = TypeError.warning
52 :    
53 : jhr 3402 datatype token = datatype TypeError.token
54 : jhr 3396
55 : jhr 3407 (* mark a variable use with its location *)
56 : jhr 3413 fun useVar (cxt : Env.context, x) = (x, #2 cxt)
57 : jhr 3407
58 : jhr 3431 (* strip any marks that enclose an expression and return the span and the expression *)
59 :     fun stripMark (_, PT.E_Mark{span, tree}) = stripMark(span, tree)
60 :     | stripMark (span, e) = (span, e)
61 :    
62 : jhr 3407 (* resolve overloading: we use a simple scheme that selects the first operator in the
63 :     * list that matches the argument types.
64 :     *)
65 :     fun resolveOverload (_, rator, _, _, []) = raise Fail(concat[
66 :     "resolveOverload: \"", Atom.toString rator, "\" has no candidates"
67 :     ])
68 :     | resolveOverload (cxt, rator, argTys, args, candidates) = let
69 :     (* FIXME: we could be more efficient by just checking for coercion matchs the first pass
70 :     * and remembering those that are not pure EQ matches.
71 :     *)
72 : jhr 4317 (* build the result *)
73 :     fun done (rator, tyArgs, args, rngTy) = if Var.same(rator, BV.pow_si)
74 :     then let (* check that the second argument is a constant expression *)
75 :     val [e1, e2] = args
76 :     in
77 :     case CheckConst.eval (cxt, false, e2)
78 :     of SOME e2' =>
79 :     (AST.E_Prim(rator, tyArgs, [e1, ConstExpr.valueToExpr e2'], rngTy), rngTy)
80 :     | NONE => err(cxt, [
81 :     S "constant-integer exponent is required when lhs of '^' is a field"
82 :     ])
83 :     end
84 :     else (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
85 : jhr 3407 (* try to match candidates while allowing type coercions *)
86 :     fun tryMatchCandidates [] = err(cxt, [
87 :     S "unable to resolve overloaded operator ", A rator, S "\n",
88 :     S " argument type is: ", TYS argTys, S "\n"
89 :     ])
90 :     | tryMatchCandidates (x::xs) = let
91 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
92 :     in
93 :     case Unify.tryMatchArgs (domTy, args, argTys)
94 : jhr 3499 of SOME args' => done(x, tyArgs, args', rngTy)
95 : jhr 3407 | NONE => tryMatchCandidates xs
96 :     (* end case *)
97 :     end
98 : jhr 3499 (* try to match candidates without type coercions *)
99 : jhr 3407 fun tryCandidates [] = tryMatchCandidates candidates
100 :     | tryCandidates (x::xs) = let
101 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
102 :     in
103 :     if Unify.tryEqualTypes(domTy, argTys)
104 : jhr 3499 then done(x, tyArgs, args, rngTy)
105 : jhr 3407 else tryCandidates xs
106 :     end
107 :     in
108 :     tryCandidates candidates
109 :     end
110 :    
111 : jhr 3396 (* check the type of a literal *)
112 : jhr 3433 fun checkLit lit = (AST.E_Lit lit, TypeOf.literal lit)
113 : jhr 3396
114 : jhr 3405 (* type check a dot product, which has the constraint:
115 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
116 :     * and similarly for fields.
117 :     *)
118 :     fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
119 : jhr 4317 (* check the shape of the two arguments to verify that the inner constraint matches *)
120 :     fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
121 :     val (dd1, d1) = let
122 :     fun splitLast (prefix, [d]) = (List.rev prefix, d)
123 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
124 :     | splitLast (_, []) = raise Fail "impossible"
125 :     in
126 :     splitLast ([], dd1)
127 :     end
128 :     in
129 :     if Unify.equalDim(d1, d2)
130 :     then SOME(Ty.Shape(dd1@dd2))
131 :     else NONE
132 :     end
133 :     | chkShape _ = NONE
134 :     fun error () = err (cxt, [
135 :     S "type error for arguments of binary operator '•'\n",
136 :     S " found: ", TYS[ty1, ty2], S "\n"
137 :     ])
138 :     in
139 :     case (TU.prune ty1, TU.prune ty2)
140 :     (* tensor * tensor inner product *)
141 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
142 :     of SOME shp => let
143 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
144 :     val resTy = Ty.T_Tensor shp
145 :     in
146 :     if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
147 :     then (AST.E_Prim(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
148 :     else error()
149 :     end
150 :     | NONE => error()
151 :     (* end case *))
152 :     (* tensor * field inner product *)
153 :     | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
154 :     of SOME shp => let
155 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
156 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
157 :     in
158 :     if Unify.equalTypes(domTy, [ty1, ty2])
159 :     andalso Unify.equalType(rngTy, resTy)
160 :     then (AST.E_Prim(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
161 :     else error()
162 :     end
163 :     | NONE => error()
164 :     (* end case *))
165 :     (* field * tensor inner product *)
166 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
167 :     of SOME shp => let
168 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
169 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
170 :     in
171 :     if Unify.equalTypes(domTy, [ty1, ty2])
172 :     andalso Unify.equalType(rngTy, resTy)
173 :     then (AST.E_Prim(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
174 :     else error()
175 :     end
176 :     | NONE => error()
177 :     (* end case *))
178 :     (* field * field inner product *)
179 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
180 :     case chkShape(s1, s2)
181 :     of SOME shp => let
182 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
183 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
184 :     in
185 : jhr 3405 (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
186 : jhr 4317 if Unify.equalDim(dim1, dim2)
187 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
188 :     andalso Unify.equalType(rngTy, resTy)
189 :     then (AST.E_Prim(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
190 :     else error()
191 :     end
192 :     | NONE => error()
193 :     (* end case *))
194 :     | (ty1, ty2) => error()
195 :     (* end case *)
196 :     end
197 : jhr 3405
198 : jhr 3807 (* type check an outer product, which has the constraint:
199 :     * ALL[sigma1, sigma2] . tensor[sigma1] * tensor[sigma2] -> tensor[sigma1, sigma2]
200 :     * and similarly for fields.
201 :     *)
202 :     fun chkOuterProduct (cxt, e1, ty1, e2, ty2) = let
203 : jhr 4317 fun mergeShp (Ty.Shape dd1, Ty.Shape dd2) = SOME(Ty.Shape(dd1@dd2))
204 :     | mergeShp _ = NONE
205 :     fun shapeError () = err (cxt, [
206 :     S "unable to determine result shape of outer product\n",
207 :     S " found: ", TYS[ty1, ty2], S "\n"
208 :     ])
209 :     fun error () = err (cxt, [
210 :     S "type error for arguments of binary operator \"⊗\"\n",
211 :     S " found: ", TYS[ty1, ty2], S "\n"
212 :     ])
213 :     in
214 :     case (TU.prune ty1, TU.prune ty2)
215 :     (* tensor * tensor outer product *)
216 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case mergeShp(s1, s2)
217 :     of SOME shp => let
218 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_outer_tt)
219 :     val resTy = Ty.T_Tensor shp
220 :     in
221 :     if Unify.equalTypes(domTy, [ty1, ty2])
222 :     andalso Unify.equalType(rngTy, resTy)
223 :     then (AST.E_Prim(BV.op_outer_tt, tyArgs, [e1, e2], rngTy), rngTy)
224 :     else error()
225 :     end
226 :     | NONE => shapeError()
227 :     (* end case *))
228 :     (* field * tensor outer product *)
229 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case mergeShp(s1, s2)
230 :     of SOME shp => let
231 : jhr 3807 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_outer_ft)
232 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
233 :     in
234 : jhr 4317 if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
235 :     then (AST.E_Prim(BV.op_outer_ft, tyArgs, [e1, e2], rngTy), rngTy)
236 :     else error()
237 : jhr 3807 end
238 : jhr 4317 | NONE => shapeError()
239 :     (* end case *))
240 :     (* tensor * field outer product *)
241 :     | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case mergeShp(s1, s2)
242 :     of SOME shp => let
243 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_outer_tf)
244 : jhr 3807 val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
245 :     in
246 : jhr 4317 if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
247 :     then (AST.E_Prim(BV.op_outer_tf, tyArgs, [e1, e2], rngTy), rngTy)
248 :     else error()
249 : jhr 3807 end
250 : jhr 4317 | NONE => shapeError()
251 :     (* end case *))
252 :     (* field * field outer product *)
253 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
254 :     case mergeShp(s1, s2)
255 :     of SOME shp => let
256 : jhr 3807 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_outer_ff)
257 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
258 :     in
259 :     (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
260 : jhr 4317 if Unify.equalDim(dim1, dim2)
261 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
262 :     andalso Unify.equalType(rngTy, resTy)
263 :     then (AST.E_Prim(BV.op_outer_ff, tyArgs, [e1, e2], rngTy), rngTy)
264 : jhr 3807 else error()
265 :     end
266 : jhr 4317 | NONE => shapeError()
267 :     (* end case *))
268 :     | _ => error()
269 :     (* end case *)
270 :     end
271 : jhr 3807
272 : jhr 3405 (* type check a colon product, which has the constraint:
273 :     * ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
274 :     * and similarly for fields.
275 :     *)
276 :     fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
277 : jhr 4317 (* check the shape of the two arguments to verify that the inner constraint matches *)
278 :     fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
279 :     val (dd1, d11, d12) = let
280 :     fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
281 :     | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
282 :     | splitLast2 (_, []) = raise Fail "impossible"
283 :     in
284 :     splitLast2 ([], dd1)
285 :     end
286 :     in
287 :     if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
288 :     then SOME(Ty.Shape(dd1@dd2))
289 :     else NONE
290 :     end
291 :     | chkShape _ = NONE
292 :     fun error () = err (cxt, [
293 :     S "type error for arguments of binary operator \":\"\n",
294 :     S " found: ", TYS[ty1, ty2], S "\n"
295 :     ])
296 :     in
297 :     case (TU.prune ty1, TU.prune ty2)
298 :     (* tensor * tensor colon product *)
299 :     of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
300 :     of SOME shp => let
301 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
302 :     val resTy = Ty.T_Tensor shp
303 :     in
304 :     if Unify.equalTypes(domTy, [ty1, ty2])
305 :     andalso Unify.equalType(rngTy, resTy)
306 :     then (AST.E_Prim(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
307 :     else error()
308 :     end
309 :     | NONE => error()
310 :     (* end case *))
311 :     (* field * tensor colon product *)
312 :     | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
313 :     of SOME shp => let
314 : jhr 3405 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
315 :     val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
316 :     in
317 : jhr 4317 if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
318 :     then (AST.E_Prim(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
319 :     else error()
320 : jhr 3405 end
321 : jhr 4317 | NONE => error()
322 :     (* end case *))
323 :     (* tensor * field colon product *)
324 :     | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
325 :     of SOME shp => let
326 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
327 : jhr 3405 val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
328 :     in
329 : jhr 4317 if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
330 :     then (AST.E_Prim(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
331 :     else error()
332 : jhr 3405 end
333 : jhr 4317 | NONE => error()
334 :     (* end case *))
335 :     (* field * field colon product *)
336 :     | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
337 :     case chkShape(s1, s2)
338 :     of SOME shp => let
339 : jhr 3405 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
340 :     val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
341 :     in
342 :     (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
343 : jhr 4317 if Unify.equalDim(dim1, dim2)
344 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
345 :     andalso Unify.equalType(rngTy, resTy)
346 :     then (AST.E_Prim(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
347 : jhr 3405 else error()
348 :     end
349 : jhr 4317 | NONE => error()
350 :     (* end case *))
351 :     | (ty1, ty2) => error()
352 :     (* end case *)
353 :     end
354 : jhr 3405
355 : jhr 4368 (* check the well-formedness of a spatial query `e`, which has already been typechecked *)
356 :     fun checkSpatialQuery (env, cxt, e, tyArgs, rngTy) = (case Env.strandTy env
357 :     of SOME(strandTy, sEnv) => (case StrandEnv.findPosVar sEnv
358 :     of SOME p => let
359 :     val [Ty.TYPE tv] = tyArgs
360 :     fun result dim = (
361 :     StrandEnv.recordSpaceDim (sEnv, dim);
362 :     (e, rngTy))
363 :     in
364 :     (* instantiate the query's type to the strand type *)
365 :     ignore (Unify.matchType (Ty.T_Var tv, strandTy));
366 :     (* check that the strand supports spatial queries *)
367 :     case StrandEnv.getSpaceDim sEnv
368 :     of SOME _ => (e, rngTy) (* we have already processed a spatial query *)
369 :     | NONE => (
370 :     Env.recordProp (env, Properties.StrandCommunication);
371 :     (* check the type of the position; should be 1D, 2D, or 3D *)
372 :     case TU.prune (Var.monoTypeOf p)
373 :     of Ty.T_Tensor(Ty.Shape[]) => result 1
374 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => result 2
375 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => result 3
376 :     | ty => err(cxt, [
377 :     S "'expected one of real, vec2, or vec3 for type of 'pos',\n",
378 :     S " but found: ", TY ty
379 :     ])
380 :     (* end case *))
381 :     (* end case *)
382 :     end
383 :     | NONE => err(cxt, [
384 :     S "spatial queries require defining a 'pos' variable of suitable type"
385 :     ])
386 : jhr 5085 (* end case *))
387 : jhr 4368 | NONE => err(cxt, [
388 :     S "spatial queries are only allowed inside strands"
389 :     ])
390 :     (* end case *))
391 :    
392 : jhr 3396 (* check the type of an expression *)
393 :     fun check (env, cxt, e) = (case e
394 : jhr 4317 of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
395 :     | PT.E_Cond(e1, cond, e2) => let
396 : jhr 3396 val eTy1 = check (env, cxt, e1)
397 :     val eTy2 = check (env, cxt, e2)
398 :     in
399 : jhr 3431 case checkAndPrune(env, cxt, cond)
400 : jhr 3396 of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
401 : jhr 4317 of SOME(e1', e2', ty) =>
402 :     if TU.isValueType ty
403 :     then (AST.E_Cond(cond', e1', e2', ty), ty)
404 :     else err (cxt, [
405 :     S "result of conditional expression must be value type,\n",
406 :     S " but found ", TY ty
407 :     ])
408 :     | NONE => err (cxt, [
409 : jhr 3396 S "types do not match in conditional expression\n",
410 :     S " true branch: ", TY(#2 eTy1), S "\n",
411 :     S " false branch: ", TY(#2 eTy2)
412 :     ])
413 : jhr 4317 (* end case *))
414 :     | (_, Ty.T_Error) => bogusExpTy
415 : jhr 3396 | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
416 :     (* end case *)
417 :     end
418 : jhr 4317 | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
419 :     of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
420 :     val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
421 :     in
422 :     (AST.E_Prim(BV.range, [], [e1', e2'], resTy), resTy)
423 :     end
424 :     | ((_, Ty.T_Int), (_, ty2)) =>
425 :     err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
426 :     | ((_, ty1), (_, Ty.T_Int)) =>
427 :     err (cxt, [S "expected type 'int' on lhs of '..', but found ", TY ty1])
428 :     | ((_, ty1), (_, ty2)) => err (cxt, [
429 :     S "arguments of '..' must have type 'int', found ",
430 :     TY ty1, S " and ", TY ty2
431 :     ])
432 :     (* end case *))
433 :     | PT.E_OrElse(e1, e2) => checkCondOp (env, cxt, e1, "||", e2, AST.E_Orelse)
434 :     | PT.E_AndAlso(e1, e2) => checkCondOp (env, cxt, e1, "&&", e2, AST.E_Andalso)
435 :     | PT.E_BinOp(e1, rator, e2) => let
436 : jhr 3396 val (e1', ty1) = check (env, cxt, e1)
437 :     val (e2', ty2) = check (env, cxt, e2)
438 :     in
439 :     if Atom.same(rator, BasisNames.op_dot)
440 : jhr 4317 then chkInnerProduct (cxt, e1', ty1, e2', ty2)
441 :     else if Atom.same(rator, BasisNames.op_outer)
442 :     then chkOuterProduct (cxt, e1', ty1, e2', ty2)
443 : jhr 3396 else if Atom.same(rator, BasisNames.op_colon)
444 : jhr 4317 then chkColonProduct (cxt, e1', ty1, e2', ty2)
445 :     else (case Env.findFunc (env, rator)
446 : jhr 3396 of Env.PrimFun[rator] => let
447 : jhr 3405 val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
448 : jhr 3396 in
449 : jhr 3402 case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
450 : jhr 3407 of SOME args => (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
451 : jhr 3396 | NONE => err (cxt, [
452 : jhr 3418 S "type error for binary operator ", V rator, S "\n",
453 : jhr 4073 S " expected: ", TYS domTy, S "\n",
454 :     S " found: ", TYS[ty1, ty2]
455 : jhr 3396 ])
456 :     (* end case *)
457 :     end
458 :     | Env.PrimFun ovldList =>
459 :     resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
460 :     | _ => raise Fail "impossible"
461 :     (* end case *))
462 :     end
463 : jhr 4317 | PT.E_UnaryOp(rator, e) => let
464 : jhr 3405 val eTy = check(env, cxt, e)
465 : jhr 3398 in
466 : jhr 3405 case Env.findFunc (env, rator)
467 : jhr 3398 of Env.PrimFun[rator] => let
468 : jhr 3405 val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
469 : jhr 3398 in
470 : jhr 3405 case Util.coerceType (domTy, eTy)
471 : jhr 3410 of SOME e' => (AST.E_Prim(rator, tyArgs, [e'], rngTy), rngTy)
472 : jhr 3398 | NONE => err (cxt, [
473 : jhr 3418 S "type error for unary operator ", V rator, S "\n",
474 : jhr 4073 S " expected: ", TY domTy, S "\n",
475 :     S " found: ", TY (#2 eTy)
476 : jhr 3398 ])
477 :     (* end case *)
478 :     end
479 : jhr 3405 | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
480 : jhr 3398 | _ => raise Fail "impossible"
481 :     (* end case *)
482 :     end
483 : jhr 4317 | PT.E_Apply(e, args) => let
484 : jhr 3407 val (args, tys) = checkList (env, cxt, args)
485 : jhr 4317 fun appTyError (f, paramTys, argTys) = err(cxt, [
486 :     S "type error in application of ", V f, S "\n",
487 :     S " expected: ", TYS paramTys, S "\n",
488 :     S " found: ", TYS argTys
489 :     ])
490 : jhr 3407 fun checkPrimApp f = if Var.isPrim f
491 : jhr 4317 then (case TU.instantiate(Var.typeOf f)
492 :     of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
493 :     case Unify.matchArgs (domTy, args, tys)
494 :     of SOME args => (AST.E_Prim(f, tyArgs, args, rngTy), rngTy)
495 :     | NONE => appTyError (f, domTy, tys)
496 :     (* end case *))
497 :     | _ => err(cxt, [S "application of non-function/field ", V f])
498 :     (* end case *))
499 :     else raise Fail "unexpected user function"
500 :     (* check the application of a user-defined function *)
501 : jhr 3407 fun checkFunApp (cxt, f) = if Var.isPrim f
502 : jhr 4317 then raise Fail "unexpected primitive function"
503 :     else (case Var.monoTypeOf f
504 :     of Ty.T_Fun(domTy, rngTy) => (
505 :     case Unify.matchArgs (domTy, args, tys)
506 :     of SOME args => (AST.E_Apply(useVar(cxt, f), args, rngTy), rngTy)
507 :     | NONE => appTyError (f, domTy, tys)
508 :     (* end case *))
509 :     | _ => err(cxt, [S "application of non-function/field ", V f])
510 :     (* end case *))
511 : jhr 3407 fun checkFieldApp (e1', ty1) = (case (args, tys)
512 :     of ([e2'], [ty2]) => let
513 :     val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
514 :     TU.instantiate(Var.typeOf BV.op_probe)
515 :     fun tyError () = err (cxt, [
516 :     S "type error for field application\n",
517 : jhr 4073 S " expected: ", TYS[fldTy, domTy], S "\n",
518 :     S " found: ", TYS[ty1, ty2]
519 : jhr 3407 ])
520 :     in
521 :     if Unify.equalType(fldTy, ty1)
522 :     then (case Util.coerceType(domTy, (e2', ty2))
523 : jhr 3410 of SOME e2' => (AST.E_Prim(BV.op_probe, tyArgs, [e1', e2'], rngTy), rngTy)
524 : jhr 3407 | NONE => tyError()
525 :     (* end case *))
526 :     else tyError()
527 :     end
528 :     | _ => err(cxt, [S "badly formed field application"])
529 :     (* end case *))
530 :     in
531 :     case stripMark(#2 cxt, e)
532 :     of (span, PT.E_Var f) => (case Env.findVar (env, f)
533 :     of SOME f' => checkFieldApp (
534 : jhr 4317 AST.E_Var(useVar((#1 cxt, span), f')),
535 :     Var.monoTypeOf f')
536 : jhr 3407 | NONE => (case Env.findFunc (env, f)
537 :     of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
538 :     | Env.PrimFun[f'] => checkPrimApp f'
539 : jhr 4399 | Env.PrimFun ovldList => (
540 : jhr 5085 case resolveOverload ((#1 cxt, span), f, tys, args, ovldList)
541 :     of (e' as AST.E_Prim(f', tyArgs, _, _), rngTy) =>
542 : jhr 4368 (* NOTE: if/when we switch to matching type patterns (instead of unification),
543 :     * we can use a "Self" type pattern to handle spatial queries.
544 : jhr 4349 *)
545 : jhr 5085 if Basis.isSpatialQueryOp f'
546 :     then checkSpatialQuery (env, cxt, e', tyArgs, rngTy)
547 :     else (e', rngTy)
548 :     | badResult => badResult
549 :     (* end case *))
550 : jhr 3407 | Env.UserFun f' => checkFunApp((#1 cxt, span), f')
551 :     (* end case *))
552 :     (* end case *))
553 :     | _ => checkFieldApp (check (env, cxt, e))
554 :     (* end case *)
555 :     end
556 : jhr 4317 | PT.E_Subscript(e, indices) => let
557 :     fun expectedTensor ty = err(cxt, [
558 :     S "expected tensor type for slicing, but found ", TY ty
559 :     ])
560 :     fun chkIndex e = let
561 :     val eTy as (_, ty) = check(env, cxt, e)
562 :     in
563 :     if Unify.equalType(ty, Ty.T_Int)
564 :     then eTy
565 :     else err (cxt, [
566 :     S "expected type 'int' for index, but found ", TY ty
567 :     ])
568 :     end
569 : jhr 5112 val (e', ty) = check(env, cxt, e)
570 : jhr 4317 in
571 : jhr 5112 case (TU.pruneHead ty, indices)
572 :     of (Ty.T_Error, _) => (
573 : jhr 4317 List.app (ignore o Option.map chkIndex) indices;
574 :     bogusExpTy)
575 : jhr 5112 | (ty1 as Ty.T_Sequence(elemTy, optDim), [SOME e2]) => let
576 : jhr 4317 val (e2', ty2) = chkIndex e2
577 :     val rator = if isSome optDim
578 :     then BV.subscript
579 :     else BV.dynSubscript
580 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
581 : jhr 3424 in
582 : jhr 4317 if Unify.equalTypes(domTy, [ty1, ty2])
583 :     then let
584 : jhr 5112 val exp = AST.E_Prim(rator, tyArgs, [e', e2'], rngTy)
585 : jhr 4317 in
586 :     (exp, rngTy)
587 :     end
588 :     else raise Fail "unexpected unification failure"
589 :     end
590 : jhr 5112 | (ty as Ty.T_Sequence _, [NONE]) => expectedTensor ty
591 :     | (ty as Ty.T_Sequence _, _) => expectedTensor ty
592 :     | (ty, _) => let
593 : jhr 4317 (* for tensor/field slicing/indexing, the indices must be constant expressions *)
594 :     fun chkConstIndex NONE = NONE
595 :     | chkConstIndex (SOME e) = (case chkIndex e
596 :     of (_, Ty.T_Error) => SOME bogusExp
597 :     | (e', _) => (case CheckConst.eval (cxt, false, e')
598 : jhr 3797 (* FIXME: should check that index is in range for type! *)
599 : jhr 4317 of SOME cexp => SOME(ConstExpr.valueToExpr cexp)
600 :     | NONE => SOME e' (* use e' to preserve variable uses *)
601 :     (* end case *))
602 :     (* end case *))
603 :     val indices' = List.map chkConstIndex indices
604 :     val order = List.length indices'
605 :     (* val expectedTy = TU.mkTensorTy order*)
606 : jhr 3992 (* QUESTION: perhaps we should lift this case up above (i.e., one case for tensor and on for fields *)
607 :     val expectedTy = (case ty
608 : jhr 4317 of Ty.T_Field{diff, dim, shape=s as Ty.Shape(d2::dd2)} =>
609 :     Ty.T_Field{diff=diff, dim=dim, shape=s}
610 :     | Ty.T_Tensor shape => TU.mkTensorTy order
611 :     | Ty.T_Field _ => raise Fail "unknown field type"
612 : jhr 5107 | ty => raise Fail("unexpected type for subscript: " ^ TU.toString ty)
613 : jhr 4317 (* end case *))
614 :     val resultTy = TU.slice(expectedTy, List.map Option.isSome indices')
615 :     in
616 :     if Unify.equalType(ty, expectedTy)
617 :     then (AST.E_Slice(e', indices', resultTy), resultTy)
618 :     else err (cxt, [
619 :     S "type error in slice operation\n",
620 :     S " expected: ", S(Int.toString order), S "-order tensor\n",
621 :     S " found: ", TY ty
622 :     ])
623 :     end
624 :     (* end case *)
625 :     end
626 :     | PT.E_Select(e, field) => (case stripMark(#2 cxt, e)
627 :     of (_, PT.E_Var x) => (case E.findStrand (env, x)
628 : jhr 4480 of SOME _ => if E.inGlobalBlock env
629 : jhr 4317 then (case E.findSetFn (env, field)
630 :     of SOME setFn => let
631 :     val (mvs, ty) = TU.instantiate (Var.typeOf setFn)
632 :     val resTy = Ty.T_Sequence(Ty.T_Strand x, NONE)
633 :     in
634 : jhr 4368 (* QUESTION: does it make sense to allow strand sets outside of reductions? *)
635 : jhr 4317 E.recordProp (env, Properties.StrandSets);
636 :     if Unify.equalType(ty, Ty.T_Fun([], resTy))
637 :     then (AST.E_Prim(setFn, mvs, [], resTy), resTy)
638 :     else raise Fail "impossible"
639 :     end
640 :     | _ => err (cxt, [
641 :     S "unknown strand-set specifier ", A field
642 :     ])
643 :     (* end case *))
644 :     else err (cxt, [
645 :     S "illegal strand set specification in ",
646 :     S(E.scopeToString(E.currentScope env))
647 :     ])
648 :     | _ => checkSelect (env, cxt, e, field)
649 :     (* end case *))
650 :     | _ => checkSelect (env, cxt, e, field)
651 :     (* end case *))
652 :     | PT.E_Real e => (case checkAndPrune (env, cxt, e)
653 : jhr 3396 of (e', Ty.T_Int) =>
654 : jhr 3407 (AST.E_Prim(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
655 : jhr 4317 | (e', Ty.T_Error) => bogusExpTy
656 : jhr 3396 | (_, ty) => err(cxt, [
657 : jhr 4317 S "argument of 'real' must have type 'int', but found ",
658 :     TY ty
659 :     ])
660 : jhr 3396 (* end case *))
661 : jhr 4491 | PT.E_LoadSeq nrrd => let
662 : jhr 4496 val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load_sequence))
663 : jhr 3396 in
664 : jhr 4317 case chkStringConstExpr (env, cxt, nrrd)
665 :     of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
666 :     | NONE => (bogusExp, rngTy)
667 :     (* end case *)
668 : jhr 3396 end
669 : jhr 4491 | PT.E_LoadImage nrrd => let
670 : jhr 4496 val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load_image))
671 : jhr 3396 in
672 : jhr 3407 case chkStringConstExpr (env, cxt, nrrd)
673 : jhr 4317 of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
674 :     | NONE => (bogusExp, rngTy)
675 :     (* end case *)
676 : jhr 3396 end
677 : jhr 4317 | PT.E_Var x => (case E.findVar (env, x)
678 : jhr 3407 of SOME x' => (AST.E_Var(useVar(cxt, x')), Var.monoTypeOf x')
679 : jhr 4043 | NONE => (case E.findKernel (env, x)
680 : jhr 5558 of SOME h => (AST.E_Kernel h, TypeOf.kernel h)
681 : jhr 4317 | NONE => err(cxt, [S "undeclared variable ", A x])
682 :     (* end case *))
683 : jhr 3421 (* end case *))
684 : jhr 4317 | PT.E_Kernel(kern, dim) => let
685 :     val k' = Int.fromLarge dim handle Overflow => 1073741823
686 :     fun mkExp (e, k, ty) = if (k = k')
687 :     then (e, ty)
688 :     else let
689 : jhr 5558 val ty' = Ty.T_Kernel(Ty.DiffConst(SOME k'))
690 : jhr 4317 in
691 :     (AST.E_Coerce{srcTy = ty, dstTy = ty', e = e}, ty')
692 :     end
693 :     in
694 :     case E.findVar (env, kern)
695 :     of SOME h => (case Var.monoTypeOf h
696 : jhr 5558 of ty as Ty.T_Kernel(Ty.DiffConst(SOME k)) =>
697 : jhr 4317 mkExp (AST.E_Var(useVar(cxt, h)), k, ty)
698 :     | _ => err(cxt, [
699 :     S "expected kernel, but found ", S(Var.kindToString h)
700 :     ])
701 :     (* end case *))
702 :     | NONE => (case E.findKernel (env, kern)
703 :     of SOME h => let
704 :     val k = Kernel.continuity h
705 :     in
706 : jhr 5558 mkExp (AST.E_Kernel h, k, TypeOf.kernel h)
707 : jhr 4317 end
708 :     | NONE => err(cxt, [S "unknown kernel ", A kern])
709 :     (* end case *))
710 :     (* end case *)
711 :     end
712 :     | PT.E_Lit lit => checkLit lit
713 :     | PT.E_Id d => let
714 : jhr 3396 val (tyArgs, Ty.T_Fun(_, rngTy)) =
715 : jhr 3405 TU.instantiate(Var.typeOf(BV.identity))
716 : jhr 3396 in
717 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, [d, d])), rngTy)
718 :     then (AST.E_Prim(BV.identity, tyArgs, [], rngTy), rngTy)
719 : jhr 3396 else raise Fail "impossible"
720 :     end
721 : jhr 4317 | PT.E_Zero dd => let
722 : jhr 3396 val (tyArgs, Ty.T_Fun(_, rngTy)) =
723 : jhr 3405 TU.instantiate(Var.typeOf(BV.zero))
724 : jhr 3396 in
725 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
726 :     then (AST.E_Prim(BV.zero, tyArgs, [], rngTy), rngTy)
727 : jhr 3396 else raise Fail "impossible"
728 :     end
729 : jhr 4317 | PT.E_NaN dd => let
730 : jhr 3396 val (tyArgs, Ty.T_Fun(_, rngTy)) =
731 : jhr 3405 TU.instantiate(Var.typeOf(BV.nan))
732 : jhr 3396 in
733 : jhr 3407 if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
734 :     then (AST.E_Prim(BV.nan, tyArgs, [], rngTy), rngTy)
735 : jhr 3396 else raise Fail "impossible"
736 :     end
737 : jhr 4317 | PT.E_Sequence exps => (case checkList (env, cxt, exps)
738 : jhr 3422 of ([], _) => let
739 : jhr 5085 (* FIXME: metavar should have kind for concrete types here! *)
740 : jhr 3422 val ty = Ty.T_Sequence(Ty.T_Var(MetaVar.newTyVar()), SOME(Ty.DimConst 0))
741 :     in
742 :     (AST.E_Seq([], ty), ty)
743 :     end
744 :     | (args, tys) => (case Util.coerceTypes(List.map TU.pruneHead tys)
745 : jhr 4317 of SOME ty => if TU.isValueType ty
746 :     then let
747 :     fun doExp eTy = valOf(Util.coerceType (ty, eTy))
748 :     val resTy = Ty.T_Sequence(ty, SOME(Ty.DimConst(List.length args)))
749 :     val args = ListPair.map doExp (args, tys)
750 :     in
751 :     (AST.E_Seq(args, resTy), resTy)
752 :     end
753 :     else err(cxt, [S "sequence expression of non-value argument type"])
754 :     | NONE => err(cxt, [S "arguments of sequence expression must have same type"])
755 :     (* end case *))
756 : jhr 3422 (* end case *))
757 : jhr 4317 | PT.E_SeqComp comp => chkComprehension (env, cxt, comp)
758 :     | PT.E_Cons args => let
759 :     (* Note that we are guaranteed that args is non-empty *)
760 : jhr 3396 val (args, tys) = checkList (env, cxt, args)
761 : jhr 4317 (* extract the first non-error type in tys *)
762 :     val ty = (case List.find (fn Ty.T_Error => false | _ => true) tys
763 :     of NONE => Ty.T_Error
764 :     | SOME ty => ty
765 :     (* end case *))
766 :     (* process the arguments checking that they all have the expected type *)
767 :     fun chkArgs (ty, shape) = let
768 :     val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
769 :     val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
770 :     fun chkArgs (arg::args, argTy::tys, args') = (
771 :     case Util.coerceType(ty, (arg, argTy))
772 :     of SOME arg' => chkArgs (args, tys, arg'::args')
773 :     | NONE => (
774 :     TypeError.error(cxt, [
775 :     S "arguments of tensor construction must have same type"
776 : jhr 5574 (* FIXME: add types to error message *)
777 : jhr 4317 ]);
778 :     chkArgs (args, tys, bogusExp::args'))
779 :     (* end case *))
780 :     | chkArgs (_, _, args') = (AST.E_Tensor(List.rev args', resTy), resTy)
781 :     in
782 :     chkArgs (args, tys, [])
783 :     end
784 : jhr 5574 fun chkArgsF (ty, diff, dim, shape) = let
785 :     val Ty.Shape dd = TU.pruneShape shape
786 :     val resTy = Ty.T_Field{diff=diff ,dim=dim, shape=Ty.Shape(Ty.DimConst(List.length args) :: dd)}
787 :     fun chkArgsF (arg::args, argTy::tys, args') = (case Util.coerceType(ty, (arg, argTy))
788 :     of SOME arg' => chkArgsF (args, tys, arg'::args')
789 :     | NONE => (
790 :     TypeError.error(cxt, [
791 :     S "arguments of tensor construction must have same type"
792 :     (* FIXME: add types to error message *)
793 :     ]);
794 :     chkArgsF (args, tys, bogusExp::args'))
795 :     (* end case *))
796 :     | chkArgsF (_, _, args') = (AST.E_Field(List.rev args', resTy), resTy)
797 :     in
798 :     chkArgsF (args, tys, [])
799 :     end
800 : jhr 3396 in
801 : jhr 3405 case TU.pruneHead ty
802 : jhr 3407 of Ty.T_Int => chkArgs(Ty.realTy, Ty.Shape[]) (* coerce integers to reals *)
803 : jhr 4317 | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
804 : jhr 5574 | ty as Ty.T_Field{diff, dim, shape} => chkArgsF(ty, diff, dim, shape)
805 : jhr 3396 | _ => err(cxt, [S "Invalid argument type for tensor construction"])
806 :     (* end case *)
807 :     end
808 : jhr 4317 (* end case *))
809 : jhr 3396
810 : jhr 3431 (* typecheck and the prune the result *)
811 :     and checkAndPrune (env, cxt, e) = let
812 : jhr 4317 val (e, ty) = check (env, cxt, e)
813 :     in
814 :     (e, TU.prune ty)
815 :     end
816 : jhr 3431
817 : jhr 3396 (* check a conditional operator (e.g., || or &&) *)
818 :     and checkCondOp (env, cxt, e1, rator, e2, mk) = (
819 : jhr 4317 case (check(env, cxt, e1), check(env, cxt, e2))
820 :     of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
821 :     | ((_, Ty.T_Bool), (_, ty2)) =>
822 :     err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
823 :     | ((_, ty1), (_, Ty.T_Bool)) =>
824 :     err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
825 :     | ((_, ty1), (_, ty2)) => err (cxt, [
826 :     S "arguments of '", S rator, S "' must have type 'bool', but found ",
827 :     TY ty1, S " and ", TY ty2
828 :     ])
829 :     (* end case *))
830 : jhr 3396
831 : jhr 3431 (* check a field select that is _not_ a strand-set *)
832 :     and checkSelect (env, cxt, e, field) = (case checkAndPrune (env, cxt, e)
833 : jhr 4317 of (e', Ty.T_Strand strand) => (case Env.findStrand(env, strand)
834 :     of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
835 :     of SOME x' => let
836 :     val ty = Var.monoTypeOf x'
837 :     in
838 :     (AST.E_Select(e', useVar(cxt, x')), ty)
839 :     end
840 :     | NONE => err(cxt, [
841 :     S "strand ", A strand,
842 :     S " does not have state variable ", A field
843 :     ])
844 :     (* end case *))
845 :     | NONE => err(cxt, [S "unknown strand ", A strand])
846 :     (* end case *))
847 :     | (_, Ty.T_Error) => bogusExpTy
848 :     | (_, ty) => err (cxt, [
849 :     S "expected strand type, but found ", TY ty,
850 :     S " in selection of ", A field
851 :     ])
852 :     (* end case *))
853 : jhr 3431
854 : jhr 3424 and chkComprehension (env, cxt, PT.COMP_Mark m) =
855 : jhr 4317 chkComprehension(E.withEnvAndContext(env, cxt, m))
856 : jhr 3424 | chkComprehension (env, cxt, PT.COMP_Comprehension(e, [iter])) = let
857 : jhr 4317 val (iter', env') = checkIter (E.blockScope env, cxt, iter)
858 :     val (e', ty) = check (env', cxt, e)
859 :     val resTy = Ty.T_Sequence(ty, NONE)
860 :     in
861 :     case iter'
862 : jhr 4359 of (x, AST.E_Prim(f, _, [], _)) => if Basis.isStrandSet f
863 : jhr 4368 then (
864 :     Env.recordProp (env, Properties.GlobalReduce);
865 : jhr 4480 if not(Env.inGlobalBlock env)
866 : jhr 4368 then err (cxt, [
867 :     S "strand comprehension outside of global initialization or update"
868 :     ])
869 :     else if Env.inLoop env
870 :     then err (cxt, [
871 :     S "strand comprehension inside loop"
872 :     ])
873 :     else (AST.E_ParallelMap(e', x, f, resTy), resTy))
874 :     else (AST.E_Comprehension(e', iter', resTy), resTy)
875 : jhr 4317 | _ => (AST.E_Comprehension(e', iter', resTy), resTy)
876 :     (* end case *)
877 :     end
878 : jhr 3424 | chkComprehension _ = raise Fail "impossible"
879 :    
880 :     and checkIter (env, cxt, PT.I_Mark m) = checkIter (E.withEnvAndContext (env, cxt, m))
881 :     | checkIter (env, cxt, PT.I_Iterator({span, tree=x}, e)) = (
882 : jhr 4317 case checkAndPrune (env, cxt, e)
883 :     of (e', ty as Ty.T_Sequence(elemTy, _)) => let
884 :     val x' = Var.new(x, span, Var.LocalVar, elemTy)
885 :     in
886 :     ((x', e'), E.insertLocal(env, cxt, x, x'))
887 :     end
888 :     | (e', ty) => let
889 :     val x' = Var.new(x, span, Var.IterVar, Ty.T_Error)
890 :     in
891 :     if TU.isErrorType ty
892 :     then ()
893 :     else TypeError.error (cxt, [
894 :     S "expected sequence type in iteration, but found '", TY ty, S "'"
895 :     ]);
896 :     ((x', bogusExp), E.insertLocal(env, cxt, x, x'))
897 :     end
898 :     (* end case *))
899 : jhr 3424
900 : jhr 3396 (* typecheck a list of expressions returning a list of AST expressions and a list
901 :     * of the types of the expressions.
902 :     *)
903 :     and checkList (env, cxt, exprs) = let
904 :     fun chk (e, (es, tys)) = let
905 : jhr 3431 val (e, ty) = checkAndPrune (env, cxt, e)
906 : jhr 3396 in
907 :     (e::es, ty::tys)
908 :     end
909 :     in
910 :     List.foldr chk ([], []) exprs
911 :     end
912 :    
913 : jhr 3407 (* check a string that is specified as a constant expression *)
914 :     and chkStringConstExpr (env, cxt, PT.E_Mark m) =
915 : jhr 4317 chkStringConstExpr (E.withEnvAndContext (env, cxt, m))
916 : jhr 3431 | chkStringConstExpr (env, cxt, e) = (case checkAndPrune (env, cxt, e)
917 : jhr 4317 of (e', Ty.T_String) => (case CheckConst.eval (cxt, false, e')
918 :     of SOME(ConstExpr.String s) => SOME s
919 :     | SOME(ConstExpr.Expr e) => raise Fail "FIXME"
920 :     | NONE => NONE
921 :     | _ => raise Fail "impossible: wrong type for constant expr"
922 :     (* end case *))
923 :     | (_, Ty.T_Error) => NONE
924 :     | (_, ty) => (
925 :     TypeError.error (cxt, [
926 :     S "expected constant expression of type 'string', but found '",
927 :     TY ty, S "'"
928 :     ]);
929 :     NONE)
930 :     (* end case *))
931 : jhr 3407
932 :     (* check a dimension that is given by a constant expression *)
933 : jhr 3431 and checkDim (env, cxt, dim) = (case checkAndPrune (env, cxt, dim)
934 : jhr 4317 of (e', Ty.T_Int) => (case CheckConst.eval (cxt, false, e')
935 :     of SOME(ConstExpr.Int d) => SOME d
936 :     | SOME(ConstExpr.Expr e) => (
937 :     TypeError.error (cxt, [S "unable to evaluate constant dimension expression"]);
938 :     NONE)
939 :     | NONE => NONE
940 :     | _ => raise Fail "impossible: wrong type for constant expr"
941 :     (* end case *))
942 :     | (_, Ty.T_Error) => NONE
943 :     | (_, ty) => (
944 :     TypeError.error (cxt, [
945 :     S "expected constant expression of type 'int', but found '",
946 :     TY ty, S "'"
947 :     ]);
948 :     NONE)
949 :     (* end case *))
950 : jhr 3407
951 :     (* check a tensor shape, where the dimensions are given by constant expressions *)
952 :     and checkShape (env, cxt, shape) = let
953 :     fun checkDim' e = (case checkDim (env, cxt, e)
954 : jhr 4317 of SOME d => (
955 :     if (d <= 1)
956 :     then TypeError.error (cxt, [
957 :     S "invalid tensor-shape dimension; must be > 1, but found ",
958 :     S (IntLit.toString d)
959 :     ])
960 :     else ();
961 :     Ty.DimConst(IntInf.toInt d))
962 :     | NONE => Ty.DimConst ~1
963 :     (* end case *))
964 : jhr 3407 in
965 :     Ty.Shape(List.map checkDim' shape)
966 :     end
967 :    
968 : jhr 3396 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0