Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/tree-ir/check-tree.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/tree-ir/check-tree.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3830 - (view) (download)

1 : jhr 3754 (* check-tree.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * TODO: check global and state variable consistency
9 :     *)
10 :    
11 : jhr 3757 (* FIXME: this module needs to be parameterized over the vector layout of the target *)
12 :    
13 : jhr 3754 structure CheckTree : sig
14 :    
15 :     val check : string * TreeIR.program -> bool
16 :    
17 :     end = struct
18 :    
19 :     structure IR = TreeIR
20 :     structure Op = TreeOps
21 :     structure GVar = IR.GlobalVar
22 :     structure SVar = IR.StateVar
23 :     structure Var = IR.Var
24 :     structure VSet = Var.Set
25 :     structure Ty = IR.Ty
26 :    
27 :     datatype token
28 :     = NL | S of string | A of Atom.atom | V of IR.var
29 : jhr 3810 | TY of Ty.t | TYS of Ty.t list
30 : jhr 3754
31 :     fun error errBuf toks = let
32 :     fun tok2str NL = "\n ** "
33 :     | tok2str (S s) = s
34 :     | tok2str (A s) = Atom.toString s
35 :     | tok2str (V x) = Var.toString x
36 :     | tok2str (TY ty) = Ty.toString ty
37 :     | tok2str (TYS []) = "()"
38 :     | tok2str (TYS[ty]) = Ty.toString ty
39 :     | tok2str (TYS tys) = String.concat[
40 :     "(", String.concatWith " * " (List.map Ty.toString tys), ")"
41 :     ]
42 :     in
43 :     errBuf := concat ("**** Error: " :: List.map tok2str toks)
44 :     :: !errBuf
45 :     end
46 :    
47 : jhr 3757 (* utility function for synthesizing eigenvector/eigenvalue signature *)
48 :     fun eigenSig dim = let
49 :     val tplTy = Ty.TupleTy[
50 :     Ty.SeqTy(Ty.realTy, SOME dim),
51 : jhr 3767 Ty.SeqTy(Ty.vecTy dim, SOME dim)
52 : jhr 3757 ]
53 :     in
54 : jhr 3767 (* FIXME: what about pieces? *)
55 :     (tplTy, [Ty.TensorTy(dim, Ty.vecTy dim)])
56 : jhr 3757 end
57 :    
58 :     (* Return the signature of a TreeIR operator. *)
59 :     fun sigOf rator = (case rator
60 :     of Op.IAdd => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
61 :     | Op.ISub => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
62 :     | Op.IMul => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
63 :     | Op.IDiv => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
64 :     | Op.IMod => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
65 :     | Op.INeg => (Ty.IntTy, [Ty.IntTy])
66 :     | Op.RAdd => (Ty.realTy, [Ty.realTy, Ty.realTy])
67 :     | Op.RSub => (Ty.realTy, [Ty.realTy, Ty.realTy])
68 :     | Op.RMul => (Ty.realTy, [Ty.realTy, Ty.realTy])
69 :     | Op.RDiv => (Ty.realTy, [Ty.realTy, Ty.realTy])
70 :     | Op.RNeg => (Ty.realTy, [Ty.realTy])
71 : jhr 3830 | Op.RClamp => (Ty.realTy, [Ty.realTy, Ty.realTy, Ty.realTy])
72 :     | Op.RLerp => (Ty.realTy, [Ty.realTy, Ty.realTy, Ty.realTy])
73 : jhr 3757 | Op.LT ty => (Ty.BoolTy, [ty, ty])
74 :     | Op.LTE ty => (Ty.BoolTy, [ty, ty])
75 :     | Op.EQ ty => (Ty.BoolTy, [ty, ty])
76 :     | Op.NEQ ty => (Ty.BoolTy, [ty, ty])
77 :     | Op.GT ty => (Ty.BoolTy, [ty, ty])
78 :     | Op.GTE ty => (Ty.BoolTy, [ty, ty])
79 :     | Op.Not => (Ty.BoolTy, [Ty.BoolTy])
80 :     | Op.Abs ty => (ty, [ty])
81 :     | Op.Max ty => (ty, [ty, ty])
82 :     | Op.Min ty => (ty, [ty, ty])
83 : jhr 3767 | Op.VAdd d => (Ty.vecTy d, [Ty.vecTy d, Ty.vecTy d])
84 :     | Op.VSub d => (Ty.vecTy d, [Ty.vecTy d, Ty.vecTy d])
85 :     | Op.VScale d => (Ty.vecTy d, [Ty.realTy, Ty.vecTy d])
86 :     | Op.VMul d => (Ty.vecTy d, [Ty.vecTy d, Ty.vecTy d])
87 :     | Op.VNeg d => (Ty.vecTy d, [Ty.vecTy d])
88 :     | Op.VSum d => (Ty.realTy, [Ty.vecTy d])
89 : jhr 3830 | Op.VClamp d => (Ty.vecTy d, [Ty.vecTy d, Ty.realTy, Ty.realTy])
90 :     | Op.VMapClamp d => (Ty.vecTy d, [Ty.vecTy d, Ty.vecTy d, Ty.vecTy d])
91 :     | Op.VLerp d => (Ty.vecTy d, [Ty.vecTy d, Ty.vecTy d, Ty.realTy])
92 : jhr 3766 (*
93 :     TensorIndex
94 :     *)
95 : jhr 3757 | Op.EigenVecs2x2 => eigenSig 2
96 :     | Op.EigenVecs3x3 => eigenSig 3
97 : jhr 3767 (* FIXME: what about pieces? *)
98 :     | Op.EigenVals2x2 => (Ty.SeqTy(Ty.realTy, SOME 2), [Ty.TensorTy(2, Ty.vecTy 2)])
99 :     | Op.EigenVals3x3 => (Ty.SeqTy(Ty.realTy, SOME 3), [Ty.TensorTy(3, Ty.vecTy 3)])
100 : jhr 3757 | Op.Zero ty => (ty, [])
101 :     | Op.Select(ty as Ty.TupleTy tys, i) => (List.nth(tys, i-1), [ty])
102 :     | Op.Subscript(ty as Ty.SeqTy(elemTy, _)) => (elemTy, [ty, Ty.intTy])
103 :     | Op.MkDynamic(ty, n) => (Ty.SeqTy(ty, NONE), [Ty.SeqTy(ty, SOME n)])
104 :     | Op.Prepend ty => (Ty.SeqTy(ty, NONE), [ty, Ty.SeqTy(ty, NONE)])
105 :     | Op.Append ty => (Ty.SeqTy(ty, NONE), [Ty.SeqTy(ty, NONE), ty])
106 :     | Op.Concat ty => (Ty.SeqTy(ty, NONE), [Ty.SeqTy(ty, NONE), Ty.SeqTy(ty, NONE)])
107 :     | Op.Range => (Ty.SeqTy(Ty.intTy, NONE), [Ty.IntTy, Ty.IntTy])
108 :     | Op.Length ty => (Ty.intTy, [Ty.SeqTy(ty, NONE)])
109 :     | Op.SphereQuery(ptTy, strandTy) => (Ty.SeqTy(strandTy, NONE), [ptTy, Ty.realTy])
110 :     | Op.Sqrt => (Ty.realTy, [Ty.realTy])
111 :     | Op.Cos => (Ty.realTy, [Ty.realTy])
112 :     | Op.ArcCos => (Ty.realTy, [Ty.realTy])
113 :     | Op.Sine => (Ty.realTy, [Ty.realTy])
114 :     | Op.ArcSin => (Ty.realTy, [Ty.realTy])
115 :     | Op.Tan => (Ty.realTy, [Ty.realTy])
116 :     | Op.ArcTan => (Ty.realTy, [Ty.realTy])
117 :     | Op.Exp => (Ty.realTy, [Ty.realTy])
118 :     | Op.Ceiling d => (Ty.vecTy d, [Ty.vecTy d])
119 :     | Op.Floor d => (Ty.vecTy d, [Ty.vecTy d])
120 :     | Op.Round d => (Ty.vecTy d, [Ty.vecTy d])
121 :     | Op.Trunc d => (Ty.vecTy d, [Ty.vecTy d])
122 :     | Op.IntToReal => (Ty.realTy, [Ty.intTy])
123 :     | Op.RealToInt 1 => (Ty.IntTy, [Ty.realTy])
124 : jhr 3767 | Op.RealToInt d => (Ty.SeqTy(Ty.IntTy, SOME d), [Ty.vecTy d])
125 : jhr 3757 (* not sure if we will need these
126 :     | R_All of ty
127 :     | R_Exists of ty
128 :     | R_Max of ty
129 :     | R_Min of ty
130 :     | R_Sum of ty
131 :     | R_Product of ty
132 :     | R_Mean of ty
133 :     | R_Variance of ty
134 :     *)
135 : jhr 3766 (* FIXME: these should probably be compiled down to lower-level operartions at this point!
136 : jhr 3757 | Op.Transform info => let
137 :     val dim = ImageInfo.dim info
138 :     in
139 :     if (dim = 1)
140 : jhr 3766 then (Ty.realTy, [Ty.ImageTy info])
141 :     else (Ty.matrixTy(dim, dim), [Ty.ImageTy info])
142 : jhr 3757 end
143 :     | Op.Translate info => let
144 :     val dim = ImageInfo.dim info
145 :     in
146 :     if (dim = 1)
147 : jhr 3766 then (Ty.realTy, [Ty.ImageTy info])
148 :     else (Ty.matrixTy(dim, dim), [Ty.ImageTy info])
149 : jhr 3757 end
150 :     *)
151 : jhr 3766 | Op.ControlIndex(info, _, _) => (Ty.IntTy, [Ty.ImageTy info, Ty.IntTy])
152 : jhr 3757 | Op.Inside(info, _) => (Ty.BoolTy, [Ty.vecTy(ImageInfo.dim info), Ty.ImageTy info])
153 :     | Op.ImageDim(info, _) => (Ty.IntTy, [Ty.ImageTy info])
154 :     | Op.LoadSeq(ty, _) => (ty, [])
155 :     | Op.LoadImage(ty, _) => (ty, [])
156 :     | Op.MathFn f => MathFns.sigOf (Ty.RealTy, f)
157 :     | _ => raise Fail("sigOf: invalid operator " ^ Op.toString rator)
158 :     (* end case *))
159 :    
160 : jhr 3754 fun check (phase, prog) = let
161 :     val IR.Program{
162 :     props, consts, inputs, constInit, globals, globalInit,
163 :     strand, create, update
164 :     } = prog
165 :     val errBuf = ref []
166 :     val errFn = error errBuf
167 :     fun final () = (case !errBuf
168 :     of [] => false
169 :     | errs => (
170 :     Log.msg ["********** IR Errors detected after ", phase, " **********\n"];
171 :     List.app (fn msg => Log.msg [msg, "\n"]) (List.rev errs);
172 :     true)
173 :     (* end case *))
174 :     (* check a variable use *)
175 :     fun checkVar (bvs, x) = if VSet.member(bvs, x)
176 :     then ()
177 :     else errFn [S "variable ", V x, S " is not bound"]
178 :     fun chkBlock (bvs, IR.Block{locals, body}) = let
179 :     fun chkExp (bvs, e) = let
180 :     fun chk e = (case e
181 :     of IR.E_Global gv => GVar.ty gv
182 :     | IR.E_State sv => SVar.ty sv
183 :     | IR.E_Var x => Var.ty x
184 :     | IR.E_Lit(Literal.Int _) => Ty.IntTy
185 :     | IR.E_Lit(Literal.Real _) => Ty.realTy
186 :     | IR.E_Lit(Literal.String _) => Ty.StringTy
187 :     | IR.E_Lit(Literal.Bool _) => Ty.BoolTy
188 :     | IR.E_Op(rator, args) => let
189 : jhr 3757 val (resTy, paramTys) = sigOf rator
190 : jhr 3754 val argTys = List.map chk args
191 :     in
192 :     if ListPair.allEq Ty.same (paramTys, argTys)
193 :     then ()
194 :     else errFn [
195 :     S "argument type mismatch in application of ",
196 :     S(Op.toString rator),
197 :     NL, S "expected: ", TYS paramTys,
198 :     NL, S "found: ", TYS argTys
199 :     ];
200 :     resTy
201 :     end
202 :     | IR.E_Cons([], ty) => (
203 :     errFn [S "empty cons"];
204 :     ty)
205 :     | IR.E_Cons(es, consTy as Ty.TensorTy(d, ty)) => (
206 :     if (length es <> d)
207 :     then errFn [
208 :     S "cons has incorrect number of elements",
209 :     NL, S " expected: ", S(Int.toString d),
210 :     NL, S " found: ", S(Int.toString(length es))
211 :     ]
212 :     else ();
213 :     chkElems ("cons", ty, es);
214 :     consTy)
215 : jhr 3768 | IR.E_Cons(es, ty) => (
216 :     errFn [S "unexpected type for cons: ", TY ty];
217 :     ty)
218 : jhr 3754 | IR.E_Seq([], ty as Ty.SeqTy(_, SOME 0)) => ty
219 :     | IR.E_Seq([], ty as Ty.SeqTy(_, SOME n)) => (
220 :     errFn [S "empty sequence, but expected ", TY ty];
221 :     ty)
222 :     | IR.E_Seq(es, seqTy as Ty.SeqTy(ty, NONE)) => (
223 :     chkElems ("sequence", ty, es);
224 :     seqTy)
225 :     | IR.E_Seq(es, seqTy as Ty.SeqTy(ty, SOME n)) => (
226 :     if (length es <> n)
227 :     then errFn [
228 :     S "sequence has incorrect number of elements",
229 :     NL, S " expected: ", S(Int.toString n),
230 :     NL, S " found: ", S(Int.toString(length es))
231 :     ]
232 :     else ();
233 :     chkElems ("sequence", ty, es);
234 :     seqTy)
235 :     | IR.E_Seq(es, ty) => (
236 :     errFn [S "unexpected type for sequence: ", TY ty];
237 :     ty)
238 : jhr 3768 | IR.E_Pack es => raise Fail "FIXME"
239 : jhr 3754 (* end case *))
240 :     and chkElems (cxt, ty, []) = ()
241 :     | chkElems (cxt, ty, e::es) = let
242 :     val ty' = chk e
243 :     in
244 :     if Ty.same(ty, ty')
245 :     then ()
246 :     else errFn [
247 :     S "element of ", S cxt, S " has incorrect type",
248 :     NL, S "expected: ", TY ty,
249 :     NL, S "found: ", TY ty'
250 :     ];
251 :     chkElems (cxt, ty, es)
252 :     end
253 :     in
254 :     chk e
255 :     end
256 :     fun chkStm (stm, bvs) = (case stm
257 :     of IR.S_Comment _ => bvs
258 : jhr 3767 | IR.S_Unpack(xs, e) => let
259 : jhr 3754 fun chkVar (x, ty) = if Ty.same(Var.ty x, ty)
260 :     then ()
261 :     else errFn[
262 :     S "type mismatch in assignment to ", S(Var.name x),
263 :     NL, S "lhs: ", TY(Var.ty x),
264 :     NL, S "rhs: ", TY ty
265 :     ]
266 :     in
267 :     case (xs, chkExp (bvs, e))
268 : jhr 3767 of (_, Ty.VecTy(w, _, dd)) => (
269 : jhr 3754 if (List.length xs <> List.length dd)
270 :     then errFn [
271 :     S "arity mismatch in assigning composite vector",
272 :     NL, S" lhs arity: ", S(Int.toString(List.length xs)),
273 :     NL, S" rhs arity: ", S(Int.toString(List.length dd))
274 :     ]
275 :     else ();
276 : jhr 3767 ListPair.app (fn (x, d) => chkVar(x, Ty.vecTy d)) (xs, dd))
277 :     | ([x], ty) => chkVar (x, ty)
278 : jhr 3754 | (_::_, ty) => errFn [
279 :     S "assignment of non-composite value to (",
280 :     S(String.concatWithMap "," Var.name xs), S ")"
281 :     ]
282 : jhr 3817 | _ => errFn [S "empty lhs for Unpack"]
283 : jhr 3754 (* end case *);
284 :     bvs
285 :     end
286 : jhr 3767 | IR.S_Assign(x, e) => let
287 :     val ty = chkExp (bvs, e)
288 :     in
289 :     if Ty.same(Var.ty x, ty)
290 :     then ()
291 :     else errFn[
292 :     S "type mismatch in assignment to local ", S(Var.name x),
293 :     NL, S "lhs: ", TY(Var.ty x),
294 :     NL, S "rhs: ", TY ty
295 :     ];
296 :     bvs
297 :     end
298 : jhr 3754 | IR.S_GAssign(gv, e) => let
299 :     val ty = chkExp (bvs, e)
300 :     in
301 :     if Ty.same(GVar.ty gv, ty)
302 :     then ()
303 :     else errFn[
304 :     S "type mismatch in assignment to global ", S(GVar.name gv),
305 :     NL, S "lhs: ", TY(GVar.ty gv),
306 :     NL, S "rhs: ", TY ty
307 :     ];
308 :     bvs
309 :     end
310 :     | IR.S_IfThen(e, b) => let
311 :     val ty = chkExp (bvs, e)
312 :     in
313 :     if Ty.same(ty, Ty.BoolTy)
314 :     then ()
315 :     else errFn[
316 :     S "expected bool for if-then, but found ", TY ty
317 :     ];
318 :     chkBlock (bvs, b);
319 :     bvs
320 :     end
321 :     | IR.S_IfThenElse(e, b1, b2) => let
322 :     val ty = chkExp (bvs, e)
323 :     in
324 :     if Ty.same(ty, Ty.BoolTy)
325 :     then ()
326 :     else errFn[
327 :     S "expected bool for if-then-else, but found ", TY ty
328 :     ];
329 :     chkBlock (bvs, b1);
330 :     chkBlock (bvs, b2);
331 :     bvs
332 :     end
333 :     | IR.S_Foreach(x, e, b) => (
334 :     case chkExp (bvs, e)
335 :     of Ty.SeqTy(ty, _) =>
336 :     if Ty.same(ty, Var.ty x)
337 :     then ()
338 :     else errFn [
339 :     S "type mismatch in foreach ", V x,
340 :     NL, S "variable type: ", TY(Var.ty x),
341 :     NL, S "domain type: ", TY ty
342 :     ]
343 :     | ty => errFn [
344 :     S "domain of foreach is not sequence type; found ", TY ty
345 :     ]
346 :     (* end case *);
347 :     ignore (chkBlock (VSet.add(bvs, x), b));
348 :     bvs)
349 :     | IR.S_LoadNrrd(x, name) => bvs (* FIXME: check type if x *)
350 :     | IR.S_Input(gv, _, _, NONE) => bvs
351 :     | IR.S_Input(gv, _, _, SOME e) => let
352 :     val ty = chkExp (bvs, e)
353 :     in
354 :     if Ty.same(GVar.ty gv, ty)
355 :     then ()
356 :     else errFn[
357 :     S "type mismatch in default for input ", S(GVar.name gv),
358 :     NL, S "expected: ", TY(GVar.ty gv),
359 :     NL, S "found: ", TY ty
360 :     ];
361 :     bvs
362 :     end
363 :     | IR.S_InputNrrd(gv, _, _, _) => (
364 :     case GVar.ty gv
365 :     of Ty.SeqTy(_, NONE) => ()
366 :     | Ty.ImageTy _ => ()
367 :     | ty => errFn [
368 :     S "input variable ", S(GVar.name gv), S " has bogus type ",
369 :     TY ty, S " for lhs for InputNrrd"
370 :     ]
371 :     (* end case *);
372 :     bvs)
373 :     | IR.S_New(_, es) => (
374 :     List.app (fn e => ignore (chkExp(bvs, e))) es;
375 :     bvs)
376 : jhr 3767 | IR.S_Save(sv, e) => let
377 :     val ty = chkExp (bvs, e)
378 : jhr 3754 in
379 : jhr 3767 if Ty.same(SVar.ty sv, ty)
380 :     then ()
381 :     else errFn[
382 :     S "type mismatch in assignment to state variable ",
383 :     S(SVar.name sv),
384 :     NL, S "lhs: ", TY(SVar.ty sv),
385 :     NL, S "rhs: ", TY ty
386 :     ];
387 : jhr 3754 bvs
388 :     end
389 :     | IR.S_Exit es => (
390 :     List.app (fn e => ignore (chkExp(bvs, e))) es;
391 :     bvs)
392 : jhr 3768 | IR.S_Print(tys, es) => (
393 :     if (length tys <> length es)
394 :     then errFn [
395 :     ]
396 :     else ();
397 :     ListPair.appi
398 :     (fn (i, ty, e) => let val ty' = chkExp(bvs, e)
399 :     in
400 :     if Ty.same(ty, ty')
401 :     then ()
402 :     else errFn[
403 :     S "type mismatch in argument ", S(Int.toString i),
404 :     S " of print",
405 :     NL, S "expected: ", TY ty,
406 :     NL, S "but found: ", TY ty'
407 :     ]
408 :     end)
409 :     (tys, es);
410 :     bvs)
411 : jhr 3754 | IR.S_Active => bvs
412 :     | IR.S_Stabilize => bvs
413 :     | IR.S_Die => bvs
414 :     (* end case *))
415 :     val bvs = List.foldl VSet.add' bvs locals
416 :     in
417 :     ignore (List.foldl chkStm bvs body)
418 :     end
419 :     fun chkOptBlock (_, NONE) = ()
420 :     | chkOptBlock (bvs, SOME blk) = ignore (chkBlock (bvs, blk))
421 :     fun chkStrand (IR.Strand{name, params, state, stateInit, initM, updateM, stabilizeM}) = (
422 :     ignore (chkBlock (VSet.fromList params, stateInit));
423 :     chkOptBlock (VSet.empty, initM);
424 :     ignore (chkBlock (VSet.empty, updateM));
425 :     chkOptBlock (VSet.empty, stabilizeM))
426 :     in
427 :     ignore (chkBlock (VSet.empty, constInit));
428 :     ignore (chkBlock (VSet.empty, globalInit));
429 :     chkStrand strand;
430 :     case create of IR.Create{code, ...} => ignore (chkBlock (VSet.empty, code));
431 :     chkOptBlock (VSet.empty, update);
432 :     final ()
433 :     end
434 :    
435 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0