Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/compiler/typechecker/typechecker.sml
ViewVC logotype

Diff of /trunk/src/compiler/typechecker/typechecker.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 70, Sat May 22 14:23:32 2010 UTC revision 96, Thu May 27 17:57:31 2010 UTC
# Line 6  Line 6 
6    
7  structure Typechecker : sig  structure Typechecker : sig
8    
9      val check : ParseTree.program -> AST.program      exception Error
10    
11        val check : Error.err_stream -> ParseTree.program -> AST.program
12    
13    end = struct    end = struct
14    
15      structure PT = ParseTree      structure PT = ParseTree
16      structure Ty = Types      structure Ty = Types
17        structure TU = TypeUtil
18        structure U = Util
19    
20        exception Error
21    
22        type context = Error.err_stream * Error.span
23    
24        fun withContext ((errStrm, _), {span, tree}) =
25              ((errStrm, span), tree)
26        fun withEnvAndContext (env, (errStrm, _), {span, tree}) =
27              (env, (errStrm, span), tree)
28    
29        fun error ((errStrm, span), msg) = (
30              Error.errorAt(errStrm, span, msg);
31              raise Error)
32    
33        datatype token
34          = S of string | A of Atom.atom
35          | V of AST.var | TY of Types.ty | TYS of Types.ty list
36    
37        fun err (cxt, toks) = let
38              fun tok2str (S s) = s
39                | tok2str (A a) = Atom.toString a
40                | tok2str (V x) = Var.nameOf x
41                | tok2str (TY ty) = TU.toString ty
42                | tok2str (TYS []) = "()"
43                | tok2str (TYS[ty]) = TU.toString ty
44                | tok2str (TYS tys) = String.concat[
45                      "(", String.concatWith " * " (List.map TU.toString tys), ")"
46                    ]
47              in
48                error(cxt, List.map tok2str toks)
49              end
50    
51        val realZero = AST.E_Lit(Literal.Float(FloatLit.zero true))
52    
53    (* check a differentiation level, which muse be >= 0 *)    (* check a differentiation level, which muse be >= 0 *)
54      fun checkDiff (cxt, k) =      fun checkDiff (cxt, k) =
55            if (k < 0)            if (k < 0)
56              then raise Fail "differentiation must be >= 0"              then raise Fail "differentiation must be >= 0"
57              else Ty.NatConst(IntInf.toInt k)              else Ty.DiffConst(IntInf.toInt k)
58    
59    (* check a dimension, which must be 2 or 3 *)    (* check a dimension, which must be 2 or 3 *)
60      fun checkDim (cxt, d) =      fun checkDim (cxt, d) =
61            if (d < 2) orelse (3 < d)            if (d <= 0)
62              then raise Fail "invalid dimension; must be 2 or 3"              then raise Fail "invalid dimension; must be > 0"
63              else Ty.NatConst(IntInf.toInt d)              else Ty.DimConst(IntInf.toInt d)
64    
65    (* check a shape *)    (* check a shape *)
66      fun checkShape (cxt, shape) = let      fun checkShape (cxt, shape) = let
67            fun chkDim d = if (d < 1)            fun chkDim d = if (d < 1)
68                  then raise Fail "invalid shape dimension; must be >= 1"                  then raise Fail "invalid shape dimension; must be >= 1"
69                  else Ty.NatConst(IntInf.toInt d)                  else Ty.DimConst(IntInf.toInt d)
70            in            in
71              Ty.Shape(List.map chkDim shape)              Ty.Shape(List.map chkDim shape)
72            end            end
73    
74    (* check the well-formedness of a type and translate it to an AST type *)    (* check the well-formedness of a type and translate it to an AST type *)
75      fun checkTy (cxt, ty) = (case ty      fun checkTy (cxt, ty) = (case ty
76             of PT.T_Mark m => checkTy(cxt, #tree m)  (* FIXME track context *)             of PT.T_Mark m => checkTy(withContext(cxt, m))
77              | PT.T_Bool => Ty.T_Bool              | PT.T_Bool => Ty.T_Bool
78              | PT.T_Int => Ty.T_Int              | PT.T_Int => Ty.T_Int
79              | PT.T_Real => Ty.realTy              | PT.T_Real => Ty.realTy
# Line 57  Line 94 
94              | PT.T_Array(ty, dims) => raise Fail "Array type"              | PT.T_Array(ty, dims) => raise Fail "Array type"
95            (* end case *))            (* end case *))
96    
97  (*      fun checkLit lit = (case lit
98               of (Literal.Int _) => (AST.E_Lit lit, Ty.T_Int)
99                | (Literal.Float _) => (AST.E_Lit lit, Ty.realTy)
100                | (Literal.String s) => (AST.E_Lit lit, Ty.T_String)
101                | (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
102              (* end case *))
103    
104      (* resolve overloading: we use a simple scheme that selects the first operator in the
105       * list that matches the argument types.
106       *)
107        fun resolveOverload (cxt, rator, argTys, args, candidates) = let
108              fun tryCandidates [] = err(cxt, [
109                      S "unable to resolve overloaded operator \"", A rator, S "\"\n",
110                      S "  argument type is: ", TYS argTys, S "\n"
111                    ])
112                | tryCandidates (x::xs) = let
113                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf x)
114                    in
115                      if U.tryMatchTypes(domTy, argTys)
116                        then (AST.E_Apply(x, tyArgs, args, rngTy), rngTy)
117                        else tryCandidates xs
118                    end
119              in
120                tryCandidates candidates
121              end
122    
123    (* typecheck an expression and translate it to AST *)    (* typecheck an expression and translate it to AST *)
124      fun checkExpr (env, e) = (case e      fun checkExpr (env, cxt, e) = (case e
125             of PT.E_Mark of expr mark             of PT.E_Mark m => checkExpr (withEnvAndContext (env, cxt, m))
126              | PT.E_Var of var              | PT.E_Var x => (case Env.findVar (env, x)
127              | PT.E_Lit of Literal.literal                   of SOME x' => let
128              | PT.E_BinOp of expr * var * expr                        val (args, ty) = Util.instantiate(Var.typeOf x')
129              | PT.E_UnaryOp of var * expr                        in
130              | PT.E_Tuple of expr list                          (AST.E_Var(x', args, ty), ty)
131              | PT.E_Apply of var * expr list                        end
132              | PT.E_Cons of ty * expr list                    | NONE => err(cxt, [S "undeclared variable ", A x])
133              | PT.E_Diff of expr                  (* end case *))
134              | PT.E_Norm of expr              | PT.E_Lit lit => checkLit lit
135                | PT.E_OrElse(e1, e2) => let
136                    val (e1', ty1) = checkExpr(env, cxt, e1)
137                    val (e2', ty2) = checkExpr(env, cxt, e2)
138                    in
139                      case (ty1, ty2)
140                       of (Ty.T_Bool, Ty.T_Bool) =>
141                            (AST.E_Cond(e1', AST.E_Lit(Literal.Bool true), e2'), Ty.T_Bool)
142                        | _ => raise Fail "arguments to \"||\" must have bool type"
143                      (* end case *)
144                    end
145                | PT.E_AndAlso(e1, e2) => let
146                    val (e1', ty1) = checkExpr(env, cxt, e1)
147                    val (e2', ty2) = checkExpr(env, cxt, e2)
148                    in
149                      case (ty1, ty2)
150                       of (Ty.T_Bool, Ty.T_Bool) =>
151                            (AST.E_Cond(e1', e2', AST.E_Lit(Literal.Bool false)), Ty.T_Bool)
152                        | _ => raise Fail "arguments to \"||\" must have bool type"
153                      (* end case *)
154                    end
155                | PT.E_BinOp(e1, rator, e2) => let
156                    val (e1', ty1) = checkExpr(env, cxt, e1)
157                    val (e2', ty2) = checkExpr(env, cxt, e2)
158                    in
159                      case Basis.findOp rator
160                       of [rator] => let
161                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)
162                            in
163                              if U.matchTypes(domTy, [ty1, ty2])
164                                then (AST.E_Apply(rator, tyArgs, [e1', e2'], rngTy), rngTy)
165                                else err (cxt, [
166                                    S "type error for binary operator \"", V rator, S "\"\n",
167                                    S "  expected:  ", TYS domTy, S "\n",
168                                    S "  but found: ", TYS[ty1, ty2], S "\n"
169                                  ])
170                            end
171                        | ovldList => resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
172                      (* end case *)
173                    end
174                | PT.E_UnaryOp(rator, e) => let
175                    val (e', ty) = checkExpr(env, cxt, e)
176                    in
177                      case Basis.findOp rator
178                       of [rator] => let
179                            val (tyArgs, Ty.T_Fun([domTy], rngTy)) = Util.instantiate(Var.typeOf rator)
180                            in
181                              if U.matchType(domTy, ty)
182                                then (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)
183                                else err (cxt, [
184                                    S "type error for unary operator \"", V rator, S "\"\n",
185                                    S "  expected:  ", TY domTy, S "\n",
186                                    S "  but found: ", TY ty, S "\n"
187                                  ])
188                            end
189                        | ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)
190                      (* end case *)
191                    end
192                | PT.E_Tuple args => let
193                    val (args, tys) = checkExprList (env, cxt, args)
194                    in
195                      raise Fail "E_Tuple not yet implemented"
196                    end
197                | PT.E_Apply(f, args) => let
198                    val (args, tys) = checkExprList (env, cxt, args)
199                    in
200                      case Env.findVar (env, f)
201                       of SOME f => (case Util.instantiate(Var.typeOf f)
202                             of (tyArgs, Ty.T_Fun(domTy, rngTy)) =>
203                                  if U.matchTypes(domTy, tys)
204                                    then (AST.E_Apply(f, tyArgs, args, rngTy), rngTy)
205                                    else err(cxt, [
206                                        S "type error in application of ", V f, S "\n",
207                                        S "  expected:  ", TYS domTy, S "\n",
208                                        S "  but found: ", TYS tys, S "\n"
209                                      ])
210                              | _ => raise Fail "application of non-function"
211                            (* end case *))
212                        | NONE => raise Fail "unknown function"
213                      (* end case *)
214                    end
215                | PT.E_Cons args => let
216                    val (args, ty::tys) = checkExprList (env, cxt, args)
217                    in
218                      case TU.pruneHead ty
219                       of Ty.T_Tensor shape => let
220                            fun chkTy ty' = U.matchType(ty, ty')
221                            val resTy = Ty.T_Tensor(Ty.shapeExt(shape, Ty.DimConst(List.length args)))
222                            in
223                              if List.all chkTy tys
224                                then (AST.E_Cons args, resTy)
225                                else raise Fail "arguments of tensor construction must have same type"
226                            end
227                        | _ => raise Fail "Invalid argument type for tensor construction"
228                      (* end case *)
229                    end
230                | PT.E_Real e => (case checkExpr (env, cxt, e)
231                     of (e', Ty.T_Int) =>
232                          (AST.E_Apply(BasisVars.i2r, [], [e'], Ty.realTy), Ty.realTy)
233                      | _ => raise Fail "argument of real conversion must be int"
234                    (* end case *))
235              (* end case *))
236    
237      (* typecheck a list of expressions returning a list of AST expressions and a list
238       * of types of the expressions.
239       *)
240        and checkExprList (env, cxt, exprs) = let
241              fun chk (e, (es, tys)) = let
242                    val (e, ty) = checkExpr (env, cxt, e)
243                    in
244                      (e::es, ty::tys)
245                    end
246              in
247                List.foldr chk ([], []) exprs
248              end
249    
250        fun checkVarDecl (env, cxt, kind, d) = (case d
251               of PT.VD_Mark m => checkVarDecl (env, (#1 cxt, #span m), kind, #tree m)
252                | PT.VD_Decl(ty, x, e) => let
253                    val ty = checkTy (cxt, ty)
254                    val x' = Var.new (x, kind, ty)
255                    val (e', ty') = checkExpr (env, cxt, e)
256                    in
257    (* FIXME: check types *)
258                      (x, x', e')
259                    end
260            (* end case *))            (* end case *))
261    
262    (* typecheck a statement and translate it to AST *)    (* typecheck a statement and translate it to AST *)
263      fun checkStmt (env, s) = (case s      fun checkStmt (env, cxt, s) = (case s
264             of PT.S_Mark of stmt mark             of PT.S_Mark m => checkStmt (withEnvAndContext (env, cxt, m))
265              | PT.S_Block of stmt list              | PT.S_Block stms => let
266              | PT.S_Decl of var_decl                  fun chk (_, [], stms) = AST.S_Block(List.rev stms)
267              | PT.S_IfThen of expr * stmt                    | chk (env, s::ss, stms) = let
268              | PT.S_IfThenElse of expr * stmt * stmt                        val (s', env') = checkStmt (env, cxt, s)
269              | PT.S_Assign of var * expr                        in
270              | PT.S_New of var * expr list                          chk (env', ss, s'::stms)
271              | PT.S_Die                        end
272              | PT.S_Stabilize                  in
273            (* end case *))                    (chk (env, stms, []), env)
274                    end
275      fun checkDecl (env, d) = (case d              | PT.S_Decl vd => let
276             of PT.D_Mark of decl mark                  val (x, x', e) = checkVarDecl (env, cxt, Var.LocalVar, vd)
277              | PT.D_Input of ty * var * expr option      (* input variable decl with optional default *)                  in
278              | PT.D_Var of var_decl                      (* global variable decl *)                    (AST.S_Decl(AST.VD_Decl(x', e)), Env.insertLocal(env, x, x'))
279              | PT.D_Actor of {                           (* actor decl *)                  end
280                    name : var,              | PT.S_IfThen(e, s) => let
281                    params : param list,                  val (e', ty) = checkExpr (env, cxt, e)
282                    state : var_decl list,                  val (s', _) = checkStmt (env, cxt, s)
283                    methods : method list                  in
284                  }                  (* check that condition has bool type *)
285              | PT.D_InitialArray of create * iter list                    case ty
286              | PT.D_InitialCollection of create * iter list                     of Ty.T_Bool => ()
287                        | _ => raise Fail "condition not boolean type"
288                      (* end case *);
289                      (AST.S_IfThenElse(e', s', AST.S_Block[]), env)
290                    end
291                | PT.S_IfThenElse(e, s1, s2) => let
292                    val (e', ty) = checkExpr (env, cxt, e)
293                    val (s1', _) = checkStmt (env, cxt, s1)
294                    val (s2', _) = checkStmt (env, cxt, s2)
295                    in
296                    (* check that condition has bool type *)
297                      case ty
298                       of Ty.T_Bool => ()
299                        | _ => raise Fail "condition not boolean type"
300                      (* end case *);
301                      (AST.S_IfThenElse(e', s1', s2'), env)
302                    end
303                | PT.S_Assign(x, e) => (case Env.findVar (env, x)
304                     of NONE => raise Fail "undefined variable"
305                      | SOME x' => let
306                          val (e', ty) = checkExpr (env, cxt, e)
307                          in
308    (* FIXME: check types *)
309                          (* check that x' is mutable *)
310                            case Var.kindOf x'
311                             of Var.ActorStateVar => ()
312                              | Var.LocalVar => ()
313                              | _ => raise Fail "assignment to immutable variable"
314                            (* end case *);
315                            (AST.S_Assign(x', e'), env)
316                          end
317            (* end case *))            (* end case *))
318  *)              | PT.S_New(actor, args) => let
319                    val argsAndTys' = List.map (fn e => checkExpr(env, cxt, e)) args
320                    val (args', tys') = ListPair.unzip argsAndTys'
321                    in
322    (* FIXME: check that actor is defined and has the argument types match *)
323                      (AST.S_New(actor, args'), env)
324                    end
325                | PT.S_Die => (AST.S_Die, env)
326                | PT.S_Stabilize => (AST.S_Stabilize, env)
327              (* end case *))
328    
329        fun checkParams (env, cxt, params) = let
330              fun chkParam (env, cxt, param) = (case param
331                     of PT.P_Mark m => chkParam (withEnvAndContext (env, cxt, m))
332                      | PT.P_Param(ty, x) => let
333                          val x' = Var.new(x, AST.ActorParam, checkTy (cxt, ty))
334                          in
335                            (x', Env.insertLocal(env, x, x'))
336                          end
337                    (* end case *))
338              fun chk (param, (xs, env)) = let
339                    val (x, env) = chkParam (env, cxt, param)
340                    in
341                      (x::xs, env)
342                    end
343              in
344    (* FIXME: need to check for multiple occurences of the same parameter name! *)
345                List.foldr chk ([], env) params
346              end
347    
348        fun checkMethod (env, cxt, meth) = (case meth
349               of PT.M_Mark m => checkMethod (withEnvAndContext (env, cxt, m))
350                | PT.M_Method(name, body) => let
351                    val (body, _) = checkStmt(env, cxt, body)
352                    in
353                      AST.M_Method(name, body)
354                    end
355              (* end case *))
356    
357        fun checkActor (env, cxt, {name, params, state, methods}) = let
358            (* check the actor parameters *)
359              val (params, env) = checkParams (env, cxt, params)
360            (* check the actor state variable definitions *)
361              val (vds, env) = let
362                    fun checkStateVar (vd, (vds, env)) = let
363                          val (x, x', e') = checkVarDecl (env, cxt, AST.ActorStateVar, vd)
364                          in
365                            (AST.VD_Decl(x', e')::vds, Env.insertLocal(env, x, x'))
366                          end
367                    val (vds, env) = List.foldl checkStateVar ([], env) state
368                    in
369                      (List.rev vds, env)
370                    end
371            (* check the actor methods *)
372              val methods = List.map (fn m => checkMethod (env, cxt, m)) methods
373              in
374                AST.D_Actor{name = name, params = params, state = vds, methods = methods}
375              end
376    
377      fun check (PT.Program dcls) = AST.Program()      fun checkCreate (env, cxt, PT.C_Mark m) = checkCreate (withEnvAndContext (env, cxt, m))
378          | checkCreate (env, cxt, PT.C_Create(actor, args)) = let
379              val (args, tys) = checkExprList (env, cxt, args)
380              in
381    (* FIXME: check against actor definition *)
382                AST.C_Create(actor, args)
383              end
384    
385        fun checkIter (env, cxt, PT.I_Mark m) = checkIter (withEnvAndContext (env, cxt, m))
386          | checkIter (env, cxt, PT.I_Range(x, e1, e2)) = let
387              val (e1', ty1) = checkExpr (env, cxt, e1)
388              val (e2', ty2) = checkExpr (env, cxt, e2)
389              val x' = Var.new(x, Var.LocalVar, Ty.T_Int)
390              val env' = Env.insertLocal(env, x, x')
391              in
392                case (ty1, ty2)
393                 of (Ty.T_Int, Ty.T_Int) => (AST.I_Range(x', e1', e2'), env')
394                  | _ => err(cxt, [
395                        S "range expressions must have integer type\n",
396                        S "  but found: ", TY ty1, S " .. ", TY ty2, S "\n"
397                      ])
398                (* end case *)
399              end
400    
401        fun checkIters (env, cxt, iters) = let
402              fun chk (env, [], iters) = (List.rev iters, env)
403                | chk (env, iter::rest, iters) = let
404                    val (iter, env) = checkIter (env, cxt, iter)
405                    in
406                      chk (env, rest, iter::iters)
407                    end
408              in
409                chk (env, iters, [])
410              end
411    
412        fun checkDecl (env, cxt, d) = (case d
413               of PT.D_Mark m => checkDecl (withEnvAndContext (env, cxt, m))
414                | PT.D_Input(ty, x, optExp) => let
415                    val ty = checkTy(cxt, ty)
416                    val x' = Var.new(x, Var.InputVar, ty)
417                    val dcl = (case optExp
418                           of NONE => AST.D_Input(x', NONE)
419                            | SOME e => let
420                                val (e', ty') = checkExpr (env, cxt, e)
421                                in
422                                  if U.matchType (ty, ty')
423                                    then AST.D_Input(x', SOME e')
424                                    else err(cxt, [
425                                        S "definition of ", V x', S " has wrong type\n",
426                                        S "  expected:  ", TY ty, S "\n",
427                                        S "  but found: ", TY ty', S "\n"
428                                      ])
429                                end
430                          (* end case *))
431                    in
432                      (dcl, Env.insertGlobal(env, x, x'))
433                    end
434                | PT.D_Var vd => let
435                    val (x, x', e') = checkVarDecl (env, cxt, Var.GlobalVar, vd)
436                    in
437                      (AST.D_Var(AST.VD_Decl(x', e')), Env.insertGlobal(env, x, x'))
438                    end
439                | PT.D_Actor arg => (checkActor(env, cxt, arg), env)
440                | PT.D_InitialArray(create, iterators) => let
441                    val (iterators, env') = checkIters (env, cxt, iterators)
442                    val create = checkCreate (env', cxt, create)
443                    in
444                      (AST.D_InitialArray(create, iterators), env)
445                    end
446                | PT.D_InitialCollection(create, iterators) => let
447                    val (iterators, env') = checkIters (env, cxt, iterators)
448                    val create = checkCreate (env', cxt, create)
449                    in
450                      (AST.D_InitialCollection(create, iterators), env)
451                    end
452              (* end case *))
453    
454        fun check errStrm (PT.Program{span, tree}) = let
455              val cxt = (errStrm, span)
456              fun chk (env, [], dcls') = AST.Program(List.rev dcls')
457                | chk (env, dcl::dcls, dcls') = let
458                    val (dcl', env) = checkDecl (env, cxt, dcl)
459                    in
460                      chk (env, dcls, dcl'::dcls')
461                    end
462              in
463                chk (Basis.env, tree, [])
464              end
465    
466    end    end

Legend:
Removed from v.70  
changed lines
  Added in v.96

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0