Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3398, Wed Nov 11 01:17:58 2015 UTC revision 3418, Fri Nov 13 00:00:26 2015 UTC
# Line 10  Line 10 
10    
11  structure CheckExpr : sig  structure CheckExpr : sig
12    
13      val check : Env.env * Env.context * ParseTree.expr -> (AST.expr * Types.ty)    (* type check an expression *)
14        val check : Env.t * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
15    
16      (* type check a list of expressions *)
17        val checkList : Env.t * Env.context * ParseTree.expr list -> (AST.expr list * Types.ty list)
18    
19      (* type check a dimension that is given by a constant expression *)
20        val checkDim : Env.t * Env.context * ParseTree.expr -> IntLit.t option
21    
22      (* type check a tensor shape, where the dimensions are given by constant expressions *)
23        val checkShape : Env.t * Env.context * ParseTree.expr list -> Types.shape
24    
25      (* `resolveOverload (cxt, rator, tys, args, candidates)` resolves the application of
26       * the overloaded operator `rator` to `args`, where `tys` are the types of the arguments
27       * and `candidates` is the list of candidate definitions.
28       *)
29        val resolveOverload : Env.context * Atom.atom * Types.ty list * AST.expr list * Var.t list
30              -> (AST.expr * Types.ty)
31    
32    end = struct    end = struct
33    
# Line 19  Line 36 
36      structure E = Env      structure E = Env
37      structure Ty = Types      structure Ty = Types
38      structure BV = BasisVars      structure BV = BasisVars
39        structure TU = TypeUtil
40    
41    (* an expression to return when there is a type error *)    (* an expression to return when there is a type error *)
42      val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)      val bogusExp = AST.E_Lit(L.Int 0)
43        val bogusExpTy = (bogusExp, Ty.T_Error)
44    
45      fun err arg = (TypeError.error arg; bogusExp)      fun err arg = (TypeError.error arg; bogusExpTy)
46      val warn = TypeError.warning      val warn = TypeError.warning
47    
48      datatype tokens = datatype TypeError.tokens      datatype token = datatype TypeError.token
49    
50      (* mark a variable use with its location *)
51        fun useVar (cxt : Env.context, x) = (x, #2 cxt)
52    
53      (* resolve overloading: we use a simple scheme that selects the first operator in the
54       * list that matches the argument types.
55       *)
56        fun resolveOverload (_, rator, _, _, []) = raise Fail(concat[
57                "resolveOverload: \"", Atom.toString rator, "\" has no candidates"
58              ])
59          | resolveOverload (cxt, rator, argTys, args, candidates) = let
60    (* FIXME: we could be more efficient by just checking for coercion matchs the first pass
61     * and remembering those that are not pure EQ matches.
62     *)
63            (* try to match candidates while allowing type coercions *)
64              fun tryMatchCandidates [] = err(cxt, [
65                      S "unable to resolve overloaded operator ", A rator, S "\n",
66                      S "  argument type is: ", TYS argTys, S "\n"
67                    ])
68                | tryMatchCandidates (x::xs) = let
69                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
70                    in
71                      case Unify.tryMatchArgs (domTy, args, argTys)
72                       of SOME args' => (AST.E_Prim(x, tyArgs, args', rngTy), rngTy)
73                        | NONE => tryMatchCandidates xs
74                      (* end case *)
75                    end
76              fun tryCandidates [] = tryMatchCandidates candidates
77                | tryCandidates (x::xs) = let
78                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
79                    in
80                      if Unify.tryEqualTypes(domTy, argTys)
81                        then (AST.E_Prim(x, tyArgs, args, rngTy), rngTy)
82                        else tryCandidates xs
83                    end
84              in
85                tryCandidates candidates
86              end
87    
88    (* check the type of a literal *)    (* check the type of a literal *)
89      fun checkLit lit = (case lit      fun checkLit lit = (case lit
# Line 36  Line 93 
93              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
94            (* end case *))            (* end case *))
95    
96      (* type check a dot product, which has the constraint:
97       *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
98       * and similarly for fields.
99       *)
100        fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
101            (* check the shape of the two arguments to verify that the inner constraint matches *)
102              fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
103                    val (dd1, d1) = let
104                          fun splitLast (prefix, [d]) = (List.rev prefix, d)
105                            | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
106                            | splitLast (_, []) = raise Fail "impossible"
107                          in
108                            splitLast ([], dd1)
109                          end
110                    in
111                      if Unify.equalDim(d1, d2)
112                        then SOME(Ty.Shape(dd1@dd2))
113                        else NONE
114                    end
115                | chkShape _ = NONE
116              fun error () = err (cxt, [
117                      S "type error for arguments of binary operator '•'\n",
118                      S "  found: ", TYS[ty1, ty2], S "\n"
119                    ])
120              in
121                case (TU.prune ty1, TU.prune ty2)
122                (* tensor * tensor inner product *)
123                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
124                       of SOME shp => let
125                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
126                            val resTy = Ty.T_Tensor shp
127                            in
128                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
129                                then (AST.E_Prim(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
130                                else error()
131                            end
132                        | NONE => error()
133                      (* end case *))
134                (* tensor * field inner product *)
135                  | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
136                       of SOME shp => let
137                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
138                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
139                            in
140                              if Unify.equalTypes(domTy, [ty1, ty2])
141                              andalso Unify.equalType(rngTy, resTy)
142                                then (AST.E_Prim(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
143                                else error()
144                            end
145                        | NONE => error()
146                      (* end case *))
147                (* field * tensor inner product *)
148                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
149                       of SOME shp => let
150                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
151                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
152                            in
153                              if Unify.equalTypes(domTy, [ty1, ty2])
154                              andalso Unify.equalType(rngTy, resTy)
155                                then (AST.E_Prim(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
156                                else error()
157                            end
158                        | NONE => error()
159                      (* end case *))
160                (* field * field inner product *)
161                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
162                      case chkShape(s1, s2)
163                       of SOME shp => let
164                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
165                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
166                            in
167    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
168                              if Unify.equalDim(dim1, dim2)
169                              andalso Unify.equalTypes(domTy, [ty1, ty2])
170                              andalso Unify.equalType(rngTy, resTy)
171                                then (AST.E_Prim(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
172                                else error()
173                            end
174                        | NONE => error()
175                      (* end case *))
176                  | (ty1, ty2) => error()
177                (* end case *)
178              end
179    
180      (* type check a colon product, which has the constraint:
181       *     ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
182       * and similarly for fields.
183       *)
184        fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
185            (* check the shape of the two arguments to verify that the inner constraint matches *)
186              fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
187                    val (dd1, d11, d12) = let
188                          fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
189                            | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
190                            | splitLast2 (_, []) = raise Fail "impossible"
191                          in
192                            splitLast2 ([], dd1)
193                          end
194                    in
195                      if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
196                        then SOME(Ty.Shape(dd1@dd2))
197                        else NONE
198                    end
199                | chkShape _ = NONE
200              fun error () = err (cxt, [
201                      S "type error for arguments of binary operator \":\"\n",
202                      S "  found: ", TYS[ty1, ty2], S "\n"
203                    ])
204              in
205                case (TU.prune ty1, TU.prune ty2)
206                (* tensor * tensor colon product *)
207                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
208                       of SOME shp => let
209                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
210                            val resTy = Ty.T_Tensor shp
211                            in
212                              if Unify.equalTypes(domTy, [ty1, ty2])
213                              andalso Unify.equalType(rngTy, resTy)
214                                then (AST.E_Prim(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
215                                else error()
216                            end
217                        | NONE => error()
218                      (* end case *))
219                (* field * tensor colon product *)
220                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
221                       of SOME shp => let
222                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
223                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
224                            in
225                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
226                                then (AST.E_Prim(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
227                                else error()
228                            end
229                        | NONE => error()
230                      (* end case *))
231                (* tensor * field colon product *)
232                  | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
233                       of SOME shp => let
234                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
235                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
236                            in
237                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
238                                then (AST.E_Prim(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
239                                else error()
240                            end
241                        | NONE => error()
242                      (* end case *))
243                (* field * field colon product *)
244                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
245                      case chkShape(s1, s2)
246                       of SOME shp => let
247                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
248                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
249                            in
250    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
251                              if Unify.equalDim(dim1, dim2)
252                              andalso Unify.equalTypes(domTy, [ty1, ty2])
253                              andalso Unify.equalType(rngTy, resTy)
254                                then (AST.E_Prim(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
255                                else error()
256                            end
257                        | NONE => error()
258                      (* end case *))
259                  | (ty1, ty2) => error()
260                (* end case *)
261              end
262    
263    (* check the type of an expression *)    (* check the type of an expression *)
264      fun check (env, cxt, e) = (case e      fun check (env, cxt, e) = (case e
265             of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))             of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
266              | PT.E_Cond(e1, cond, e2) => let              | PT.E_Cond(e1, cond, e2) => let
267                  val eTy1 = check (env, cxt, e1)                  val eTy1 = check (env, cxt, e1)
268                  val eTy2 = check (env, cxt, e2)                  val eTy2 = check (env, cxt, e2)
269                  in                  in
270                    case checkExpr(env, cxt, cond)                    case check(env, cxt, cond)
271                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
272                           of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)                           of SOME(e1', e2', ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
273                            | NONE => err (cxt, [                            | NONE => err (cxt, [
274                                S "types do not match in conditional expression\n",                                S "types do not match in conditional expression\n",
275                                S "  true branch:  ", TY(#2 eTy1), S "\n",                                S "  true branch:  ", TY(#2 eTy1), S "\n",
# Line 59  Line 283 
283                   of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let                   of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
284                        val resTy = Ty.T_Sequence(Ty.T_Int, NONE)                        val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
285                        in                        in
286                          (AST.E_Apply(BV.range, [], [e1', e2'], resTy), resTy)                          (AST.E_Prim(BV.range, [], [e1', e2'], resTy), resTy)
287                        end                        end
288                    | ((_, Ty.T_Int), (_, ty2)) =>                    | ((_, Ty.T_Int), (_, ty2)) =>
289                        err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])                        err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
# Line 81  Line 305 
305                  val (e2', ty2) = check (env, cxt, e2)                  val (e2', ty2) = check (env, cxt, e2)
306                  in                  in
307                    if Atom.same(rator, BasisNames.op_dot)                    if Atom.same(rator, BasisNames.op_dot)
308                    (* we have to handle inner product as a special case, because our type                      then chkInnerProduct (cxt, e1', ty1, e2', ty2)
                    * system cannot express the constraint that the type is  
                    *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]  
                    *)  
                     then (case (TU.prune ty1, TU.prune ty2)  
                        of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let  
                             val (dd1, d1) = let  
                                   fun splitLast (prefix, [d]) = (List.rev prefix, d)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
                                   in  
                                     splitLast ([], dd1)  
                                   end  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)  
                             val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))  
                             in  
                               if U.equalDim(d1, d2)  
                               andalso U.equalTypes(domTy, [ty1, ty2])  
                               andalso U.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)  
                                 else err (cxt, [  
                                     S "type error for arguments of binary operator '•'\n",  
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
                                   ])  
                             end  
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator '•'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
309                    else if Atom.same(rator, BasisNames.op_colon)                    else if Atom.same(rator, BasisNames.op_colon)
310                      then (case (TU.prune ty1, TU.prune ty2)                      then chkColonProduct (cxt, e1', ty1, e2', ty2)
311                         of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let                      else (case Env.findFunc (env, rator)
                             val (dd1, d11, d12) = let  
                                   fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
                                   in  
                                     splitLast ([], dd1)  
                                   end  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)  
                             val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))  
                             in  
                               if U.equalDim(d11, d21) andalso U.equalDim(d12, d22)  
                               andalso U.equalTypes(domTy, [ty1, ty2])  
                               andalso U.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)  
                                 else err (cxt, [  
                                     S "type error for arguments of binary operator ':'\n",  
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
                                   ])  
                             end  
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator ':'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
                     else (case Env.findFunc (#env env, rator)  
312                         of Env.PrimFun[rator] => let                         of Env.PrimFun[rator] => let
313                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
314                              in                              in
315                                case U.matchArgs(domTy, [e1', e2'], [ty1, ty2])                                case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
316                                 of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)                                 of SOME args => (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
317                                  | NONE => err (cxt, [                                  | NONE => err (cxt, [
318                                        S "type error for binary operator '", V rator, S "'\n",                                        S "type error for binary operator ", V rator, S "\n",
319                                        S "  expected:  ", TYS domTy, S "\n",                                        S "  expected:  ", TYS domTy, S "\n",
320                                        S "  but found: ", TYS[ty1, ty2]                                        S "  but found: ", TYS[ty1, ty2]
321                                      ])                                      ])
# Line 157  Line 327 
327                        (* end case *))                        (* end case *))
328                  end                  end
329              | PT.E_UnaryOp(rator, e) => let              | PT.E_UnaryOp(rator, e) => let
330                  val (e', ty) = checkExpr(env, cxt, e)                  val eTy = check(env, cxt, e)
331                  in                  in
332                    case Env.findFunc (#env env, rator)                    case Env.findFunc (env, rator)
333                     of Env.PrimFun[rator] => let                     of Env.PrimFun[rator] => let
334                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
335                          in                          in
336                            case coerceType (domTy, ty, e')                            case Util.coerceType (domTy, eTy)
337                             of SOME e' => (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)                             of SOME e' => (AST.E_Prim(rator, tyArgs, [e'], rngTy), rngTy)
338                              | NONE => err (cxt, [                              | NONE => err (cxt, [
339                                    S "type error for unary operator \"", V rator, S "\"\n",                                    S "type error for unary operator ", V rator, S "\n",
340                                    S "  expected:  ", TY domTy, S "\n",                                    S "  expected:  ", TY domTy, S "\n",
341                                    S "  but found: ", TY ty                                    S "  but found: ", TY (#2 eTy)
342                                  ])                                  ])
343                            (* end case *)                            (* end case *)
344                          end                          end
345                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
346                      | _ => raise Fail "impossible"                      | _ => raise Fail "impossible"
347                    (* end case *)                    (* end case *)
348                  end                  end
349              | PT.E_Apply(e, args) => raise Fail "FIXME"              | PT.E_Apply(e, args) => let
350                    fun stripMark (_, PT.E_Mark{span, tree}) = stripMark(span, tree)
351                      | stripMark (span, e) = (span, e)
352                    val (args, tys) = checkList (env, cxt, args)
353                    fun appTyError (f, paramTys, argTys) = err(cxt, [
354                            S "type error in application of ", V f, S "\n",
355                            S "  expected:  ", TYS paramTys, S "\n",
356                            S "  but found: ", TYS argTys
357                          ])
358                    fun checkPrimApp f = if Var.isPrim f
359                          then (case TU.instantiate(Var.typeOf f)
360                             of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
361                                  case Unify.matchArgs (domTy, args, tys)
362                                   of SOME args => (AST.E_Prim(f, tyArgs, args, rngTy), rngTy)
363                                    | NONE => appTyError (f, domTy, tys)
364                                  (* end case *))
365                              | _ => err(cxt, [S "application of non-function/field ", V f])
366                            (* end case *))
367                          else raise Fail "unexpected user function"
368                  (* check the application of a user-defined function *)
369                    fun checkFunApp (cxt, f) = if Var.isPrim f
370                          then raise Fail "unexpected primitive function"
371                          else (case Var.monoTypeOf f
372                             of Ty.T_Fun(domTy, rngTy) => (
373                                  case Unify.matchArgs (domTy, args, tys)
374                                   of SOME args => (AST.E_Apply(useVar(cxt, f), args, rngTy), rngTy)
375                                    | NONE => appTyError (f, domTy, tys)
376                                  (* end case *))
377                              | _ => err(cxt, [S "application of non-function/field ", V f])
378                            (* end case *))
379                    fun checkFieldApp (e1', ty1) = (case (args, tys)
380                           of ([e2'], [ty2]) => let
381                                val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
382                                      TU.instantiate(Var.typeOf BV.op_probe)
383                                fun tyError () = err (cxt, [
384                                        S "type error for field application\n",
385                                        S "  expected:  ", TYS[fldTy, domTy], S "\n",
386                                        S "  but found: ", TYS[ty1, ty2]
387                                      ])
388                                in
389                                  if Unify.equalType(fldTy, ty1)
390                                    then (case Util.coerceType(domTy, (e2', ty2))
391                                       of SOME e2' => (AST.E_Prim(BV.op_probe, tyArgs, [e1', e2'], rngTy), rngTy)
392                                        | NONE => tyError()
393                                      (* end case *))
394                                    else tyError()
395                                end
396                            | _ => err(cxt, [S "badly formed field application"])
397                          (* end case *))
398                    in
399                      case stripMark(#2 cxt, e)
400                       of (span, PT.E_Var f) => (case Env.findVar (env, f)
401                             of SOME f' => checkFieldApp (
402                                  AST.E_Var(useVar((#1 cxt, span), f')),
403                                  Var.monoTypeOf f')
404                              | NONE => (case Env.findFunc (env, f)
405                                   of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
406                                    | Env.PrimFun[f'] => checkPrimApp f'
407                                    | Env.PrimFun ovldList =>
408                                        resolveOverload ((#1 cxt, span), f, tys, args, ovldList)
409                                    | Env.UserFun f' => checkFunApp((#1 cxt, span), f')
410                                  (* end case *))
411                              (* end case *))
412                        | _ => checkFieldApp (check (env, cxt, e))
413                      (* end case *)
414                    end
415              | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)              | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)
416                   of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"                   of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"
417                    | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"                    | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"
# Line 186  Line 421 
421                        ])                        ])
422                  (* end case *))                  (* end case *))
423              | PT.E_Select(e, field) => (case check(env, cxt, e)              | PT.E_Select(e, field) => (case check(env, cxt, e)
424                   of (e', Ty.T_Strand strand) => (case Env.findStrand(#env env, strand)                   of (e', Ty.T_Named strand) => (case Env.findStrand(env, strand)
425                         of SOME(AST.Strand{name, state, ...}) => let                         of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
426                              fun isField (AST.VD_Decl(AST.V{name, ...}, _)) = Atom.same(name, field)                               of SOME x' => let
                             in  
                               case List.find isField state  
                                of SOME(AST.VD_Decl(x', _)) => let  
427                                      val ty = Var.monoTypeOf x'                                      val ty = Var.monoTypeOf x'
428                                      in                                      in
429                                        (AST.E_Selector(e', field, ty), ty)                                      (AST.E_Select(e', useVar(cxt, x')), ty)
430                                      end                                      end
431                                  | NONE => err(cxt, [                                  | NONE => err(cxt, [
432                                        S "strand ", A name,                                      S "strand ", A strand,
433                                        S " does not have state variable ", A field                                        S " does not have state variable ", A field
434                                      ])                                      ])
435                                (* end case *)                              (* end case *))
                             end  
436                          | NONE => err(cxt, [S "unknown strand ", A strand])                          | NONE => err(cxt, [S "unknown strand ", A strand])
437                        (* end case *))                        (* end case *))
438                    | (_, ty) => err (cxt, [                    | (_, ty) => err (cxt, [
# Line 211  Line 442 
442                  (* end case *))                  (* end case *))
443              | PT.E_Real e => (case check (env, cxt, e)              | PT.E_Real e => (case check (env, cxt, e)
444                   of (e', Ty.T_Int) =>                   of (e', Ty.T_Int) =>
445                        (AST.E_Apply(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)                        (AST.E_Prim(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
446                    | (_, ty) => err(cxt, [                    | (_, ty) => err(cxt, [
447                          S "argument of 'real' must have type 'int', but found ",                          S "argument of 'real' must have type 'int', but found ",
448                          TY ty                          TY ty
449                        ])                        ])
450                  (* end case *))                  (* end case *))
451              | PT.E_Load nrrd => let              | PT.E_Load nrrd => let
452                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load))
453                  in                  in
454                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    case chkStringConstExpr (env, cxt, nrrd)
455                       of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
456                        | NONE => (bogusExp, rngTy)
457                      (* end case *)
458                  end                  end
459              | PT.E_Image nrrd => let              | PT.E_Image nrrd => let
460                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_image))
461                  in                  in
462                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    case chkStringConstExpr (env, cxt, nrrd)
463                       of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
464                        | NONE => (bogusExp, rngTy)
465                      (* end case *)
466                  end                  end
467              | PT.E_Var x => (case E.findVar (#env env, x)              | PT.E_Var x => (case E.findVar (env, x)
468                   of SOME x' => (                   of SOME x' => (AST.E_Var(useVar(cxt, x')), Var.monoTypeOf x')
                       markUsed (x', true);  
                       (AST.E_Var x', Var.monoTypeOf x'))  
469                    | NONE => err(cxt, [S "undeclared variable ", A x])                    | NONE => err(cxt, [S "undeclared variable ", A x])
470                  (* end case *))                  (* end case *))
471              | PT.E_Kernel(kern, dim) => raise Fail "FIXME"              | PT.E_Kernel(kern, dim) => raise Fail "FIXME"
472              | PT.E_Lit lit => checkLit lit              | PT.E_Lit lit => checkLit lit
473              | PT.E_Id d => let              | PT.E_Id d => let
474                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
475                        Util.instantiate(Var.typeOf(BV.identity))                        TU.instantiate(Var.typeOf(BV.identity))
476                  in                  in
477                    if U.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, [d, d])), rngTy)
478                      then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.identity, tyArgs, [], rngTy), rngTy)
479                      else raise Fail "impossible"                      else raise Fail "impossible"
480                  end                  end
481              | PT.E_Zero dd => let              | PT.E_Zero dd => let
482                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
483                        Util.instantiate(Var.typeOf(BV.zero))                        TU.instantiate(Var.typeOf(BV.zero))
484                  in                  in
485                    if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
486                      then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.zero, tyArgs, [], rngTy), rngTy)
487                      else raise Fail "impossible"                      else raise Fail "impossible"
488                  end                  end
489              | PT.E_NaN dd => let              | PT.E_NaN dd => let
490                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
491                        Util.instantiate(Var.typeOf(BV.nan))                        TU.instantiate(Var.typeOf(BV.nan))
492                  in                  in
493                    if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
494                      then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.nan, tyArgs, [], rngTy), rngTy)
495                      else raise Fail "impossible"                      else raise Fail "impossible"
496                  end                  end
497              | PT.E_Sequence exps => raise Fail "FIXME"              | PT.E_Sequence exps => raise Fail "FIXME"
# Line 269  Line 504 
504                         of NONE => Ty.T_Error                         of NONE => Ty.T_Error
505                          | SOME ty => ty                          | SOME ty => ty
506                        (* end case *))                        (* end case *))
507                  in                (* process the arguments checking that they all have the expected type *)
508                    case realType(TU.pruneHead ty)                  fun chkArgs (ty, shape) = let
                    of ty as Ty.T_Tensor shape => let  
509                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
510                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
511                          fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)                        fun chkArgs (arg::args, argTy::tys, args') = (
512                                case Util.coerceType(ty, (arg, argTy))
513                                 of SOME arg' => chkArgs (args, tys, arg'::args')                                 of SOME arg' => chkArgs (args, tys, arg'::args')
514                                  | NONE => (                                  | NONE => (
515                                      TypeError.error(cxt, [                                      TypeError.error(cxt, [
516                                          S "arguments of tensor construction must have same type"                                          S "arguments of tensor construction must have same type"
517                                        ]);                                        ]);
518                                      ??)                                    chkArgs (args, tys, bogusExp::args'))
519                                (* end case *))                                (* end case *))
520                            | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)                          | chkArgs (_, _, args') = (AST.E_Tensor(List.rev args', resTy), resTy)
521                          in                          in
522                            chkArgs (args, tys, [])                            chkArgs (args, tys, [])
523                          end                          end
524                    in
525                      case TU.pruneHead ty
526                       of Ty.T_Int => chkArgs(Ty.realTy, Ty.Shape[]) (* coerce integers to reals *)
527                        | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
528                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])
529                    (* end case *)                    (* end case *)
530                  end                  end
531              | PT.E_Deprecate(msg, e) => (              | PT.E_Deprecate(msg, e) => (
532                  warn (cxt, [S msg]);                  warn (cxt, [S msg]);
533                  chk (env, cxt, e))                  check (env, cxt, e))
534            (* end case *))            (* end case *))
535    
536    (* check a conditional operator (e.g., || or &&) *)    (* check a conditional operator (e.g., || or &&) *)
# Line 299  Line 538 
538            case (check(env, cxt, e1), check(env, cxt, e2))            case (check(env, cxt, e1), check(env, cxt, e2))
539             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
540              | ((_, Ty.T_Bool), (_, ty2)) =>              | ((_, Ty.T_Bool), (_, ty2)) =>
541                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
542              | ((_, ty1), (_, Ty.T_Bool)) =>              | ((_, ty1), (_, Ty.T_Bool)) =>
543                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
544              | ((_, ty1), (_, ty2)) => err (cxt, [              | ((_, ty1), (_, ty2)) => err (cxt, [
545                    S "arguments of '", S rator, "' must have type 'bool', but found ",                    S "arguments of '", S rator, S "' must have type 'bool', but found ",
546                    TY ty1, S " and ", TY ty2                    TY ty1, S " and ", TY ty2
547                  ])                  ])
548            (* end case *))            (* end case *))
# Line 313  Line 552 
552     *)     *)
553      and checkList (env, cxt, exprs) = let      and checkList (env, cxt, exprs) = let
554            fun chk (e, (es, tys)) = let            fun chk (e, (es, tys)) = let
555                  val (e, ty) = checkExpr (env, cxt, e)                  val (e, ty) = check (env, cxt, e)
556                  in                  in
557                    (e::es, ty::tys)                    (e::es, ty::tys)
558                  end                  end
# Line 321  Line 560 
560              List.foldr chk ([], []) exprs              List.foldr chk ([], []) exprs
561            end            end
562    
563      (* check a string that is specified as a constant expression *)
564        and chkStringConstExpr (env, cxt, PT.E_Mark m) =
565              chkStringConstExpr (E.withEnvAndContext (env, cxt, m))
566          | chkStringConstExpr (env, cxt, e) = (case check (env, cxt, e)
567               of (e', Ty.T_String) => (case ConstExpr.eval (cxt, e')
568                     of SOME(ConstExpr.String s) => SOME s
569                      | SOME(ConstExpr.Expr e) => raise Fail "FIXME"
570                      | NONE => NONE
571                      | _ => raise Fail "impossible: wrong type for constant expr"
572                    (* end case *))
573                | (_, ty) => (
574                    TypeError.error (cxt, [
575                        S "expected constant expression of type 'string', but found '",
576                        TY ty, S "'"
577                      ]);
578                    NONE)
579              (* end case *))
580    
581      (* check a dimension that is given by a constant expression *)
582        and checkDim (env, cxt, dim) = (case check (env, cxt, dim)
583               of (e', Ty.T_Int) => (case ConstExpr.eval (cxt, e')
584                     of SOME(ConstExpr.Int d) => SOME d
585                      | SOME(ConstExpr.Expr e) => (
586                          TypeError.error (cxt, [S "unable to evaluate constant dimension expression"]);
587                          NONE)
588                      | NONE => NONE
589                      | _ => raise Fail "impossible: wrong type for constant expr"
590                    (* end case *))
591                | (_, ty) => (
592                    TypeError.error (cxt, [
593                        S "expected constant expression of type 'int', but found '",
594                        TY ty, S "'"
595                      ]);
596                    NONE)
597              (* end case *))
598    
599      (* check a tensor shape, where the dimensions are given by constant expressions *)
600        and checkShape (env, cxt, shape) = let
601              fun checkDim' e = (case checkDim (env, cxt, e)
602                     of SOME d => (
603                          if (d <= 1)
604                            then TypeError.error (cxt, [
605                                S "invalid tensor-shape dimension; must be > 1, but found ",
606                                S (IntLit.toString d)
607                              ])
608                            else ();
609                          Ty.DimConst(IntInf.toInt d))
610                      | NONE => Ty.DimConst ~1
611                    (* end case *))
612              in
613                Ty.Shape(List.map checkDim' shape)
614              end
615    
616    end    end

Legend:
Removed from v.3398  
changed lines
  Added in v.3418

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0