Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3402, Wed Nov 11 02:54:23 2015 UTC revision 3407, Wed Nov 11 18:53:18 2015 UTC
# Line 10  Line 10 
10    
11  structure CheckExpr : sig  structure CheckExpr : sig
12    
13      val check : Env.env * Env.context * ParseTree.expr -> (AST.expr * Types.ty)    (* type check an expression *)
14        val check : Env.t * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
15    
16      (* `resolveOverload (cxt, rator, tys, args, candidates)` resolves the application of
17       * the overloaded operator `rator` to `args`, where `tys` are the types of the arguments
18       * and `candidates` is the list of candidate definitions.
19       *)
20        val resolveOverload : Env.context * Atom.atom * Types.ty list * AST.expr list * Var.t list
21              -> (AST.expr * Types.ty)
22    
23      (* check a dimension that is given by a constant expression *)
24        val checkDim : Env.t * Env.context * ParseTree.expr -> IntLit.t option
25    
26      (* check a tensor shape, where the dimensions are given by constant expressions *)
27        val checkShape : Env.t * Env.context * ParseTree.expr list -> Types.shape
28    
29    end = struct    end = struct
30    
# Line 19  Line 33 
33      structure E = Env      structure E = Env
34      structure Ty = Types      structure Ty = Types
35      structure BV = BasisVars      structure BV = BasisVars
36        structure TU = TypeUtil
37    
38    (* an expression to return when there is a type error *)    (* an expression to return when there is a type error *)
39      val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)      val bogusExp = AST.E_Lit(L.Int 0)
40        val bogusExpTy = (bogusExp, Ty.T_Error)
41    
42      fun err arg = (TypeError.error arg; bogusExp)      fun err arg = (TypeError.error arg; bogusExpTy)
43      val warn = TypeError.warning      val warn = TypeError.warning
44    
45      datatype token = datatype TypeError.token      datatype token = datatype TypeError.token
46    
47      (* mark a variable use with its location *)
48        fun useVar (cxt, x) = (x, Error.location cxt)
49    
50      (* resolve overloading: we use a simple scheme that selects the first operator in the
51       * list that matches the argument types.
52       *)
53        fun resolveOverload (_, rator, _, _, []) = raise Fail(concat[
54                "resolveOverload: \"", Atom.toString rator, "\" has no candidates"
55              ])
56          | resolveOverload (cxt, rator, argTys, args, candidates) = let
57    (* FIXME: we could be more efficient by just checking for coercion matchs the first pass
58     * and remembering those that are not pure EQ matches.
59     *)
60            (* try to match candidates while allowing type coercions *)
61              fun tryMatchCandidates [] = err(cxt, [
62                      S "unable to resolve overloaded operator ", A rator, S "\n",
63                      S "  argument type is: ", TYS argTys, S "\n"
64                    ])
65                | tryMatchCandidates (x::xs) = let
66                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
67                    in
68                      case Unify.tryMatchArgs (domTy, args, argTys)
69                       of SOME args' => (AST.E_Prim(x, tyArgs, args', rngTy), rngTy)
70                        | NONE => tryMatchCandidates xs
71                      (* end case *)
72                    end
73              fun tryCandidates [] = tryMatchCandidates candidates
74                | tryCandidates (x::xs) = let
75                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf x)
76                    in
77                      if Unify.tryEqualTypes(domTy, argTys)
78                        then (AST.E_Prim(x, tyArgs, args, rngTy), rngTy)
79                        else tryCandidates xs
80                    end
81              in
82                tryCandidates candidates
83              end
84    
85    (* check the type of a literal *)    (* check the type of a literal *)
86      fun checkLit lit = (case lit      fun checkLit lit = (case lit
87             of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)             of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)
# Line 36  Line 90 
90              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
91            (* end case *))            (* end case *))
92    
93      (* type check a dot product, which has the constraint:
94       *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
95       * and similarly for fields.
96       *)
97        fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
98            (* check the shape of the two arguments to verify that the inner constraint matches *)
99              fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
100                    val (dd1, d1) = let
101                          fun splitLast (prefix, [d]) = (List.rev prefix, d)
102                            | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
103                            | splitLast (_, []) = raise Fail "impossible"
104                          in
105                            splitLast ([], dd1)
106                          end
107                    in
108                      if Unify.equalDim(d1, d2)
109                        then SOME(Ty.Shape(dd1@dd2))
110                        else NONE
111                    end
112                | chkShape _ = NONE
113              fun error () = err (cxt, [
114                      S "type error for arguments of binary operator '•'\n",
115                      S "  found: ", TYS[ty1, ty2], S "\n"
116                    ])
117              in
118                case (TU.prune ty1, TU.prune ty2)
119                (* tensor * tensor inner product *)
120                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
121                       of SOME shp => let
122                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
123                            val resTy = Ty.T_Tensor shp
124                            in
125                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
126                                then (AST.E_Prim(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
127                                else error()
128                            end
129                        | NONE => error()
130                      (* end case *))
131                (* tensor * field inner product *)
132                  | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
133                       of SOME shp => let
134                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
135                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
136                            in
137                              if Unify.equalTypes(domTy, [ty1, ty2])
138                              andalso Unify.equalType(rngTy, resTy)
139                                then (AST.E_Prim(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
140                                else error()
141                            end
142                        | NONE => error()
143                      (* end case *))
144                (* field * tensor inner product *)
145                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
146                       of SOME shp => let
147                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
148                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
149                            in
150                              if Unify.equalTypes(domTy, [ty1, ty2])
151                              andalso Unify.equalType(rngTy, resTy)
152                                then (AST.E_Prim(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
153                                else error()
154                            end
155                        | NONE => error()
156                      (* end case *))
157                (* field * field inner product *)
158                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
159                      case chkShape(s1, s2)
160                       of SOME shp => let
161                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
162                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
163                            in
164    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
165                              if Unify.equalDim(dim1, dim2)
166                              andalso Unify.equalTypes(domTy, [ty1, ty2])
167                              andalso Unify.equalType(rngTy, resTy)
168                                then (AST.E_Prim(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
169                                else error()
170                            end
171                        | NONE => error()
172                      (* end case *))
173                  | (ty1, ty2) => error()
174                (* end case *)
175              end
176    
177      (* type check a colon product, which has the constraint:
178       *     ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
179       * and similarly for fields.
180       *)
181        fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
182            (* check the shape of the two arguments to verify that the inner constraint matches *)
183              fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
184                    val (dd1, d11, d12) = let
185                          fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
186                            | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
187                            | splitLast2 (_, []) = raise Fail "impossible"
188                          in
189                            splitLast2 ([], dd1)
190                          end
191                    in
192                      if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
193                        then SOME(Ty.Shape(dd1@dd2))
194                        else NONE
195                    end
196                | chkShape _ = NONE
197              fun error () = err (cxt, [
198                      S "type error for arguments of binary operator \":\"\n",
199                      S "  found: ", TYS[ty1, ty2], S "\n"
200                    ])
201              in
202                case (TU.prune ty1, TU.prune ty2)
203                (* tensor * tensor colon product *)
204                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
205                       of SOME shp => let
206                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
207                            val resTy = Ty.T_Tensor shp
208                            in
209                              if Unify.equalTypes(domTy, [ty1, ty2])
210                              andalso Unify.equalType(rngTy, resTy)
211                                then (AST.E_Prim(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
212                                else error()
213                            end
214                        | NONE => error()
215                      (* end case *))
216                (* field * tensor colon product *)
217                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
218                       of SOME shp => let
219                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
220                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
221                            in
222                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
223                                then (AST.E_Prim(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
224                                else error()
225                            end
226                        | NONE => error()
227                      (* end case *))
228                (* tensor * field colon product *)
229                  | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
230                       of SOME shp => let
231                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
232                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
233                            in
234                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
235                                then (AST.E_Prim(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
236                                else error()
237                            end
238                        | NONE => error()
239                      (* end case *))
240                (* field * field colon product *)
241                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
242                      case chkShape(s1, s2)
243                       of SOME shp => let
244                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
245                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
246                            in
247    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
248                              if Unify.equalDim(dim1, dim2)
249                              andalso Unify.equalTypes(domTy, [ty1, ty2])
250                              andalso Unify.equalType(rngTy, resTy)
251                                then (AST.E_Prim(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
252                                else error()
253                            end
254                        | NONE => error()
255                      (* end case *))
256                  | (ty1, ty2) => error()
257                (* end case *)
258              end
259    
260    (* check the type of an expression *)    (* check the type of an expression *)
261      fun check (env, cxt, e) = (case e      fun check (env, cxt, e) = (case e
262             of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))             of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
263              | PT.E_Cond(e1, cond, e2) => let              | PT.E_Cond(e1, cond, e2) => let
264                  val eTy1 = check (env, cxt, e1)                  val eTy1 = check (env, cxt, e1)
265                  val eTy2 = check (env, cxt, e2)                  val eTy2 = check (env, cxt, e2)
266                  in                  in
267                    case check(env, cxt, cond)                    case check(env, cxt, cond)
268                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
269                           of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)                           of SOME(e1', e2', ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
270                            | NONE => err (cxt, [                            | NONE => err (cxt, [
271                                S "types do not match in conditional expression\n",                                S "types do not match in conditional expression\n",
272                                S "  true branch:  ", TY(#2 eTy1), S "\n",                                S "  true branch:  ", TY(#2 eTy1), S "\n",
# Line 59  Line 280 
280                   of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let                   of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
281                        val resTy = Ty.T_Sequence(Ty.T_Int, NONE)                        val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
282                        in                        in
283                          (AST.E_Apply(BV.range, [], [e1', e2'], resTy), resTy)                          (AST.E_Prim(BV.range, [], [e1', e2'], resTy), resTy)
284                        end                        end
285                    | ((_, Ty.T_Int), (_, ty2)) =>                    | ((_, Ty.T_Int), (_, ty2)) =>
286                        err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])                        err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
# Line 81  Line 302 
302                  val (e2', ty2) = check (env, cxt, e2)                  val (e2', ty2) = check (env, cxt, e2)
303                  in                  in
304                    if Atom.same(rator, BasisNames.op_dot)                    if Atom.same(rator, BasisNames.op_dot)
305                    (* we have to handle inner product as a special case, because our type                      then chkInnerProduct (cxt, e1', ty1, e2', ty2)
                    * system cannot express the constraint that the type is  
                    *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]  
                    *)  
                     then (case (TU.prune ty1, TU.prune ty2)  
                        of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let  
                             val (dd1, d1) = let  
                                   fun splitLast (prefix, [d]) = (List.rev prefix, d)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
                                   in  
                                     splitLast ([], dd1)  
                                   end  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)  
                             val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))  
                             in  
                               if Unify.equalDim(d1, d2)  
                               andalso Unify.equalTypes(domTy, [ty1, ty2])  
                               andalso Unify.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)  
                                 else err (cxt, [  
                                     S "type error for arguments of binary operator '•'\n",  
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
                                   ])  
                             end  
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator '•'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
306                    else if Atom.same(rator, BasisNames.op_colon)                    else if Atom.same(rator, BasisNames.op_colon)
307                      then (case (TU.prune ty1, TU.prune ty2)                      then chkColonProduct (cxt, e1', ty1, e2', ty2)
308                         of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let                      else (case Env.findFunc (env, rator)
                             val (dd1, d11, d12) = let  
                                   fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
                                   in  
                                     splitLast ([], dd1)  
                                   end  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)  
                             val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))  
                             in  
                               if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)  
                               andalso Unify.equalTypes(domTy, [ty1, ty2])  
                               andalso Unify.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)  
                                 else err (cxt, [  
                                     S "type error for arguments of binary operator ':'\n",  
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
                                   ])  
                             end  
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator ':'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
                     else (case Env.findFunc (#env env, rator)  
309                         of Env.PrimFun[rator] => let                         of Env.PrimFun[rator] => let
310                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
311                              in                              in
312                                case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])                                case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
313                                 of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)                                 of SOME args => (AST.E_Prim(rator, tyArgs, args, rngTy), rngTy)
314                                  | NONE => err (cxt, [                                  | NONE => err (cxt, [
315                                        S "type error for binary operator '", V rator, S "'\n",                                        S "type error for binary operator '", V rator, S "'\n",
316                                        S "  expected:  ", TYS domTy, S "\n",                                        S "  expected:  ", TYS domTy, S "\n",
# Line 157  Line 324 
324                        (* end case *))                        (* end case *))
325                  end                  end
326              | PT.E_UnaryOp(rator, e) => let              | PT.E_UnaryOp(rator, e) => let
327                  val (e', ty) = check(env, cxt, e)                  val eTy = check(env, cxt, e)
328                  in                  in
329                    case Env.findFunc (#env env, rator)                    case Env.findFunc (env, rator)
330                     of Env.PrimFun[rator] => let                     of Env.PrimFun[rator] => let
331                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
332                          in                          in
333                            case coerceType (domTy, ty, e')                            case Util.coerceType (domTy, eTy)
334                             of SOME e' => (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)                             of SOME(e', ty) => (AST.E_Prim(rator, tyArgs, [e'], rngTy), rngTy)
335                              | NONE => err (cxt, [                              | NONE => err (cxt, [
336                                    S "type error for unary operator \"", V rator, S "\"\n",                                    S "type error for unary operator \"", V rator, S "\"\n",
337                                    S "  expected:  ", TY domTy, S "\n",                                    S "  expected:  ", TY domTy, S "\n",
338                                    S "  but found: ", TY ty                                    S "  but found: ", TY (#2 eTy)
339                                  ])                                  ])
340                            (* end case *)                            (* end case *)
341                          end                          end
342                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
343                      | _ => raise Fail "impossible"                      | _ => raise Fail "impossible"
344                    (* end case *)                    (* end case *)
345                  end                  end
346              | PT.E_Apply(e, args) => raise Fail "FIXME"              | PT.E_Apply(e, args) => let
347                    fun stripMark (_, PT.E_Mark{span, tree}) = stripMark(span, tree)
348                      | stripMark (span, e) = (span, e)
349                    val (args, tys) = checkList (env, cxt, args)
350                    fun appTyError (f, paramTys, argTys) = err(cxt, [
351                            S "type error in application of ", V f, S "\n",
352                            S "  expected:  ", TYS paramTys, S "\n",
353                            S "  but found: ", TYS argTys
354                          ])
355                    fun checkPrimApp f = if Var.isPrim f
356                          then (case TU.instantiate(Var.typeOf f)
357                             of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
358                                  case Unify.matchArgs (domTy, args, tys)
359                                   of SOME args => (AST.E_Prim(f, tyArgs, args, rngTy), rngTy)
360                                    | NONE => appTyError (f, domTy, tys)
361                                  (* end case *))
362                              | _ => err(cxt, [S "application of non-function/field ", V f])
363                            (* end case *))
364                          else raise Fail "unexpected user function"
365                  (* check the application of a user-defined function *)
366                    fun checkFunApp (cxt, f) = if Var.isPrim f
367                          then raise Fail "unexpected primitive function"
368                          else (case Var.monoTypeOf f
369                             of Ty.T_Fun(domTy, rngTy) => (
370                                  case Unify.matchArgs (domTy, args, tys)
371                                   of SOME args => (AST.E_Apply(useVar(cxt, f), args, rngTy), rngTy)
372                                    | NONE => appTyError (f, domTy, tys)
373                                  (* end case *))
374                              | _ => err(cxt, [S "application of non-function/field ", V f])
375                            (* end case *))
376                    fun checkFieldApp (e1', ty1) = (case (args, tys)
377                           of ([e2'], [ty2]) => let
378                                val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
379                                      TU.instantiate(Var.typeOf BV.op_probe)
380                                fun tyError () = err (cxt, [
381                                        S "type error for field application\n",
382                                        S "  expected:  ", TYS[fldTy, domTy], S "\n",
383                                        S "  but found: ", TYS[ty1, ty2]
384                                      ])
385                                in
386                                  if Unify.equalType(fldTy, ty1)
387                                    then (case Util.coerceType(domTy, (e2', ty2))
388                                       of SOME(e2', _) => (AST.E_Prim(BV.op_probe, tyArgs, [e1', e2'], rngTy), rngTy)
389                                        | NONE => tyError()
390                                      (* end case *))
391                                    else tyError()
392                                end
393                            | _ => err(cxt, [S "badly formed field application"])
394                          (* end case *))
395                    in
396                      case stripMark(#2 cxt, e)
397                       of (span, PT.E_Var f) => (case Env.findVar (env, f)
398                             of SOME f' => checkFieldApp (
399                                  AST.E_Var(useVar((#1 cxt, span), f')),
400                                  Var.monoTypeOf f')
401                              | NONE => (case Env.findFunc (env, f)
402                                   of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
403                                    | Env.PrimFun[f'] => checkPrimApp f'
404                                    | Env.PrimFun ovldList =>
405                                        resolveOverload ((#1 cxt, span), f, tys, args, ovldList)
406                                    | Env.UserFun f' => checkFunApp((#1 cxt, span), f')
407                                  (* end case *))
408                              (* end case *))
409                        | _ => checkFieldApp (check (env, cxt, e))
410                      (* end case *)
411                    end
412              | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)              | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)
413                   of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"                   of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"
414                    | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"                    | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"
# Line 186  Line 418 
418                        ])                        ])
419                  (* end case *))                  (* end case *))
420              | PT.E_Select(e, field) => (case check(env, cxt, e)              | PT.E_Select(e, field) => (case check(env, cxt, e)
421                   of (e', Ty.T_Strand strand) => (case Env.findStrand(#env env, strand)                   of (e', Ty.T_Named strand) => (case Env.findStrand(env, strand)
422                         of SOME(AST.Strand{name, state, ...}) => let                         of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
423                              fun isField (AST.VD_Decl(AST.V{name, ...}, _)) = Atom.same(name, field)                               of SOME x' => let
                             in  
                               case List.find isField state  
                                of SOME(AST.VD_Decl(x', _)) => let  
424                                      val ty = Var.monoTypeOf x'                                      val ty = Var.monoTypeOf x'
425                                      in                                      in
426                                        (AST.E_Selector(e', field, ty), ty)                                      (AST.E_Select(e', useVar(cxt, x')), ty)
427                                      end                                      end
428                                  | NONE => err(cxt, [                                  | NONE => err(cxt, [
429                                        S "strand ", A name,                                      S "strand '", A strand,
430                                        S " does not have state variable ", A field                                      S "' does not have state variable '", A field, S "'"
431                                      ])                                      ])
432                                (* end case *)                              (* end case *))
433                              end                          | NONE => err(cxt, [S "unknown strand '", A strand, S "'"])
                         | NONE => err(cxt, [S "unknown strand ", A strand])  
434                        (* end case *))                        (* end case *))
435                    | (_, ty) => err (cxt, [                    | (_, ty) => err (cxt, [
436                          S "expected strand type, but found ", TY ty,                          S "expected strand type, but found ", TY ty,
437                          S " in selection of ", A field                          S " in selection of '", A field, S "'"
438                        ])                        ])
439                  (* end case *))                  (* end case *))
440              | PT.E_Real e => (case check (env, cxt, e)              | PT.E_Real e => (case check (env, cxt, e)
441                   of (e', Ty.T_Int) =>                   of (e', Ty.T_Int) =>
442                        (AST.E_Apply(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)                        (AST.E_Prim(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
443                    | (_, ty) => err(cxt, [                    | (_, ty) => err(cxt, [
444                          S "argument of 'real' must have type 'int', but found ",                          S "argument of 'real' must have type 'int', but found ",
445                          TY ty                          TY ty
446                        ])                        ])
447                  (* end case *))                  (* end case *))
448              | PT.E_Load nrrd => let              | PT.E_Load nrrd => let
449                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_image))
450                  in                  in
451                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    case chkStringConstExpr (env, cxt, nrrd)
452                       of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
453                        | NONE => (bogusExp, rngTy)
454                      (* end case *)
455                  end                  end
456              | PT.E_Image nrrd => let              | PT.E_Image nrrd => let
457                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load))
458                  in                  in
459                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    case chkStringConstExpr (env, cxt, nrrd)
460                       of SOME nrrd => (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
461                        | NONE => (bogusExp, rngTy)
462                      (* end case *)
463                  end                  end
464              | PT.E_Var x => (case E.findVar (#env env, x)              | PT.E_Var x => (case E.findVar (env, x)
465                   of SOME x' => (                   of SOME x' => (AST.E_Var(useVar(cxt, x')), Var.monoTypeOf x')
                       markUsed (x', true);  
                       (AST.E_Var x', Var.monoTypeOf x'))  
466                    | NONE => err(cxt, [S "undeclared variable ", A x])                    | NONE => err(cxt, [S "undeclared variable ", A x])
467                  (* end case *))                  (* end case *))
468              | PT.E_Kernel(kern, dim) => raise Fail "FIXME"              | PT.E_Kernel(kern, dim) => raise Fail "FIXME"
469              | PT.E_Lit lit => checkLit lit              | PT.E_Lit lit => checkLit lit
470              | PT.E_Id d => let              | PT.E_Id d => let
471                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
472                        Util.instantiate(Var.typeOf(BV.identity))                        TU.instantiate(Var.typeOf(BV.identity))
473                  in                  in
474                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, [d, d])), rngTy)
475                      then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.identity, tyArgs, [], rngTy), rngTy)
476                      else raise Fail "impossible"                      else raise Fail "impossible"
477                  end                  end
478              | PT.E_Zero dd => let              | PT.E_Zero dd => let
479                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
480                        Util.instantiate(Var.typeOf(BV.zero))                        TU.instantiate(Var.typeOf(BV.zero))
481                  in                  in
482                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
483                      then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.zero, tyArgs, [], rngTy), rngTy)
484                      else raise Fail "impossible"                      else raise Fail "impossible"
485                  end                  end
486              | PT.E_NaN dd => let              | PT.E_NaN dd => let
487                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
488                        Util.instantiate(Var.typeOf(BV.nan))                        TU.instantiate(Var.typeOf(BV.nan))
489                  in                  in
490                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(env, cxt, dd)), rngTy)
491                      then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)                      then (AST.E_Prim(BV.nan, tyArgs, [], rngTy), rngTy)
492                      else raise Fail "impossible"                      else raise Fail "impossible"
493                  end                  end
494              | PT.E_Sequence exps => raise Fail "FIXME"              | PT.E_Sequence exps => raise Fail "FIXME"
# Line 269  Line 501 
501                         of NONE => Ty.T_Error                         of NONE => Ty.T_Error
502                          | SOME ty => ty                          | SOME ty => ty
503                        (* end case *))                        (* end case *))
504                  in                (* process the arguments checking that they all have the expected type *)
505                    case realType(TU.pruneHead ty)                  fun chkArgs (ty, shape) = let
                    of ty as Ty.T_Tensor shape => let  
506                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
507                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
508                          fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)                        fun chkArgs (arg::args, argTy::tys, args') = (
509                                 of SOME arg' => chkArgs (args, tys, arg'::args')                              case Util.coerceType(ty, (arg, argTy))
510                                 of SOME(arg', _) => chkArgs (args, tys, arg'::args')
511                                  | NONE => (                                  | NONE => (
512                                      TypeError.error(cxt, [                                      TypeError.error(cxt, [
513                                          S "arguments of tensor construction must have same type"                                          S "arguments of tensor construction must have same type"
514                                        ]);                                        ]);
515                                      ??)                                    chkArgs (args, tys, bogusExp::args'))
516                                (* end case *))                                (* end case *))
517                            | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)                          | chkArgs ([], [], args') = (AST.E_Tensor(List.rev args', resTy), resTy)
518                          in                          in
519                            chkArgs (args, tys, [])                            chkArgs (args, tys, [])
520                          end                          end
521                    in
522                      case TU.pruneHead ty
523                       of Ty.T_Int => chkArgs(Ty.realTy, Ty.Shape[]) (* coerce integers to reals *)
524                        | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
525                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])
526                    (* end case *)                    (* end case *)
527                  end                  end
# Line 299  Line 535 
535            case (check(env, cxt, e1), check(env, cxt, e2))            case (check(env, cxt, e1), check(env, cxt, e2))
536             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
537              | ((_, Ty.T_Bool), (_, ty2)) =>              | ((_, Ty.T_Bool), (_, ty2)) =>
538                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
539              | ((_, ty1), (_, Ty.T_Bool)) =>              | ((_, ty1), (_, Ty.T_Bool)) =>
540                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
541              | ((_, ty1), (_, ty2)) => err (cxt, [              | ((_, ty1), (_, ty2)) => err (cxt, [
542                    S "arguments of '", S rator, "' must have type 'bool', but found ",                    S "arguments of '", S rator, S "' must have type 'bool', but found ",
543                    TY ty1, S " and ", TY ty2                    TY ty1, S " and ", TY ty2
544                  ])                  ])
545            (* end case *))            (* end case *))
# Line 321  Line 557 
557              List.foldr chk ([], []) exprs              List.foldr chk ([], []) exprs
558            end            end
559    
560      (* check a string that is specified as a constant expression *)
561        and chkStringConstExpr (env, cxt, PT.E_Mark m) =
562              chkStringConstExpr (E.withEnvAndContext (env, cxt, m))
563          | chkStringConstExpr (env, cxt, e) = (case check (env, cxt, e)
564               of (e', Ty.T_String) => (case ConstExpr.eval (cxt, e')
565                     of SOME(ConstExpr.String s) => SOME s
566                      | SOME(ConstExpr.Expr e) => raise Fail "FIXME"
567                      | NONE => NONE
568                      | _ => raise Fail "impossible: wrong type for constant expr"
569                    (* end case *))
570                | (_, ty) => (
571                    TypeError.error (cxt, [
572                        S "expected constant expression of type 'string', but found '",
573                        TY ty, S "'"
574                      ]);
575                    NONE)
576              (* end case *))
577    
578      (* check a dimension that is given by a constant expression *)
579        and checkDim (env, cxt, dim) = (case check (env, cxt, dim)
580               of (e', Ty.T_Int) => (case ConstExpr.eval (cxt, e')
581                     of SOME(ConstExpr.Int d) => SOME d
582                      | SOME(ConstExpr.Expr e) => (
583                          TypeError.error (cxt, [S "unable to evaluate constant dimension expression"]);
584                          NONE)
585                      | NONE => NONE
586                      | _ => raise Fail "impossible: wrong type for constant expr"
587                    (* end case *))
588                | (_, ty) => (
589                    TypeError.error (cxt, [
590                        S "expected constant expression of type 'int', but found '",
591                        TY ty, S "'"
592                      ]);
593                    NONE)
594              (* end case *))
595    
596      (* check a tensor shape, where the dimensions are given by constant expressions *)
597        and checkShape (env, cxt, shape) = let
598              fun checkDim' e = (case checkDim (env, cxt, e)
599                     of SOME d => (
600                          if (d <= 1)
601                            then TypeError.error (cxt, [
602                                S "invalid tensor-shape dimension; must be > 1, but found ",
603                                S (IntLit.toString d)
604                              ])
605                            else ();
606                          Ty.DimConst(IntInf.toInt d))
607                      | NONE => Ty.DimConst ~1
608                    (* end case *))
609              in
610                Ty.Shape(List.map checkDim' shape)
611              end
612    
613    end    end

Legend:
Removed from v.3402  
changed lines
  Added in v.3407

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0