8 |
|
|
9 |
structure Util : sig |
structure Util : sig |
10 |
|
|
|
val prune : Types.ty -> Types.ty |
|
|
|
|
11 |
val matchType : Types.ty * Types.ty -> bool |
val matchType : Types.ty * Types.ty -> bool |
12 |
val matchTypes : Types.ty list * Types.ty list -> bool |
val matchTypes : Types.ty list * Types.ty list -> bool |
13 |
|
|
20 |
|
|
21 |
structure Ty = Types |
structure Ty = Types |
22 |
structure MV = MetaVar |
structure MV = MetaVar |
23 |
|
structure TU = TypeUtil |
|
(* prune out instantiated meta variables from a type *) |
|
|
fun prune ty = let |
|
|
fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind |
|
|
of NONE => ty |
|
|
| SOME ty => prune' ty |
|
|
(* end case *)) |
|
|
| prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff) |
|
|
| prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape) |
|
|
| prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{ |
|
|
dim = pruneDim dim, |
|
|
shape = pruneShape shape |
|
|
} |
|
|
| prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{ |
|
|
diff = pruneDiff diff, |
|
|
dim = pruneDim dim, |
|
|
shape = pruneShape shape |
|
|
} |
|
|
| prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2) |
|
|
| prune' ty = ty |
|
|
in |
|
|
prune' ty |
|
|
end |
|
|
|
|
|
and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = ( |
|
|
case pruneDiff diff |
|
|
of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i') |
|
|
| Ty.DiffConst i' => Ty.DiffConst(i+i') |
|
|
(* end case *)) |
|
|
| pruneDiff diff = diff |
|
|
|
|
|
and pruneDim dim = (case dim |
|
|
of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim |
|
|
| dim => dim |
|
|
(* end case *)) |
|
|
|
|
|
and pruneShape shape = (case shape |
|
|
of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape |
|
|
| Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim) |
|
|
| _ => shape |
|
|
(* end case *)) |
|
24 |
|
|
25 |
(* a patch list tracks the meta variables that have been updated so that we can undo |
(* a patch list tracks the meta variables that have been updated so that we can undo |
26 |
* the effects of unification when just testing for a possible type match. |
* the effects of unification when just testing for a possible type match. |
56 |
end |
end |
57 |
|
|
58 |
(* FIXME: what about the bounds? *) |
(* FIXME: what about the bounds? *) |
59 |
fun matchDiff (pl, diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2) |
fun matchDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2) |
60 |
of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2) |
of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2) |
61 |
| (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let |
| (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let |
62 |
val k' = k+i |
val k' = k+i |
73 |
| (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *) |
| (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *) |
74 |
(* end case *)) |
(* end case *)) |
75 |
|
|
76 |
fun matchDim (pl, dim1, dim2) = (case (pruneDim dim1, pruneDim dim2) |
fun matchDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2) |
77 |
of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2) |
of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2) |
78 |
| (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true) |
| (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true) |
79 |
| (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true) |
| (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true) |
80 |
(* end case *)) |
(* end case *)) |
81 |
|
|
82 |
fun matchShape (pl, shape1, shape2) = (case (pruneShape shape1, pruneShape shape2) |
fun matchShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2) |
83 |
of (Ty.Shape dd1, Ty.Shape dd2) => let |
of (Ty.Shape dd1, Ty.Shape dd2) => let |
84 |
fun chk ([], []) = true |
fun chk ([], []) = true |
85 |
| chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2) |
| chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2) |
123 |
ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22) |
ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22) |
124 |
| match _ = false |
| match _ = false |
125 |
in |
in |
126 |
match (prune ty1, prune ty2) |
match (TU.pruneHead ty1, TU.pruneHead ty2) |
127 |
end |
end |
128 |
|
|
129 |
fun matchTypes (tys1, tys2) = let |
fun matchTypes (tys1, tys2) = let |
149 |
orelse (undo pl; false) |
orelse (undo pl; false) |
150 |
end |
end |
151 |
|
|
152 |
|
(* QUESTION: perhaps this function belongs in the TypeUtil module? *) |
153 |
(* instantiate a type scheme, returning the argument meta variables and the resulting type. |
(* instantiate a type scheme, returning the argument meta variables and the resulting type. |
154 |
* Note that we assume that the scheme is closed. |
* Note that we assume that the scheme is closed. |
155 |
*) |
*) |