1 : |
jhr |
80 |
(* util.sml
|
2 : |
|
|
*
|
3 : |
|
|
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
|
4 : |
|
|
* All rights reserved.
|
5 : |
|
|
*
|
6 : |
|
|
* Utilities for typechecking
|
7 : |
|
|
*)
|
8 : |
|
|
|
9 : |
jhr |
85 |
structure Util : sig
|
10 : |
jhr |
80 |
|
11 : |
jhr |
85 |
val prune : Types.ty -> Types.ty
|
12 : |
|
|
|
13 : |
|
|
val matchType : Types.ty * Types.ty -> bool
|
14 : |
|
|
val matchTypes : Types.ty list * Types.ty list -> bool
|
15 : |
|
|
|
16 : |
|
|
val tryMatchType : Types.ty * Types.ty -> bool
|
17 : |
|
|
val tryMatchTypes : Types.ty list * Types.ty list -> bool
|
18 : |
|
|
|
19 : |
|
|
val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)
|
20 : |
|
|
|
21 : |
|
|
end = struct
|
22 : |
|
|
|
23 : |
jhr |
80 |
structure Ty = Types
|
24 : |
jhr |
81 |
structure MV = MetaVar
|
25 : |
jhr |
80 |
|
26 : |
|
|
(* prune out instantiated meta variables from a type *)
|
27 : |
|
|
fun prune ty = let
|
28 : |
|
|
fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
|
29 : |
|
|
of NONE => ty
|
30 : |
|
|
| SOME ty => prune' ty
|
31 : |
|
|
(* end case *))
|
32 : |
|
|
| prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
|
33 : |
|
|
| prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
|
34 : |
|
|
| prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
|
35 : |
|
|
dim = pruneDim dim,
|
36 : |
|
|
shape = pruneShape shape
|
37 : |
|
|
}
|
38 : |
|
|
| prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
|
39 : |
|
|
diff = pruneDiff diff,
|
40 : |
|
|
dim = pruneDim dim,
|
41 : |
|
|
shape = pruneShape shape
|
42 : |
|
|
}
|
43 : |
jhr |
81 |
| prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2)
|
44 : |
jhr |
80 |
| prune' ty = ty
|
45 : |
|
|
in
|
46 : |
|
|
prune' ty
|
47 : |
|
|
end
|
48 : |
|
|
|
49 : |
jhr |
81 |
and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
|
50 : |
|
|
case pruneDiff diff
|
51 : |
|
|
of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
|
52 : |
|
|
| Ty.DiffConst i' => Ty.DiffConst(i+i')
|
53 : |
|
|
(* end case *))
|
54 : |
|
|
| pruneDiff diff = diff
|
55 : |
|
|
|
56 : |
|
|
and pruneDim dim = (case dim
|
57 : |
|
|
of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
|
58 : |
|
|
| dim => dim
|
59 : |
|
|
(* end case *))
|
60 : |
|
|
|
61 : |
|
|
and pruneShape shape = (case shape
|
62 : |
|
|
of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
|
63 : |
|
|
| Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim)
|
64 : |
|
|
| _ => shape
|
65 : |
|
|
(* end case *))
|
66 : |
|
|
|
67 : |
jhr |
85 |
(* a patch list tracks the meta variables that have been updated so that we can undo
|
68 : |
|
|
* the effects of unification when just testing for a possible type match.
|
69 : |
|
|
*)
|
70 : |
|
|
|
71 : |
|
|
fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = (
|
72 : |
|
|
bind := SOME ty;
|
73 : |
|
|
pl := Ty.TYPE tv :: !pl)
|
74 : |
|
|
| bindTyVar _ = raise Fail "rebinding type variable"
|
75 : |
|
|
|
76 : |
|
|
fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = (
|
77 : |
|
|
bind := SOME diff;
|
78 : |
|
|
pl := Ty.DIFF dv :: !pl)
|
79 : |
|
|
| bindDiffVar _ = raise Fail "rebinding differentiation variable"
|
80 : |
|
|
|
81 : |
|
|
fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = (
|
82 : |
|
|
bind := SOME shape;
|
83 : |
|
|
pl := Ty.SHAPE sv :: !pl)
|
84 : |
|
|
| bindShapeVar _ = raise Fail "rebinding shape variable"
|
85 : |
|
|
|
86 : |
|
|
fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = (
|
87 : |
|
|
bind := SOME dim;
|
88 : |
|
|
pl := Ty.DIM dv :: !pl)
|
89 : |
|
|
| bindDimVar _ = raise Fail "rebinding dimension variable"
|
90 : |
|
|
|
91 : |
|
|
fun undo pl = let
|
92 : |
|
|
fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE
|
93 : |
|
|
| undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE
|
94 : |
|
|
| undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE
|
95 : |
|
|
| undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE
|
96 : |
|
|
in
|
97 : |
|
|
List.map undo1 (!pl)
|
98 : |
|
|
end
|
99 : |
|
|
|
100 : |
jhr |
81 |
(* FIXME: what about the bounds? *)
|
101 : |
jhr |
85 |
fun matchDiff (pl, diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)
|
102 : |
jhr |
81 |
of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
|
103 : |
jhr |
85 |
| (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
|
104 : |
jhr |
81 |
val k' = k+i
|
105 : |
|
|
in
|
106 : |
|
|
if k' < 0 then false
|
107 : |
jhr |
85 |
else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
|
108 : |
jhr |
81 |
end
|
109 : |
jhr |
85 |
| (Ty.DiffVar(dv, i), Ty.DiffConst k) => let
|
110 : |
jhr |
81 |
val k' = k+i
|
111 : |
|
|
in
|
112 : |
|
|
if k' < 0 then false
|
113 : |
jhr |
85 |
else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
|
114 : |
jhr |
81 |
end
|
115 : |
|
|
| (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
|
116 : |
|
|
(* end case *))
|
117 : |
|
|
|
118 : |
jhr |
85 |
fun matchDim (pl, dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)
|
119 : |
jhr |
81 |
of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
|
120 : |
jhr |
85 |
| (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
|
121 : |
|
|
| (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
|
122 : |
jhr |
81 |
(* end case *))
|
123 : |
|
|
|
124 : |
jhr |
85 |
fun matchShape (pl, shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)
|
125 : |
jhr |
81 |
of (Ty.Shape dd1, Ty.Shape dd2) => let
|
126 : |
|
|
fun chk ([], []) = true
|
127 : |
jhr |
85 |
| chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2)
|
128 : |
jhr |
81 |
| chk _ = false
|
129 : |
|
|
in
|
130 : |
|
|
chk (dd1, dd2)
|
131 : |
|
|
end
|
132 : |
|
|
| (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
|
133 : |
|
|
fun chk ([], _) = false
|
134 : |
|
|
| chk ([d], revDD) =
|
135 : |
jhr |
85 |
matchDim(pl, d, d2) andalso matchShape(pl, Ty.Shape(List.rev revDD), shape)
|
136 : |
jhr |
81 |
| chk (d::dd, revDD) = chk(dd, d::revDD)
|
137 : |
|
|
in
|
138 : |
|
|
chk (dd, [])
|
139 : |
|
|
end
|
140 : |
jhr |
85 |
| (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
|
141 : |
jhr |
81 |
| (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
|
142 : |
jhr |
85 |
matchDim(pl, d1, d2) andalso matchShape(pl, shape1, shape2)
|
143 : |
|
|
| (shape1, shape2) => matchShape(pl, shape2, shape1)
|
144 : |
jhr |
81 |
(* end case *))
|
145 : |
|
|
|
146 : |
|
|
(* QUESTION: do we need an occurs check? *)
|
147 : |
jhr |
85 |
fun unifyType (pl, ty1, ty2) = let
|
148 : |
jhr |
81 |
fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =
|
149 : |
|
|
if Stamp.same(id1, id2)
|
150 : |
|
|
then ()
|
151 : |
jhr |
85 |
else bindTyVar (pl, tv1, Ty.T_Var tv2)
|
152 : |
jhr |
81 |
fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)
|
153 : |
jhr |
85 |
| match (Ty.T_Var tv1, ty2) = (bindTyVar(pl, tv1, ty2); true)
|
154 : |
|
|
| match (ty1, Ty.T_Var tv2) = (bindTyVar(pl, tv2, ty2); true)
|
155 : |
jhr |
80 |
| match (Ty.T_Bool, Ty.T_Bool) = true
|
156 : |
|
|
| match (Ty.T_Int, Ty.T_Int) = true
|
157 : |
|
|
| match (Ty.T_String, Ty.T_String) = true
|
158 : |
jhr |
85 |
| match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (pl, k1, k2)
|
159 : |
|
|
| match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (pl, s1, s2)
|
160 : |
jhr |
80 |
| match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
|
161 : |
jhr |
85 |
matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)
|
162 : |
jhr |
80 |
| match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
|
163 : |
jhr |
85 |
matchDiff (pl, k1, k2) andalso matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)
|
164 : |
jhr |
81 |
| match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
|
165 : |
jhr |
85 |
ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
|
166 : |
jhr |
80 |
| match _ = false
|
167 : |
|
|
in
|
168 : |
|
|
match (prune ty1, prune ty2)
|
169 : |
|
|
end
|
170 : |
|
|
|
171 : |
jhr |
85 |
fun matchTypes (tys1, tys2) = let
|
172 : |
|
|
val pl = ref[]
|
173 : |
|
|
in
|
174 : |
|
|
ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
|
175 : |
|
|
end
|
176 : |
jhr |
81 |
|
177 : |
jhr |
85 |
fun matchType (ty1, ty2) = unifyType (ref[], ty1, ty2)
|
178 : |
|
|
|
179 : |
|
|
(* try to match types; if we fail, all meta-variable bindings are undone *)
|
180 : |
|
|
fun tryMatchType (ty1, ty2) = let
|
181 : |
|
|
val pl = ref[]
|
182 : |
|
|
in
|
183 : |
|
|
unifyType(pl, ty1, ty2) orelse (undo pl; false)
|
184 : |
|
|
end
|
185 : |
|
|
|
186 : |
|
|
(* try to match types; if we fail, all meta-variable bindings are undone *)
|
187 : |
|
|
fun tryMatchTypes (tys1, tys2) = let
|
188 : |
|
|
val pl = ref[]
|
189 : |
|
|
in
|
190 : |
|
|
ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
|
191 : |
|
|
orelse (undo pl; false)
|
192 : |
|
|
end
|
193 : |
|
|
|
194 : |
jhr |
81 |
(* instantiate a type scheme, returning the argument meta variables and the resulting type.
|
195 : |
|
|
* Note that we assume that the scheme is closed.
|
196 : |
|
|
*)
|
197 : |
|
|
fun instantiate ([], ty) = ([], ty)
|
198 : |
|
|
| instantiate (mvs, ty) = let
|
199 : |
|
|
fun instantiateVar (mv, (mvs, env)) = let
|
200 : |
|
|
val mv' = MV.copy mv
|
201 : |
|
|
in
|
202 : |
|
|
(mv'::mvs, MV.Map.insert(env, mv, mv'))
|
203 : |
|
|
end
|
204 : |
|
|
val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs
|
205 : |
jhr |
82 |
fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
|
206 : |
|
|
of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
|
207 : |
|
|
| _ => raise Fail "impossible"
|
208 : |
|
|
(* end case *))
|
209 : |
|
|
| iDiff diff = diff
|
210 : |
|
|
fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
|
211 : |
|
|
of SOME(Ty.DIM dv) => Ty.DimVar dv
|
212 : |
|
|
| _ => raise Fail "impossible"
|
213 : |
|
|
(* end case *))
|
214 : |
|
|
| iDim dim = dim
|
215 : |
|
|
fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
|
216 : |
|
|
of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
|
217 : |
|
|
| _ => raise Fail "impossible"
|
218 : |
|
|
(* end case *))
|
219 : |
|
|
| iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
|
220 : |
|
|
| iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
|
221 : |
jhr |
81 |
fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
|
222 : |
|
|
of SOME(Ty.TYPE tv) => Ty.T_Var tv
|
223 : |
|
|
| _ => raise Fail "impossible"
|
224 : |
|
|
(* end case *))
|
225 : |
jhr |
82 |
| ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
|
226 : |
|
|
| ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
|
227 : |
|
|
| ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
|
228 : |
|
|
| ity (Ty.T_Field{diff, dim, shape}) =
|
229 : |
|
|
Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
|
230 : |
|
|
| ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
|
231 : |
|
|
| ity ty = ty
|
232 : |
jhr |
81 |
in
|
233 : |
jhr |
82 |
(mvs, ity ty)
|
234 : |
jhr |
81 |
end
|
235 : |
|
|
|
236 : |
jhr |
80 |
end
|