Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3402 - (view) (download)

1 : jhr 3396 (* check-expr.sml
2 :     *
3 :     * The typechecker for expressions.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure CheckExpr : sig
12 :    
13 :     val check : Env.env * Env.context * ParseTree.expr -> (AST.expr * Types.ty)
14 :    
15 :     end = struct
16 :    
17 :     structure PT = ParseTree
18 :     structure L = Literal
19 :     structure E = Env
20 :     structure Ty = Types
21 :     structure BV = BasisVars
22 :    
23 :     (* an expression to return when there is a type error *)
24 :     val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)
25 :    
26 :     fun err arg = (TypeError.error arg; bogusExp)
27 :     val warn = TypeError.warning
28 :    
29 : jhr 3402 datatype token = datatype TypeError.token
30 : jhr 3396
31 :     (* check the type of a literal *)
32 :     fun checkLit lit = (case lit
33 :     of (L.Int _) => (AST.E_Lit lit, Ty.T_Int)
34 :     | (L.Real _) => (AST.E_Lit lit, Ty.realTy)
35 :     | (L.String s) => (AST.E_Lit lit, Ty.T_String)
36 :     | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
37 :     (* end case *))
38 :    
39 :     (* check the type of an expression *)
40 :     fun check (env, cxt, e) = (case e
41 :     of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))
42 :     | PT.E_Cond(e1, cond, e2) => let
43 :     val eTy1 = check (env, cxt, e1)
44 :     val eTy2 = check (env, cxt, e2)
45 :     in
46 : jhr 3402 case check(env, cxt, cond)
47 : jhr 3396 of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
48 :     of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
49 :     | NONE => err (cxt, [
50 :     S "types do not match in conditional expression\n",
51 :     S " true branch: ", TY(#2 eTy1), S "\n",
52 :     S " false branch: ", TY(#2 eTy2)
53 :     ])
54 : jhr 3398 (* end case *))
55 : jhr 3396 | (_, ty') => err (cxt, [S "expected bool type, but found ", TY ty'])
56 :     (* end case *)
57 :     end
58 :     | PT.E_Range(e1, e2) => (case (check (env, cxt, e1), check (env, cxt, e2))
59 :     of ((e1', Ty.T_Int), (e2', Ty.T_Int)) => let
60 : jhr 3398 val resTy = Ty.T_Sequence(Ty.T_Int, NONE)
61 : jhr 3396 in
62 :     (AST.E_Apply(BV.range, [], [e1', e2'], resTy), resTy)
63 :     end
64 :     | ((_, Ty.T_Int), (_, ty2)) =>
65 :     err (cxt, [S "expected type 'int' on rhs of '..', but found ", TY ty2])
66 :     | ((_, ty1), (_, Ty.T_Int)) =>
67 :     err (cxt, [S "expected type 'int' on lhs of '..', but found ", TY ty1])
68 :     | ((_, ty1), (_, ty2)) => err (cxt, [
69 :     S "arguments of '..' must have type 'int', found ",
70 :     TY ty1, S " and ", TY ty2
71 :     ])
72 :     (* end case *))
73 :     | PT.E_OrElse(e1, e2) =>
74 :     checkCondOp (env, cxt, e1, "||", e2,
75 :     fn (e1', e2') => AST.E_Cond(e1', AST.E_Lit(L.Bool true), e2', Ty.T_Bool))
76 :     | PT.E_AndAlso(e1, e2) =>
77 :     checkCondOp (env, cxt, e1, "&&", e2,
78 :     fn (e1', e2') => AST.E_Cond(e1', e2', AST.E_Lit(L.Bool false), Ty.T_Bool))
79 :     | PT.E_BinOp(e1, rator, e2) => let
80 :     val (e1', ty1) = check (env, cxt, e1)
81 :     val (e2', ty2) = check (env, cxt, e2)
82 :     in
83 :     if Atom.same(rator, BasisNames.op_dot)
84 :     (* we have to handle inner product as a special case, because our type
85 :     * system cannot express the constraint that the type is
86 :     * ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
87 :     *)
88 :     then (case (TU.prune ty1, TU.prune ty2)
89 :     of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let
90 :     val (dd1, d1) = let
91 :     fun splitLast (prefix, [d]) = (List.rev prefix, d)
92 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
93 :     | splitLast (_, []) = raise Fail "impossible"
94 :     in
95 :     splitLast ([], dd1)
96 :     end
97 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)
98 :     val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))
99 :     in
100 : jhr 3402 if Unify.equalDim(d1, d2)
101 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
102 :     andalso Unify.equalType(rngTy, resTy)
103 : jhr 3396 then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)
104 :     else err (cxt, [
105 :     S "type error for arguments of binary operator '•'\n",
106 :     S " found: ", TYS[ty1, ty2], S "\n"
107 :     ])
108 :     end
109 :     | (ty1, ty2) => err (cxt, [
110 :     S "type error for arguments of binary operator '•'\n",
111 :     S " found: ", TYS[ty1, ty2], S "\n"
112 :     ])
113 :     (* end case *))
114 :     else if Atom.same(rator, BasisNames.op_colon)
115 :     then (case (TU.prune ty1, TU.prune ty2)
116 :     of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let
117 :     val (dd1, d11, d12) = let
118 :     fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
119 :     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
120 :     | splitLast (_, []) = raise Fail "impossible"
121 :     in
122 :     splitLast ([], dd1)
123 :     end
124 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)
125 :     val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))
126 :     in
127 : jhr 3402 if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
128 :     andalso Unify.equalTypes(domTy, [ty1, ty2])
129 :     andalso Unify.equalType(rngTy, resTy)
130 : jhr 3396 then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)
131 :     else err (cxt, [
132 :     S "type error for arguments of binary operator ':'\n",
133 :     S " found: ", TYS[ty1, ty2], S "\n"
134 :     ])
135 :     end
136 :     | (ty1, ty2) => err (cxt, [
137 :     S "type error for arguments of binary operator ':'\n",
138 :     S " found: ", TYS[ty1, ty2], S "\n"
139 :     ])
140 :     (* end case *))
141 :     else (case Env.findFunc (#env env, rator)
142 :     of Env.PrimFun[rator] => let
143 :     val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)
144 :     in
145 : jhr 3402 case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
146 : jhr 3396 of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)
147 :     | NONE => err (cxt, [
148 :     S "type error for binary operator '", V rator, S "'\n",
149 :     S " expected: ", TYS domTy, S "\n",
150 :     S " but found: ", TYS[ty1, ty2]
151 :     ])
152 :     (* end case *)
153 :     end
154 :     | Env.PrimFun ovldList =>
155 :     resolveOverload (cxt, rator, [ty1, ty2], [e1', e2'], ovldList)
156 :     | _ => raise Fail "impossible"
157 :     (* end case *))
158 :     end
159 : jhr 3398 | PT.E_UnaryOp(rator, e) => let
160 : jhr 3402 val (e', ty) = check(env, cxt, e)
161 : jhr 3398 in
162 :     case Env.findFunc (#env env, rator)
163 :     of Env.PrimFun[rator] => let
164 :     val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)
165 :     in
166 :     case coerceType (domTy, ty, e')
167 :     of SOME e' => (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)
168 :     | NONE => err (cxt, [
169 :     S "type error for unary operator \"", V rator, S "\"\n",
170 :     S " expected: ", TY domTy, S "\n",
171 :     S " but found: ", TY ty
172 :     ])
173 :     (* end case *)
174 :     end
175 :     | Env.PrimFun ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)
176 :     | _ => raise Fail "impossible"
177 :     (* end case *)
178 :     end
179 :     | PT.E_Apply(e, args) => raise Fail "FIXME"
180 :     | PT.E_Subscript(e, indices) => (case (check(env, cxt, e), indices)
181 :     of ((e', Ty.T_Sequence(elemTy, _)), [SOME e2]) => raise Fail "FIXME"
182 :     | ((e', Ty.T_Tensor shape), _) => raise Fail "FIXME"
183 :     | ((_, ty), _) => err(cxt, [
184 :     S "expected sequence or tensor type for object of subscripting, but found",
185 :     TY ty
186 :     ])
187 :     (* end case *))
188 :     | PT.E_Select(e, field) => (case check(env, cxt, e)
189 :     of (e', Ty.T_Strand strand) => (case Env.findStrand(#env env, strand)
190 :     of SOME(AST.Strand{name, state, ...}) => let
191 :     fun isField (AST.VD_Decl(AST.V{name, ...}, _)) = Atom.same(name, field)
192 :     in
193 :     case List.find isField state
194 :     of SOME(AST.VD_Decl(x', _)) => let
195 :     val ty = Var.monoTypeOf x'
196 :     in
197 :     (AST.E_Selector(e', field, ty), ty)
198 :     end
199 :     | NONE => err(cxt, [
200 :     S "strand ", A name,
201 :     S " does not have state variable ", A field
202 :     ])
203 :     (* end case *)
204 :     end
205 :     | NONE => err(cxt, [S "unknown strand ", A strand])
206 :     (* end case *))
207 :     | (_, ty) => err (cxt, [
208 :     S "expected strand type, but found ", TY ty,
209 :     S " in selection of ", A field
210 :     ])
211 :     (* end case *))
212 : jhr 3396 | PT.E_Real e => (case check (env, cxt, e)
213 :     of (e', Ty.T_Int) =>
214 :     (AST.E_Apply(BV.i2r, [], [e'], Ty.realTy), Ty.realTy)
215 :     | (_, ty) => err(cxt, [
216 :     S "argument of 'real' must have type 'int', but found ",
217 :     TY ty
218 :     ])
219 :     (* end case *))
220 :     | PT.E_Load nrrd => let
221 :     val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))
222 :     in
223 :     (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
224 :     end
225 :     | PT.E_Image nrrd => let
226 :     val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))
227 :     in
228 :     (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
229 :     end
230 : jhr 3398 | PT.E_Var x => (case E.findVar (#env env, x)
231 : jhr 3396 of SOME x' => (
232 :     markUsed (x', true);
233 :     (AST.E_Var x', Var.monoTypeOf x'))
234 :     | NONE => err(cxt, [S "undeclared variable ", A x])
235 :     (* end case *))
236 : jhr 3398 | PT.E_Kernel(kern, dim) => raise Fail "FIXME"
237 : jhr 3396 | PT.E_Lit lit => checkLit lit
238 :     | PT.E_Id d => let
239 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
240 :     Util.instantiate(Var.typeOf(BV.identity))
241 :     in
242 : jhr 3402 if Unify.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)
243 : jhr 3396 then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)
244 :     else raise Fail "impossible"
245 :     end
246 :     | PT.E_Zero dd => let
247 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
248 :     Util.instantiate(Var.typeOf(BV.zero))
249 :     in
250 : jhr 3402 if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
251 : jhr 3396 then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)
252 :     else raise Fail "impossible"
253 :     end
254 :     | PT.E_NaN dd => let
255 :     val (tyArgs, Ty.T_Fun(_, rngTy)) =
256 :     Util.instantiate(Var.typeOf(BV.nan))
257 :     in
258 : jhr 3402 if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
259 : jhr 3396 then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)
260 :     else raise Fail "impossible"
261 :     end
262 : jhr 3398 | PT.E_Sequence exps => raise Fail "FIXME"
263 :     | PT.E_SeqComp comp => raise Fail "FIXME"
264 : jhr 3396 | PT.E_Cons args => let
265 :     (* Note that we are guaranteed that args is non-empty *)
266 :     val (args, tys) = checkList (env, cxt, args)
267 :     (* extract the first non-error type in tys *)
268 :     val ty = (case List.find (fn Ty.T_Error => false | _ => true) tys
269 :     of NONE => Ty.T_Error
270 :     | SOME ty => ty
271 :     (* end case *))
272 :     in
273 :     case realType(TU.pruneHead ty)
274 :     of ty as Ty.T_Tensor shape => let
275 :     val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
276 :     val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
277 :     fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)
278 :     of SOME arg' => chkArgs (args, tys, arg'::args')
279 :     | NONE => (
280 :     TypeError.error(cxt, [
281 :     S "arguments of tensor construction must have same type"
282 :     ]);
283 :     ??)
284 :     (* end case *))
285 :     | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)
286 :     in
287 :     chkArgs (args, tys, [])
288 :     end
289 :     | _ => err(cxt, [S "Invalid argument type for tensor construction"])
290 :     (* end case *)
291 :     end
292 :     | PT.E_Deprecate(msg, e) => (
293 :     warn (cxt, [S msg]);
294 : jhr 3402 check (env, cxt, e))
295 : jhr 3396 (* end case *))
296 :    
297 :     (* check a conditional operator (e.g., || or &&) *)
298 :     and checkCondOp (env, cxt, e1, rator, e2, mk) = (
299 :     case (check(env, cxt, e1), check(env, cxt, e2))
300 :     of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
301 :     | ((_, Ty.T_Bool), (_, ty2)) =>
302 :     err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])
303 :     | ((_, ty1), (_, Ty.T_Bool)) =>
304 :     err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])
305 :     | ((_, ty1), (_, ty2)) => err (cxt, [
306 :     S "arguments of '", S rator, "' must have type 'bool', but found ",
307 :     TY ty1, S " and ", TY ty2
308 :     ])
309 :     (* end case *))
310 :    
311 :     (* typecheck a list of expressions returning a list of AST expressions and a list
312 :     * of the types of the expressions.
313 :     *)
314 :     and checkList (env, cxt, exprs) = let
315 :     fun chk (e, (es, tys)) = let
316 : jhr 3402 val (e, ty) = check (env, cxt, e)
317 : jhr 3396 in
318 :     (e::es, ty::tys)
319 :     end
320 :     in
321 :     List.foldr chk ([], []) exprs
322 :     end
323 :    
324 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0