Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3396, Tue Nov 10 18:45:38 2015 UTC revision 3431, Sat Nov 14 14:03:58 2015 UTC
# Line 10  Line 10 
10    
11  structure CheckExpr : sig  structure CheckExpr : sig
12    
13      val check : Env.env * Env.context * ParseTree.expr -> (AST.expr * Types.ty)    (* type check an expression *)
14        val check : Env.t * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
15    
16      (* type check a list of expressions *)
17        val checkList : Env.t * Env.context * ParseTree.expr list -> (AST.expr list * Types.ty list)
18    
19      (* type check an iteration expression (i.e., "x 'in' expr"), returning the iterator
20       * and the environment extended with a binding for x.
21       *)
22        val checkIter : Env.t * Env.context * ParseTree.iterator -> ((AST.var * AST.expr) * Env.t)
23    
24      (* type check a dimension that is given by a constant expression *)
25        val checkDim : Env.t * Env.context * ParseTree.expr -> IntLit.t option
26    
27      (* type check a tensor shape, where the dimensions are given by constant expressions *)
28        val checkShape : Env.t * Env.context * ParseTree.expr list -> Types.shape
29    
30      (* `resolveOverload (cxt, rator, tys, args, candidates)` resolves the application of
31       * the overloaded operator `rator` to `args`, where `tys` are the types of the arguments
32       * and `candidates` is the list of candidate definitions.
33       *)
34        val resolveOverload : Env.context * Atom.atom * Types.ty list * AST.expr list * Var.t list
35              -> (AST.expr * Types.ty)
36    
37    end = struct    end = struct
38    
# Line 19  Line 41 
41      structure E = Env      structure E = Env
42      structure Ty = Types      structure Ty = Types
43      structure BV = BasisVars      structure BV = BasisVars
44        structure TU = TypeUtil
45    
46    (* an expression to return when there is a type error *)    (* an expression to return when there is a type error *)
47      val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)      val bogusExp = AST.E_Lit(L.Int 0)
48        val bogusExpTy = (bogusExp, Ty.T_Error)
49    
50      fun err arg = (TypeError.error arg; bogusExp)      fun err arg = (TypeError.error arg; bogusExpTy)
51      val warn = TypeError.warning      val warn = TypeError.warning
52    
53      datatype tokens = datatype TypeError.tokens      datatype token = datatype TypeError.token
54    
55      (* mark a variable use with its location *)
56        fun useVar (cxt : Env.context, x) = (x, #2 cxt)
57    
58      (* strip any marks that enclose an expression and return the span and the expression *)
59        fun stripMark (_, PT.E_Mark{span, tree}) = stripMark(span, tree)
60          | stripMark (span, e) = (span, e)
61    
62      (* resolve overloading: we use a simple scheme that selects the first operator in the
63       * list that matches the argument types.
64       *)
65        fun resolveOverload (_, rator, _, _, []) = raise Fail(concat[
66                "resolveOverload: \"", Atom.toString rator, "\" has no candidates"
67              ])
68          | resolveOverload (cxt, rator, argTys, args, candidates) = let
69    (* FIXME: we could be more efficient by just checking for coercion matchs the first pass
70     * and remembering those that are not pure EQ matches.
71     *)
72            (* try to match candidates while allowing type coercions *)
73              fun tryMatchCandidates [] = err(cxt, [
74                      S "unable to resolve overloaded operator ", A rator, S "\n",
75                      S "  argument type is: ", TYS argTys, S "\n"
76                    ])
77                | tryMatchCandidates (x::xs) = let
78                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
79                    in
80                      case Unify.tryMatchArgs (domTy, args, argTys)
81                       of SOME args' => (AST.E_Prim(x, tyArgs, args', rngTy), rngTy)
82                        | NONE => tryMatchCandidates xs
83                      (* end case *)
84                    end
85              fun tryCandidates [] = tryMatchCandidates candidates
86                | tryCandidates (x::xs) = let
87                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
88                    in
89                      if Unify.tryEqualTypes(domTy, argTys)
90                        then (AST.E_Prim(x, tyArgs, args, rngTy), rngTy)
91                        else tryCandidates xs
92                    end
93              in
94                tryCandidates candidates
95              end
96    
97    (* check the type of a literal *)    (* check the type of a literal *)
98      fun checkLit lit = (case lit      fun checkLit lit = (case lit
# Line 36  Line 102 
102              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
103            (* end case *))            (* end case *))
104    
105      (* type check a dot product, which has the constraint:
106       *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
107       * and similarly for fields.
108       *)
109        fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
110            (* check the shape of the two arguments to verify that the inner constraint matches *)
111              fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
112                    val (dd1, d1) = let
113                          fun splitLast (prefix, [d]) = (List.rev prefix, d)
114                            | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
115                            | splitLast (_, []) = raise Fail "impossible"
116                          in
117                            splitLast ([], dd1)
118                          end
119                    in
120                      if Unify.equalDim(d1, d2)
121                        then SOME(Ty.Shape(dd1@dd2))
122                        else NONE
123                    end
124                | chkShape _ = NONE
125              fun error () = err (cxt, [
126                      S "type error for arguments of binary operator '•'\n",
127                      S "  found: ", TYS[ty1, ty2], S "\n"
128                    ])
129              in
130                case (TU.prune ty1, TU.prune ty2)
131                (* tensor * tensor inner product *)
132                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
133                       of SOME shp => let
134                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
135                            val resTy = Ty.T_Tensor shp
136                            in
137                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
138                                then (AST.E_Prim(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
139                                else error()
140                            end
141                        | NONE => error()
142                      (* end case *))
143                (* tensor * field inner product *)
144                  | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
145                       of SOME shp => let
146                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
147                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
148                            in
149                              if Unify.equalTypes(domTy, [ty1, ty2])
150                              andalso Unify.equalType(rngTy, resTy)
151                                then (AST.E_Prim(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
152                                else error()
153                            end
154                        | NONE => error()
155                      (* end case *))
156                (* field * tensor inner product *)
157                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
158                       of SOME shp => let
159                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
160                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
161                            in
162                              if Unify.equalTypes(domTy, [ty1, ty2])
163                              andalso Unify.equalType(rngTy, resTy)
164                                then (AST.E_Prim(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
165                                else error()
166                            end
167                        | NONE => error()
168                      (* end case *))
169                (* field * field inner product *)
170                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
171                      case chkShape(s1, s2)
172                       of SOME shp => let
173                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
174                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
175                            in
176    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
177                              if Unify.equalDim(dim1, dim2)
178                              andalso Unify.equalTypes(domTy, [ty1, ty2])
179                              andalso Unify.equalType(rngTy, resTy)
180                                then (AST.E_Prim(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
181                                else error()
182                            end
183                        | NONE => error()
184                      (* end case *))
185                  | (ty1, ty2) => error()
186                (* end case *)
187              end
188    
189      (* type check a colon product, which has the constraint:
190       *     ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
191       * and similarly for fields.
192       *)
193        fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
194            (* check the shape of the two arguments to verify that the inner constraint matches *)
195              fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
196                    val (dd1, d11, d12) = let
197                          fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
198                            | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
199                            | splitLast2 (_, []) = raise Fail "impossible"
200                          in
201                            splitLast2 ([], dd1)
202                          end
203                    in
204                      if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
205                        then SOME(Ty.Shape(dd1@dd2))
206                        else NONE
207                    end
208                | chkShape _ = NONE
209              fun error () = err (cxt, [
210                      S "type error for arguments of binary operator \":\"\n",
211                      S "  found: ", TYS[ty1, ty2], S "\n"
212                    ])
213              in
214                case (TU.prune ty1, TU.prune ty2)
215                (* tensor * tensor colon product *)
216                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
217                       of SOME shp => let
218                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
219                            val resTy = Ty.T_Tensor shp
220                            in
221                              if Unify.equalTypes(domTy, [ty1, ty2])
222                              andalso Unify.equalType(rngTy, resTy)
223                                then (AST.E_Prim(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
224                                else error()
225                            end
226                        | NONE => error()
227                      (* end case *))
228                (* field * tensor colon product *)
229                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
230                       of SOME shp => let
231                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
232                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
233                            in
234                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
235                                then (AST.E_Prim(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
236                                else error()
237                            end
238                        | NONE => error()
239                      (* end case *))
240                (* tensor * field colon product *)
241                  | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
242                       of SOME shp => let
243                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
244                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
245                            in
246                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
247                                then (AST.E_Prim(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
248                                else error()
249                            end
250                        | NONE => error()
251                      (* end case *))
252                (* field * field colon product *)
253                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
254                      case chkShape(s1, s2)
255                       of SOME shp => let
256                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
257                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
258                            in
259    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
260                              if Unify.equalDim(dim1, dim2)
261                              andalso Unify.equalTypes(domTy, [ty1, ty2])
262                              andalso Unify.equalType(rngTy, resTy)
263                                then (AST.E_Prim(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
264                                else error()
265                            end
266                        | NONE => error()
267                      (* end case *))
268                  | (ty1, ty2) => error()
269                (* end case *)
270              end
271    
272    (* check the type of an expression *)    (* check the type of an expression *)
273      fun check (env, cxt, e) = (case e      fun check (env, cxt, e) = (case e
274             of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))             of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
275              | PT.E_Cond(e1, cond, e2) => let              | PT.E_Cond(e1, cond, e2) => let
276                  val eTy1 = check (env, cxt, e1)                  val eTy1 = check (env, cxt, e1)
277                  val eTy2 = check (env, cxt, e2)                  val eTy2 = check (env, cxt, e2)
278                  in                  in
279                    case checkExpr(env, cxt, cond)                    case checkAndPrune(env, cxt, cond)
280                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
281                           of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)                           of SOME(e1', e2', ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
282                            | NONE => err (cxt, [                            | NONE => err (cxt, [
283                                S "types do not match in conditional expression\n",                                S "types do not match in conditional expression\n",
284                                S "  true branch:  ", TY(#2 eTy1), S "\n",                                S "  true branch:  ", TY(#2 eTy1), S "\n",
285                                S "  false branch: ", TY(#2 eTy2)                                S "  false branch: ", TY(#2 eTy2)
286                              ])                              ])
287                            (* end case *))
288                        | (_, Ty.T_Error) => bogusExpTy
289                      | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])                      | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
290                    (* end case *)                    (* end case *)
291                  end                  end
292              | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))              | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
293                   of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let                   of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
294                        val resTy = Ty.T_DynSequence Ty.T_Int                        val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
295                        in                        in
296                          (AST.E_Apply(BV.range, [], [e1', e2'], resTy), resTy)                          (AST.E_Prim(BV.range, [], [e1', e2'], resTy), resTy)
297                        end                        end
298                    | ((_, Ty.T_Int), (_, ty2)) =>                    | ((_, Ty.T_Int), (_, ty2)) =>
299                        err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])                        err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
# Line 80  Line 315 
315                  val (e2', ty2) = check (env, cxt, e2)                  val (e2', ty2) = check (env, cxt, e2)
316                  in                  in
317                    if Atom.same(rator, BasisNames.op_dot)                    if Atom.same(rator, BasisNames.op_dot)
318                    (* we have to handle inner product as a special case, because our type                      then chkInnerProduct (cxt, e1', ty1, e2', ty2)
319                     * system cannot express the constraint that the type is                    else if Atom.same(rator, BasisNames.op_colon)
320                     *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]                      then chkColonProduct (cxt, e1', ty1, e2', ty2)
321                     *)                      else (case Env.findFunc (env, rator)
322                      then (case (TU.prune ty1, TU.prune ty2)                         of Env.PrimFun[rator] => let
323                         of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
                             val (dd1, d1) = let  
                                   fun splitLast (prefix, [d]) = (List.rev prefix, d)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
324                                    in                                    in
325                                      splitLast ([], dd1)                                case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
326                                   of SOME args => (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
327                                    | NONE => err (cxt, [
328                                          S "type error for binary operator ", V rator, S "\n",
329                                          S "  expected:  ", TYS domTy, S "\n",
330                                          S "  but found: ", TYS[ty1, ty2]
331                                        ])
332                                  (* end case *)
333                                    end                                    end
334                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)                          | Env.PrimFun ovldList =>
335                              val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))                              resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
336                            | _ => raise Fail "impossible"
337                          (* end case *))
338                    end
339                | PT.E_UnaryOp(rator, e) => let
340                    val eTy = check(env, cxt, e)
341                              in                              in
342                                if U.equalDim(d1, d2)                    case Env.findFunc (env, rator)
343                                andalso U.equalTypes(domTy, [ty1, ty2])                     of Env.PrimFun[rator] => let
344                                andalso U.equalType(rngTy, resTy)                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
345                                  then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)                          in
346                                  else err (cxt, [                            case Util.coerceType (domTy, eTy)
347                                      S "type error for arguments of binary operator '•'\n",                             of SOME e' => (AST.E_Prim(rator, tyArgs, [e'], rngTy), rngTy)
348                                      S "  found: ", TYS[ty1, ty2], S "\n"                              | NONE => err (cxt, [
349                                      S "type error for unary operator ", V rator, S "\n",
350                                      S "  expected:  ", TY domTy, S "\n",
351                                      S "  but found: ", TY (#2 eTy)
352                                    ])                                    ])
353                              (* end case *)
354                              end                              end
355                         | (ty1, ty2) => err (cxt, [                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
356                                S "type error for arguments of binary operator '•'\n",                      | _ => raise Fail "impossible"
357                                S "  found: ", TYS[ty1, ty2], S "\n"                    (* end case *)
358                    end
359                | PT.E_Apply(e, args) => let
360                    val (args, tys) = checkList (env, cxt, args)
361                    fun appTyError (f, paramTys, argTys) = err(cxt, [
362                            S "type error in application of ", V f, S "\n",
363                            S "  expected:  ", TYS paramTys, S "\n",
364                            S "  but found: ", TYS argTys
365                          ])
366                    fun checkPrimApp f = if Var.isPrim f
367                          then (case TU.instantiate(Var.typeOf f)
368                             of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
369                                  case Unify.matchArgs (domTy, args, tys)
370                                   of SOME args => (AST.E_Prim(f, tyArgs, args, rngTy), rngTy)
371                                    | NONE => appTyError (f, domTy, tys)
372                                  (* end case *))
373                              | _ => err(cxt, [S "application of non-function/field ", V f])
374                            (* end case *))
375                          else raise Fail "unexpected user function"
376                  (* check the application of a user-defined function *)
377                    fun checkFunApp (cxt, f) = if Var.isPrim f
378                          then raise Fail "unexpected primitive function"
379                          else (case Var.monoTypeOf f
380                             of Ty.T_Fun(domTy, rngTy) => (
381                                  case Unify.matchArgs (domTy, args, tys)
382                                   of SOME args => (AST.E_Apply(useVar(cxt, f), args, rngTy), rngTy)
383                                    | NONE => appTyError (f, domTy, tys)
384                                  (* end case *))
385                              | _ => err(cxt, [S "application of non-function/field ", V f])
386                            (* end case *))
387                    fun checkFieldApp (e1', ty1) = (case (args, tys)
388                           of ([e2'], [ty2]) => let
389                                val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
390                                      TU.instantiate(Var.typeOf BV.op_probe)
391                                fun tyError () = err (cxt, [
392                                        S "type error for field application\n",
393                                        S "  expected:  ", TYS[fldTy, domTy], S "\n",
394                                        S "  but found: ", TYS[ty1, ty2]
395                              ])                              ])
                       (* end case *))  
                   else if Atom.same(rator, BasisNames.op_colon)  
                     then (case (TU.prune ty1, TU.prune ty2)  
                        of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let  
                             val (dd1, d11, d12) = let  
                                   fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
396                                    in                                    in
397                                      splitLast ([], dd1)                                if Unify.equalType(fldTy, ty1)
398                                    then (case Util.coerceType(domTy, (e2', ty2))
399                                       of SOME e2' => (AST.E_Prim(BV.op_probe, tyArgs, [e1', e2'], rngTy), rngTy)
400                                        | NONE => tyError()
401                                      (* end case *))
402                                    else tyError()
403                                end
404                            | _ => err(cxt, [S "badly formed field application"])
405                          (* end case *))
406                    in
407                      case stripMark(#2 cxt, e)
408                       of (span, PT.E_Var f) => (case Env.findVar (env, f)
409                             of SOME f' => checkFieldApp (
410                                  AST.E_Var(useVar((#1 cxt, span), f')),
411                                  Var.monoTypeOf f')
412                              | NONE => (case Env.findFunc (env, f)
413                                   of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
414                                    | Env.PrimFun[f'] => checkPrimApp f'
415                                    | Env.PrimFun ovldList =>
416                                        resolveOverload ((#1 cxt, span), f, tys, args, ovldList)
417                                    | Env.UserFun f' => checkFunApp((#1 cxt, span), f')
418                                  (* end case *))
419                              (* end case *))
420                        | _ => checkFieldApp (check (env, cxt, e))
421                      (* end case *)
422                                    end                                    end
423                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)              | PT.E_Subscript(e, indices) => let
424                              val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))                  fun expectedTensor ty = err(cxt, [
425                            S "expected tensor type for slicing, but found ", TY ty
426                          ])
427                    fun chkIndex e = let
428                          val eTy as (_, ty) = check(env, cxt, e)
429                              in                              in
430                                if U.equalDim(d11, d21) andalso U.equalDim(d12, d22)                          if Unify.equalType(ty, Ty.T_Int)
431                                andalso U.equalTypes(domTy, [ty1, ty2])                            then eTy
                               andalso U.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)  
432                                  else err (cxt, [                                  else err (cxt, [
433                                      S "type error for arguments of binary operator ':'\n",                                S "expected type 'int' for index, but found ", TY ty
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
434                                    ])                                    ])
435                              end                              end
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator ':'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
                     else (case Env.findFunc (#env env, rator)  
                        of Env.PrimFun[rator] => let  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)  
436                              in                              in
437                                case U.matchArgs(domTy, [e1', e2'], [ty1, ty2])                    case (check(env, cxt, e), indices)
438                                 of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)                     of ((e', Ty.T_Error), _) => (
439                                  | NONE => err (cxt, [                          List.app (ignore o Option.map chkIndex) indices;
440                                        S "type error for binary operator '", V rator, S "'\n",                          bogusExpTy)
441                                        S "  expected:  ", TYS domTy, S "\n",                      | ((e1', ty1 as Ty.T_Sequence(elemTy, optDim)), [SOME e2]) => let
442                                        S "  but found: ", TYS[ty1, ty2]                          val (e2', ty2) = chkIndex e2
443                            val rator = if isSome optDim
444                                  then BV.subscript
445                                  else BV.dynSubscript
446                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
447                            in
448                              if Unify.equalTypes(domTy, [ty1, ty2])
449                                then let
450                                  val exp = AST.E_Prim(rator, tyArgs, [e1', e2'], rngTy)
451                                  in
452                                    (exp, rngTy)
453                                  end
454                                else raise Fail "unexpected unification failure"
455                            end
456                        | ((e', ty as Ty.T_Sequence _), [NONE]) => expectedTensor ty
457                        | ((e', ty as Ty.T_Sequence _), _) => expectedTensor ty
458                        | ((e', ty as Ty.T_Tensor shape), _) => let
459                            val indices' = List.map (Option.map (#1 o chkIndex)) indices
460                            val order = List.length indices'
461                            val expectedTy = TU.mkTensorTy order
462                            val resultTy = TU.slice(expectedTy, List.map Option.isSome indices')
463                            in
464                              if Unify.equalType(ty, expectedTy)
465                                then (AST.E_Slice(e', indices', resultTy), resultTy)
466                                else err (cxt, [
467                                    S "type error in slice operation\n",
468                                    S "  expected:  ", S(Int.toString order), S "-order tensor\n",
469                                    S "  but found: ", TY ty
470                                      ])                                      ])
471                            end
472                        | ((_, ty), _) => expectedTensor ty
473                                (* end case *)                                (* end case *)
474                              end                              end
475                          | Env.PrimFun ovldList =>              | PT.E_Select(e, field) => (case stripMark(#2 cxt, e)
476                              resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)                   of (_, PT.E_Var x) => (case E.findStrand (env, x)
477                          | _ => raise Fail "impossible"                         of SOME _ => if E.inGlobalUpdate env
478                        (* end case *))                              then (case E.findSetFn (env, field)
479                                   of SOME setFn => let
480                                        val (mvs, ty) = TU.instantiate (Var.typeOf setFn)
481                                        val resTy = Ty.T_Sequence(Ty.T_Named x, NONE)
482                                        in
483                                          E.recordProp (env, Properties.StrandSets);
484                                          if Unify.equalType(ty, Ty.T_Fun([], resTy))
485                                            then (AST.E_Prim(setFn, mvs, [], resTy), resTy)
486                                            else raise Fail "impossible"
487                  end                  end
488              | PT.E_UnaryOp of var * expr                        (* <op> e *)                                  | _ => err (cxt, [
489              | PT.E_Apply of expr * expr list             (* field/function/reduction application *)                                        S "unknown strand-set specifier ", A field
490              | PT.E_Subscript of expr * expr option list  (* sequence/tensor indexing; NONE for ':' *)                                      ])
491              | PT.E_Select of expr * field               (* e '.' <field> *)                                (* end case *))
492              | PT.E_Real e => (case check (env, cxt, e)                              else err (cxt, [
493                                    S "illegal strand set specification in ",
494                                    S(E.scopeToString(E.currentScope env))
495                                  ])
496                            | _ => checkSelect (env, cxt, e, field)
497                          (* end case *))
498                      | _ => checkSelect (env, cxt, e, field)
499                    (* end case *))
500                | PT.E_Real e => (case checkAndPrune (env, cxt, e)
501                   of (e', Ty.T_Int) =>                   of (e', Ty.T_Int) =>
502                        (AST.E_Apply(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)                        (AST.E_Prim(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
503                      | (e', Ty.T_Error) => bogusExpTy
504                    | (_, ty) => err(cxt, [                    | (_, ty) => err(cxt, [
505                          S "argument of 'real' must have type 'int', but found ",                          S "argument of 'real' must have type 'int', but found ",
506                          TY ty                          TY ty
507                        ])                        ])
508                  (* end case *))                  (* end case *))
509              | PT.E_Load nrrd => let              | PT.E_Load nrrd => let
510                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load))
511                  in                  in
512                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    case chkStringConstExpr (env, cxt, nrrd)
513                       of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
514                        | NONE => (bogusExp, rngTy)
515                      (* end case *)
516                  end                  end
517              | PT.E_Image nrrd => let              | PT.E_Image nrrd => let
518                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_image))
519                  in                  in
520                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    case chkStringConstExpr (env, cxt, nrrd)
521                       of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
522                        | NONE => (bogusExp, rngTy)
523                      (* end case *)
524                  end                  end
525              | PT.E_Var (case E.findVar (#env env, x)              | PT.E_Var x => (case E.findVar (env, x)
526                   of SOME x' => (                   of SOME x' => (AST.E_Var(useVar(cxt, x')), Var.monoTypeOf x')
                       markUsed (x', true);  
                       (AST.E_Var x', Var.monoTypeOf x'))  
527                    | NONE => err(cxt, [S "undeclared variable ", A x])                    | NONE => err(cxt, [S "undeclared variable ", A x])
528                  (* end case *))                  (* end case *))
529              | PT.E_Kernel of var * dim                  (* kernel '#' dim *)              | PT.E_Kernel(kern, dim) => (case E.findVar (env, kern)
530                     of SOME kern' => (case Var.monoTypeOf kern'
531                           of ty as Ty.T_Kernel(Ty.DiffConst k) => let
532                                val k' = Int.fromLarge dim handle Overflow => 1073741823
533                                val e = AST.E_Var(useVar(cxt, kern'))
534                                in
535                                  if (k = k')
536                                    then (e, ty)
537                                    else let
538                                      val ty' = Ty.T_Kernel(Ty.DiffConst k')
539                                      in
540                                        (AST.E_Coerce{srcTy = ty, dstTy = ty', e = e}, ty')
541                                      end
542                                end
543                            | _ => err(cxt, [S "expected kernel, but found ", S(Var.kindToString kern')])
544                          (* end case *))
545                      | NONE => err(cxt, [S "unknown kernel ", A kern])
546                    (* end case *))
547              | PT.E_Lit lit => checkLit lit              | PT.E_Lit lit => checkLit lit
548              | PT.E_Id d => let              | PT.E_Id d => let
549                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
550                        Util.instantiate(Var.typeOf(BV.identity))                        TU.instantiate(Var.typeOf(BV.identity))
551                  in                  in
552                    if U.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, [d, d])), rngTy)
553                      then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.identity, tyArgs, [], rngTy), rngTy)
554                      else raise Fail "impossible"                      else raise Fail "impossible"
555                  end                  end
556              | PT.E_Zero dd => let              | PT.E_Zero dd => let
557                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
558                        Util.instantiate(Var.typeOf(BV.zero))                        TU.instantiate(Var.typeOf(BV.zero))
559                  in                  in
560                    if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
561                      then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.zero, tyArgs, [], rngTy), rngTy)
562                      else raise Fail "impossible"                      else raise Fail "impossible"
563                  end                  end
564              | PT.E_NaN dd => let              | PT.E_NaN dd => let
565                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
566                        Util.instantiate(Var.typeOf(BV.nan))                        TU.instantiate(Var.typeOf(BV.nan))
567                  in                  in
568                    if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
569                      then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.nan, tyArgs, [], rngTy), rngTy)
570                      else raise Fail "impossible"                      else raise Fail "impossible"
571                  end                  end
572              | PT.E_Sequence of expr list                 (* sequence construction *)              | PT.E_Sequence exps => (case checkList (env, cxt, exps)
573              | PT.E_SeqComp of comprehension             (* sequence comprehension *)  (* FIXME: need kind for concrete types here! *)
574                     of ([], _) => let
575                          val ty = Ty.T_Sequence(Ty.T_Var(MetaVar.newTyVar()), SOME(Ty.DimConst 0))
576                          in
577                            (AST.E_Seq([], ty), ty)
578                          end
579                      | (args, tys) => (case Util.coerceTypes(List.map TU.pruneHead tys)
580                           of SOME ty => if TU.isValueType ty
581                                then let
582                                  fun doExp eTy = valOf(Util.coerceType (ty, eTy))
583                                  val resTy = Ty.T_Sequence(ty, SOME(Ty.DimConst(List.length args)))
584                                  val args = ListPair.map doExp (args, tys)
585                                  in
586                                    (AST.E_Seq(args, resTy), resTy)
587                                  end
588                                else err(cxt, [S "sequence expression of non-value argument type"])
589                            | NONE => err(cxt, [S "arguments of sequence expression must have same type"])
590                          (* end case *))
591                    (* end case *))
592                | PT.E_SeqComp comp => chkComprehension (env, cxt, comp)
593              | PT.E_Cons args => let              | PT.E_Cons args => let
594                (* Note that we are guaranteed that args is non-empty *)                (* Note that we are guaranteed that args is non-empty *)
595                  val (args, tys) = checkList (env, cxt, args)                  val (args, tys) = checkList (env, cxt, args)
# Line 219  Line 598 
598                         of NONE => Ty.T_Error                         of NONE => Ty.T_Error
599                          | SOME ty => ty                          | SOME ty => ty
600                        (* end case *))                        (* end case *))
601                  in                (* process the arguments checking that they all have the expected type *)
602                    case realType(TU.pruneHead ty)                  fun chkArgs (ty, shape) = let
                    of ty as Ty.T_Tensor shape => let  
603                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
604                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
605                          fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)                        fun chkArgs (arg::args, argTy::tys, args') = (
606                                case Util.coerceType(ty, (arg, argTy))
607                                 of SOME arg' => chkArgs (args, tys, arg'::args')                                 of SOME arg' => chkArgs (args, tys, arg'::args')
608                                  | NONE => (                                  | NONE => (
609                                      TypeError.error(cxt, [                                      TypeError.error(cxt, [
610                                          S "arguments of tensor construction must have same type"                                          S "arguments of tensor construction must have same type"
611                                        ]);                                        ]);
612                                      ??)                                    chkArgs (args, tys, bogusExp::args'))
613                                (* end case *))                                (* end case *))
614                            | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)                          | chkArgs (_, _, args') = (AST.E_Tensor(List.rev args', resTy), resTy)
615                          in                          in
616                            chkArgs (args, tys, [])                            chkArgs (args, tys, [])
617                          end                          end
618                    in
619                      case TU.pruneHead ty
620                       of Ty.T_Int => chkArgs(Ty.realTy, Ty.Shape[]) (* coerce integers to reals *)
621                        | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
622                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])
623                    (* end case *)                    (* end case *)
624                  end                  end
625              | PT.E_Deprecate(msg, e) => (              | PT.E_Deprecate(msg, e) => (
626                  warn (cxt, [S msg]);                  warn (cxt, [S msg]);
627                  chk (env, cxt, e))                  check (env, cxt, e))
628            (* end case *))            (* end case *))
629    
630      (* typecheck and the prune the result *)
631        and checkAndPrune (env, cxt, e) = let
632              val (e, ty) = check (env, cxt, e)
633              in
634                (e, TU.prune ty)
635              end
636    
637    (* check a conditional operator (e.g., || or &&) *)    (* check a conditional operator (e.g., || or &&) *)
638      and checkCondOp (env, cxt, e1, rator, e2, mk) = (      and checkCondOp (env, cxt, e1, rator, e2, mk) = (
639            case (check(env, cxt, e1), check(env, cxt, e2))            case (check(env, cxt, e1), check(env, cxt, e2))
640             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
641              | ((_, Ty.T_Bool), (_, ty2)) =>              | ((_, Ty.T_Bool), (_, ty2)) =>
642                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
643              | ((_, ty1), (_, Ty.T_Bool)) =>              | ((_, ty1), (_, Ty.T_Bool)) =>
644                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
645              | ((_, ty1), (_, ty2)) => err (cxt, [              | ((_, ty1), (_, ty2)) => err (cxt, [
646                    S "arguments of '", S rator, "' must have type 'bool', but found ",                    S "arguments of '", S rator, S "' must have type 'bool', but found ",
647                    TY ty1, S " and ", TY ty2                    TY ty1, S " and ", TY ty2
648                  ])                  ])
649            (* end case *))            (* end case *))
650    
651      (* check a field select that is _not_ a strand-set *)
652        and checkSelect (env, cxt, e, field) = (case checkAndPrune (env, cxt, e)
653               of (e', Ty.T_Named strand) => (case Env.findStrand(env, strand)
654                     of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
655                           of SOME x' => let
656                                val ty = Var.monoTypeOf x'
657                                in
658                                  (AST.E_Select(e', useVar(cxt, x')), ty)
659                                end
660                            | NONE => err(cxt, [
661                                  S "strand ", A strand,
662                                  S " does not have state variable ", A field
663                                ])
664                          (* end case *))
665                      | NONE => err(cxt, [S "unknown strand ", A strand])
666                    (* end case *))
667                | (_, Ty.T_Error) => bogusExpTy
668                | (_, ty) => err (cxt, [
669                      S "expected strand type, but found ", TY ty,
670                      S " in selection of ", A field
671                    ])
672              (* end case *))
673    
674        and chkComprehension (env, cxt, PT.COMP_Mark m) =
675              chkComprehension(E.withEnvAndContext(env, cxt, m))
676          | chkComprehension (env, cxt, PT.COMP_Comprehension(e, [iter])) = let
677              val (iter', env') = checkIter (E.blockScope env, cxt, iter)
678              val (e', ty) = check (env', cxt, e)
679              val resTy = Ty.T_Sequence(ty, NONE)
680              in
681                (AST.E_Comprehension(e', iter', resTy), resTy)
682              end
683          | chkComprehension _ = raise Fail "impossible"
684    
685        and checkIter (env, cxt, PT.I_Mark m) = checkIter (E.withEnvAndContext (env, cxt, m))
686          | checkIter (env, cxt, PT.I_Iterator({span, tree=x}, e)) = (
687              case checkAndPrune (env, cxt, e)
688               of (e', ty as Ty.T_Sequence(elemTy, _)) => let
689                    val x' = Var.new(x, Error.location(#1 cxt, span), Var.LocalVar, elemTy)
690                    in
691                      ((x', e'), E.insertLocal(env, cxt, x, x'))
692                    end
693                | (e', ty) => let
694                    val x' = Var.new(x, Error.UNKNOWN, Var.IterVar, Ty.T_Error)
695                    in
696                      if TU.isErrorType ty
697                        then ()
698                        else TypeError.error (cxt, [
699                            S "expected sequence type in iteration, but found '", TY ty, S "'"
700                          ]);
701                      ((x', bogusExp), E.insertLocal(env, cxt, x, x'))
702                    end
703              (* end case *))
704    
705    (* typecheck a list of expressions returning a list of AST expressions and a list    (* typecheck a list of expressions returning a list of AST expressions and a list
706     * of the types of the expressions.     * of the types of the expressions.
707     *)     *)
708      and checkList (env, cxt, exprs) = let      and checkList (env, cxt, exprs) = let
709            fun chk (e, (es, tys)) = let            fun chk (e, (es, tys)) = let
710                  val (e, ty) = checkExpr (env, cxt, e)                  val (e, ty) = checkAndPrune (env, cxt, e)
711                  in                  in
712                    (e::es, ty::tys)                    (e::es, ty::tys)
713                  end                  end
# Line 271  Line 715 
715              List.foldr chk ([], []) exprs              List.foldr chk ([], []) exprs
716            end            end
717    
718      (* check a string that is specified as a constant expression *)
719        and chkStringConstExpr (env, cxt, PT.E_Mark m) =
720              chkStringConstExpr (E.withEnvAndContext (env, cxt, m))
721          | chkStringConstExpr (env, cxt, e) = (case checkAndPrune (env, cxt, e)
722               of (e', Ty.T_String) => (case ConstExpr.eval (cxt, e')
723                     of SOME(ConstExpr.String s) => SOME s
724                      | SOME(ConstExpr.Expr e) => raise Fail "FIXME"
725                      | NONE => NONE
726                      | _ => raise Fail "impossible: wrong type for constant expr"
727                    (* end case *))
728                | (_, Ty.T_Error) => NONE
729                | (_, ty) => (
730                    TypeError.error (cxt, [
731                        S "expected constant expression of type 'string', but found '",
732                        TY ty, S "'"
733                      ]);
734                    NONE)
735              (* end case *))
736    
737      (* check a dimension that is given by a constant expression *)
738        and checkDim (env, cxt, dim) = (case checkAndPrune (env, cxt, dim)
739               of (e', Ty.T_Int) => (case ConstExpr.eval (cxt, e')
740                     of SOME(ConstExpr.Int d) => SOME d
741                      | SOME(ConstExpr.Expr e) => (
742                          TypeError.error (cxt, [S "unable to evaluate constant dimension expression"]);
743                          NONE)
744                      | NONE => NONE
745                      | _ => raise Fail "impossible: wrong type for constant expr"
746                    (* end case *))
747                | (_, Ty.T_Error) => NONE
748                | (_, ty) => (
749                    TypeError.error (cxt, [
750                        S "expected constant expression of type 'int', but found '",
751                        TY ty, S "'"
752                      ]);
753                    NONE)
754              (* end case *))
755    
756      (* check a tensor shape, where the dimensions are given by constant expressions *)
757        and checkShape (env, cxt, shape) = let
758              fun checkDim' e = (case checkDim (env, cxt, e)
759                     of SOME d => (
760                          if (d <= 1)
761                            then TypeError.error (cxt, [
762                                S "invalid tensor-shape dimension; must be > 1, but found ",
763                                S (IntLit.toString d)
764                              ])
765                            else ();
766                          Ty.DimConst(IntInf.toInt d))
767                      | NONE => Ty.DimConst ~1
768                    (* end case *))
769              in
770                Ty.Shape(List.map checkDim' shape)
771              end
772    
773    end    end

Legend:
Removed from v.3396  
changed lines
  Added in v.3431

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0