Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/ast/type-util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3400 - (view) (download)

1 : jhr 3384 (* type-util.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure TypeUtil : sig
10 :    
11 :     (* constructor for building a tensor type of known order, but unknown
12 :     * dimensions.
13 :     *)
14 :     val mkTensorTy : int -> Types.ty
15 :    
16 : jhr 3400 (* constructor for building a fixed-size sequence type of unknown size *)
17 : jhr 3384 val mkSequenceTy : Types.ty -> Types.ty
18 :    
19 :     (* function to compute the slice of a tensor type based on a boolean
20 :     * mask. The value true in the mask means that the corresponding
21 :     * dimension is being indexed, while false means that it is being
22 :     * copied.
23 :     *)
24 :     val slice : Types.ty * bool list -> Types.ty
25 :    
26 :     (* returns true if the type is a has a static size (i.e., not a dynamic sequence) *)
27 :     val isFixedSizeType : Types.ty -> bool
28 :    
29 :     (* returns true if the type is a value type (bool, int, string, tensor, seq, ...) *)
30 :     val isValueType : Types.ty -> bool
31 :    
32 :     (* return true if the type is an image type *)
33 :     val isImageType : Types.ty -> bool
34 :    
35 :     (* return the range (return type) of a function type *)
36 :     val rngOf : Types.ty -> Types.ty
37 :    
38 :     (* prune out instantiated meta variables from a type. We also normalize
39 :     * tensor shapes (i.e., remove 1s).
40 :     *)
41 :     val prune : Types.ty -> Types.ty
42 :     val pruneDiff : Types.diff -> Types.diff
43 :     val pruneShape : Types.shape -> Types.shape
44 :     val pruneDim : Types.dim -> Types.dim
45 :    
46 :     (* prune the head of a type *)
47 :     val pruneHead : Types.ty -> Types.ty
48 :    
49 :     (* resolve meta variables to their instantiations (or else variable) *)
50 :     val resolve : Types.ty_var -> Types.ty
51 :     val resolveDiff : Types.diff_var -> Types.diff
52 :     val resolveShape : Types.shape_var -> Types.shape
53 :     val resolveDim : Types.dim_var -> Types.dim
54 :     val resolveVar : Types.meta_var -> Types.var_bind
55 :    
56 :     (* equality testing *)
57 :     val sameDim : Types.dim * Types.dim -> bool
58 :    
59 :     (* string representations of types, etc *)
60 :     val toString : Types.ty -> string
61 :     val diffToString : Types.diff -> string
62 :     val shapeToString : Types.shape -> string
63 :     val dimToString : Types.dim -> string
64 :    
65 :     (* convert to fully resolved monomorphic forms *)
66 :     val monoDim : Types.dim -> int
67 :     val monoShape : Types.shape -> int list
68 :     val monoDiff : Types.diff -> int
69 :    
70 :     end = struct
71 :    
72 :     structure Ty = Types
73 :     structure MV = MetaVar
74 :    
75 :     (* constructor for building a tensor type of known order, but unknown
76 :     * dimensions.
77 :     *)
78 :     fun mkTensorTy order =
79 :     Ty.T_Tensor(
80 :     Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))
81 :    
82 : jhr 3400 fun mkSequenceTy ty = Ty.T_Sequence(ty, SOME(Ty.DimVar(MetaVar.newDimVar())))
83 : jhr 3384
84 :     (* prune out instantiated meta variables from a type. We also normalize
85 :     * tensor dimensions (i.e., remove 1s).
86 :     *)
87 :     fun prune ty = (case ty
88 :     of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
89 :     of NONE => ty
90 :     | SOME ty => prune ty
91 :     (* end case *))
92 : jhr 3398 | Ty.T_Sequence(ty, NONE) => Ty.T_Sequence(prune ty, NONE)
93 :     | Ty.T_Sequence(ty, SOME dim) => Ty.T_Sequence(prune ty, SOME(pruneDim dim))
94 : jhr 3384 | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
95 :     | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
96 :     | (Ty.T_Image{dim, shape}) => Ty.T_Image{
97 :     dim = pruneDim dim,
98 :     shape = pruneShape shape
99 :     }
100 :     | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
101 :     diff = pruneDiff diff,
102 :     dim = pruneDim dim,
103 :     shape = pruneShape shape
104 :     }
105 :     | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
106 :     | ty => ty
107 :     (* end case *))
108 :    
109 :     and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
110 :     case pruneDiff diff
111 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
112 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
113 :     (* end case *))
114 :     | pruneDiff diff = diff
115 :    
116 :     and pruneDim dim = (case dim
117 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
118 :     | dim => dim
119 :     (* end case *))
120 :    
121 :     and filterDim dim = (case pruneDim dim
122 :     of Ty.DimConst 1 => NONE
123 :     | dim => SOME dim
124 :     (* end case *))
125 :    
126 :     and pruneShape shape = (case shape
127 :     of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
128 :     | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
129 :     | Ty.ShapeExt(shape, dim) => (case filterDim dim
130 :     of SOME dim => Ty.shapeExt(pruneShape shape, dim)
131 :     | NONE => pruneShape shape
132 :     (* end case *))
133 :     | _ => shape
134 :     (* end case *))
135 :    
136 :     (* resolve meta variables to their instantiations (or else variable) *)
137 :     fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
138 :     of NONE => Ty.T_Var tv
139 :     | SOME ty => prune ty
140 :     (* end case *))
141 :    
142 :     fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
143 :     of NONE => Ty.DiffVar(dv, 0)
144 :     | SOME diff => pruneDiff diff
145 :     (* end case *))
146 :    
147 :     fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
148 :     of NONE => Ty.ShapeVar sv
149 :     | SOME shape => pruneShape shape
150 :     (* end case *))
151 :    
152 :     fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
153 :     of NONE => Ty.DimVar dv
154 :     | SOME dim => pruneDim dim
155 :     (* end case *))
156 :    
157 :     fun resolveVar (Ty.TYPE tv) = Ty.TYPE(resolve tv)
158 :     | resolveVar (Ty.DIFF dv) = Ty.DIFF(resolveDiff dv)
159 :     | resolveVar (Ty.SHAPE sv) = Ty.SHAPE(resolveShape sv)
160 :     | resolveVar (Ty.DIM d) = Ty.DIM(resolveDim d)
161 :    
162 :     (* prune the head of a type *)
163 :     fun pruneHead ty = let
164 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
165 :     of NONE => ty
166 :     | SOME ty => prune' ty
167 :     (* end case *))
168 : jhr 3398 | prune' (Ty.T_Sequence(ty, NONE)) = Ty.T_Sequence(ty, NONE)
169 :     | prune' (Ty.T_Sequence(ty, SOME dim)) = Ty.T_Sequence(ty, SOME(pruneDim dim))
170 : jhr 3384 | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
171 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
172 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
173 :     dim = pruneDim dim,
174 :     shape = pruneShape shape
175 :     }
176 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
177 :     diff = pruneDiff diff,
178 :     dim = pruneDim dim,
179 :     shape = pruneShape shape
180 :     }
181 :     | prune' ty = ty
182 :     in
183 :     prune' ty
184 :     end
185 :    
186 :     fun isFixedSizeType ty = (case prune ty
187 :     of Ty.T_Bool => true
188 :     | Ty.T_Int => true
189 :     | Ty.T_String => true
190 :     | Ty.T_Sequence _ => true
191 :     | Ty.T_Tensor _ => true
192 :     | _ => false
193 :     (* end case *))
194 :    
195 :     (* returns true if the type is a value type (bool, int, string, or tensor) *)
196 :     fun isValueType ty = (case prune ty
197 :     of Ty.T_Bool => true
198 :     | Ty.T_Int => true
199 :     | Ty.T_String => true
200 :     | Ty.T_Sequence _ => true
201 :     | Ty.T_Tensor _ => true
202 :     | _ => false
203 :     (* end case *))
204 :    
205 :     (* returns true if the type is an ImageTy *)
206 :     fun isImageType ty = (case prune ty
207 :     of Ty.T_Image _ => true
208 :     | _ => false
209 :     (* end case *))
210 :    
211 :     (* equality testing *)
212 :     fun sameDim (Ty.DimConst d1, Ty.DimConst d2) = (d1 = d2)
213 :     | sameDim (Ty.DimVar v1, Ty.DimVar v2) = MetaVar.sameDimVar(v1, v2)
214 :     | sameDim _ = false
215 :    
216 :     fun listToString fmt sep items = String.concatWith sep (List.map fmt items)
217 :    
218 :     fun diffToString diff = (case pruneDiff diff
219 :     of Ty.DiffConst n => Int.toString n
220 :     | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
221 :     | Ty.DiffVar(dv, i) => if i < 0
222 :     then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
223 :     else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
224 :     (* end case *))
225 :    
226 :     fun shapeToString shape = (case pruneShape shape
227 :     of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
228 :     | Ty.ShapeVar sv => MV.shapeVarToString sv
229 :     | Ty.ShapeExt(shape, d) => let
230 :     fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
231 :     | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
232 :     | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
233 :     in
234 :     toS shape ^ dimToString d
235 :     end
236 :     (* end case *))
237 :    
238 :     and dimToString dim = (case pruneDim dim
239 :     of Ty.DimConst n => Int.toString n
240 :     | Ty.DimVar v => MV.dimVarToString v
241 :     (* end case *))
242 :    
243 :     fun toString ty = (case pruneHead ty
244 : jhr 3400 of Ty.T_Var(Ty.TV{bind=ref(SOME ty), ...}) => toString ty
245 :     | Ty.T_Var tv => MV.tyVarToString tv
246 : jhr 3384 | Ty.T_Bool => "bool"
247 :     | Ty.T_Int => "int"
248 :     | Ty.T_String => "string"
249 : jhr 3398 | Ty.T_Sequence(ty, NONE) => concat[toString ty, "[]"]
250 :     | Ty.T_Sequence(ty, SOME dim) => concat[toString ty, "[", dimToString dim, "]"]
251 : jhr 3384 | Ty.T_Named id => Atom.toString id
252 :     | Ty.T_Kernel n => "kernel#" ^ diffToString n
253 :     | Ty.T_Tensor(Ty.Shape[]) => "real"
254 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
255 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
256 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
257 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2, Ty.DimConst 2]) => "mat2"
258 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3, Ty.DimConst 3]) => "mat3"
259 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4, Ty.DimConst 4]) => "mat4"
260 :     | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
261 :     | Ty.T_Image{dim, shape} => concat[
262 :     "image(", dimToString dim, ")", shapeToString shape
263 :     ]
264 :     | Ty.T_Field{diff, dim, shape} => concat[
265 :     "field#", diffToString diff, "(", dimToString dim,
266 :     ")", shapeToString shape
267 :     ]
268 :     | Ty.T_Fun(tys1, ty2) => let
269 :     fun tysToString [] = "()"
270 :     | tysToString [ty] = toString ty
271 :     | tysToString tys = String.concat[
272 :     "(", listToString toString " * " tys, ")"
273 :     ]
274 :     in
275 :     String.concat[tysToString tys1, " -> ", toString ty2]
276 :     end
277 : jhr 3400 | Ty.T_Error => "<error-type>"
278 : jhr 3384 (* end case *))
279 :    
280 :     (* return the range (return type) of a function type *)
281 :     fun rngOf (Ty.T_Fun(_, ty)) = ty
282 :     | rngOf ty = raise Fail(concat["TypeUtil.rngOf(", toString ty, ")"])
283 :    
284 :     fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
285 :     fun f (d, true, dd) = dd
286 :     | f (d, false, dd) = d::dd
287 :     in
288 :     Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
289 :     end
290 :     | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])
291 :    
292 :     (* convert to fully resolved monomorphic forms *)
293 :     fun monoDim dim = (case pruneDim dim
294 :     of Ty.DimConst d => d
295 :     | dim => raise Fail(concat["dim ", dimToString dim, " is not constant"])
296 :     (* end case *))
297 :    
298 :     fun monoShape shp = (case pruneShape shp
299 :     of Ty.Shape shp => List.map monoDim shp
300 :     | shp => raise Fail(concat["shape ", shapeToString shp, " is not constant"])
301 :     (* end case *))
302 :    
303 :     fun monoDiff diff = (case pruneDiff diff
304 :     of Ty.DiffConst k => k
305 :     | diff => raise Fail(concat["diff ", diffToString diff, " is not constant"])
306 :     (* end case *))
307 :    
308 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0