Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/ast/type-util.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1925 - (view) (download)

1 : jhr 63 (* type-util.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 63 * All rights reserved.
5 :     *)
6 :    
7 :     structure TypeUtil : sig
8 :    
9 : jhr 381 (* constructor for building a tensor type of known order, but unknown
10 :     * dimensions.
11 :     *)
12 :     val mkTensorTy : int -> Types.ty
13 :    
14 : jhr 1640 (* constructor for building a sequence type of unknown size *)
15 :     val mkSequenceTy : Types.ty -> Types.ty
16 :    
17 : jhr 399 (* function to compute the slice of a tensor type based on a boolean
18 :     * mask. The value true in the mask means that the corresponding
19 :     * dimension is being indexed, while false means that it is being
20 :     * copied.
21 :     *)
22 :     val slice : Types.ty * bool list -> Types.ty
23 :    
24 : jhr 1687 val isFixedSizeType : Types.ty -> bool
25 :    
26 : jhr 228 (* returns true if the type is a value type (bool, int, string, or tensor) *)
27 :     val isValueType : Types.ty -> bool
28 :    
29 : jhr 383 (* prune out instantiated meta variables from a type. We also normalize
30 :     * tensor shapes (i.e., remove 1s).
31 :     *)
32 : jhr 96 val prune : Types.ty -> Types.ty
33 :     val pruneDiff : Types.diff -> Types.diff
34 :     val pruneShape : Types.shape -> Types.shape
35 :     val pruneDim : Types.dim -> Types.dim
36 :    
37 :     (* prune the head of a type *)
38 :     val pruneHead : Types.ty -> Types.ty
39 :    
40 :     (* resolve meta variables to their instantiations (or else variable) *)
41 :     val resolve : Types.ty_var -> Types.ty
42 :     val resolveDiff : Types.diff_var -> Types.diff
43 :     val resolveShape : Types.shape_var -> Types.shape
44 :     val resolveDim : Types.dim_var -> Types.dim
45 : jhr 1925 val resolveVar : Types.meta_var -> Types.var_bind
46 : jhr 96
47 : jhr 1116 (* equality testing *)
48 :     val sameDim : Types.dim * Types.dim -> bool
49 :    
50 : jhr 95 (* string representations of types, etc *)
51 : jhr 63 val toString : Types.ty -> string
52 : jhr 95 val diffToString : Types.diff -> string
53 :     val shapeToString : Types.shape -> string
54 :     val dimToString : Types.dim -> string
55 : jhr 63
56 :     end = struct
57 :    
58 :     structure Ty = Types
59 : jhr 75 structure MV = MetaVar
60 : jhr 63
61 : jhr 381 (* constructor for building a tensor type of known order, but unknown
62 :     * dimensions.
63 :     *)
64 :     fun mkTensorTy order =
65 :     Ty.T_Tensor(
66 :     Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))
67 :    
68 : jhr 1640 fun mkSequenceTy ty =
69 :     Ty.T_Sequence(ty, Ty.DimVar(MetaVar.newDimVar()))
70 :    
71 : jhr 383 (* prune out instantiated meta variables from a type. We also normalize
72 :     * tensor dimensions (i.e., remove 1s).
73 :     *)
74 : jhr 96 fun prune ty = (case ty
75 :     of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
76 :     of NONE => ty
77 :     | SOME ty => prune ty
78 :     (* end case *))
79 : jhr 1116 | Ty.T_Sequence(ty, dim) => Ty.T_Sequence(prune ty, pruneDim dim)
80 : jhr 1687 | Ty.T_DynSequence ty => Ty.T_DynSequence(prune ty)
81 : jhr 96 | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
82 :     | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
83 :     | (Ty.T_Image{dim, shape}) => Ty.T_Image{
84 :     dim = pruneDim dim,
85 :     shape = pruneShape shape
86 :     }
87 :     | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
88 :     diff = pruneDiff diff,
89 :     dim = pruneDim dim,
90 :     shape = pruneShape shape
91 :     }
92 :     | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
93 :     | ty => ty
94 :     (* end case *))
95 : jhr 63
96 : jhr 96 and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
97 :     case pruneDiff diff
98 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
99 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
100 :     (* end case *))
101 :     | pruneDiff diff = diff
102 : jhr 75
103 : jhr 96 and pruneDim dim = (case dim
104 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
105 :     | dim => dim
106 :     (* end case *))
107 :    
108 : jhr 383 and filterDim dim = (case pruneDim dim
109 :     of Ty.DimConst 1 => NONE
110 :     | dim => SOME dim
111 :     (* end case *))
112 :    
113 : jhr 96 and pruneShape shape = (case shape
114 : jhr 383 of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
115 : jhr 241 | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
116 : jhr 383 | Ty.ShapeExt(shape, dim) => (case filterDim dim
117 :     of SOME dim => Ty.shapeExt(pruneShape shape, dim)
118 :     | NONE => pruneShape shape
119 :     (* end case *))
120 : jhr 96 | _ => shape
121 :     (* end case *))
122 :    
123 :     (* resolve meta variables to their instantiations (or else variable) *)
124 :     fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
125 :     of NONE => Ty.T_Var tv
126 :     | SOME ty => prune ty
127 :     (* end case *))
128 :    
129 :     fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
130 :     of NONE => Ty.DiffVar(dv, 0)
131 :     | SOME diff => pruneDiff diff
132 :     (* end case *))
133 :    
134 :     fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
135 :     of NONE => Ty.ShapeVar sv
136 :     | SOME shape => pruneShape shape
137 :     (* end case *))
138 :    
139 :     fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
140 :     of NONE => Ty.DimVar dv
141 :     | SOME dim => pruneDim dim
142 :     (* end case *))
143 :    
144 : jhr 1925 fun resolveVar (Ty.TYPE tv) = Ty.TYPE(resolve tv)
145 :     | resolveVar (Ty.DIFF dv) = Ty.DIFF(resolveDiff dv)
146 :     | resolveVar (Ty.SHAPE sv) = Ty.SHAPE(resolveShape sv)
147 :     | resolveVar (Ty.DIM d) = Ty.DIM(resolveDim d)
148 :    
149 : jhr 96 (* prune the head of a type *)
150 :     fun pruneHead ty = let
151 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
152 :     of NONE => ty
153 :     | SOME ty => prune' ty
154 :     (* end case *))
155 : jhr 1116 | prune' (Ty.T_Sequence(ty, dim)) = Ty.T_Sequence(ty, pruneDim dim)
156 : jhr 1687 | prune' (Ty.T_DynSequence ty) = Ty.T_DynSequence ty
157 : jhr 96 | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
158 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
159 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
160 :     dim = pruneDim dim,
161 :     shape = pruneShape shape
162 :     }
163 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
164 :     diff = pruneDiff diff,
165 :     dim = pruneDim dim,
166 :     shape = pruneShape shape
167 :     }
168 :     | prune' ty = ty
169 : jhr 63 in
170 : jhr 96 prune' ty
171 : jhr 63 end
172 :    
173 : jhr 1687 fun isFixedSizeType ty = (case prune ty
174 :     of Ty.T_Bool => true
175 :     | Ty.T_Int => true
176 :     | Ty.T_String => true
177 :     | Ty.T_Sequence _ => true
178 :     | Ty.T_Tensor _ => true
179 :     | _ => false
180 :     (* end case *))
181 :    
182 : jhr 1116 (* returns true if the type is a value type (bool, int, string, or tensor) *)
183 :     fun isValueType ty = (case prune ty
184 :     of Ty.T_Bool => true
185 :     | Ty.T_Int => true
186 :     | Ty.T_String => true
187 :     | Ty.T_Sequence _ => true
188 : jhr 1687 | Ty.T_DynSequence _ => true
189 : jhr 1116 | Ty.T_Tensor _ => true
190 :     | _ => false
191 :     (* end case *))
192 :    
193 :     (* equality testing *)
194 :     fun sameDim (Ty.DimConst d1, Ty.DimConst d2) = (d1 = d2)
195 :     | sameDim (Ty.DimVar v1, Ty.DimVar v2) = MetaVar.sameDimVar(v1, v2)
196 :     | sameDim _ = false
197 :    
198 : jhr 96 fun listToString fmt sep items = String.concatWith sep (List.map fmt items)
199 : jhr 63
200 : jhr 96 fun diffToString diff = (case pruneDiff diff
201 :     of Ty.DiffConst n => Int.toString n
202 :     | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
203 :     | Ty.DiffVar(dv, i) => if i < 0
204 :     then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
205 :     else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
206 :     (* end case *))
207 :    
208 :     fun shapeToString shape = (case pruneShape shape
209 :     of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
210 :     | Ty.ShapeVar sv => MV.shapeVarToString sv
211 :     | Ty.ShapeExt(shape, d) => let
212 :     fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
213 :     | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
214 :     | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
215 :     in
216 :     toS shape ^ dimToString d
217 :     end
218 :     (* end case *))
219 :    
220 :     and dimToString dim = (case pruneDim dim
221 :     of Ty.DimConst n => Int.toString n
222 :     | Ty.DimVar v => MV.dimVarToString v
223 :     (* end case *))
224 :    
225 :     fun toString ty = (case pruneHead ty
226 : jhr 75 of Ty.T_Var tv => MV.tyVarToString tv
227 : jhr 63 | Ty.T_Bool => "bool"
228 :     | Ty.T_Int => "int"
229 :     | Ty.T_String => "string"
230 : jhr 1640 | Ty.T_Sequence(ty, dim) => concat[toString ty, "{", dimToString dim, "}"]
231 : jhr 1687 | Ty.T_DynSequence ty => toString ty ^ "{}"
232 : jhr 75 | Ty.T_Kernel n => "kernel#" ^ diffToString n
233 :     | Ty.T_Tensor(Ty.Shape[]) => "real"
234 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
235 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
236 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
237 : jhr 63 | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
238 :     | Ty.T_Image{dim, shape} => concat[
239 : jhr 75 "image(", dimToString dim, ")", shapeToString shape
240 : jhr 63 ]
241 :     | Ty.T_Field{diff, dim, shape} => concat[
242 : jhr 75 "field#", diffToString diff, "(", dimToString dim,
243 : jhr 63 ")", shapeToString shape
244 :     ]
245 : jhr 81 | Ty.T_Fun(tys1, ty2) => let
246 : jhr 63 fun tysToString [] = "()"
247 :     | tysToString [ty] = toString ty
248 :     | tysToString tys = String.concat[
249 :     "(", listToString toString " * " tys, ")"
250 :     ]
251 :     in
252 : jhr 81 String.concat[tysToString tys1, " -> ", toString ty2]
253 : jhr 63 end
254 :     (* end case *))
255 :    
256 : jhr 399 fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
257 :     fun f (d, true, dd) = dd
258 :     | f (d, false, dd) = d::dd
259 :     in
260 :     Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
261 :     end
262 :     | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])
263 :    
264 : jhr 63 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0