SCM Repository
Annotation of /trunk/src/compiler/ast/type-util.sml
Parent Directory
|
Revision Log
Revision 399 - (view) (download)
1 : | jhr | 63 | (* type-util.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | *) | ||
6 : | |||
7 : | structure TypeUtil : sig | ||
8 : | |||
9 : | jhr | 381 | (* constructor for building a tensor type of known order, but unknown |
10 : | * dimensions. | ||
11 : | *) | ||
12 : | val mkTensorTy : int -> Types.ty | ||
13 : | |||
14 : | jhr | 399 | (* function to compute the slice of a tensor type based on a boolean |
15 : | * mask. The value true in the mask means that the corresponding | ||
16 : | * dimension is being indexed, while false means that it is being | ||
17 : | * copied. | ||
18 : | *) | ||
19 : | val slice : Types.ty * bool list -> Types.ty | ||
20 : | |||
21 : | jhr | 228 | (* returns true if the type is a value type (bool, int, string, or tensor) *) |
22 : | val isValueType : Types.ty -> bool | ||
23 : | |||
24 : | jhr | 383 | (* prune out instantiated meta variables from a type. We also normalize |
25 : | * tensor shapes (i.e., remove 1s). | ||
26 : | *) | ||
27 : | jhr | 96 | val prune : Types.ty -> Types.ty |
28 : | val pruneDiff : Types.diff -> Types.diff | ||
29 : | val pruneShape : Types.shape -> Types.shape | ||
30 : | val pruneDim : Types.dim -> Types.dim | ||
31 : | |||
32 : | (* prune the head of a type *) | ||
33 : | val pruneHead : Types.ty -> Types.ty | ||
34 : | |||
35 : | (* resolve meta variables to their instantiations (or else variable) *) | ||
36 : | val resolve : Types.ty_var -> Types.ty | ||
37 : | val resolveDiff : Types.diff_var -> Types.diff | ||
38 : | val resolveShape : Types.shape_var -> Types.shape | ||
39 : | val resolveDim : Types.dim_var -> Types.dim | ||
40 : | |||
41 : | jhr | 95 | (* string representations of types, etc *) |
42 : | jhr | 63 | val toString : Types.ty -> string |
43 : | jhr | 95 | val diffToString : Types.diff -> string |
44 : | val shapeToString : Types.shape -> string | ||
45 : | val dimToString : Types.dim -> string | ||
46 : | jhr | 63 | |
47 : | end = struct | ||
48 : | |||
49 : | structure Ty = Types | ||
50 : | jhr | 75 | structure MV = MetaVar |
51 : | jhr | 63 | |
52 : | jhr | 381 | (* constructor for building a tensor type of known order, but unknown |
53 : | * dimensions. | ||
54 : | *) | ||
55 : | fun mkTensorTy order = | ||
56 : | Ty.T_Tensor( | ||
57 : | Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar())))) | ||
58 : | |||
59 : | jhr | 383 | (* prune out instantiated meta variables from a type. We also normalize |
60 : | * tensor dimensions (i.e., remove 1s). | ||
61 : | *) | ||
62 : | jhr | 96 | fun prune ty = (case ty |
63 : | of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind | ||
64 : | of NONE => ty | ||
65 : | | SOME ty => prune ty | ||
66 : | (* end case *)) | ||
67 : | | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff) | ||
68 : | | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape) | ||
69 : | | (Ty.T_Image{dim, shape}) => Ty.T_Image{ | ||
70 : | dim = pruneDim dim, | ||
71 : | shape = pruneShape shape | ||
72 : | } | ||
73 : | | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{ | ||
74 : | diff = pruneDiff diff, | ||
75 : | dim = pruneDim dim, | ||
76 : | shape = pruneShape shape | ||
77 : | } | ||
78 : | | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2) | ||
79 : | | ty => ty | ||
80 : | (* end case *)) | ||
81 : | jhr | 63 | |
82 : | jhr | 96 | and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = ( |
83 : | case pruneDiff diff | ||
84 : | of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i') | ||
85 : | | Ty.DiffConst i' => Ty.DiffConst(i+i') | ||
86 : | (* end case *)) | ||
87 : | | pruneDiff diff = diff | ||
88 : | jhr | 75 | |
89 : | jhr | 96 | and pruneDim dim = (case dim |
90 : | of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim | ||
91 : | | dim => dim | ||
92 : | (* end case *)) | ||
93 : | |||
94 : | jhr | 383 | and filterDim dim = (case pruneDim dim |
95 : | of Ty.DimConst 1 => NONE | ||
96 : | | dim => SOME dim | ||
97 : | (* end case *)) | ||
98 : | |||
99 : | jhr | 96 | and pruneShape shape = (case shape |
100 : | jhr | 383 | of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd) |
101 : | jhr | 241 | | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape |
102 : | jhr | 383 | | Ty.ShapeExt(shape, dim) => (case filterDim dim |
103 : | of SOME dim => Ty.shapeExt(pruneShape shape, dim) | ||
104 : | | NONE => pruneShape shape | ||
105 : | (* end case *)) | ||
106 : | jhr | 96 | | _ => shape |
107 : | (* end case *)) | ||
108 : | |||
109 : | (* resolve meta variables to their instantiations (or else variable) *) | ||
110 : | fun resolve (tv as Ty.TV{bind, ...}) = (case !bind | ||
111 : | of NONE => Ty.T_Var tv | ||
112 : | | SOME ty => prune ty | ||
113 : | (* end case *)) | ||
114 : | |||
115 : | fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind | ||
116 : | of NONE => Ty.DiffVar(dv, 0) | ||
117 : | | SOME diff => pruneDiff diff | ||
118 : | (* end case *)) | ||
119 : | |||
120 : | fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind | ||
121 : | of NONE => Ty.ShapeVar sv | ||
122 : | | SOME shape => pruneShape shape | ||
123 : | (* end case *)) | ||
124 : | |||
125 : | fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind | ||
126 : | of NONE => Ty.DimVar dv | ||
127 : | | SOME dim => pruneDim dim | ||
128 : | (* end case *)) | ||
129 : | |||
130 : | (* prune the head of a type *) | ||
131 : | fun pruneHead ty = let | ||
132 : | fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind | ||
133 : | of NONE => ty | ||
134 : | | SOME ty => prune' ty | ||
135 : | (* end case *)) | ||
136 : | | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff) | ||
137 : | | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape) | ||
138 : | | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{ | ||
139 : | dim = pruneDim dim, | ||
140 : | shape = pruneShape shape | ||
141 : | } | ||
142 : | | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{ | ||
143 : | diff = pruneDiff diff, | ||
144 : | dim = pruneDim dim, | ||
145 : | shape = pruneShape shape | ||
146 : | } | ||
147 : | | prune' ty = ty | ||
148 : | jhr | 63 | in |
149 : | jhr | 96 | prune' ty |
150 : | jhr | 63 | end |
151 : | |||
152 : | jhr | 96 | fun listToString fmt sep items = String.concatWith sep (List.map fmt items) |
153 : | jhr | 63 | |
154 : | jhr | 96 | fun diffToString diff = (case pruneDiff diff |
155 : | of Ty.DiffConst n => Int.toString n | ||
156 : | | Ty.DiffVar(dv, 0) => MV.diffVarToString dv | ||
157 : | | Ty.DiffVar(dv, i) => if i < 0 | ||
158 : | then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"] | ||
159 : | else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"] | ||
160 : | (* end case *)) | ||
161 : | |||
162 : | fun shapeToString shape = (case pruneShape shape | ||
163 : | of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"] | ||
164 : | | Ty.ShapeVar sv => MV.shapeVarToString sv | ||
165 : | | Ty.ShapeExt(shape, d) => let | ||
166 : | fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ "," | ||
167 : | | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";" | ||
168 : | | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","] | ||
169 : | in | ||
170 : | toS shape ^ dimToString d | ||
171 : | end | ||
172 : | (* end case *)) | ||
173 : | |||
174 : | and dimToString dim = (case pruneDim dim | ||
175 : | of Ty.DimConst n => Int.toString n | ||
176 : | | Ty.DimVar v => MV.dimVarToString v | ||
177 : | (* end case *)) | ||
178 : | |||
179 : | jhr | 228 | (* returns true if the type is a value type (bool, int, string, or tensor) *) |
180 : | fun isValueType ty = (case prune ty | ||
181 : | of Ty.T_Bool => true | ||
182 : | | Ty.T_Int => true | ||
183 : | | Ty.T_String => true | ||
184 : | | Ty.T_Tensor _ => true | ||
185 : | | _ => false | ||
186 : | (* end case *)) | ||
187 : | |||
188 : | jhr | 96 | fun toString ty = (case pruneHead ty |
189 : | jhr | 75 | of Ty.T_Var tv => MV.tyVarToString tv |
190 : | jhr | 63 | | Ty.T_Bool => "bool" |
191 : | | Ty.T_Int => "int" | ||
192 : | | Ty.T_String => "string" | ||
193 : | jhr | 75 | | Ty.T_Kernel n => "kernel#" ^ diffToString n |
194 : | | Ty.T_Tensor(Ty.Shape[]) => "real" | ||
195 : | | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2" | ||
196 : | | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3" | ||
197 : | | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4" | ||
198 : | jhr | 63 | | Ty.T_Tensor shape => "tensor" ^ shapeToString shape |
199 : | | Ty.T_Image{dim, shape} => concat[ | ||
200 : | jhr | 75 | "image(", dimToString dim, ")", shapeToString shape |
201 : | jhr | 63 | ] |
202 : | | Ty.T_Field{diff, dim, shape} => concat[ | ||
203 : | jhr | 75 | "field#", diffToString diff, "(", dimToString dim, |
204 : | jhr | 63 | ")", shapeToString shape |
205 : | ] | ||
206 : | jhr | 81 | | Ty.T_Fun(tys1, ty2) => let |
207 : | jhr | 63 | fun tysToString [] = "()" |
208 : | | tysToString [ty] = toString ty | ||
209 : | | tysToString tys = String.concat[ | ||
210 : | "(", listToString toString " * " tys, ")" | ||
211 : | ] | ||
212 : | in | ||
213 : | jhr | 81 | String.concat[tysToString tys1, " -> ", toString ty2] |
214 : | jhr | 63 | end |
215 : | (* end case *)) | ||
216 : | |||
217 : | jhr | 399 | fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let |
218 : | fun f (d, true, dd) = dd | ||
219 : | | f (d, false, dd) = d::dd | ||
220 : | in | ||
221 : | Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask))) | ||
222 : | end | ||
223 : | | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"]) | ||
224 : | |||
225 : | jhr | 63 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |