Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3396 - (view) (download)

1 : jhr 3396 (* check-expr.sml
2 :     *
3 :     * The typechecker for expressions.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure CheckExpr : sig
12 :    
13 :     val check : Env.env * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
14 :    
15 :     end = struct
16 :    
17 :     structure PT = ParseTree
18 :     structure L = Literal
19 :     structure E = Env
20 :     structure Ty = Types
21 :     structure BV = BasisVars
22 :    
23 :     (* an expression to return when there is a type error *)
24 :     val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)
25 :    
26 :     fun err arg = (TypeError.error arg; bogusExp)
27 :     val warn = TypeError.warning
28 :    
29 :     datatype tokens = datatype TypeError.tokens
30 :    
31 :     (* check the type of a literal *)
32 :     fun checkLit lit = (case lit
33 :     of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)
34 :     | (L.Real _) => (AST.E_Lit lit, Ty.realTy)
35 :     | (L.String s) => (AST.E_Lit lit, Ty.T_String)
36 :     | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
37 :     (* end case *))
38 :    
39 :     (* check the type of an expression *)
40 :     fun check (env, cxt, e) = (case e
41 :     of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))
42 :     | PT.E_Cond(e1, cond, e2) => let
43 :     val eTy1 = check (env, cxt, e1)
44 :     val eTy2 = check (env, cxt, e2)
45 :     in
46 :     case checkExpr(env, cxt, cond)
47 :     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
48 :     of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
49 :     | NONE => err (cxt, [
50 :     S "types do not match in conditional expression\n",
51 :     S " true branch: ", TY(#2 eTy1), S "\n",
52 :     S " false branch: ", TY(#2 eTy2)
53 :     ])
54 :     | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
55 :     (* end case *)
56 :     end
57 :     | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
58 :     of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
59 :     val resTy = Ty.T_DynSequence Ty.T_Int
60 :     in
61 :     (AST.E_Apply(BV.range, [], [e1', e2'], resTy), resTy)
62 :     end
63 :     | ((_, Ty.T_Int), (_, ty2)) =>
64 :     err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
65 :     | ((_, ty1), (_, Ty.T_Int)) =>
66 :     err (cxt, [S "expected type 'int' on lhs of '..', but found ", TY ty1])
67 :     | ((_, ty1), (_, ty2)) => err (cxt, [
68 :     S "arguments of '..' must have type 'int', found ",
69 :     TY ty1, S " and ", TY ty2
70 :     ])
71 :     (* end case *))
72 :     | PT.E_OrElse(e1, e2) =>
73 :     checkCondOp (env, cxt, e1, "||", e2,
74 :     fn (e1', e2') => AST.E_Cond(e1', AST.E_Lit(L.Bool true), e2', Ty.T_Bool))
75 :     | PT.E_AndAlso(e1, e2) =>
76 :     checkCondOp (env, cxt, e1, "&&", e2,
77 :     fn (e1', e2') => AST.E_Cond(e1', e2', AST.E_Lit(L.Bool false), Ty.T_Bool))
78 :     | PT.E_BinOp(e1, rator, e2) => let
79 :     val (e1', ty1) = check (env, cxt, e1)
80 :     val (e2', ty2) = check (env, cxt, e2)
81 :     in
82 :     if Atom.same(rator, BasisNames.op_dot)
83 :     (* we have to handle inner product as a special case, because our type
84 :     * system cannot express the constraint that the type is
85 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
86 :     *)
87 :     then (case (TU.prune ty1, TU.prune ty2)
88 :     of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let
89 :     val (dd1, d1) = let
90 :     fun splitLast (prefix, [d]) = (List.rev prefix, d)
91 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
92 :     | splitLast (_, []) = raise Fail "impossible"
93 :     in
94 :     splitLast ([], dd1)
95 :     end
96 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)
97 :     val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))
98 :     in
99 :     if U.equalDim(d1, d2)
100 :     andalso U.equalTypes(domTy, [ty1, ty2])
101 :     andalso U.equalType(rngTy, resTy)
102 :     then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)
103 :     else err (cxt, [
104 :     S "type error for arguments of binary operator '•'\n",
105 :     S " found: ", TYS[ty1, ty2], S "\n"
106 :     ])
107 :     end
108 :     | (ty1, ty2) => err (cxt, [
109 :     S "type error for arguments of binary operator '•'\n",
110 :     S " found: ", TYS[ty1, ty2], S "\n"
111 :     ])
112 :     (* end case *))
113 :     else if Atom.same(rator, BasisNames.op_colon)
114 :     then (case (TU.prune ty1, TU.prune ty2)
115 :     of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let
116 :     val (dd1, d11, d12) = let
117 :     fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
118 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
119 :     | splitLast (_, []) = raise Fail "impossible"
120 :     in
121 :     splitLast ([], dd1)
122 :     end
123 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)
124 :     val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))
125 :     in
126 :     if U.equalDim(d11, d21) andalso U.equalDim(d12, d22)
127 :     andalso U.equalTypes(domTy, [ty1, ty2])
128 :     andalso U.equalType(rngTy, resTy)
129 :     then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)
130 :     else err (cxt, [
131 :     S "type error for arguments of binary operator ':'\n",
132 :     S " found: ", TYS[ty1, ty2], S "\n"
133 :     ])
134 :     end
135 :     | (ty1, ty2) => err (cxt, [
136 :     S "type error for arguments of binary operator ':'\n",
137 :     S " found: ", TYS[ty1, ty2], S "\n"
138 :     ])
139 :     (* end case *))
140 :     else (case Env.findFunc (#env env, rator)
141 :     of Env.PrimFun[rator] => let
142 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)
143 :     in
144 :     case U.matchArgs(domTy, [e1', e2'], [ty1, ty2])
145 :     of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)
146 :     | NONE => err (cxt, [
147 :     S "type error for binary operator '", V rator, S "'\n",
148 :     S " expected: ", TYS domTy, S "\n",
149 :     S " but found: ", TYS[ty1, ty2]
150 :     ])
151 :     (* end case *)
152 :     end
153 :     | Env.PrimFun ovldList =>
154 :     resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
155 :     | _ => raise Fail "impossible"
156 :     (* end case *))
157 :     end
158 :     | PT.E_UnaryOp of var * expr (* <op> e *)
159 :     | PT.E_Apply of expr * expr list (* field/function/reduction application *)
160 :     | PT.E_Subscript of expr * expr option list (* sequence/tensor indexing; NONE for ':' *)
161 :     | PT.E_Select of expr * field (* e '.' <field> *)
162 :     | PT.E_Real e => (case check (env, cxt, e)
163 :     of (e', Ty.T_Int) =>
164 :     (AST.E_Apply(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
165 :     | (_, ty) => err(cxt, [
166 :     S "argument of 'real' must have type 'int', but found ",
167 :     TY ty
168 :     ])
169 :     (* end case *))
170 :     | PT.E_Load nrrd => let
171 :     val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))
172 :     in
173 :     (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
174 :     end
175 :     | PT.E_Image nrrd => let
176 :     val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))
177 :     in
178 :     (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
179 :     end
180 :     | PT.E_Var (case E.findVar (#env env, x)
181 :     of SOME x' => (
182 :     markUsed (x', true);
183 :     (AST.E_Var x', Var.monoTypeOf x'))
184 :     | NONE => err(cxt, [S "undeclared variable ", A x])
185 :     (* end case *))
186 :     | PT.E_Kernel of var * dim (* kernel '#' dim *)
187 :     | PT.E_Lit lit => checkLit lit
188 :     | PT.E_Id d => let
189 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
190 :     Util.instantiate(Var.typeOf(BV.identity))
191 :     in
192 :     if U.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)
193 :     then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)
194 :     else raise Fail "impossible"
195 :     end
196 :     | PT.E_Zero dd => let
197 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
198 :     Util.instantiate(Var.typeOf(BV.zero))
199 :     in
200 :     if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
201 :     then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)
202 :     else raise Fail "impossible"
203 :     end
204 :     | PT.E_NaN dd => let
205 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
206 :     Util.instantiate(Var.typeOf(BV.nan))
207 :     in
208 :     if U.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
209 :     then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)
210 :     else raise Fail "impossible"
211 :     end
212 :     | PT.E_Sequence of expr list (* sequence construction *)
213 :     | PT.E_SeqComp of comprehension (* sequence comprehension *)
214 :     | PT.E_Cons args => let
215 :     (* Note that we are guaranteed that args is non-empty *)
216 :     val (args, tys) = checkList (env, cxt, args)
217 :     (* extract the first non-error type in tys *)
218 :     val ty = (case List.find (fn Ty.T_Error => false | _ => true) tys
219 :     of NONE => Ty.T_Error
220 :     | SOME ty => ty
221 :     (* end case *))
222 :     in
223 :     case realType(TU.pruneHead ty)
224 :     of ty as Ty.T_Tensor shape => let
225 :     val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
226 :     val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
227 :     fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)
228 :     of SOME arg' => chkArgs (args, tys, arg'::args')
229 :     | NONE => (
230 :     TypeError.error(cxt, [
231 :     S "arguments of tensor construction must have same type"
232 :     ]);
233 :     ??)
234 :     (* end case *))
235 :     | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)
236 :     in
237 :     chkArgs (args, tys, [])
238 :     end
239 :     | _ => err(cxt, [S "Invalid argument type for tensor construction"])
240 :     (* end case *)
241 :     end
242 :     | PT.E_Deprecate(msg, e) => (
243 :     warn (cxt, [S msg]);
244 :     chk (env, cxt, e))
245 :     (* end case *))
246 :    
247 :     (* check a conditional operator (e.g., || or &&) *)
248 :     and checkCondOp (env, cxt, e1, rator, e2, mk) = (
249 :     case (check(env, cxt, e1), check(env, cxt, e2))
250 :     of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
251 :     | ((_, Ty.T_Bool), (_, ty2)) =>
252 :     err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])
253 :     | ((_, ty1), (_, Ty.T_Bool)) =>
254 :     err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])
255 :     | ((_, ty1), (_, ty2)) => err (cxt, [
256 :     S "arguments of '", S rator, "' must have type 'bool', but found ",
257 :     TY ty1, S " and ", TY ty2
258 :     ])
259 :     (* end case *))
260 :    
261 :     (* typecheck a list of expressions returning a list of AST expressions and a list
262 :     * of the types of the expressions.
263 :     *)
264 :     and checkList (env, cxt, exprs) = let
265 :     fun chk (e, (es, tys)) = let
266 :     val (e, ty) = checkExpr (env, cxt, e)
267 :     in
268 :     (e::es, ty::tys)
269 :     end
270 :     in
271 :     List.foldr chk ([], []) exprs
272 :     end
273 :    
274 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0