Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/typechecker/util.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3291 - (view) (download)

1 : jhr 80 (* util.sml
2 :     *
3 : jhr 3291 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : jhr 80 * All rights reserved.
7 :     *
8 :     * Utilities for typechecking
9 :     *)
10 :    
11 : jhr 85 structure Util : sig
12 : jhr 80
13 : jhr 1687 (* when matching two types (ty1 and ty2), there are three possible outcomes:
14 :     * EQ -- types are equal
15 :     * COERCE -- ty2 can be coerced to match ty1 (e.g., int -> float, fixed seq -> dynamic seq)
16 :     * FAIL -- types do not match
17 :     *)
18 :     datatype match = EQ | COERCE | FAIL
19 : jhr 85
20 : jhr 1687 val matchType : Types.ty * Types.ty -> match
21 :    
22 : jhr 1971 val tryMatchType : Types.ty * Types.ty -> match
23 :    
24 : jhr 1973 (* attempt to match a list of parameter types with a list of typed arguments. Return
25 :     * the arguments with any required coercions, or NONE on failure.
26 :     *)
27 : jhr 1975 val matchArgs : Types.ty list * AST.expr list * Types.ty list -> AST.expr list option
28 : jhr 1973 val tryMatchArgs : Types.ty list * AST.expr list * Types.ty list -> AST.expr list option
29 :    
30 : jhr 1687 val equalType : Types.ty * Types.ty -> bool
31 :     val equalTypes : Types.ty list * Types.ty list -> bool
32 :    
33 : jhr 1971 val tryEqualType : Types.ty * Types.ty -> bool
34 :     val tryEqualTypes : Types.ty list * Types.ty list -> bool
35 : jhr 85
36 :     val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)
37 :    
38 : jhr 1687 val equalDim : Types.dim * Types.dim -> bool
39 : jhr 1113
40 : jhr 85 end = struct
41 :    
42 : jhr 80 structure Ty = Types
43 : jhr 81 structure MV = MetaVar
44 : jhr 96 structure TU = TypeUtil
45 : jhr 80
46 : jhr 1687 datatype match = EQ | COERCE | FAIL
47 :    
48 : jhr 85 (* a patch list tracks the meta variables that have been updated so that we can undo
49 :     * the effects of unification when just testing for a possible type match.
50 :     *)
51 :    
52 :     fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = (
53 :     bind := SOME ty;
54 :     pl := Ty.TYPE tv :: !pl)
55 :     | bindTyVar _ = raise Fail "rebinding type variable"
56 :    
57 :     fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = (
58 :     bind := SOME diff;
59 :     pl := Ty.DIFF dv :: !pl)
60 :     | bindDiffVar _ = raise Fail "rebinding differentiation variable"
61 :    
62 :     fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = (
63 :     bind := SOME shape;
64 :     pl := Ty.SHAPE sv :: !pl)
65 :     | bindShapeVar _ = raise Fail "rebinding shape variable"
66 :    
67 :     fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = (
68 :     bind := SOME dim;
69 :     pl := Ty.DIM dv :: !pl)
70 :     | bindDimVar _ = raise Fail "rebinding dimension variable"
71 :    
72 :     fun undo pl = let
73 :     fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE
74 :     | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE
75 :     | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE
76 :     | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE
77 :     in
78 :     List.map undo1 (!pl)
79 :     end
80 :    
81 : jhr 81 (* FIXME: what about the bounds? *)
82 : jhr 1687 fun equalDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
83 : jhr 81 of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
84 : jhr 85 | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
85 : jhr 81 val k' = k+i
86 :     in
87 :     if k' < 0 then false
88 : jhr 85 else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
89 : jhr 81 end
90 : jhr 85 | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let
91 : jhr 81 val k' = k+i
92 :     in
93 :     if k' < 0 then false
94 : jhr 85 else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
95 : jhr 81 end
96 :     | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
97 :     (* end case *))
98 :    
99 : jhr 1687 (* match two differentiation constants where the first is allowed to be less than the second *)
100 :     fun matchDiff (diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
101 :     of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 <= k2)
102 :     | _ => raise Fail "unimplemented" (* FIXME *)
103 :     (* end case *))
104 :    
105 :     fun equalDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2)
106 : jhr 81 of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
107 : jhr 85 | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
108 :     | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
109 : jhr 81 (* end case *))
110 :    
111 : jhr 1687 fun equalShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2)
112 : jhr 81 of (Ty.Shape dd1, Ty.Shape dd2) => let
113 :     fun chk ([], []) = true
114 : jhr 1687 | chk (d1::dd1, d2::dd2) = equalDim(pl, d1, d2) andalso chk (dd1, dd2)
115 : jhr 81 | chk _ = false
116 :     in
117 :     chk (dd1, dd2)
118 :     end
119 :     | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
120 :     fun chk ([], _) = false
121 :     | chk ([d], revDD) =
122 : jhr 1687 equalDim(pl, d, d2) andalso equalShape(pl, Ty.Shape(List.rev revDD), shape)
123 : jhr 81 | chk (d::dd, revDD) = chk(dd, d::revDD)
124 :     in
125 :     chk (dd, [])
126 :     end
127 : jhr 85 | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
128 : jhr 81 | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
129 : jhr 1687 equalDim(pl, d1, d2) andalso equalShape(pl, shape1, shape2)
130 :     | (shape1, shape2) => equalShape(pl, shape2, shape1)
131 : jhr 81 (* end case *))
132 :    
133 :     (* QUESTION: do we need an occurs check? *)
134 : jhr 85 fun unifyType (pl, ty1, ty2) = let
135 : jhr 1640 fun matchVar (tv1 as Ty.TV{id=id1, bind=b1}, tv2 as Ty.TV{id=id2, bind=b2}) =
136 : jhr 81 if Stamp.same(id1, id2)
137 : jhr 1640 then true
138 :     else (case (!b1, !b2)
139 :     of (SOME ty1, SOME ty2) => match(ty1, ty2)
140 :     | (SOME ty1, NONE) => (bindTyVar (pl, tv2, ty1); true)
141 :     | (NONE, SOME ty2) => (bindTyVar (pl, tv1, ty2); true)
142 :     | (NONE, NONE) => (bindTyVar (pl, tv1, Ty.T_Var tv2); true)
143 :     (* end case *))
144 :     and matchVarTy (tv as Ty.TV{bind, ...}, ty) = (case !bind
145 :     of NONE => (bindTyVar(pl, tv, ty); true)
146 :     | SOME ty' => match(ty', ty)
147 :     (* end case *))
148 :     and match (Ty.T_Var tv1, Ty.T_Var tv2) = matchVar(tv1, tv2)
149 :     | match (Ty.T_Var tv1, ty2) = matchVarTy(tv1, ty2)
150 :     | match (ty1, Ty.T_Var tv2) = matchVarTy(tv2, ty1)
151 : jhr 80 | match (Ty.T_Bool, Ty.T_Bool) = true
152 :     | match (Ty.T_Int, Ty.T_Int) = true
153 :     | match (Ty.T_String, Ty.T_String) = true
154 : jhr 1640 | match (Ty.T_Sequence(ty1, d1), Ty.T_Sequence(ty2, d2)) =
155 : jhr 1687 equalDim(pl, d1, d2) andalso match(ty1, ty2)
156 :     | match (Ty.T_DynSequence ty1, Ty.T_DynSequence ty2) = match(ty1, ty2)
157 :     | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = equalDiff (pl, k1, k2)
158 :     | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = equalShape (pl, s1, s2)
159 : jhr 80 | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
160 : jhr 1687 equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
161 : jhr 80 | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
162 : jhr 1687 equalDiff (pl, k1, k2) andalso equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
163 : jhr 81 | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
164 : jhr 85 ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
165 : jhr 80 | match _ = false
166 :     in
167 : jhr 96 match (TU.pruneHead ty1, TU.pruneHead ty2)
168 : jhr 80 end
169 :    
170 : jhr 1971 fun unifyTypeWithCoercion (pl, ty1, ty2) = (case (TU.pruneHead ty1, TU.pruneHead ty2)
171 : jhr 1973 of (Ty.T_Tensor shp, Ty.T_Int) =>
172 :     if equalShape (pl, Ty.Shape[], shp) then COERCE else FAIL
173 : jhr 1971 | (Ty.T_DynSequence ty1, Ty.T_Sequence(ty2, _)) =>
174 :     if unifyType(pl, ty1, ty2) then COERCE else FAIL
175 :     | (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =>
176 :     if unifyType(pl, ty1, ty2)
177 :     then EQ
178 :     else if matchDiff (k1, k2) andalso equalDim(pl, d1, d2)
179 :     andalso equalShape(pl, s1, s2)
180 :     then COERCE
181 :     else FAIL
182 :     | (ty1, ty2) => if unifyType(pl, ty1, ty2) then EQ else FAIL
183 :     (* end case *))
184 : jhr 1991 (* +DEBUG *
185 :     val unifyTypeWithCoercion = fn (pl, ty1, ty2) => let
186 :     val res = unifyTypeWithCoercion (pl, ty1, ty2)
187 :     val res' = (case res of EQ => "EQ" | COERCE => "COERCE" | FAIL => "FAIL")
188 :     in
189 :     print(concat["unifyTypeWithCoercion (_, ", TU.toString ty1, ", ", TU.toString ty2, ") = ", res', "\n"]);
190 :     res
191 :     end
192 :     * -DEBUG *)
193 : jhr 1971
194 : jhr 1687 fun equalTypes (tys1, tys2) = let
195 : jhr 85 val pl = ref[]
196 :     in
197 :     ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
198 :     end
199 : jhr 81
200 : jhr 1687 fun equalType (ty1, ty2) = unifyType (ref[], ty1, ty2)
201 : jhr 85
202 :     (* try to match types; if we fail, all meta-variable bindings are undone *)
203 : jhr 1971 fun tryEqualType (ty1, ty2) = let
204 : jhr 85 val pl = ref[]
205 :     in
206 :     unifyType(pl, ty1, ty2) orelse (undo pl; false)
207 :     end
208 :    
209 : jhr 1971 (* try to unify two types to equality; if we fail, all meta-variable bindings are undone *)
210 :     fun tryEqualTypes (tys1, tys2) = let
211 : jhr 85 val pl = ref[]
212 :     in
213 :     ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
214 :     orelse (undo pl; false)
215 :     end
216 :    
217 : jhr 1971 fun matchType (ty1, ty2) = unifyTypeWithCoercion (ref[], ty1, ty2)
218 :    
219 :     (* try to unify two type lists to equality; if we fail, all meta-variable bindings are undone *)
220 :     fun tryMatchType (ty1, ty2) = let
221 :     val pl = ref[]
222 :     in
223 :     case unifyTypeWithCoercion (pl, ty1, ty2)
224 :     of FAIL => (undo pl; FAIL)
225 :     | result => result
226 :     (* end case *)
227 :     end
228 :    
229 : jhr 1973 (* attempt to match a list of parameter types with a list of typed arguments. Return
230 :     * the arguments with any required coercions, or NONE on failure.
231 :     *)
232 : jhr 1975 local
233 :     fun matchArgs' (pl, paramTys, args, argTys) = let
234 :     fun matchArgTys ([], [], [], args') = SOME(List.rev args')
235 :     | matchArgTys (ty1::tys1, arg::args, ty2::tys2, args') = (
236 :     case unifyTypeWithCoercion (pl, ty1, ty2)
237 :     of EQ => matchArgTys (tys1, args, tys2, arg::args')
238 :     | COERCE => matchArgTys (tys1, args, tys2, AST.E_Coerce{srcTy=ty2, dstTy=ty1, e=arg}::args')
239 :     | _ => (undo pl; NONE)
240 :     (* end case *))
241 : jhr 2170 | matchArgTys _ = NONE
242 : jhr 1975 in
243 :     matchArgTys (paramTys, args, argTys, [])
244 :     end
245 :     in
246 :     fun matchArgs (paramTys, args, argTys) = matchArgs' (ref[], paramTys, args, argTys)
247 : jhr 1973 fun tryMatchArgs (paramTys, args, argTys) = let
248 :     val pl = ref[]
249 :     in
250 : jhr 1975 case matchArgs' (ref[], paramTys, args, argTys)
251 :     of NONE => (undo pl; NONE)
252 :     | someResult => someResult
253 :     (* end case *)
254 : jhr 1973 end
255 : jhr 1975 end
256 : jhr 1973
257 : jhr 1687 (* rebind equalDim without patch-list argument *)
258 :     val equalDim = fn (d1, d2) => equalDim(ref [], d1, d2)
259 : jhr 1113
260 : jhr 96 (* QUESTION: perhaps this function belongs in the TypeUtil module? *)
261 : jhr 81 (* instantiate a type scheme, returning the argument meta variables and the resulting type.
262 :     * Note that we assume that the scheme is closed.
263 :     *)
264 :     fun instantiate ([], ty) = ([], ty)
265 :     | instantiate (mvs, ty) = let
266 :     fun instantiateVar (mv, (mvs, env)) = let
267 :     val mv' = MV.copy mv
268 :     in
269 :     (mv'::mvs, MV.Map.insert(env, mv, mv'))
270 :     end
271 : jhr 242 val (mvs, env) = List.foldr instantiateVar ([], MV.Map.empty) mvs
272 : jhr 82 fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
273 :     of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
274 :     | _ => raise Fail "impossible"
275 :     (* end case *))
276 :     | iDiff diff = diff
277 :     fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
278 :     of SOME(Ty.DIM dv) => Ty.DimVar dv
279 :     | _ => raise Fail "impossible"
280 :     (* end case *))
281 :     | iDim dim = dim
282 :     fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
283 :     of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
284 :     | _ => raise Fail "impossible"
285 :     (* end case *))
286 :     | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
287 :     | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
288 : jhr 81 fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
289 :     of SOME(Ty.TYPE tv) => Ty.T_Var tv
290 :     | _ => raise Fail "impossible"
291 :     (* end case *))
292 : jhr 1925 | ity Ty.T_Bool = Ty.T_Bool
293 :     | ity Ty.T_Int = Ty.T_Int
294 :     | ity Ty.T_String = Ty.T_String
295 : jhr 1640 | ity (Ty.T_Sequence(ty, d)) = Ty.T_Sequence(ity ty, iDim d)
296 : jhr 1925 | ity (Ty.T_DynSequence ty) = Ty.T_DynSequence(ity ty)
297 : jhr 82 | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
298 :     | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
299 :     | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
300 :     | ity (Ty.T_Field{diff, dim, shape}) =
301 :     Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
302 :     | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
303 : jhr 81 in
304 : jhr 82 (mvs, ity ty)
305 : jhr 81 end
306 :    
307 : jhr 80 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0