10 |
struct |
struct |
11 |
|
|
12 |
structure Ty = Types |
structure Ty = Types |
13 |
|
structure MV = MetaVar |
14 |
|
|
15 |
(* prune out instantiated meta variables from a type *) |
(* prune out instantiated meta variables from a type *) |
16 |
fun prune ty = let |
fun prune ty = let |
|
fun pruneDiff (Ty.DiffVar(Dfv{bind=ref(SOME diff), ...}, i)) = ( |
|
|
case pruneDiff diff |
|
|
of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i') |
|
|
| Ty.DiffConst i' => Ty.DiffConst(i+i') |
|
|
(* end case *)) |
|
|
| prunDiff diff = diff |
|
|
fun pruneDim dim = (case dim |
|
|
of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim |
|
|
| dim => dim |
|
|
(* end case *)) |
|
|
fun pruneShape shape = (case shape |
|
|
of ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape |
|
|
| ShapeExt(shape, dim) => ShapeExt(pruneShape shape, pruneDim dim) |
|
|
| _ => shape |
|
|
(* end case *)) |
|
17 |
fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind |
fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind |
18 |
of NONE => ty |
of NONE => ty |
19 |
| SOME ty => prune' ty |
| SOME ty => prune' ty |
29 |
dim = pruneDim dim, |
dim = pruneDim dim, |
30 |
shape = pruneShape shape |
shape = pruneShape shape |
31 |
} |
} |
32 |
| prune' (Ty.T_Fun(tys1, tys2)) = Ty.T_Fun(List.map prune' tys1, List.map prune' tys2) |
| prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2) |
33 |
| prune' ty = ty |
| prune' ty = ty |
34 |
in |
in |
35 |
prune' ty |
prune' ty |
36 |
end |
end |
37 |
|
|
38 |
fun matchTypes (ty1, ty2) = let |
and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = ( |
39 |
fun match (Ty.T_Var tv1, Ty.T_Var tv2) = |
case pruneDiff diff |
40 |
| match (Ty.T_Var tv1, ty2) = |
of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i') |
41 |
| match (ty1, Ty.T_Var tv2) = |
| Ty.DiffConst i' => Ty.DiffConst(i+i') |
42 |
|
(* end case *)) |
43 |
|
| pruneDiff diff = diff |
44 |
|
|
45 |
|
and pruneDim dim = (case dim |
46 |
|
of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim |
47 |
|
| dim => dim |
48 |
|
(* end case *)) |
49 |
|
|
50 |
|
and pruneShape shape = (case shape |
51 |
|
of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape |
52 |
|
| Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim) |
53 |
|
| _ => shape |
54 |
|
(* end case *)) |
55 |
|
|
56 |
|
(* FIXME: what about the bounds? *) |
57 |
|
fun matchDiff (diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2) |
58 |
|
of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2) |
59 |
|
| (Ty.DiffConst k, Ty.DiffVar(Ty.DfV{bind, bound, ...}, i)) => let |
60 |
|
val k' = k+i |
61 |
|
in |
62 |
|
if k' < 0 then false |
63 |
|
else (bind := SOME(Ty.DiffConst k'); true) |
64 |
|
end |
65 |
|
| (Ty.DiffVar(Ty.DfV{bind, bound, ...}, i), Ty.DiffConst k) => let |
66 |
|
val k' = k+i |
67 |
|
in |
68 |
|
if k' < 0 then false |
69 |
|
else (bind := SOME(Ty.DiffConst k'); true) |
70 |
|
end |
71 |
|
| (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *) |
72 |
|
(* end case *)) |
73 |
|
|
74 |
|
fun matchDim (dim1, dim2) = (case (pruneDim dim1, pruneDim dim2) |
75 |
|
of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2) |
76 |
|
| (Ty.DimVar(Ty.DV{bind, ...}), dim2) => (bind := SOME dim2; true) |
77 |
|
| (dim1, Ty.DimVar(Ty.DV{bind, ...})) => (bind := SOME dim1; true) |
78 |
|
(* end case *)) |
79 |
|
|
80 |
|
fun matchShape (shape1, shape2) = (case (pruneShape shape1, pruneShape shape2) |
81 |
|
of (Ty.Shape dd1, Ty.Shape dd2) => let |
82 |
|
fun chk ([], []) = true |
83 |
|
| chk (d1::dd1, d2::dd2) = matchDim(d1, d2) andalso chk (dd1, dd2) |
84 |
|
| chk _ = false |
85 |
|
in |
86 |
|
chk (dd1, dd2) |
87 |
|
end |
88 |
|
| (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let |
89 |
|
fun chk ([], _) = false |
90 |
|
| chk ([d], revDD) = |
91 |
|
matchDim(d, d2) andalso matchShape(Ty.Shape(List.rev revDD), shape) |
92 |
|
| chk (d::dd, revDD) = chk(dd, d::revDD) |
93 |
|
in |
94 |
|
chk (dd, []) |
95 |
|
end |
96 |
|
| (Ty.ShapeVar(Ty.SV{bind, ...}), shape) => (bind := SOME shape; true) |
97 |
|
| (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) => |
98 |
|
matchDim(d1, d2) andalso matchShape(shape1, shape2) |
99 |
|
| (shape1, shape2) => matchShape(shape2, shape1) |
100 |
|
(* end case *)) |
101 |
|
|
102 |
|
(* QUESTION: do we need an occurs check? *) |
103 |
|
fun matchType (ty1, ty2) = let |
104 |
|
fun setBind (Ty.TV{bind=ref(SOME _), ...}, _) = raise Fail "prune fail" |
105 |
|
| setBind (Ty.TV{bind, ...}, ty) = bind := SOME ty |
106 |
|
fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) = |
107 |
|
if Stamp.same(id1, id2) |
108 |
|
then () |
109 |
|
else setBind (tv1, Ty.T_Var tv2) |
110 |
|
fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true) |
111 |
|
| match (Ty.T_Var tv1, ty2) = (setBind(tv1, ty2); true) |
112 |
|
| match (ty1, Ty.T_Var tv2) = (setBind(tv2, ty2); true) |
113 |
| match (Ty.T_Bool, Ty.T_Bool) = true |
| match (Ty.T_Bool, Ty.T_Bool) = true |
114 |
| match (Ty.T_Int, Ty.T_Int) = true |
| match (Ty.T_Int, Ty.T_Int) = true |
115 |
| match (Ty.T_String, Ty.T_String) = true |
| match (Ty.T_String, Ty.T_String) = true |
116 |
| match (Ty.T_Kernel d1, Ty.T_Kernel d2) = |
| match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (k1, k2) |
117 |
| match (Ty.T_Tensor s1, Ty.T_Tensor s2) = |
| match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (s1, s2) |
118 |
| match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) = |
| match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) = |
119 |
|
matchDim (d1, d2) andalso matchShape(s1, s2) |
120 |
| match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) = |
| match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) = |
121 |
| match (Ty.T_Fun(tys11, tys22), Ty.T_Fun(tys21, tys22)) = |
matchDiff (k1, k2) andalso matchDim (d1, d2) andalso matchShape(s1, s2) |
122 |
|
| match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) = |
123 |
|
matchTypes (tys11, tys21) andalso match (ty12, ty22) |
124 |
| match _ = false |
| match _ = false |
125 |
in |
in |
126 |
match (prune ty1, prune ty2) |
match (prune ty1, prune ty2) |
127 |
end |
end |
128 |
|
|
129 |
|
and matchTypes (tys1, tys2) = ListPair.allEq matchType (tys1, tys2) |
130 |
|
|
131 |
|
(* instantiate a type scheme, returning the argument meta variables and the resulting type. |
132 |
|
* Note that we assume that the scheme is closed. |
133 |
|
*) |
134 |
|
fun instantiate ([], ty) = ([], ty) |
135 |
|
| instantiate (mvs, ty) = let |
136 |
|
fun instantiateVar (mv, (mvs, env)) = let |
137 |
|
val mv' = MV.copy mv |
138 |
|
in |
139 |
|
(mv'::mvs, MV.Map.insert(env, mv, mv')) |
140 |
|
end |
141 |
|
val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs |
142 |
|
fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv) |
143 |
|
of SOME(Ty.TYPE tv) => Ty.T_Var tv |
144 |
|
| _ => raise Fail "impossible" |
145 |
|
(* end case *)) |
146 |
|
in |
147 |
|
raise Fail "unimplemented" |
148 |
|
end |
149 |
|
|
150 |
end |
end |