Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/ast/type-util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3384 - (view) (download)

1 : jhr 3384 (* type-util.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure TypeUtil : sig
10 :    
11 :     (* constructor for building a tensor type of known order, but unknown
12 :     * dimensions.
13 :     *)
14 :     val mkTensorTy : int -> Types.ty
15 :    
16 :     (* constructor for building a sequence type of unknown size *)
17 :     val mkSequenceTy : Types.ty -> Types.ty
18 :    
19 :     (* function to compute the slice of a tensor type based on a boolean
20 :     * mask. The value true in the mask means that the corresponding
21 :     * dimension is being indexed, while false means that it is being
22 :     * copied.
23 :     *)
24 :     val slice : Types.ty * bool list -> Types.ty
25 :    
26 :     (* returns true if the type is a has a static size (i.e., not a dynamic sequence) *)
27 :     val isFixedSizeType : Types.ty -> bool
28 :    
29 :     (* returns true if the type is a value type (bool, int, string, tensor, seq, ...) *)
30 :     val isValueType : Types.ty -> bool
31 :    
32 :     (* return true if the type is an image type *)
33 :     val isImageType : Types.ty -> bool
34 :    
35 :     (* return the range (return type) of a function type *)
36 :     val rngOf : Types.ty -> Types.ty
37 :    
38 :     (* prune out instantiated meta variables from a type. We also normalize
39 :     * tensor shapes (i.e., remove 1s).
40 :     *)
41 :     val prune : Types.ty -> Types.ty
42 :     val pruneDiff : Types.diff -> Types.diff
43 :     val pruneShape : Types.shape -> Types.shape
44 :     val pruneDim : Types.dim -> Types.dim
45 :    
46 :     (* prune the head of a type *)
47 :     val pruneHead : Types.ty -> Types.ty
48 :    
49 :     (* resolve meta variables to their instantiations (or else variable) *)
50 :     val resolve : Types.ty_var -> Types.ty
51 :     val resolveDiff : Types.diff_var -> Types.diff
52 :     val resolveShape : Types.shape_var -> Types.shape
53 :     val resolveDim : Types.dim_var -> Types.dim
54 :     val resolveVar : Types.meta_var -> Types.var_bind
55 :    
56 :     (* equality testing *)
57 :     val sameDim : Types.dim * Types.dim -> bool
58 :    
59 :     (* string representations of types, etc *)
60 :     val toString : Types.ty -> string
61 :     val diffToString : Types.diff -> string
62 :     val shapeToString : Types.shape -> string
63 :     val dimToString : Types.dim -> string
64 :    
65 :     (* convert to fully resolved monomorphic forms *)
66 :     val monoDim : Types.dim -> int
67 :     val monoShape : Types.shape -> int list
68 :     val monoDiff : Types.diff -> int
69 :    
70 :     end = struct
71 :    
72 :     structure Ty = Types
73 :     structure MV = MetaVar
74 :    
75 :     (* constructor for building a tensor type of known order, but unknown
76 :     * dimensions.
77 :     *)
78 :     fun mkTensorTy order =
79 :     Ty.T_Tensor(
80 :     Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))
81 :    
82 :     fun mkSequenceTy ty =
83 :     Ty.T_Sequence(ty, Ty.DimVar(MetaVar.newDimVar()))
84 :    
85 :     (* prune out instantiated meta variables from a type. We also normalize
86 :     * tensor dimensions (i.e., remove 1s).
87 :     *)
88 :     fun prune ty = (case ty
89 :     of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
90 :     of NONE => ty
91 :     | SOME ty => prune ty
92 :     (* end case *))
93 :     | Ty.T_Sequence(ty, dim) => Ty.T_Sequence(prune ty, pruneDim dim)
94 :     | Ty.T_DynSequence ty => Ty.T_DynSequence(prune ty)
95 :     | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
96 :     | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
97 :     | (Ty.T_Image{dim, shape}) => Ty.T_Image{
98 :     dim = pruneDim dim,
99 :     shape = pruneShape shape
100 :     }
101 :     | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
102 :     diff = pruneDiff diff,
103 :     dim = pruneDim dim,
104 :     shape = pruneShape shape
105 :     }
106 :     | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
107 :     | ty => ty
108 :     (* end case *))
109 :    
110 :     and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
111 :     case pruneDiff diff
112 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
113 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
114 :     (* end case *))
115 :     | pruneDiff diff = diff
116 :    
117 :     and pruneDim dim = (case dim
118 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
119 :     | dim => dim
120 :     (* end case *))
121 :    
122 :     and filterDim dim = (case pruneDim dim
123 :     of Ty.DimConst 1 => NONE
124 :     | dim => SOME dim
125 :     (* end case *))
126 :    
127 :     and pruneShape shape = (case shape
128 :     of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
129 :     | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
130 :     | Ty.ShapeExt(shape, dim) => (case filterDim dim
131 :     of SOME dim => Ty.shapeExt(pruneShape shape, dim)
132 :     | NONE => pruneShape shape
133 :     (* end case *))
134 :     | _ => shape
135 :     (* end case *))
136 :    
137 :     (* resolve meta variables to their instantiations (or else variable) *)
138 :     fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
139 :     of NONE => Ty.T_Var tv
140 :     | SOME ty => prune ty
141 :     (* end case *))
142 :    
143 :     fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
144 :     of NONE => Ty.DiffVar(dv, 0)
145 :     | SOME diff => pruneDiff diff
146 :     (* end case *))
147 :    
148 :     fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
149 :     of NONE => Ty.ShapeVar sv
150 :     | SOME shape => pruneShape shape
151 :     (* end case *))
152 :    
153 :     fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
154 :     of NONE => Ty.DimVar dv
155 :     | SOME dim => pruneDim dim
156 :     (* end case *))
157 :    
158 :     fun resolveVar (Ty.TYPE tv) = Ty.TYPE(resolve tv)
159 :     | resolveVar (Ty.DIFF dv) = Ty.DIFF(resolveDiff dv)
160 :     | resolveVar (Ty.SHAPE sv) = Ty.SHAPE(resolveShape sv)
161 :     | resolveVar (Ty.DIM d) = Ty.DIM(resolveDim d)
162 :    
163 :     (* prune the head of a type *)
164 :     fun pruneHead ty = let
165 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
166 :     of NONE => ty
167 :     | SOME ty => prune' ty
168 :     (* end case *))
169 :     | prune' (Ty.T_Sequence(ty, dim)) = Ty.T_Sequence(ty, pruneDim dim)
170 :     | prune' (Ty.T_DynSequence ty) = Ty.T_DynSequence ty
171 :     | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
172 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
173 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
174 :     dim = pruneDim dim,
175 :     shape = pruneShape shape
176 :     }
177 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
178 :     diff = pruneDiff diff,
179 :     dim = pruneDim dim,
180 :     shape = pruneShape shape
181 :     }
182 :     | prune' ty = ty
183 :     in
184 :     prune' ty
185 :     end
186 :    
187 :     fun isFixedSizeType ty = (case prune ty
188 :     of Ty.T_Bool => true
189 :     | Ty.T_Int => true
190 :     | Ty.T_String => true
191 :     | Ty.T_Sequence _ => true
192 :     | Ty.T_Tensor _ => true
193 :     | _ => false
194 :     (* end case *))
195 :    
196 :     (* returns true if the type is a value type (bool, int, string, or tensor) *)
197 :     fun isValueType ty = (case prune ty
198 :     of Ty.T_Bool => true
199 :     | Ty.T_Int => true
200 :     | Ty.T_String => true
201 :     | Ty.T_Sequence _ => true
202 :     | Ty.T_DynSequence _ => true
203 :     | Ty.T_Tensor _ => true
204 :     | _ => false
205 :     (* end case *))
206 :    
207 :     (* returns true if the type is an ImageTy *)
208 :     fun isImageType ty = (case prune ty
209 :     of Ty.T_Image _ => true
210 :     | _ => false
211 :     (* end case *))
212 :    
213 :     (* equality testing *)
214 :     fun sameDim (Ty.DimConst d1, Ty.DimConst d2) = (d1 = d2)
215 :     | sameDim (Ty.DimVar v1, Ty.DimVar v2) = MetaVar.sameDimVar(v1, v2)
216 :     | sameDim _ = false
217 :    
218 :     fun listToString fmt sep items = String.concatWith sep (List.map fmt items)
219 :    
220 :     fun diffToString diff = (case pruneDiff diff
221 :     of Ty.DiffConst n => Int.toString n
222 :     | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
223 :     | Ty.DiffVar(dv, i) => if i < 0
224 :     then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
225 :     else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
226 :     (* end case *))
227 :    
228 :     fun shapeToString shape = (case pruneShape shape
229 :     of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
230 :     | Ty.ShapeVar sv => MV.shapeVarToString sv
231 :     | Ty.ShapeExt(shape, d) => let
232 :     fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
233 :     | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
234 :     | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
235 :     in
236 :     toS shape ^ dimToString d
237 :     end
238 :     (* end case *))
239 :    
240 :     and dimToString dim = (case pruneDim dim
241 :     of Ty.DimConst n => Int.toString n
242 :     | Ty.DimVar v => MV.dimVarToString v
243 :     (* end case *))
244 :    
245 :     fun toString ty = (case pruneHead ty
246 :     of Ty.T_Var tv => MV.tyVarToString tv
247 :     | Ty.T_Bool => "bool"
248 :     | Ty.T_Int => "int"
249 :     | Ty.T_String => "string"
250 :     | Ty.T_Sequence(ty, dim) => concat[toString ty, "[", dimToString dim, "]"]
251 :     | Ty.T_DynSequence ty => toString ty ^ "[]"
252 :     | Ty.T_Named id => Atom.toString id
253 :     | Ty.T_Kernel n => "kernel#" ^ diffToString n
254 :     | Ty.T_Tensor(Ty.Shape[]) => "real"
255 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
256 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
257 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
258 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2, Ty.DimConst 2]) => "mat2"
259 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3, Ty.DimConst 3]) => "mat3"
260 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4, Ty.DimConst 4]) => "mat4"
261 :     | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
262 :     | Ty.T_Image{dim, shape} => concat[
263 :     "image(", dimToString dim, ")", shapeToString shape
264 :     ]
265 :     | Ty.T_Field{diff, dim, shape} => concat[
266 :     "field#", diffToString diff, "(", dimToString dim,
267 :     ")", shapeToString shape
268 :     ]
269 :     | Ty.T_Fun(tys1, ty2) => let
270 :     fun tysToString [] = "()"
271 :     | tysToString [ty] = toString ty
272 :     | tysToString tys = String.concat[
273 :     "(", listToString toString " * " tys, ")"
274 :     ]
275 :     in
276 :     String.concat[tysToString tys1, " -> ", toString ty2]
277 :     end
278 :     (* end case *))
279 :    
280 :     (* return the range (return type) of a function type *)
281 :     fun rngOf (Ty.T_Fun(_, ty)) = ty
282 :     | rngOf ty = raise Fail(concat["TypeUtil.rngOf(", toString ty, ")"])
283 :    
284 :     fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
285 :     fun f (d, true, dd) = dd
286 :     | f (d, false, dd) = d::dd
287 :     in
288 :     Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
289 :     end
290 :     | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])
291 :    
292 :     (* convert to fully resolved monomorphic forms *)
293 :     fun monoDim dim = (case pruneDim dim
294 :     of Ty.DimConst d => d
295 :     | dim => raise Fail(concat["dim ", dimToString dim, " is not constant"])
296 :     (* end case *))
297 :    
298 :     fun monoShape shp = (case pruneShape shp
299 :     of Ty.Shape shp => List.map monoDim shp
300 :     | shp => raise Fail(concat["shape ", shapeToString shp, " is not constant"])
301 :     (* end case *))
302 :    
303 :     fun monoDiff diff = (case pruneDiff diff
304 :     of Ty.DiffConst k => k
305 :     | diff => raise Fail(concat["diff ", diffToString diff, " is not constant"])
306 :     (* end case *))
307 :    
308 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0