(* typechecker.sml
*
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
* All rights reserved.
*)

structure Typechecker : sig

val check : ParseTree.program -> AST.program

end = struct

structure PT = ParseTree
structure Ty = Types

(* check a differentiation level, which muse be >= 0 *)
fun checkDiff (cxt, k) =
if (k < 0)
then raise Fail "differentiation must be >= 0"
else Ty.NatConst(IntInf.toInt k)

(* check a dimension, which must be 2 or 3 *)
fun checkDim (cxt, d) =
if (d < 2) orelse (3 < d)
then raise Fail "invalid dimension; must be 2 or 3"
else Ty.NatConst(IntInf.toInt d)

(* check a shape *)
fun checkShape (cxt, shape) = let
fun chkDim d = if (d < 1)
then raise Fail "invalid shape dimension; must be >= 1"
else Ty.NatConst(IntInf.toInt d)
in
Ty.Shape(List.map chkDim shape)
end

(* check the well-formedness of a type and translate it to an AST type *)
fun checkTy (cxt, ty) = (case ty
of PT.T_Mark m => checkTy(#span m, #tree m)
| PT.T_Bool => Ty.T_Bool
| PT.T_Int => Ty.T_Int
| PT.T_Real => Ty.realTy
| PT.T_String => Ty.T_String
| PT.T_Vec n => (* NOTE: the parser guarantees that 2 <= n <= 4 *)
Ty.vecTy(IntInf.toInt n)
| PT.T_Kernel k => Ty.T_Kernel(checkDiff(cxt, k))
| PT.T_Field{diff, dim, shape} => Ty.T_Field{
diff = checkDiff (cxt, diff),
dim = checkDim (cxt, dim),
shape = checkShape (cxt, shape)
}
| PT.T_Tensor shape => Ty.T_Tensor(checkShape(cxt, shape))
| PT.T_Image{dim, shape} => Ty.T_Image{
dim = checkDim (cxt, dim),
shape = checkShape (cxt, shape)
}
| PT.T_Array(ty, dims) => raise Fail "Array type"
(* end case *))

fun checkLit lit = (case lit
of (Literal.Int _) => (AST.E_Lit lit, Ty.T_Int)
| (Literal.Float _) => (AST.E_Lit lit, Ty.realTy)
| (Literal.String s) => (AST.E_Lit lit, Ty.T_String)
| (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
(* end case *))

(*
(* typecheck an expression and translate it to AST *)
fun checkExpr (env, cxt, e) = (case e
of PT.E_Mark m => checkExpr (env, #span m, #tree m)
| PT.E_Var x => (case Env.findVar (env, x)
of SOME x' => (case Var.typeOf x'
of ([], ty) => (E_Var x', ty)
| (tvs, ty) => raise Fail "unimplemented"
(* end case *))
| NONE => raise Fail "undefined variable"
(* end case *))
| PT.E_Lit lit => checkLit lit
| PT.E_BinOp of expr * var * expr
| PT.E_UnaryOp of var * expr
| PT.E_Tuple of expr list
| PT.E_Apply of var * expr list
| PT.E_Cons of ty * expr list
| PT.E_Diff of expr
| PT.E_Norm of expr
(* end case *))

(* typecheck a statement and translate it to AST *)
fun checkStmt (env, cxt, s) = (case s
of PT.S_Mark m => checkStmt (env, #span m, #tree m)
| PT.S_Block of stmt list
| PT.S_Decl of var_decl
| PT.S_IfThen of expr * stmt
| PT.S_IfThenElse of expr * stmt * stmt
| PT.S_Assign of var * expr
| PT.S_New of var * expr list
| PT.S_Die
| PT.S_Stabilize
(* end case *))

fun checkDecl (env, cxt, d) = (case d
of PT.D_Mark m => checkDecl (env, #span m, #tree m)
| PT.D_Input(ty, x, optExp) = let
val ty = checkTy(cxt, ty)
val x' = Var.new(x, Var.InputVar, ty)
val dcl = (case optExp
of NONE => AST.D_Input(x', NONE)
| SOME e => let
val (e', ty') = checkExpr (env, cxt, e)
in
(* FIXME: check types *)
AST.D_Input(x', SOME e')
end
(* end case *))
in
(dcl, Env.insertGlobal(env, x, x'))
end
| PT.D_Var of var_decl			(* global variable decl *)
| PT.D_Actor of {				(* actor decl *)
name : var,
params : param list,
state : var_decl list,
methods : method list
}
| PT.D_InitialArray of create * iter list
| PT.D_InitialCollection of create * iter list
(* end case *))
*)

fun check (PT.Program dcls) = AST.Program[]

end

