Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

View of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3398 - (download) (annotate)
Wed Nov 11 01:17:58 2015 UTC (3 years, 10 months ago) by jhr
File size: 15648 byte(s)
working on merge
(* check-expr.sml
 *
 * The typechecker for expressions.
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure CheckExpr : sig

    val check : Env.env * Env.context * ParseTree.expr -> (AST.expr * Types.ty)

  end = struct

    structure PT = ParseTree
    structure L = Literal
    structure E = Env
    structure Ty = Types
    structure BV = BasisVars

  (* an expression to return when there is a type error *)
    val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)

    fun err arg = (TypeError.error arg; bogusExp)
    val warn = TypeError.warning

    datatype tokens = datatype TypeError.tokens

  (* check the type of a literal *)
    fun checkLit lit = (case lit
           of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)
            | (L.Real _) => (AST.E_Lit lit, Ty.realTy)
            | (L.String s) => (AST.E_Lit lit, Ty.T_String)
            | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
          (* end case *))

  (* check the type of an expression *)
    fun check (env, cxt, e) = (case e
	   of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))
	    | PT.E_Cond(e1, cond, e2) => let
                val eTy1 = check (env, cxt, e1)
                val eTy2 = check (env, cxt, e2)
                in
                  case checkExpr(env, cxt, cond)
                   of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
			 of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
			  | NONE => err (cxt, [
                              S "types do not match in conditional expression\n",
                              S "  true branch:  ", TY(#2 eTy1), S "\n",
                              S "  false branch: ", TY(#2 eTy2)
                            ])
			(* end case *))
                    | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
                  (* end case *)
                end
	    | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
		 of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
		      val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
		      in
			(AST.E_Apply(BV.range, [], [e1', e2'], resTy), resTy)
		      end
		  | ((_, Ty.T_Int), (_, ty2)) =>
		      err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
		  | ((_, ty1), (_, Ty.T_Int)) =>
		      err (cxt, [S "expected type 'int' on lhs of '..', but found ", TY ty1])
		  | ((_, ty1), (_, ty2)) => err (cxt, [
			S "arguments of '..' must have type 'int', found ",
			TY ty1, S " and ", TY ty2
		      ])
		(* end case *))
	    | PT.E_OrElse(e1, e2) =>
		checkCondOp (env, cxt, e1, "||", e2,
		  fn (e1', e2') => AST.E_Cond(e1', AST.E_Lit(L.Bool true), e2', Ty.T_Bool))
	    | PT.E_AndAlso(e1, e2) =>
		checkCondOp (env, cxt, e1, "&&", e2,
		  fn (e1', e2') => AST.E_Cond(e1', e2', AST.E_Lit(L.Bool false), Ty.T_Bool))
	    | PT.E_BinOp(e1, rator, e2) => let
                val (e1', ty1) = check (env, cxt, e1)
                val (e2', ty2) = check (env, cxt, e2)
                in
                  if Atom.same(rator, BasisNames.op_dot)
		  (* we have to handle inner product as a special case, because our type
		   * system cannot express the constraint that the type is
		   *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
		   *)
                    then (case (TU.prune ty1, TU.prune ty2)
                       of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let
                            val (dd1, d1) = let
                                  fun splitLast (prefix, [d]) = (List.rev prefix, d)
                                    | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
                                    | splitLast (_, []) = raise Fail "impossible"
                                  in
                                    splitLast ([], dd1)
                                  end
                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)
                            val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))
                            in
                              if U.equalDim(d1, d2)
                              andalso U.equalTypes(domTy, [ty1, ty2])
                              andalso U.equalType(rngTy, resTy)
                                then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)
                                else err (cxt, [
                                    S "type error for arguments of binary operator '•'\n",
                                    S "  found: ", TYS[ty1, ty2], S "\n"
                                  ])
                            end
                       | (ty1, ty2) => err (cxt, [
                              S "type error for arguments of binary operator '•'\n",
                              S "  found: ", TYS[ty1, ty2], S "\n"
                            ])
                      (* end case *))
                  else if Atom.same(rator, BasisNames.op_colon)
                    then (case (TU.prune ty1, TU.prune ty2)
                       of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let
                            val (dd1, d11, d12) = let
                                  fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
                                    | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
                                    | splitLast (_, []) = raise Fail "impossible"
                                  in
                                    splitLast ([], dd1)
                                  end
                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)
                            val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))
                            in
                              if U.equalDim(d11, d21) andalso U.equalDim(d12, d22)
                              andalso U.equalTypes(domTy, [ty1, ty2])
                              andalso U.equalType(rngTy, resTy)
                                then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)
                                else err (cxt, [
                                    S "type error for arguments of binary operator ':'\n",
                                    S "  found: ", TYS[ty1, ty2], S "\n"
                                  ])
                            end
                       | (ty1, ty2) => err (cxt, [
                              S "type error for arguments of binary operator ':'\n",
                              S "  found: ", TYS[ty1, ty2], S "\n"
                            ])
                      (* end case *))
                    else (case Env.findFunc (#env env, rator)
                       of Env.PrimFun[rator] => let
                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)
                            in
                              case U.matchArgs(domTy, [e1', e2'], [ty1, ty2])
                               of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)
                                | NONE => err (cxt, [
                                      S "type error for binary operator '", V rator, S "'\n",
                                      S "  expected:  ", TYS domTy, S "\n",
                                      S "  but found: ", TYS[ty1, ty2]
                                    ])
                              (* end case *)
                            end
                        | Env.PrimFun ovldList =>
                            resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
                        | _ => raise Fail "impossible"
                      (* end case *))
                end
	    | PT.E_UnaryOp(rator, e) => let
                val (e', ty) = checkExpr(env, cxt, e)
                in
                  case Env.findFunc (#env env, rator)
                   of Env.PrimFun[rator] => let
                        val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)
                        in
                          case coerceType (domTy, ty, e')
                           of SOME e' => (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)
                            | NONE => err (cxt, [
                                  S "type error for unary operator \"", V rator, S "\"\n",
                                  S "  expected:  ", TY domTy, S "\n",
                                  S "  but found: ", TY ty
                                ])
                          (* end case *)
                        end
                    | Env.PrimFun ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)
                    | _ => raise Fail "impossible"
                  (* end case *)
                end
	    | PT.E_Apply(e, args) => raise Fail "FIXME"
	    | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)
		 of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"
		  | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"
		  | ((_, ty), _) => err(cxt, [
		        S "expected sequence or tensor type for object of subscripting, but found",
			TY ty
		      ])
		(* end case *))
	    | PT.E_Select(e, field) => (case check(env, cxt, e)
                 of (e', Ty.T_Strand strand) => (case Env.findStrand(#env env, strand)
                       of SOME(AST.Strand{name, state, ...}) => let
                            fun isField (AST.VD_Decl(AST.V{name, ...}, _)) = Atom.same(name, field)
                            in
                              case List.find isField state
                               of SOME(AST.VD_Decl(x', _)) => let
                                    val ty = Var.monoTypeOf x'
                                    in
                                      (AST.E_Selector(e', field, ty), ty)
                                    end
                                | NONE => err(cxt, [
                                      S "strand ", A name,
                                      S " does not have state variable ", A field
                                    ])
                              (* end case *)
                            end
                        | NONE => err(cxt, [S "unknown strand ", A strand])
                      (* end case *))
                  | (_, ty) => err (cxt, [
                        S "expected strand type, but found ", TY ty,
                        S " in selection of ", A field
                      ])
                (* end case *))
	    | PT.E_Real e => (case check (env, cxt, e)
                 of (e', Ty.T_Int) =>
                      (AST.E_Apply(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
                  | (_, ty) => err(cxt, [
			S "argument of 'real' must have type 'int', but found ",
			TY ty
		      ])
                (* end case *))
	    | PT.E_Load nrrd => let
                val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))
                in
                  (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
                end
	    | PT.E_Image nrrd => let
                val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))
                in
                  (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
                end
	    | PT.E_Var x => (case E.findVar (#env env, x)
                 of SOME x' => (
                      markUsed (x', true);
                      (AST.E_Var x', Var.monoTypeOf x'))
                  | NONE => err(cxt, [S "undeclared variable ", A x])
                (* end case *))
	    | PT.E_Kernel(kern, dim) => raise Fail "FIXME"
	    | PT.E_Lit lit => checkLit lit
	    | PT.E_Id d => let
                val (tyArgs, Ty.T_Fun(_, rngTy)) =
                      Util.instantiate(Var.typeOf(BV.identity))
                in
                  if U.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)
                    then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)
                    else raise Fail "impossible"
                end
	    | PT.E_Zero dd => let
                val (tyArgs, Ty.T_Fun(_, rngTy)) =
                      Util.instantiate(Var.typeOf(BV.zero))
                in
                  if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
                    then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)
                    else raise Fail "impossible"
                end
	    | PT.E_NaN dd => let
                val (tyArgs, Ty.T_Fun(_, rngTy)) =
                      Util.instantiate(Var.typeOf(BV.nan))
                in
                  if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
                    then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)
                    else raise Fail "impossible"
                end
	    | PT.E_Sequence exps => raise Fail "FIXME"
	    | PT.E_SeqComp comp => raise Fail "FIXME"
	    | PT.E_Cons args => let
	      (* Note that we are guaranteed that args is non-empty *)
                val (args, tys) = checkList (env, cxt, args)
	      (* extract the first non-error type in tys *)
		val ty = (case List.find (fn Ty.T_Error => false | _ => true) tys
		       of NONE => Ty.T_Error
			| SOME ty => ty
		      (* end case *))
                in
                  case realType(TU.pruneHead ty)
                   of ty as Ty.T_Tensor shape => let
                        val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
                        val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
                        fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)
                               of SOME arg' => chkArgs (args, tys, arg'::args')
                                | NONE => (
				    TypeError.error(cxt, [
					S "arguments of tensor construction must have same type"
				      ]);
				    ??)
                              (* end case *))
                          | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)
                        in
                          chkArgs (args, tys, [])
                        end
                    | _ => err(cxt, [S "Invalid argument type for tensor construction"])
                  (* end case *)
                end
	    | PT.E_Deprecate(msg, e) => (
		warn (cxt, [S msg]);
		chk (env, cxt, e))
	  (* end case *))

  (* check a conditional operator (e.g., || or &&) *)
    and checkCondOp (env, cxt, e1, rator, e2, mk) = (
	  case (check(env, cxt, e1), check(env, cxt, e2))
	   of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
	    | ((_, Ty.T_Bool), (_, ty2)) =>
		err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])
	    | ((_, ty1), (_, Ty.T_Bool)) =>
		err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])
	    | ((_, ty1), (_, ty2)) => err (cxt, [
		  S "arguments of '", S rator, "' must have type 'bool', but found ",
		  TY ty1, S " and ", TY ty2
		])
	  (* end case *))

  (* typecheck a list of expressions returning a list of AST expressions and a list
   * of the types of the expressions.
   *)
    and checkList (env, cxt, exprs) = let
          fun chk (e, (es, tys)) = let
                val (e, ty) = checkExpr (env, cxt, e)
                in
                  (e::es, ty::tys)
                end
          in
            List.foldr chk ([], []) exprs
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0