Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/ast/type-util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5558 - (view) (download)

1 : jhr 3384 (* type-util.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 : jhr 5558 * COPYRIGHT (c) 2018 The University of Chicago
6 : jhr 3384 * All rights reserved.
7 :     *)
8 :    
9 :     structure TypeUtil : sig
10 :    
11 :     (* constructor for building a tensor type of known order, but unknown
12 :     * dimensions.
13 :     *)
14 :     val mkTensorTy : int -> Types.ty
15 :    
16 : jhr 3400 (* constructor for building a fixed-size sequence type of unknown size *)
17 : jhr 3384 val mkSequenceTy : Types.ty -> Types.ty
18 :    
19 :     (* function to compute the slice of a tensor type based on a boolean
20 :     * mask. The value true in the mask means that the corresponding
21 :     * dimension is being indexed, while false means that it is being
22 :     * copied.
23 :     *)
24 :     val slice : Types.ty * bool list -> Types.ty
25 :    
26 : jhr 4468 (* FIXME: make terminology be consistent between documentation and implementation w.r.t.
27 :     * the kinds of types (e.g., concrete type, value type, ...)
28 :     *)
29 : jhr 4375 (* returns true if the type is a value type; i.e., a basic value type (bool, int,
30 :     * string, or tensor), or a sequence of values.
31 :     *)
32 : jhr 3384 val isValueType : Types.ty -> bool
33 :    
34 : jhr 4375 (* returns true if the type is a value type, or a strand type, or a sequence of such types *)
35 :     val isValueOrStrandType : Types.ty -> bool
36 :    
37 : jhr 3384 (* return true if the type is an image type *)
38 :     val isImageType : Types.ty -> bool
39 :    
40 : jhr 3431 (* return true if the type is T_Error *)
41 :     val isErrorType : Types.ty -> bool
42 :    
43 : jhr 3384 (* return the range (return type) of a function type *)
44 :     val rngOf : Types.ty -> Types.ty
45 :    
46 :     (* prune out instantiated meta variables from a type. We also normalize
47 :     * tensor shapes (i.e., remove 1s).
48 :     *)
49 :     val prune : Types.ty -> Types.ty
50 :     val pruneDiff : Types.diff -> Types.diff
51 :     val pruneShape : Types.shape -> Types.shape
52 :     val pruneDim : Types.dim -> Types.dim
53 :    
54 :     (* prune the head of a type *)
55 :     val pruneHead : Types.ty -> Types.ty
56 :    
57 :     (* resolve meta variables to their instantiations (or else variable) *)
58 :     val resolve : Types.ty_var -> Types.ty
59 :     val resolveDiff : Types.diff_var -> Types.diff
60 :     val resolveShape : Types.shape_var -> Types.shape
61 :     val resolveDim : Types.dim_var -> Types.dim
62 :     val resolveVar : Types.meta_var -> Types.var_bind
63 :    
64 :     (* equality testing *)
65 :     val sameDim : Types.dim * Types.dim -> bool
66 :    
67 :     (* string representations of types, etc *)
68 :     val toString : Types.ty -> string
69 :     val diffToString : Types.diff -> string
70 :     val shapeToString : Types.shape -> string
71 :     val dimToString : Types.dim -> string
72 :    
73 :     (* convert to fully resolved monomorphic forms *)
74 :     val monoDim : Types.dim -> int
75 :     val monoShape : Types.shape -> int list
76 : jhr 5558 val monoDiff : Types.diff -> int option
77 : jhr 3384
78 : jhr 3405 (* instantiate a type scheme, returning the argument meta variables and the resulting type.
79 :     * Note that we assume that the scheme is closed.
80 :     *)
81 :     val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)
82 :    
83 : jhr 3384 end = struct
84 :    
85 :     structure Ty = Types
86 :     structure MV = MetaVar
87 :    
88 :     (* constructor for building a tensor type of known order, but unknown
89 :     * dimensions.
90 :     *)
91 :     fun mkTensorTy order =
92 : jhr 4317 Ty.T_Tensor(
93 :     Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))
94 : jhr 3384
95 : jhr 3400 fun mkSequenceTy ty = Ty.T_Sequence(ty, SOME(Ty.DimVar(MetaVar.newDimVar())))
96 : jhr 3384
97 :     (* prune out instantiated meta variables from a type. We also normalize
98 :     * tensor dimensions (i.e., remove 1s).
99 :     *)
100 :     fun prune ty = (case ty
101 : jhr 4317 of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
102 :     of NONE => ty
103 :     | SOME ty => prune ty
104 :     (* end case *))
105 :     | Ty.T_Sequence(ty, NONE) => Ty.T_Sequence(prune ty, NONE)
106 :     | Ty.T_Sequence(ty, SOME dim) => Ty.T_Sequence(prune ty, SOME(pruneDim dim))
107 :     | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
108 :     | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
109 :     | (Ty.T_Image{dim, shape}) => Ty.T_Image{
110 :     dim = pruneDim dim,
111 :     shape = pruneShape shape
112 :     }
113 :     | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
114 :     diff = pruneDiff diff,
115 :     dim = pruneDim dim,
116 :     shape = pruneShape shape
117 :     }
118 :     | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
119 :     | ty => ty
120 :     (* end case *))
121 : jhr 3384
122 :     and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
123 : jhr 4317 case pruneDiff diff
124 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
125 : jhr 5558 | Ty.DiffConst NONE => Ty.DiffConst NONE
126 :     | Ty.DiffConst(SOME i') => Ty.DiffConst(SOME(i+i'))
127 : jhr 4317 (* end case *))
128 : jhr 3384 | pruneDiff diff = diff
129 :    
130 :     and pruneDim dim = (case dim
131 : jhr 4317 of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
132 :     | dim => dim
133 :     (* end case *))
134 : jhr 3384
135 :     and filterDim dim = (case pruneDim dim
136 : jhr 4317 of Ty.DimConst 1 => NONE
137 :     | dim => SOME dim
138 :     (* end case *))
139 : jhr 3384
140 :     and pruneShape shape = (case shape
141 : jhr 4317 of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
142 :     | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
143 :     | Ty.ShapeExt(shape, dim) => (case filterDim dim
144 :     of SOME dim => Ty.shapeExt(pruneShape shape, dim)
145 :     | NONE => pruneShape shape
146 :     (* end case *))
147 :     | _ => shape
148 :     (* end case *))
149 : jhr 3384
150 :     (* resolve meta variables to their instantiations (or else variable) *)
151 :     fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
152 : jhr 4317 of NONE => Ty.T_Var tv
153 :     | SOME ty => prune ty
154 :     (* end case *))
155 : jhr 3384
156 :     fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
157 : jhr 4317 of NONE => Ty.DiffVar(dv, 0)
158 :     | SOME diff => pruneDiff diff
159 :     (* end case *))
160 : jhr 3384
161 :     fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
162 : jhr 4317 of NONE => Ty.ShapeVar sv
163 :     | SOME shape => pruneShape shape
164 :     (* end case *))
165 : jhr 3384
166 :     fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
167 : jhr 4317 of NONE => Ty.DimVar dv
168 :     | SOME dim => pruneDim dim
169 :     (* end case *))
170 : jhr 3384
171 :     fun resolveVar (Ty.TYPE tv) = Ty.TYPE(resolve tv)
172 :     | resolveVar (Ty.DIFF dv) = Ty.DIFF(resolveDiff dv)
173 :     | resolveVar (Ty.SHAPE sv) = Ty.SHAPE(resolveShape sv)
174 :     | resolveVar (Ty.DIM d) = Ty.DIM(resolveDim d)
175 :    
176 :     (* prune the head of a type *)
177 :     fun pruneHead ty = let
178 : jhr 4317 fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
179 :     of NONE => ty
180 :     | SOME ty => prune' ty
181 :     (* end case *))
182 :     | prune' (Ty.T_Sequence(ty, NONE)) = Ty.T_Sequence(ty, NONE)
183 :     | prune' (Ty.T_Sequence(ty, SOME dim)) = Ty.T_Sequence(ty, SOME(pruneDim dim))
184 :     | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
185 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
186 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
187 :     dim = pruneDim dim,
188 :     shape = pruneShape shape
189 :     }
190 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
191 :     diff = pruneDiff diff,
192 :     dim = pruneDim dim,
193 :     shape = pruneShape shape
194 :     }
195 :     | prune' ty = ty
196 :     in
197 :     prune' ty
198 :     end
199 : jhr 3384
200 : jhr 4377 (* helper function for isValueType and isValueOrStrandType; checks for fixed-size types
201 :     * inside a dynamic sequence (i.e., no nested dynamic sequences).
202 :     *)
203 :     fun isFixedSize (allowStrand, ty) = (case ty
204 : jhr 4317 of Ty.T_Bool => true
205 :     | Ty.T_Int => true
206 :     | Ty.T_String => true
207 : jhr 4377 | Ty.T_Sequence(ty, SOME _) => isFixedSize (allowStrand, ty)
208 :     | Ty.T_Strand _ => allowStrand
209 : jhr 4317 | Ty.T_Tensor _ => true
210 :     | Ty.T_Error => true
211 :     | _ => false
212 :     (* end case *))
213 : jhr 3384
214 : jhr 4375 (* returns true if the type is a value type; i.e., a basic value type (bool, int,
215 :     * string, or tensor), or a sequence of values.
216 :     *)
217 : jhr 3384 fun isValueType ty = (case prune ty
218 : jhr 4317 of Ty.T_Bool => true
219 :     | Ty.T_Int => true
220 :     | Ty.T_String => true
221 : jhr 4377 | Ty.T_Sequence(ty, SOME _) => isValueType ty
222 :     | Ty.T_Sequence(ty, NONE) => isFixedSize (false, ty)
223 : jhr 4317 | Ty.T_Tensor _ => true
224 :     | Ty.T_Error => true
225 :     | _ => false
226 :     (* end case *))
227 : jhr 3384
228 : jhr 4375 (* returns true if the type is a value type, or a strand type, or a sequence of such types *)
229 :     fun isValueOrStrandType ty = (case prune ty
230 :     of Ty.T_Bool => true
231 :     | Ty.T_Int => true
232 :     | Ty.T_String => true
233 : jhr 4377 | Ty.T_Sequence(ty, SOME _) => isValueOrStrandType ty
234 :     | Ty.T_Sequence(ty, NONE) => isFixedSize (true, ty)
235 : jhr 4375 | Ty.T_Strand _ => true
236 :     | Ty.T_Tensor _ => true
237 :     | Ty.T_Error => true
238 :     | _ => false
239 :     (* end case *))
240 :    
241 : jhr 3384 (* returns true if the type is an ImageTy *)
242 :     fun isImageType ty = (case prune ty
243 : jhr 4317 of Ty.T_Image _ => true
244 :     | Ty.T_Error => true
245 :     | _ => false
246 :     (* end case *))
247 : jhr 3384
248 : jhr 3431 fun isErrorType ty = (case prune ty
249 : jhr 4317 of Ty.T_Error => true
250 :     | _ => false
251 :     (* end case *))
252 : jhr 3431
253 : jhr 3384 (* equality testing *)
254 :     fun sameDim (Ty.DimConst d1, Ty.DimConst d2) = (d1 = d2)
255 :     | sameDim (Ty.DimVar v1, Ty.DimVar v2) = MetaVar.sameDimVar(v1, v2)
256 :     | sameDim _ = false
257 :    
258 :     fun listToString fmt sep items = String.concatWith sep (List.map fmt items)
259 :    
260 :     fun diffToString diff = (case pruneDiff diff
261 : jhr 5558 of Ty.DiffConst NONE => "" (* QUESTION: should we do something else here? *)
262 :     | Ty.DiffConst(SOME n) => Int.toString n
263 : jhr 4317 | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
264 :     | Ty.DiffVar(dv, i) => if i < 0
265 :     then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
266 :     else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
267 :     (* end case *))
268 : jhr 3384
269 :     fun shapeToString shape = (case pruneShape shape
270 : jhr 4317 of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
271 :     | Ty.ShapeVar sv => concat["[", MV.shapeVarToString sv, "]"]
272 :     | Ty.ShapeExt(shape, d) => let
273 :     fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
274 :     | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
275 :     | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
276 :     in
277 :     concat["[", toS shape, dimToString d, "]"]
278 :     end
279 :     (* end case *))
280 : jhr 3384
281 :     and dimToString dim = (case pruneDim dim
282 : jhr 4317 of Ty.DimConst n => Int.toString n
283 :     | Ty.DimVar v => MV.dimVarToString v
284 :     (* end case *))
285 : jhr 3384
286 :     fun toString ty = (case pruneHead ty
287 : jhr 4317 of Ty.T_Var(Ty.TV{bind=ref(SOME ty), ...}) => toString ty
288 :     | Ty.T_Var tv => MV.tyVarToString tv
289 :     | Ty.T_Bool => "bool"
290 :     | Ty.T_Int => "int"
291 :     | Ty.T_String => "string"
292 :     | Ty.T_Sequence(ty, NONE) => concat[toString ty, "[]"]
293 :     | Ty.T_Sequence(ty, SOME dim) => concat[toString ty, "[", dimToString dim, "]"]
294 :     | Ty.T_Strand id => Atom.toString id
295 : jhr 5558 | Ty.T_Kernel(Ty.DiffConst NONE) => raise Fail "unexpected infinite kernel"
296 :     | Ty.T_Kernel diff => "kernel#" ^ diffToString diff
297 : jhr 4317 | Ty.T_Tensor(Ty.Shape[]) => "real"
298 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
299 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
300 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
301 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2, Ty.DimConst 2]) => "mat2"
302 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3, Ty.DimConst 3]) => "mat3"
303 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4, Ty.DimConst 4]) => "mat4"
304 :     | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
305 :     | Ty.T_Image{dim, shape} => concat[
306 :     "image(", dimToString dim, ")", shapeToString shape
307 :     ]
308 : jhr 5558 | Ty.T_Field{diff=(Ty.DiffConst NONE), dim, shape} => concat[
309 :     "field", "(", dimToString dim, ")", shapeToString shape
310 :     ]
311 : jhr 4317 | Ty.T_Field{diff, dim, shape} => concat[
312 :     "field#", diffToString diff, "(", dimToString dim,
313 :     ")", shapeToString shape
314 :     ]
315 :     | Ty.T_Fun(tys1, ty2) => let
316 :     fun tysToString [] = "()"
317 :     | tysToString [ty] = toString ty
318 :     | tysToString tys = String.concat[
319 :     "(", listToString toString " * " tys, ")"
320 :     ]
321 :     in
322 :     String.concat[tysToString tys1, " -> ", toString ty2]
323 :     end
324 :     | Ty.T_Error => "<error-type>"
325 :     (* end case *))
326 : jhr 3384
327 :     (* return the range (return type) of a function type *)
328 :     fun rngOf (Ty.T_Fun(_, ty)) = ty
329 :     | rngOf ty = raise Fail(concat["TypeUtil.rngOf(", toString ty, ")"])
330 :    
331 :     fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
332 : jhr 4317 fun f (d, true, dd) = dd
333 :     | f (d, false, dd) = d::dd
334 :     in
335 :     Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
336 :     end
337 : cchiw 3991 | slice (Ty.T_Field{shape as Ty.Shape l,diff,dim}, mask) = let
338 : jhr 3992 fun f (d, true, dd) = dd
339 :     | f (d, false, dd) = d::dd
340 : cchiw 3991 in
341 : jhr 3992 Ty.T_Field{diff=diff, dim=dim, shape= Ty.Shape (ListPair.foldr f [] (l, mask))}
342 : cchiw 3991 end
343 : jhr 3384 | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])
344 :    
345 :     (* convert to fully resolved monomorphic forms *)
346 :     fun monoDim dim = (case pruneDim dim
347 : jhr 4317 of Ty.DimConst d => d
348 :     | dim => raise Fail(concat["dim ", dimToString dim, " is not constant"])
349 :     (* end case *))
350 : jhr 3384
351 :     fun monoShape shp = (case pruneShape shp
352 : jhr 4317 of Ty.Shape shp => List.map monoDim shp
353 :     | shp => raise Fail(concat["shape ", shapeToString shp, " is not constant"])
354 :     (* end case *))
355 : jhr 3384
356 :     fun monoDiff diff = (case pruneDiff diff
357 : jhr 4317 of Ty.DiffConst k => k
358 :     | diff => raise Fail(concat["diff ", diffToString diff, " is not constant"])
359 :     (* end case *))
360 : jhr 3384
361 : jhr 3405 (* instantiate a type scheme, returning the argument meta variables and the resulting type.
362 :     * Note that we assume that the scheme is closed.
363 :     *)
364 :     fun instantiate ([], ty) = ([], ty)
365 :     | instantiate (mvs, ty) = let
366 : jhr 4317 fun instantiateVar (mv, (mvs, env)) = let
367 :     val mv' = MV.copy mv
368 :     in
369 :     (mv'::mvs, MV.Map.insert(env, mv, mv'))
370 :     end
371 :     val (mvs, env) = List.foldr instantiateVar ([], MV.Map.empty) mvs
372 :     fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
373 :     of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
374 :     | _ => raise Fail "impossible"
375 :     (* end case *))
376 :     | iDiff diff = diff
377 :     fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
378 :     of SOME(Ty.DIM dv) => Ty.DimVar dv
379 :     | _ => raise Fail "impossible"
380 :     (* end case *))
381 :     | iDim dim = dim
382 :     fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
383 :     of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
384 :     | _ => raise Fail "impossible"
385 :     (* end case *))
386 :     | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
387 :     | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
388 :     fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
389 :     of SOME(Ty.TYPE tv) => Ty.T_Var tv
390 :     | _ => raise Fail "impossible"
391 :     (* end case *))
392 : jhr 3405 | ity Ty.T_Bool = Ty.T_Bool
393 :     | ity Ty.T_Int = Ty.T_Int
394 :     | ity Ty.T_String = Ty.T_String
395 :     | ity (Ty.T_Sequence(ty, NONE)) = Ty.T_Sequence(ity ty, NONE)
396 :     | ity (Ty.T_Sequence(ty, SOME d)) = Ty.T_Sequence(ity ty, SOME(iDim d))
397 : jhr 4317 | ity (ty as Ty.T_Strand _) = ty
398 :     | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
399 :     | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
400 :     | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
401 :     | ity (Ty.T_Field{diff, dim, shape}) =
402 :     Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
403 :     | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
404 :     | ity Ty.T_Error = Ty.T_Error
405 :     in
406 :     (mvs, ity ty)
407 :     end
408 : jhr 3405
409 : jhr 3384 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0