Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/ast/type-util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3398 - (view) (download)

1 : jhr 3384 (* type-util.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure TypeUtil : sig
10 :    
11 :     (* constructor for building a tensor type of known order, but unknown
12 :     * dimensions.
13 :     *)
14 :     val mkTensorTy : int -> Types.ty
15 :    
16 :     (* constructor for building a sequence type of unknown size *)
17 :     val mkSequenceTy : Types.ty -> Types.ty
18 :    
19 :     (* function to compute the slice of a tensor type based on a boolean
20 :     * mask. The value true in the mask means that the corresponding
21 :     * dimension is being indexed, while false means that it is being
22 :     * copied.
23 :     *)
24 :     val slice : Types.ty * bool list -> Types.ty
25 :    
26 :     (* returns true if the type is a has a static size (i.e., not a dynamic sequence) *)
27 :     val isFixedSizeType : Types.ty -> bool
28 :    
29 :     (* returns true if the type is a value type (bool, int, string, tensor, seq, ...) *)
30 :     val isValueType : Types.ty -> bool
31 :    
32 :     (* return true if the type is an image type *)
33 :     val isImageType : Types.ty -> bool
34 :    
35 :     (* return the range (return type) of a function type *)
36 :     val rngOf : Types.ty -> Types.ty
37 :    
38 :     (* prune out instantiated meta variables from a type. We also normalize
39 :     * tensor shapes (i.e., remove 1s).
40 :     *)
41 :     val prune : Types.ty -> Types.ty
42 :     val pruneDiff : Types.diff -> Types.diff
43 :     val pruneShape : Types.shape -> Types.shape
44 :     val pruneDim : Types.dim -> Types.dim
45 :    
46 :     (* prune the head of a type *)
47 :     val pruneHead : Types.ty -> Types.ty
48 :    
49 :     (* resolve meta variables to their instantiations (or else variable) *)
50 :     val resolve : Types.ty_var -> Types.ty
51 :     val resolveDiff : Types.diff_var -> Types.diff
52 :     val resolveShape : Types.shape_var -> Types.shape
53 :     val resolveDim : Types.dim_var -> Types.dim
54 :     val resolveVar : Types.meta_var -> Types.var_bind
55 :    
56 :     (* equality testing *)
57 :     val sameDim : Types.dim * Types.dim -> bool
58 :    
59 :     (* string representations of types, etc *)
60 :     val toString : Types.ty -> string
61 :     val diffToString : Types.diff -> string
62 :     val shapeToString : Types.shape -> string
63 :     val dimToString : Types.dim -> string
64 :    
65 :     (* convert to fully resolved monomorphic forms *)
66 :     val monoDim : Types.dim -> int
67 :     val monoShape : Types.shape -> int list
68 :     val monoDiff : Types.diff -> int
69 :    
70 :     end = struct
71 :    
72 :     structure Ty = Types
73 :     structure MV = MetaVar
74 :    
75 :     (* constructor for building a tensor type of known order, but unknown
76 :     * dimensions.
77 :     *)
78 :     fun mkTensorTy order =
79 :     Ty.T_Tensor(
80 :     Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))
81 :    
82 :     fun mkSequenceTy ty =
83 :     Ty.T_Sequence(ty, Ty.DimVar(MetaVar.newDimVar()))
84 :    
85 :     (* prune out instantiated meta variables from a type. We also normalize
86 :     * tensor dimensions (i.e., remove 1s).
87 :     *)
88 :     fun prune ty = (case ty
89 :     of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
90 :     of NONE => ty
91 :     | SOME ty => prune ty
92 :     (* end case *))
93 : jhr 3398 | Ty.T_Sequence(ty, NONE) => Ty.T_Sequence(prune ty, NONE)
94 :     | Ty.T_Sequence(ty, SOME dim) => Ty.T_Sequence(prune ty, SOME(pruneDim dim))
95 : jhr 3384 | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
96 :     | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
97 :     | (Ty.T_Image{dim, shape}) => Ty.T_Image{
98 :     dim = pruneDim dim,
99 :     shape = pruneShape shape
100 :     }
101 :     | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
102 :     diff = pruneDiff diff,
103 :     dim = pruneDim dim,
104 :     shape = pruneShape shape
105 :     }
106 :     | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
107 :     | ty => ty
108 :     (* end case *))
109 :    
110 :     and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
111 :     case pruneDiff diff
112 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
113 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
114 :     (* end case *))
115 :     | pruneDiff diff = diff
116 :    
117 :     and pruneDim dim = (case dim
118 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
119 :     | dim => dim
120 :     (* end case *))
121 :    
122 :     and filterDim dim = (case pruneDim dim
123 :     of Ty.DimConst 1 => NONE
124 :     | dim => SOME dim
125 :     (* end case *))
126 :    
127 :     and pruneShape shape = (case shape
128 :     of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
129 :     | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
130 :     | Ty.ShapeExt(shape, dim) => (case filterDim dim
131 :     of SOME dim => Ty.shapeExt(pruneShape shape, dim)
132 :     | NONE => pruneShape shape
133 :     (* end case *))
134 :     | _ => shape
135 :     (* end case *))
136 :    
137 :     (* resolve meta variables to their instantiations (or else variable) *)
138 :     fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
139 :     of NONE => Ty.T_Var tv
140 :     | SOME ty => prune ty
141 :     (* end case *))
142 :    
143 :     fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
144 :     of NONE => Ty.DiffVar(dv, 0)
145 :     | SOME diff => pruneDiff diff
146 :     (* end case *))
147 :    
148 :     fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
149 :     of NONE => Ty.ShapeVar sv
150 :     | SOME shape => pruneShape shape
151 :     (* end case *))
152 :    
153 :     fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
154 :     of NONE => Ty.DimVar dv
155 :     | SOME dim => pruneDim dim
156 :     (* end case *))
157 :    
158 :     fun resolveVar (Ty.TYPE tv) = Ty.TYPE(resolve tv)
159 :     | resolveVar (Ty.DIFF dv) = Ty.DIFF(resolveDiff dv)
160 :     | resolveVar (Ty.SHAPE sv) = Ty.SHAPE(resolveShape sv)
161 :     | resolveVar (Ty.DIM d) = Ty.DIM(resolveDim d)
162 :    
163 :     (* prune the head of a type *)
164 :     fun pruneHead ty = let
165 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
166 :     of NONE => ty
167 :     | SOME ty => prune' ty
168 :     (* end case *))
169 : jhr 3398 | prune' (Ty.T_Sequence(ty, NONE)) = Ty.T_Sequence(ty, NONE)
170 :     | prune' (Ty.T_Sequence(ty, SOME dim)) = Ty.T_Sequence(ty, SOME(pruneDim dim))
171 : jhr 3384 | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
172 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
173 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
174 :     dim = pruneDim dim,
175 :     shape = pruneShape shape
176 :     }
177 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
178 :     diff = pruneDiff diff,
179 :     dim = pruneDim dim,
180 :     shape = pruneShape shape
181 :     }
182 :     | prune' ty = ty
183 :     in
184 :     prune' ty
185 :     end
186 :    
187 :     fun isFixedSizeType ty = (case prune ty
188 :     of Ty.T_Bool => true
189 :     | Ty.T_Int => true
190 :     | Ty.T_String => true
191 :     | Ty.T_Sequence _ => true
192 :     | Ty.T_Tensor _ => true
193 :     | _ => false
194 :     (* end case *))
195 :    
196 :     (* returns true if the type is a value type (bool, int, string, or tensor) *)
197 :     fun isValueType ty = (case prune ty
198 :     of Ty.T_Bool => true
199 :     | Ty.T_Int => true
200 :     | Ty.T_String => true
201 :     | Ty.T_Sequence _ => true
202 :     | Ty.T_Tensor _ => true
203 :     | _ => false
204 :     (* end case *))
205 :    
206 :     (* returns true if the type is an ImageTy *)
207 :     fun isImageType ty = (case prune ty
208 :     of Ty.T_Image _ => true
209 :     | _ => false
210 :     (* end case *))
211 :    
212 :     (* equality testing *)
213 :     fun sameDim (Ty.DimConst d1, Ty.DimConst d2) = (d1 = d2)
214 :     | sameDim (Ty.DimVar v1, Ty.DimVar v2) = MetaVar.sameDimVar(v1, v2)
215 :     | sameDim _ = false
216 :    
217 :     fun listToString fmt sep items = String.concatWith sep (List.map fmt items)
218 :    
219 :     fun diffToString diff = (case pruneDiff diff
220 :     of Ty.DiffConst n => Int.toString n
221 :     | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
222 :     | Ty.DiffVar(dv, i) => if i < 0
223 :     then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
224 :     else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
225 :     (* end case *))
226 :    
227 :     fun shapeToString shape = (case pruneShape shape
228 :     of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
229 :     | Ty.ShapeVar sv => MV.shapeVarToString sv
230 :     | Ty.ShapeExt(shape, d) => let
231 :     fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
232 :     | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
233 :     | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
234 :     in
235 :     toS shape ^ dimToString d
236 :     end
237 :     (* end case *))
238 :    
239 :     and dimToString dim = (case pruneDim dim
240 :     of Ty.DimConst n => Int.toString n
241 :     | Ty.DimVar v => MV.dimVarToString v
242 :     (* end case *))
243 :    
244 :     fun toString ty = (case pruneHead ty
245 :     of Ty.T_Var tv => MV.tyVarToString tv
246 :     | Ty.T_Bool => "bool"
247 :     | Ty.T_Int => "int"
248 :     | Ty.T_String => "string"
249 : jhr 3398 | Ty.T_Sequence(ty, NONE) => concat[toString ty, "[]"]
250 :     | Ty.T_Sequence(ty, SOME dim) => concat[toString ty, "[", dimToString dim, "]"]
251 : jhr 3384 | Ty.T_Named id => Atom.toString id
252 :     | Ty.T_Kernel n => "kernel#" ^ diffToString n
253 :     | Ty.T_Tensor(Ty.Shape[]) => "real"
254 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
255 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
256 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
257 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2, Ty.DimConst 2]) => "mat2"
258 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3, Ty.DimConst 3]) => "mat3"
259 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4, Ty.DimConst 4]) => "mat4"
260 :     | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
261 :     | Ty.T_Image{dim, shape} => concat[
262 :     "image(", dimToString dim, ")", shapeToString shape
263 :     ]
264 :     | Ty.T_Field{diff, dim, shape} => concat[
265 :     "field#", diffToString diff, "(", dimToString dim,
266 :     ")", shapeToString shape
267 :     ]
268 :     | Ty.T_Fun(tys1, ty2) => let
269 :     fun tysToString [] = "()"
270 :     | tysToString [ty] = toString ty
271 :     | tysToString tys = String.concat[
272 :     "(", listToString toString " * " tys, ")"
273 :     ]
274 :     in
275 :     String.concat[tysToString tys1, " -> ", toString ty2]
276 :     end
277 :     (* end case *))
278 :    
279 :     (* return the range (return type) of a function type *)
280 :     fun rngOf (Ty.T_Fun(_, ty)) = ty
281 :     | rngOf ty = raise Fail(concat["TypeUtil.rngOf(", toString ty, ")"])
282 :    
283 :     fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
284 :     fun f (d, true, dd) = dd
285 :     | f (d, false, dd) = d::dd
286 :     in
287 :     Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
288 :     end
289 :     | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])
290 :    
291 :     (* convert to fully resolved monomorphic forms *)
292 :     fun monoDim dim = (case pruneDim dim
293 :     of Ty.DimConst d => d
294 :     | dim => raise Fail(concat["dim ", dimToString dim, " is not constant"])
295 :     (* end case *))
296 :    
297 :     fun monoShape shp = (case pruneShape shp
298 :     of Ty.Shape shp => List.map monoDim shp
299 :     | shp => raise Fail(concat["shape ", shapeToString shp, " is not constant"])
300 :     (* end case *))
301 :    
302 :     fun monoDiff diff = (case pruneDiff diff
303 :     of Ty.DiffConst k => k
304 :     | diff => raise Fail(concat["diff ", diffToString diff, " is not constant"])
305 :     (* end case *))
306 :    
307 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0