Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/compiler/ast/type-util.sml
ViewVC logotype

Annotation of /trunk/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 399 - (view) (download)

1 : jhr 63 (* type-util.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure TypeUtil : sig
8 :    
9 : jhr 381 (* constructor for building a tensor type of known order, but unknown
10 :     * dimensions.
11 :     *)
12 :     val mkTensorTy : int -> Types.ty
13 :    
14 : jhr 399 (* function to compute the slice of a tensor type based on a boolean
15 :     * mask. The value true in the mask means that the corresponding
16 :     * dimension is being indexed, while false means that it is being
17 :     * copied.
18 :     *)
19 :     val slice : Types.ty * bool list -> Types.ty
20 :    
21 : jhr 228 (* returns true if the type is a value type (bool, int, string, or tensor) *)
22 :     val isValueType : Types.ty -> bool
23 :    
24 : jhr 383 (* prune out instantiated meta variables from a type. We also normalize
25 :     * tensor shapes (i.e., remove 1s).
26 :     *)
27 : jhr 96 val prune : Types.ty -> Types.ty
28 :     val pruneDiff : Types.diff -> Types.diff
29 :     val pruneShape : Types.shape -> Types.shape
30 :     val pruneDim : Types.dim -> Types.dim
31 :    
32 :     (* prune the head of a type *)
33 :     val pruneHead : Types.ty -> Types.ty
34 :    
35 :     (* resolve meta variables to their instantiations (or else variable) *)
36 :     val resolve : Types.ty_var -> Types.ty
37 :     val resolveDiff : Types.diff_var -> Types.diff
38 :     val resolveShape : Types.shape_var -> Types.shape
39 :     val resolveDim : Types.dim_var -> Types.dim
40 :    
41 : jhr 95 (* string representations of types, etc *)
42 : jhr 63 val toString : Types.ty -> string
43 : jhr 95 val diffToString : Types.diff -> string
44 :     val shapeToString : Types.shape -> string
45 :     val dimToString : Types.dim -> string
46 : jhr 63
47 :     end = struct
48 :    
49 :     structure Ty = Types
50 : jhr 75 structure MV = MetaVar
51 : jhr 63
52 : jhr 381 (* constructor for building a tensor type of known order, but unknown
53 :     * dimensions.
54 :     *)
55 :     fun mkTensorTy order =
56 :     Ty.T_Tensor(
57 :     Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))
58 :    
59 : jhr 383 (* prune out instantiated meta variables from a type. We also normalize
60 :     * tensor dimensions (i.e., remove 1s).
61 :     *)
62 : jhr 96 fun prune ty = (case ty
63 :     of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
64 :     of NONE => ty
65 :     | SOME ty => prune ty
66 :     (* end case *))
67 :     | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
68 :     | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
69 :     | (Ty.T_Image{dim, shape}) => Ty.T_Image{
70 :     dim = pruneDim dim,
71 :     shape = pruneShape shape
72 :     }
73 :     | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
74 :     diff = pruneDiff diff,
75 :     dim = pruneDim dim,
76 :     shape = pruneShape shape
77 :     }
78 :     | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
79 :     | ty => ty
80 :     (* end case *))
81 : jhr 63
82 : jhr 96 and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
83 :     case pruneDiff diff
84 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
85 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
86 :     (* end case *))
87 :     | pruneDiff diff = diff
88 : jhr 75
89 : jhr 96 and pruneDim dim = (case dim
90 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
91 :     | dim => dim
92 :     (* end case *))
93 :    
94 : jhr 383 and filterDim dim = (case pruneDim dim
95 :     of Ty.DimConst 1 => NONE
96 :     | dim => SOME dim
97 :     (* end case *))
98 :    
99 : jhr 96 and pruneShape shape = (case shape
100 : jhr 383 of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
101 : jhr 241 | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
102 : jhr 383 | Ty.ShapeExt(shape, dim) => (case filterDim dim
103 :     of SOME dim => Ty.shapeExt(pruneShape shape, dim)
104 :     | NONE => pruneShape shape
105 :     (* end case *))
106 : jhr 96 | _ => shape
107 :     (* end case *))
108 :    
109 :     (* resolve meta variables to their instantiations (or else variable) *)
110 :     fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
111 :     of NONE => Ty.T_Var tv
112 :     | SOME ty => prune ty
113 :     (* end case *))
114 :    
115 :     fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
116 :     of NONE => Ty.DiffVar(dv, 0)
117 :     | SOME diff => pruneDiff diff
118 :     (* end case *))
119 :    
120 :     fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
121 :     of NONE => Ty.ShapeVar sv
122 :     | SOME shape => pruneShape shape
123 :     (* end case *))
124 :    
125 :     fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
126 :     of NONE => Ty.DimVar dv
127 :     | SOME dim => pruneDim dim
128 :     (* end case *))
129 :    
130 :     (* prune the head of a type *)
131 :     fun pruneHead ty = let
132 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
133 :     of NONE => ty
134 :     | SOME ty => prune' ty
135 :     (* end case *))
136 :     | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
137 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
138 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
139 :     dim = pruneDim dim,
140 :     shape = pruneShape shape
141 :     }
142 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
143 :     diff = pruneDiff diff,
144 :     dim = pruneDim dim,
145 :     shape = pruneShape shape
146 :     }
147 :     | prune' ty = ty
148 : jhr 63 in
149 : jhr 96 prune' ty
150 : jhr 63 end
151 :    
152 : jhr 96 fun listToString fmt sep items = String.concatWith sep (List.map fmt items)
153 : jhr 63
154 : jhr 96 fun diffToString diff = (case pruneDiff diff
155 :     of Ty.DiffConst n => Int.toString n
156 :     | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
157 :     | Ty.DiffVar(dv, i) => if i < 0
158 :     then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
159 :     else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
160 :     (* end case *))
161 :    
162 :     fun shapeToString shape = (case pruneShape shape
163 :     of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
164 :     | Ty.ShapeVar sv => MV.shapeVarToString sv
165 :     | Ty.ShapeExt(shape, d) => let
166 :     fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
167 :     | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
168 :     | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
169 :     in
170 :     toS shape ^ dimToString d
171 :     end
172 :     (* end case *))
173 :    
174 :     and dimToString dim = (case pruneDim dim
175 :     of Ty.DimConst n => Int.toString n
176 :     | Ty.DimVar v => MV.dimVarToString v
177 :     (* end case *))
178 :    
179 : jhr 228 (* returns true if the type is a value type (bool, int, string, or tensor) *)
180 :     fun isValueType ty = (case prune ty
181 :     of Ty.T_Bool => true
182 :     | Ty.T_Int => true
183 :     | Ty.T_String => true
184 :     | Ty.T_Tensor _ => true
185 :     | _ => false
186 :     (* end case *))
187 :    
188 : jhr 96 fun toString ty = (case pruneHead ty
189 : jhr 75 of Ty.T_Var tv => MV.tyVarToString tv
190 : jhr 63 | Ty.T_Bool => "bool"
191 :     | Ty.T_Int => "int"
192 :     | Ty.T_String => "string"
193 : jhr 75 | Ty.T_Kernel n => "kernel#" ^ diffToString n
194 :     | Ty.T_Tensor(Ty.Shape[]) => "real"
195 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
196 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
197 :     | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
198 : jhr 63 | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
199 :     | Ty.T_Image{dim, shape} => concat[
200 : jhr 75 "image(", dimToString dim, ")", shapeToString shape
201 : jhr 63 ]
202 :     | Ty.T_Field{diff, dim, shape} => concat[
203 : jhr 75 "field#", diffToString diff, "(", dimToString dim,
204 : jhr 63 ")", shapeToString shape
205 :     ]
206 : jhr 81 | Ty.T_Fun(tys1, ty2) => let
207 : jhr 63 fun tysToString [] = "()"
208 :     | tysToString [ty] = toString ty
209 :     | tysToString tys = String.concat[
210 :     "(", listToString toString " * " tys, ")"
211 :     ]
212 :     in
213 : jhr 81 String.concat[tysToString tys1, " -> ", toString ty2]
214 : jhr 63 end
215 :     (* end case *))
216 :    
217 : jhr 399 fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
218 :     fun f (d, true, dd) = dd
219 :     | f (d, false, dd) = d::dd
220 :     in
221 :     Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
222 :     end
223 :     | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])
224 :    
225 : jhr 63 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0