1 : |
jhr |
80 |
(* util.sml
|
2 : |
|
|
*
|
3 : |
|
|
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
|
4 : |
|
|
* All rights reserved.
|
5 : |
|
|
*
|
6 : |
|
|
* Utilities for typechecking
|
7 : |
|
|
*)
|
8 : |
|
|
|
9 : |
|
|
structure Util =
|
10 : |
|
|
struct
|
11 : |
|
|
|
12 : |
|
|
structure Ty = Types
|
13 : |
jhr |
81 |
structure MV = MetaVar
|
14 : |
jhr |
80 |
|
15 : |
|
|
(* prune out instantiated meta variables from a type *)
|
16 : |
|
|
fun prune ty = let
|
17 : |
|
|
fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
|
18 : |
|
|
of NONE => ty
|
19 : |
|
|
| SOME ty => prune' ty
|
20 : |
|
|
(* end case *))
|
21 : |
|
|
| prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
|
22 : |
|
|
| prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
|
23 : |
|
|
| prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
|
24 : |
|
|
dim = pruneDim dim,
|
25 : |
|
|
shape = pruneShape shape
|
26 : |
|
|
}
|
27 : |
|
|
| prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
|
28 : |
|
|
diff = pruneDiff diff,
|
29 : |
|
|
dim = pruneDim dim,
|
30 : |
|
|
shape = pruneShape shape
|
31 : |
|
|
}
|
32 : |
jhr |
81 |
| prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2)
|
33 : |
jhr |
80 |
| prune' ty = ty
|
34 : |
|
|
in
|
35 : |
|
|
prune' ty
|
36 : |
|
|
end
|
37 : |
|
|
|
38 : |
jhr |
81 |
and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
|
39 : |
|
|
case pruneDiff diff
|
40 : |
|
|
of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
|
41 : |
|
|
| Ty.DiffConst i' => Ty.DiffConst(i+i')
|
42 : |
|
|
(* end case *))
|
43 : |
|
|
| pruneDiff diff = diff
|
44 : |
|
|
|
45 : |
|
|
and pruneDim dim = (case dim
|
46 : |
|
|
of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
|
47 : |
|
|
| dim => dim
|
48 : |
|
|
(* end case *))
|
49 : |
|
|
|
50 : |
|
|
and pruneShape shape = (case shape
|
51 : |
|
|
of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
|
52 : |
|
|
| Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim)
|
53 : |
|
|
| _ => shape
|
54 : |
|
|
(* end case *))
|
55 : |
|
|
|
56 : |
|
|
(* FIXME: what about the bounds? *)
|
57 : |
|
|
fun matchDiff (diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)
|
58 : |
|
|
of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
|
59 : |
|
|
| (Ty.DiffConst k, Ty.DiffVar(Ty.DfV{bind, bound, ...}, i)) => let
|
60 : |
|
|
val k' = k+i
|
61 : |
|
|
in
|
62 : |
|
|
if k' < 0 then false
|
63 : |
|
|
else (bind := SOME(Ty.DiffConst k'); true)
|
64 : |
|
|
end
|
65 : |
|
|
| (Ty.DiffVar(Ty.DfV{bind, bound, ...}, i), Ty.DiffConst k) => let
|
66 : |
|
|
val k' = k+i
|
67 : |
|
|
in
|
68 : |
|
|
if k' < 0 then false
|
69 : |
|
|
else (bind := SOME(Ty.DiffConst k'); true)
|
70 : |
|
|
end
|
71 : |
|
|
| (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
|
72 : |
|
|
(* end case *))
|
73 : |
|
|
|
74 : |
|
|
fun matchDim (dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)
|
75 : |
|
|
of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
|
76 : |
|
|
| (Ty.DimVar(Ty.DV{bind, ...}), dim2) => (bind := SOME dim2; true)
|
77 : |
|
|
| (dim1, Ty.DimVar(Ty.DV{bind, ...})) => (bind := SOME dim1; true)
|
78 : |
|
|
(* end case *))
|
79 : |
|
|
|
80 : |
|
|
fun matchShape (shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)
|
81 : |
|
|
of (Ty.Shape dd1, Ty.Shape dd2) => let
|
82 : |
|
|
fun chk ([], []) = true
|
83 : |
|
|
| chk (d1::dd1, d2::dd2) = matchDim(d1, d2) andalso chk (dd1, dd2)
|
84 : |
|
|
| chk _ = false
|
85 : |
|
|
in
|
86 : |
|
|
chk (dd1, dd2)
|
87 : |
|
|
end
|
88 : |
|
|
| (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
|
89 : |
|
|
fun chk ([], _) = false
|
90 : |
|
|
| chk ([d], revDD) =
|
91 : |
|
|
matchDim(d, d2) andalso matchShape(Ty.Shape(List.rev revDD), shape)
|
92 : |
|
|
| chk (d::dd, revDD) = chk(dd, d::revDD)
|
93 : |
|
|
in
|
94 : |
|
|
chk (dd, [])
|
95 : |
|
|
end
|
96 : |
|
|
| (Ty.ShapeVar(Ty.SV{bind, ...}), shape) => (bind := SOME shape; true)
|
97 : |
|
|
| (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
|
98 : |
|
|
matchDim(d1, d2) andalso matchShape(shape1, shape2)
|
99 : |
|
|
| (shape1, shape2) => matchShape(shape2, shape1)
|
100 : |
|
|
(* end case *))
|
101 : |
|
|
|
102 : |
|
|
(* QUESTION: do we need an occurs check? *)
|
103 : |
|
|
fun matchType (ty1, ty2) = let
|
104 : |
|
|
fun setBind (Ty.TV{bind=ref(SOME _), ...}, _) = raise Fail "prune fail"
|
105 : |
|
|
| setBind (Ty.TV{bind, ...}, ty) = bind := SOME ty
|
106 : |
|
|
fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =
|
107 : |
|
|
if Stamp.same(id1, id2)
|
108 : |
|
|
then ()
|
109 : |
|
|
else setBind (tv1, Ty.T_Var tv2)
|
110 : |
|
|
fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)
|
111 : |
|
|
| match (Ty.T_Var tv1, ty2) = (setBind(tv1, ty2); true)
|
112 : |
|
|
| match (ty1, Ty.T_Var tv2) = (setBind(tv2, ty2); true)
|
113 : |
jhr |
80 |
| match (Ty.T_Bool, Ty.T_Bool) = true
|
114 : |
|
|
| match (Ty.T_Int, Ty.T_Int) = true
|
115 : |
|
|
| match (Ty.T_String, Ty.T_String) = true
|
116 : |
jhr |
81 |
| match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (k1, k2)
|
117 : |
|
|
| match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (s1, s2)
|
118 : |
jhr |
80 |
| match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
|
119 : |
jhr |
81 |
matchDim (d1, d2) andalso matchShape(s1, s2)
|
120 : |
jhr |
80 |
| match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
|
121 : |
jhr |
81 |
matchDiff (k1, k2) andalso matchDim (d1, d2) andalso matchShape(s1, s2)
|
122 : |
|
|
| match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
|
123 : |
|
|
matchTypes (tys11, tys21) andalso match (ty12, ty22)
|
124 : |
jhr |
80 |
| match _ = false
|
125 : |
|
|
in
|
126 : |
|
|
match (prune ty1, prune ty2)
|
127 : |
|
|
end
|
128 : |
|
|
|
129 : |
jhr |
81 |
and matchTypes (tys1, tys2) = ListPair.allEq matchType (tys1, tys2)
|
130 : |
|
|
|
131 : |
|
|
(* instantiate a type scheme, returning the argument meta variables and the resulting type.
|
132 : |
|
|
* Note that we assume that the scheme is closed.
|
133 : |
|
|
*)
|
134 : |
|
|
fun instantiate ([], ty) = ([], ty)
|
135 : |
|
|
| instantiate (mvs, ty) = let
|
136 : |
|
|
fun instantiateVar (mv, (mvs, env)) = let
|
137 : |
|
|
val mv' = MV.copy mv
|
138 : |
|
|
in
|
139 : |
|
|
(mv'::mvs, MV.Map.insert(env, mv, mv'))
|
140 : |
|
|
end
|
141 : |
|
|
val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs
|
142 : |
jhr |
82 |
fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
|
143 : |
|
|
of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
|
144 : |
|
|
| _ => raise Fail "impossible"
|
145 : |
|
|
(* end case *))
|
146 : |
|
|
| iDiff diff = diff
|
147 : |
|
|
fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
|
148 : |
|
|
of SOME(Ty.DIM dv) => Ty.DimVar dv
|
149 : |
|
|
| _ => raise Fail "impossible"
|
150 : |
|
|
(* end case *))
|
151 : |
|
|
| iDim dim = dim
|
152 : |
|
|
fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
|
153 : |
|
|
of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
|
154 : |
|
|
| _ => raise Fail "impossible"
|
155 : |
|
|
(* end case *))
|
156 : |
|
|
| iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
|
157 : |
|
|
| iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
|
158 : |
jhr |
81 |
fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
|
159 : |
|
|
of SOME(Ty.TYPE tv) => Ty.T_Var tv
|
160 : |
|
|
| _ => raise Fail "impossible"
|
161 : |
|
|
(* end case *))
|
162 : |
jhr |
82 |
| ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
|
163 : |
|
|
| ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
|
164 : |
|
|
| ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
|
165 : |
|
|
| ity (Ty.T_Field{diff, dim, shape}) =
|
166 : |
|
|
Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
|
167 : |
|
|
| ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
|
168 : |
|
|
| ity ty = ty
|
169 : |
jhr |
81 |
in
|
170 : |
jhr |
82 |
(mvs, ity ty)
|
171 : |
jhr |
81 |
end
|
172 : |
|
|
|
173 : |
jhr |
80 |
end
|