SCM Repository
[diderot] / trunk / src / compiler / typechecker / typechecker.sml |
View of /trunk/src/compiler/typechecker/typechecker.sml
Parent Directory
|
Revision Log
Revision 110 -
(download)
(annotate)
Wed Jun 23 19:28:48 2010 UTC (12 years ago) by jhr
File size: 15525 byte(s)
Wed Jun 23 19:28:48 2010 UTC (12 years ago) by jhr
File size: 15525 byte(s)
Moving compiler sources into src/compiler
(* typechecker.sml * * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) * All rights reserved. *) structure Typechecker : sig exception Error val check : Error.err_stream -> ParseTree.program -> AST.program end = struct structure PT = ParseTree structure Ty = Types structure TU = TypeUtil structure U = Util exception Error type context = Error.err_stream * Error.span fun withContext ((errStrm, _), {span, tree}) = ((errStrm, span), tree) fun withEnvAndContext (env, (errStrm, _), {span, tree}) = (env, (errStrm, span), tree) fun error ((errStrm, span), msg) = ( Error.errorAt(errStrm, span, msg); raise Error) datatype token = S of string | A of Atom.atom | V of AST.var | TY of Types.ty | TYS of Types.ty list fun err (cxt, toks) = let fun tok2str (S s) = s | tok2str (A a) = Atom.toString a | tok2str (V x) = Var.nameOf x | tok2str (TY ty) = TU.toString ty | tok2str (TYS []) = "()" | tok2str (TYS[ty]) = TU.toString ty | tok2str (TYS tys) = String.concat[ "(", String.concatWith " * " (List.map TU.toString tys), ")" ] in error(cxt, List.map tok2str toks) end val realZero = AST.E_Lit(Literal.Float(FloatLit.zero true)) (* check a differentiation level, which muse be >= 0 *) fun checkDiff (cxt, k) = if (k < 0) then raise Fail "differentiation must be >= 0" else Ty.DiffConst(IntInf.toInt k) (* check a dimension, which must be 2 or 3 *) fun checkDim (cxt, d) = if (d <= 0) then raise Fail "invalid dimension; must be > 0" else Ty.DimConst(IntInf.toInt d) (* check a shape *) fun checkShape (cxt, shape) = let fun chkDim d = if (d < 1) then raise Fail "invalid shape dimension; must be >= 1" else Ty.DimConst(IntInf.toInt d) in Ty.Shape(List.map chkDim shape) end (* check the well-formedness of a type and translate it to an AST type *) fun checkTy (cxt, ty) = (case ty of PT.T_Mark m => checkTy(withContext(cxt, m)) | PT.T_Bool => Ty.T_Bool | PT.T_Int => Ty.T_Int | PT.T_Real => Ty.realTy | PT.T_String => Ty.T_String | PT.T_Vec n => (* NOTE: the parser guarantees that 2 <= n <= 4 *) Ty.vecTy(IntInf.toInt n) | PT.T_Kernel k => Ty.T_Kernel(checkDiff(cxt, k)) | PT.T_Field{diff, dim, shape} => Ty.T_Field{ diff = checkDiff (cxt, diff), dim = checkDim (cxt, dim), shape = checkShape (cxt, shape) } | PT.T_Tensor shape => Ty.T_Tensor(checkShape(cxt, shape)) | PT.T_Image{dim, shape} => Ty.T_Image{ dim = checkDim (cxt, dim), shape = checkShape (cxt, shape) } | PT.T_Array(ty, dims) => raise Fail "Array type not supported" (* end case *)) fun checkLit lit = (case lit of (Literal.Int _) => (AST.E_Lit lit, Ty.T_Int) | (Literal.Float _) => (AST.E_Lit lit, Ty.realTy) | (Literal.String s) => (AST.E_Lit lit, Ty.T_String) | (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool) (* end case *)) (* resolve overloading: we use a simple scheme that selects the first operator in the * list that matches the argument types. *) fun resolveOverload (cxt, rator, argTys, args, candidates) = let fun tryCandidates [] = err(cxt, [ S "unable to resolve overloaded operator \"", A rator, S "\"\n", S " argument type is: ", TYS argTys, S "\n" ]) | tryCandidates (x::xs) = let val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf x) in if U.tryMatchTypes(domTy, argTys) then (AST.E_Apply(x, tyArgs, args, rngTy), rngTy) else tryCandidates xs end in tryCandidates candidates end (* typecheck an expression and translate it to AST *) fun checkExpr (env, cxt, e) = (case e of PT.E_Mark m => checkExpr (withEnvAndContext (env, cxt, m)) | PT.E_Var x => (case Env.findVar (env, x) of SOME x' => let val (args, ty) = Util.instantiate(Var.typeOf x') in (AST.E_Var(x', args, ty), ty) end | NONE => err(cxt, [S "undeclared variable ", A x]) (* end case *)) | PT.E_Lit lit => checkLit lit | PT.E_OrElse(e1, e2) => let val (e1', ty1) = checkExpr(env, cxt, e1) val (e2', ty2) = checkExpr(env, cxt, e2) in case (ty1, ty2) of (Ty.T_Bool, Ty.T_Bool) => (AST.E_Cond(e1', AST.E_Lit(Literal.Bool true), e2'), Ty.T_Bool) | _ => err (cxt, [S "arguments to \"||\" must have bool type"]) (* end case *) end | PT.E_AndAlso(e1, e2) => let val (e1', ty1) = checkExpr(env, cxt, e1) val (e2', ty2) = checkExpr(env, cxt, e2) in case (ty1, ty2) of (Ty.T_Bool, Ty.T_Bool) => (AST.E_Cond(e1', e2', AST.E_Lit(Literal.Bool false)), Ty.T_Bool) | _ => err (cxt, [S "arguments to \"&&\" must have bool type"]) (* end case *) end | PT.E_BinOp(e1, rator, e2) => let val (e1', ty1) = checkExpr(env, cxt, e1) val (e2', ty2) = checkExpr(env, cxt, e2) in case Basis.findOp rator of [rator] => let val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator) in if U.matchTypes(domTy, [ty1, ty2]) then (AST.E_Apply(rator, tyArgs, [e1', e2'], rngTy), rngTy) else err (cxt, [ S "type error for binary operator \"", V rator, S "\"\n", S " expected: ", TYS domTy, S "\n", S " but found: ", TYS[ty1, ty2], S "\n" ]) end | ovldList => resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList) (* end case *) end | PT.E_UnaryOp(rator, e) => let val (e', ty) = checkExpr(env, cxt, e) in case Basis.findOp rator of [rator] => let val (tyArgs, Ty.T_Fun([domTy], rngTy)) = Util.instantiate(Var.typeOf rator) in if U.matchType(domTy, ty) then (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy) else err (cxt, [ S "type error for unary operator \"", V rator, S "\"\n", S " expected: ", TY domTy, S "\n", S " but found: ", TY ty, S "\n" ]) end | ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList) (* end case *) end | PT.E_Tuple args => let val (args, tys) = checkExprList (env, cxt, args) in raise Fail "E_Tuple not yet implemented" end | PT.E_Apply(f, args) => let val (args, tys) = checkExprList (env, cxt, args) in case Env.findVar (env, f) of SOME f => (case Util.instantiate(Var.typeOf f) of (tyArgs, Ty.T_Fun(domTy, rngTy)) => if U.matchTypes(domTy, tys) then (AST.E_Apply(f, tyArgs, args, rngTy), rngTy) else err(cxt, [ S "type error in application of ", V f, S "\n", S " expected: ", TYS domTy, S "\n", S " but found: ", TYS tys, S "\n" ]) | _ => err(cxt, [S "application of non-function ", V f]) (* end case *)) | NONE => err(cxt, [S "unknown function ", A f]) (* end case *) end | PT.E_Cons args => let val (args, ty::tys) = checkExprList (env, cxt, args) in case TU.pruneHead ty of Ty.T_Tensor shape => let fun chkTy ty' = U.matchType(ty, ty') val resTy = Ty.T_Tensor(Ty.shapeExt(shape, Ty.DimConst(List.length args))) in if List.all chkTy tys then (AST.E_Cons args, resTy) else err(cxt, [S "arguments of tensor construction must have same type"]) end | _ => err(cxt, [S "Invalid argument type for tensor construction"]) (* end case *) end | PT.E_Real e => (case checkExpr (env, cxt, e) of (e', Ty.T_Int) => (AST.E_Apply(BasisVars.i2r, [], [e'], Ty.realTy), Ty.realTy) | _ => err(cxt, [S "argument of real conversion must be int"]) (* end case *)) (* end case *)) (* typecheck a list of expressions returning a list of AST expressions and a list * of types of the expressions. *) and checkExprList (env, cxt, exprs) = let fun chk (e, (es, tys)) = let val (e, ty) = checkExpr (env, cxt, e) in (e::es, ty::tys) end in List.foldr chk ([], []) exprs end fun checkVarDecl (env, cxt, kind, d) = (case d of PT.VD_Mark m => checkVarDecl (env, (#1 cxt, #span m), kind, #tree m) | PT.VD_Decl(ty, x, e) => let val ty = checkTy (cxt, ty) val x' = Var.new (x, kind, ty) val (e', ty') = checkExpr (env, cxt, e) in (* FIXME: this check is not flexible enough; should allow lhs type to support * fewer levels of differentiation than rhs provides. *) if U.matchType(ty, ty') then (x, x', e') else err(cxt, [ S "type of variable ", A x, S " does not match type of initializer\n", S " expected: ", TY ty, S "\n", S " but found: ", TY ty', S "\n" ]) end (* end case *)) (* typecheck a statement and translate it to AST *) fun checkStmt (env, cxt, s) = (case s of PT.S_Mark m => checkStmt (withEnvAndContext (env, cxt, m)) | PT.S_Block stms => let fun chk (_, [], stms) = AST.S_Block(List.rev stms) | chk (env, s::ss, stms) = let val (s', env') = checkStmt (env, cxt, s) in chk (env', ss, s'::stms) end in (chk (env, stms, []), env) end | PT.S_Decl vd => let val (x, x', e) = checkVarDecl (env, cxt, Var.LocalVar, vd) in (AST.S_Decl(AST.VD_Decl(x', e)), Env.insertLocal(env, x, x')) end | PT.S_IfThen(e, s) => let val (e', ty) = checkExpr (env, cxt, e) val (s', _) = checkStmt (env, cxt, s) in (* check that condition has bool type *) case ty of Ty.T_Bool => () | _ => err(cxt, [S "condition not boolean type"]) (* end case *); (AST.S_IfThenElse(e', s', AST.S_Block[]), env) end | PT.S_IfThenElse(e, s1, s2) => let val (e', ty) = checkExpr (env, cxt, e) val (s1', _) = checkStmt (env, cxt, s1) val (s2', _) = checkStmt (env, cxt, s2) in (* check that condition has bool type *) case ty of Ty.T_Bool => () | _ => err(cxt, [S "condition not boolean type"]) (* end case *); (AST.S_IfThenElse(e', s1', s2'), env) end | PT.S_Assign(x, e) => (case Env.findVar (env, x) of NONE => err(cxt, [ S "undefined variable ", A x ]) | SOME x' => let (* FIXME: check for polymorphic variables *) val ([], ty) = Var.typeOf x' val (e', ty') = checkExpr (env, cxt, e) in if U.matchType(ty, ty') then (x, x', e') else err(cxt, [ S "type of assigned variable ", A x, S " does not match type of rhs\n", S " expected: ", TY ty, S "\n", S " but found: ", TY ty', S "\n" ]); (* check that x' is mutable *) case Var.kindOf x' of Var.ActorStateVar => () | Var.LocalVar => () | _ => err(cxt, [ S "assignment to immutable variable ", A x ]) (* end case *); (AST.S_Assign(x', e'), env) end (* end case *)) | PT.S_New(actor, args) => let val argsAndTys' = List.map (fn e => checkExpr(env, cxt, e)) args val (args', tys') = ListPair.unzip argsAndTys' in (* FIXME: check that actor is defined and has the argument types match *) (AST.S_New(actor, args'), env) end | PT.S_Die => (AST.S_Die, env) | PT.S_Stabilize => (AST.S_Stabilize, env) (* end case *)) fun checkParams (env, cxt, params) = let fun chkParam (env, cxt, param) = (case param of PT.P_Mark m => chkParam (withEnvAndContext (env, cxt, m)) | PT.P_Param(ty, x) => let val x' = Var.new(x, AST.ActorParam, checkTy (cxt, ty)) in (x', Env.insertLocal(env, x, x')) end (* end case *)) fun chk (param, (xs, env)) = let val (x, env) = chkParam (env, cxt, param) in (x::xs, env) end in (* FIXME: need to check for multiple occurences of the same parameter name! *) List.foldr chk ([], env) params end fun checkMethod (env, cxt, meth) = (case meth of PT.M_Mark m => checkMethod (withEnvAndContext (env, cxt, m)) | PT.M_Method(name, body) => let val (body, _) = checkStmt(env, cxt, body) in AST.M_Method(name, body) end (* end case *)) fun checkActor (env, cxt, {name, params, state, methods}) = let (* check the actor parameters *) val (params, env) = checkParams (env, cxt, params) (* check the actor state variable definitions *) val (vds, env) = let fun checkStateVar (vd, (vds, env)) = let val (x, x', e') = checkVarDecl (env, cxt, AST.ActorStateVar, vd) in (AST.VD_Decl(x', e')::vds, Env.insertLocal(env, x, x')) end val (vds, env) = List.foldl checkStateVar ([], env) state in (List.rev vds, env) end (* check the actor methods *) val methods = List.map (fn m => checkMethod (env, cxt, m)) methods in AST.D_Actor{name = name, params = params, state = vds, methods = methods} end fun checkCreate (env, cxt, PT.C_Mark m) = checkCreate (withEnvAndContext (env, cxt, m)) | checkCreate (env, cxt, PT.C_Create(actor, args)) = let val (args, tys) = checkExprList (env, cxt, args) in (* FIXME: check against actor definition *) AST.C_Create(actor, args) end fun checkIter (env, cxt, PT.I_Mark m) = checkIter (withEnvAndContext (env, cxt, m)) | checkIter (env, cxt, PT.I_Range(x, e1, e2)) = let val (e1', ty1) = checkExpr (env, cxt, e1) val (e2', ty2) = checkExpr (env, cxt, e2) val x' = Var.new(x, Var.LocalVar, Ty.T_Int) val env' = Env.insertLocal(env, x, x') in case (ty1, ty2) of (Ty.T_Int, Ty.T_Int) => (AST.I_Range(x', e1', e2'), env') | _ => err(cxt, [ S "range expressions must have integer type\n", S " but found: ", TY ty1, S " .. ", TY ty2, S "\n" ]) (* end case *) end fun checkIters (env, cxt, iters) = let fun chk (env, [], iters) = (List.rev iters, env) | chk (env, iter::rest, iters) = let val (iter, env) = checkIter (env, cxt, iter) in chk (env, rest, iter::iters) end in chk (env, iters, []) end fun checkDecl (env, cxt, d) = (case d of PT.D_Mark m => checkDecl (withEnvAndContext (env, cxt, m)) | PT.D_Input(ty, x, optExp) => let val ty = checkTy(cxt, ty) val x' = Var.new(x, Var.InputVar, ty) val dcl = (case optExp of NONE => AST.D_Input(x', NONE) | SOME e => let val (e', ty') = checkExpr (env, cxt, e) in if U.matchType (ty, ty') then AST.D_Input(x', SOME e') else err(cxt, [ S "definition of ", V x', S " has wrong type\n", S " expected: ", TY ty, S "\n", S " but found: ", TY ty', S "\n" ]) end (* end case *)) in (dcl, Env.insertGlobal(env, x, x')) end | PT.D_Var vd => let val (x, x', e') = checkVarDecl (env, cxt, Var.GlobalVar, vd) in (AST.D_Var(AST.VD_Decl(x', e')), Env.insertGlobal(env, x, x')) end | PT.D_Actor arg => (checkActor(env, cxt, arg), env) | PT.D_InitialArray(create, iterators) => let val (iterators, env') = checkIters (env, cxt, iterators) val create = checkCreate (env', cxt, create) in (AST.D_InitialArray(create, iterators), env) end | PT.D_InitialCollection(create, iterators) => let val (iterators, env') = checkIters (env, cxt, iterators) val create = checkCreate (env', cxt, create) in (AST.D_InitialCollection(create, iterators), env) end (* end case *)) fun check errStrm (PT.Program{span, tree}) = let val cxt = (errStrm, span) fun chk (env, [], dcls') = AST.Program(List.rev dcls') | chk (env, dcl::dcls, dcls') = let val (dcl', env) = checkDecl (env, cxt, dcl) in chk (env, dcls, dcl'::dcls') end in chk (Basis.env, tree, []) end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |