Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/staging/src/compiler/typechecker/typechecker.sml
ViewVC logotype

Diff of /branches/staging/src/compiler/typechecker/typechecker.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2151, Sun Feb 17 19:07:20 2013 UTC revision 2152, Sun Feb 17 19:39:37 2013 UTC
# Line 1  Line 1 
1  (* typechecker.sml  (* typechecker.sml
2   *   *
3   * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * COPYRIGHT (c) 2013 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4   * All rights reserved.   * All rights reserved.
5   *   *
6   * TODO:   * TODO:
7   *      check that variables are not redefined in the same scope   *      check for unreachable code and prune it (see simplify/simplify.sml)
8   *      int --> real type promotion   *      error recovery so that we can detect multiple errors in a single compile
9   *      sequence operations   *      check that functions have a return on all paths
10     *      check that the args of strand creation have the same type and number as the params
11   *)   *)
12    
13  structure Typechecker : sig  structure Typechecker : sig
14    
     exception Error  
   
15      val check : Error.err_stream -> ParseTree.program -> AST.program      val check : Error.err_stream -> ParseTree.program -> AST.program
16    
17    end = struct    end = struct
# Line 23  Line 22 
22      structure TU = TypeUtil      structure TU = TypeUtil
23      structure U = Util      structure U = Util
24    
25      datatype scope = GlobalScope | StrandScope | MethodScope | InitScope    (* exception to abort typechecking when we hit an error.  Eventually, we should continue
26       * checking for more errors and not use this.
27       *)
28        exception TypeError
29    
30        datatype scope
31          = GlobalScope
32          | FunctionScope of Ty.ty
33          | StrandScope
34          | MethodScope
35          | InitScope
36    
37      type env = {      type env = {
38          scope : scope,          scope : scope,
# Line 34  Line 43 
43      type context = Error.err_stream * Error.span      type context = Error.err_stream * Error.span
44    
45    (* start a new scope *)    (* start a new scope *)
46    (* QUESTION: do we want to restrict access to globals from a function? *)
47        fun functionScope ({scope, bindings, env}, ty) =
48              {scope=FunctionScope ty, bindings=AtomMap.empty, env=env}
49      fun strandScope {scope, bindings, env} =      fun strandScope {scope, bindings, env} =
50            {scope=StrandScope, bindings=AtomMap.empty, env=env}            {scope=StrandScope, bindings=AtomMap.empty, env=env}
51      fun methodScope {scope, bindings, env} =      fun methodScope {scope, bindings, env} =
# Line 47  Line 59 
59        | inStrand {scope=MethodScope, ...} = true        | inStrand {scope=MethodScope, ...} = true
60        | inStrand _ = false        | inStrand _ = false
61    
62        fun insertFunc ({scope, bindings, env}, cxt, f, f') = {
63                scope=scope,
64                bindings = AtomMap.insert(bindings, f, Error.location cxt),
65                env=Env.insertFunc(env, f, Env.UserFun f')
66              }
67      fun insertLocal ({scope, bindings, env}, cxt, x, x') = {      fun insertLocal ({scope, bindings, env}, cxt, x, x') = {
68              scope=scope,              scope=scope,
69              bindings = AtomMap.insert(bindings, x, Error.location cxt),              bindings = AtomMap.insert(bindings, x, Error.location cxt),
# Line 58  Line 75 
75              env=Env.insertGlobal(env, x, x')              env=Env.insertGlobal(env, x, x')
76            }            }
77    
     exception Error  
   
78      fun withContext ((errStrm, _), {span, tree}) =      fun withContext ((errStrm, _), {span, tree}) =
79            ((errStrm, span), tree)            ((errStrm, span), tree)
80      fun withEnvAndContext (env, (errStrm, _), {span, tree}) =      fun withEnvAndContext (env, (errStrm, _), {span, tree}) =
# Line 67  Line 82 
82    
83      fun error ((errStrm, span), msg) = (      fun error ((errStrm, span), msg) = (
84            Error.errorAt(errStrm, span, msg);            Error.errorAt(errStrm, span, msg);
85            raise Error)            raise TypeError)
86    
87      datatype token      datatype token
88        = S of string | A of Atom.atom        = S of string | A of Atom.atom
# Line 99  Line 114 
114    
115      val realZero = AST.E_Lit(Literal.Float(FloatLit.zero true))      val realZero = AST.E_Lit(Literal.Float(FloatLit.zero true))
116    
117    (* check a differentiation level, which muse be >= 0 *)    (* check a differentiation level, which must be >= 0 *)
118      fun checkDiff (cxt, k) =      fun checkDiff (cxt, k) =
119            if (k < 0)            if (k < 0)
120              then err (cxt, [S "differentiation must be >= 0"])              then err (cxt, [S "differentiation must be >= 0"])
121              else Ty.DiffConst(IntInf.toInt k)              else Ty.DiffConst(IntInf.toInt k)
122    
123      (* check a sequence dimension, which must be > 0 *)
124        fun checkSeqDim (cxt, d) =
125              if (d < 0)
126                then err (cxt, [S "invalid dimension; must be positive"])
127                else Ty.DimConst(IntInf.toInt d)
128    
129    (* check a dimension, which must be 1, 2 or 3 *)    (* check a dimension, which must be 1, 2 or 3 *)
130      fun checkDim (cxt, d) =      fun checkDim (cxt, d) =
131            if (d < 1) orelse (3 < d)            if (d < 1) orelse (3 < d)
# Line 145  Line 166 
166                  val ty = checkTy(cxt, ty)                  val ty = checkTy(cxt, ty)
167                  in                  in
168                    if TU.isFixedSizeType ty                    if TU.isFixedSizeType ty
169                      then Ty.T_Sequence(ty, checkDim (cxt, dim))                      then Ty.T_Sequence(ty, checkSeqDim (cxt, dim))
170                      else err(cxt, [S "elements of sequence types must be fixed-size types"])                      else err(cxt, [S "elements of sequence types must be fixed-size types"])
171                  end                  end
172            (* end case *))            (* end case *))
# Line 157  Line 178 
178              | (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool)              | (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
179            (* end case *))            (* end case *))
180    
181      fun coerceType (ty1, ty2, e) = (case U.matchType(ty1, ty2)      fun coerceExp (Ty.T_Tensor(Ty.Shape[]), Ty.T_Int, AST.E_Lit(Literal.Int n)) =
182              AST.E_Lit(Literal.Float(FloatLit.fromInt n))
183          | coerceExp (ty1, ty2, e) = AST.E_Coerce{srcTy=ty2, dstTy=ty1, e=e}
184    
185        fun coerceType (dstTy, srcTy, e) = (case U.matchType(dstTy, srcTy)
186             of U.EQ => SOME e             of U.EQ => SOME e
187              | U.COERCE => SOME(AST.E_Coerce{srcTy=ty2, dstTy=ty1, e=e})              | U.COERCE => SOME(coerceExp (dstTy, srcTy, e))
188              | U.FAIL => NONE              | U.FAIL => NONE
189            (* end case *))            (* end case *))
190    
# Line 239  Line 264 
264                          if U.equalType(ty1, ty2)                          if U.equalType(ty1, ty2)
265                            then (AST.E_Cond(cond', e1', e2', ty1), ty1)                            then (AST.E_Cond(cond', e1', e2', ty1), ty1)
266                            else err (cxt, [                            else err (cxt, [
267                                S "type do not match in conditional expression\n",                                S "types do not match in conditional expression\n",
268                                S "  true branch:  ", TY ty1, S "\n",                                S "  true branch:  ", TY ty1, S "\n",
269                                S "  false branch: ", TY ty2                                S "  false branch: ", TY ty2
270                              ])                              ])
# Line 309  Line 334 
334                              ])                              ])
335                        (* end case *))                        (* end case *))
336                      else (case Env.findFunc (#env env, rator)                      else (case Env.findFunc (#env env, rator)
337                         of [rator] => let                         of Env.PrimFun[rator] => let
338                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)
339                              in                              in
340                                case U.matchArgs(domTy, [e1', e2'], [ty1, ty2])                                case U.matchArgs(domTy, [e1', e2'], [ty1, ty2])
# Line 321  Line 346 
346                                      ])                                      ])
347                                (* end case *)                                (* end case *)
348                              end                              end
349                          | ovldList => resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)                          | Env.PrimFun ovldList =>
350                                resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
351                            | _ => raise Fail "impossible"
352                        (* end case *))                        (* end case *))
353                  end                  end
354              | PT.E_UnaryOp(rator, e) => let              | PT.E_UnaryOp(rator, e) => let
355                  val (e', ty) = checkExpr(env, cxt, e)                  val (e', ty) = checkExpr(env, cxt, e)
356                  in                  in
357                    case Env.findFunc (#env env, rator)                    case Env.findFunc (#env env, rator)
358                     of [rator] => let                     of Env.PrimFun[rator] => let
359                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)
360                          in                          in
361                            case coerceType (domTy, ty, e')                            case coerceType (domTy, ty, e')
# Line 340  Line 367 
367                                  ])                                  ])
368                            (* end case *)                            (* end case *)
369                          end                          end
370                      | ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)
371                        | _ => raise Fail "impossible"
372                    (* end case *)                    (* end case *)
373                  end                  end
374              | PT.E_Slice(e, indices) => let              | PT.E_Slice(e, indices) => let
# Line 405  Line 433 
433                  fun stripMark (PT.E_Mark{tree, ...}) = stripMark tree                  fun stripMark (PT.E_Mark{tree, ...}) = stripMark tree
434                    | stripMark e = e                    | stripMark e = e
435                  val (args, tys) = checkExprList (env, cxt, args)                  val (args, tys) = checkExprList (env, cxt, args)
436                    fun checkFunApp f = (case Util.instantiate(Var.typeOf f)
437                           of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (
438                                case U.matchArgs (domTy, args, tys)
439                                 of SOME args => (AST.E_Apply(f, tyArgs, args, rngTy), rngTy)
440                                  | NONE => err(cxt, [
441                                        S "type error in application of ", V f, S "\n",
442                                        S "  expected:  ", TYS domTy, S "\n",
443                                        S "  but found: ", TYS tys
444                                      ])
445                                (* end case *))
446                            | _ => err(cxt, [S "application of non-function ", V f])
447                          (* end case *))
448                  fun checkFieldApp (e1', ty1) = (case (args, tys)                  fun checkFieldApp (e1', ty1) = (case (args, tys)
449                         of ([e2'], [ty2]) => let                         of ([e2'], [ty2]) => let
450                              val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =                              val (tyArgs, Ty.T_Fun([fldTy, domTy], rngTy)) =
# Line 429  Line 469 
469                     of PT.E_Var f => (case Env.findVar (#env env, f)                     of PT.E_Var f => (case Env.findVar (#env env, f)
470                           of SOME f' => checkFieldApp (AST.E_Var f', Var.monoTypeOf f')                           of SOME f' => checkFieldApp (AST.E_Var f', Var.monoTypeOf f')
471                            | NONE => (case Env.findFunc (#env env, f)                            | NONE => (case Env.findFunc (#env env, f)
472                                 of [] => err(cxt, [S "unknown function ", A f])                                 of Env.PrimFun[] => err(cxt, [S "unknown function ", A f])
473                                  | [f] =>                                  | Env.PrimFun[f'] =>
474                                      if (inStrand env) andalso (Basis.isRestricted f)                                      if (inStrand env) andalso (Basis.isRestricted f')
475                                        then err(cxt, [                                        then err(cxt, [
476                                            S "use of restricted operation ", V f,                                            S "use of restricted operation ", V f',
477                                            S " in strand body"                                            S " in strand body"
478                                          ])                                          ])
479                                        else (case Util.instantiate(Var.typeOf f)                                        else checkFunApp f'
480                                           of (tyArgs, Ty.T_Fun(domTy, rngTy)) => (                                  | Env.PrimFun ovldList =>
481                                                case U.matchArgs (domTy, args, tys)                                      resolveOverload (cxt, f, tys, args, ovldList)
482                                                 of SOME args => (AST.E_Apply(f, tyArgs, args, rngTy), rngTy)                                  | Env.UserFun f' => checkFunApp f'
                                                 | NONE => err(cxt, [  
                                                       S "type error in application of ", V f, S "\n",  
                                                       S "  expected:  ", TYS domTy, S "\n",  
                                                       S "  but found: ", TYS tys  
                                                     ])  
                                               (* end case *))  
                                           | _ => err(cxt, [S "application of non-function ", V f])  
                                         (* end case *))  
                                 | ovldList => resolveOverload (cxt, f, tys, args, ovldList)  
483                                (* end case *))                                (* end case *))
484                            (* end case *))                            (* end case *))
485                      | _ => checkFieldApp (checkExpr (env, cxt, e))                      | _ => checkFieldApp (checkExpr (env, cxt, e))
# Line 629  Line 660 
660                        val e1' = AST.E_Var x'                        val e1' = AST.E_Var x'
661                        val ty1 = Var.monoTypeOf x'                        val ty1 = Var.monoTypeOf x'
662                        val (e2', ty2) = checkExpr(env, cxt, e)                        val (e2', ty2) = checkExpr(env, cxt, e)
663                        val ovldList = Env.findFunc (#env env, rator)                        val Env.PrimFun ovldList = Env.findFunc (#env env, rator)
664                        val (rhs, _) = resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)                        val (rhs, _) = resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
665                        in                        in
666                          (AST.S_Assign(x', rhs), env)                          (AST.S_Assign(x', rhs), env)
# Line 660  Line 691 
691                    | _ => err(cxt, [S "\"stabilize\" statment outside of method"])                    | _ => err(cxt, [S "\"stabilize\" statment outside of method"])
692                  (* end case *);                  (* end case *);
693                  (AST.S_Stabilize, env))                  (AST.S_Stabilize, env))
694                | PT.S_Return e => let
695                    val (e', ty) = checkExpr (env, cxt, e)
696                    in
697                      case #scope env
698                       of FunctionScope ty' => (case coerceType(ty', ty, e')
699                             of SOME e' => (AST.S_Return e', env)
700                              | NONE => err(cxt, [
701                                    S "type of return expression does not match function's return type\n",
702                                    S "  expected: ", TY ty', S "\n",
703                                    S "  but found: ", TY ty
704                                  ])
705                            (* end case *))
706                        | _ => err(cxt, [S "\"return\" statment outside of function"])
707                      (* end case *)
708                    end
709              | PT.S_Print args => let              | PT.S_Print args => let
710                  fun chkArg e = let                  fun chkArg e = let
711                        val (e', ty) = checkExpr (env, cxt, e)                        val (e', ty) = checkExpr (env, cxt, e)
# Line 683  Line 729 
729                    | PT.P_Param(ty, x) => let                    | PT.P_Param(ty, x) => let
730                        val x' = Var.new(x, AST.StrandParam, checkTy (cxt, ty))                        val x' = Var.new(x, AST.StrandParam, checkTy (cxt, ty))
731                        in                        in
732    (* FIXME: should use an empty bindings list for the parameters *)
733                          checkForRedef (env, cxt, x);                          checkForRedef (env, cxt, x);
734                          (x', insertLocal(env, cxt, x, x'))                          (x', insertLocal(env, cxt, x, x'))
735                        end                        end
# Line 693  Line 740 
740                    (x::xs, env)                    (x::xs, env)
741                  end                  end
742            in            in
 (* FIXME: need to check for multiple occurences of the same parameter name! *)  
743              List.foldr chk ([], env) params              List.foldr chk ([], env) params
744            end            end
745    
# Line 710  Line 756 
756          (* check the strand parameters *)          (* check the strand parameters *)
757            val (params, env) = checkParams (env, cxt, params)            val (params, env) = checkParams (env, cxt, params)
758          (* check the strand state variable definitions *)          (* check the strand state variable definitions *)
759            val (vds, env) = let            val (vds, hasOutput, env) = let
760                  fun checkStateVar ((isOut, vd), (vds, env)) = let                  fun checkStateVar ((isOut, vd), (vds, hasOut, env)) = let
761                        val kind = if isOut then AST.StrandOutputVar else AST.StrandStateVar                        val kind = if isOut then AST.StrandOutputVar else AST.StrandStateVar
762                        val (x, x', e') = checkVarDecl (env, cxt, kind, vd)                        val (x, x', e') = checkVarDecl (env, cxt, kind, vd)
763                        in                        in
# Line 723  Line 769 
769                              ])                              ])
770                            else ();                            else ();
771                          checkForRedef (env, cxt, x);                          checkForRedef (env, cxt, x);
772                          (AST.VD_Decl(x', e')::vds, insertLocal(env, cxt, x, x'))                          (AST.VD_Decl(x', e')::vds, hasOut orelse isOut, insertLocal(env, cxt, x, x'))
773                        end                        end
774                  val (vds, env) = List.foldl checkStateVar ([], env) state                  val (vds, hasOutput, env) = List.foldl checkStateVar ([], false, env) state
775                  in                  in
776                    (List.rev vds, env)                    (List.rev vds, hasOutput, env)
777                  end                  end
778          (* check the strand methods *)          (* check the strand methods *)
779            val methods = List.map (fn m => checkMethod (env, cxt, m)) methods            val methods = List.map (fn m => checkMethod (env, cxt, m)) methods
# Line 738  Line 784 
784                  then methods                  then methods
785                  else methods @ [AST.M_Method(StrandUtil.Stabilize, AST.S_Block[])]                  else methods @ [AST.M_Method(StrandUtil.Stabilize, AST.S_Block[])]
786            in            in
787    (* FIXME: once there are global outputs, then it should be okay to have not strand outputs! *)
788            (* check that there is at least one output variable *)
789                if not hasOutput
790                  then err(cxt, [S "strand ", A name, S " does not have any outputs"])
791                  else ();
792  (* FIXME: should check for duplicate method definitions *)  (* FIXME: should check for duplicate method definitions *)
793              if not(List.exists (fn StrandUtil.Update => true | _ => false) methodNames)              if not(List.exists (fn StrandUtil.Update => true | _ => false) methodNames)
794                then err(cxt, [S "strand ", A name, S " is missing an update method"])                then err(cxt, [S "strand ", A name, S " is missing an update method"])
# Line 749  Line 800 
800        | checkCreate (env, cxt, PT.C_Create(strand, args)) = let        | checkCreate (env, cxt, PT.C_Create(strand, args)) = let
801            val (args, tys) = checkExprList (env, cxt, args)            val (args, tys) = checkExprList (env, cxt, args)
802            in            in
803  (* FIXME: check against strand definition *)  (* FIXME: check args against strand definition *)
804              AST.C_Create(strand, args)              AST.C_Create(strand, args)
805            end            end
806    
# Line 819  Line 870 
870                    checkForRedef (env, cxt, x);                    checkForRedef (env, cxt, x);
871                    (AST.D_Var(AST.VD_Decl(x', e')), insertGlobal(env, cxt, x, x'))                    (AST.D_Var(AST.VD_Decl(x', e')), insertGlobal(env, cxt, x, x'))
872                  end                  end
873                | PT.D_Func(ty, f, params, body) => let
874                    val ty' = checkTy(cxt, ty)
875                    val (params', env') = checkParams (env, cxt, params)
876                    val body' = (case body
877                           of PT.FB_Expr e => let
878                                val (e', ty) = checkExpr (env', cxt, e)
879                                in
880                                  case coerceType(ty', ty, e')
881                                   of SOME e' => AST.S_Return e'
882                                    | NONE => err(cxt, [
883                                          S "type of function body does not match return type\n",
884                                          S "  expected: ", TY ty', S "\n",
885                                          S "  but found: ", TY ty
886                                        ])
887                                  (* end case *)
888                                end
889    (* FIXME: we need to check that there is a return on all control-flow paths *)
890                            | PT.FB_Stmt s => #1(checkStmt(functionScope (env', ty'), cxt, s))
891                          (* end case *))
892                    val fnTy = Ty.T_Fun(List.map Var.monoTypeOf params', ty')
893                    val f' = Var.new (f, AST.FunVar, fnTy)
894                    in
895    (* QUESTION: should we check for redefinition of the f? *)
896                      (AST.D_Func(f', params', body'), insertFunc(env, cxt, f, f'))
897                    end
898              | PT.D_Strand arg => (checkStrand(strandScope env, cxt, arg), env)              | PT.D_Strand arg => (checkStrand(strandScope env, cxt, arg), env)
899              | PT.D_InitialArray(create, iterators) => let              | PT.D_InitialArray(create, iterators) => let
900                  val env = initScope env                  val env = initScope env
# Line 856  Line 932 
932            in            in
933              chk ({scope=GlobalScope, bindings=AtomMap.empty, env=Basis.env}, tree, [])              chk ({scope=GlobalScope, bindings=AtomMap.empty, env=Basis.env}, tree, [])
934            end            end
935                handle TypeError => AST.Program[]
936    
937    end    end

Legend:
Removed from v.2151  
changed lines
  Added in v.2152

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0