SCM Repository
Annotation of /trunk/src/compiler/typechecker/util.sml
Parent Directory
|
Revision Log
Revision 110 - (view) (download)
1 : | jhr | 80 | (* util.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | * Utilities for typechecking | ||
7 : | *) | ||
8 : | |||
9 : | jhr | 85 | structure Util : sig |
10 : | jhr | 80 | |
11 : | jhr | 85 | val matchType : Types.ty * Types.ty -> bool |
12 : | val matchTypes : Types.ty list * Types.ty list -> bool | ||
13 : | |||
14 : | val tryMatchType : Types.ty * Types.ty -> bool | ||
15 : | val tryMatchTypes : Types.ty list * Types.ty list -> bool | ||
16 : | |||
17 : | val instantiate : Types.scheme -> (Types.meta_var list * Types.ty) | ||
18 : | |||
19 : | end = struct | ||
20 : | |||
21 : | jhr | 80 | structure Ty = Types |
22 : | jhr | 81 | structure MV = MetaVar |
23 : | jhr | 96 | structure TU = TypeUtil |
24 : | jhr | 80 | |
25 : | jhr | 85 | (* a patch list tracks the meta variables that have been updated so that we can undo |
26 : | * the effects of unification when just testing for a possible type match. | ||
27 : | *) | ||
28 : | |||
29 : | fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = ( | ||
30 : | bind := SOME ty; | ||
31 : | pl := Ty.TYPE tv :: !pl) | ||
32 : | | bindTyVar _ = raise Fail "rebinding type variable" | ||
33 : | |||
34 : | fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = ( | ||
35 : | bind := SOME diff; | ||
36 : | pl := Ty.DIFF dv :: !pl) | ||
37 : | | bindDiffVar _ = raise Fail "rebinding differentiation variable" | ||
38 : | |||
39 : | fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = ( | ||
40 : | bind := SOME shape; | ||
41 : | pl := Ty.SHAPE sv :: !pl) | ||
42 : | | bindShapeVar _ = raise Fail "rebinding shape variable" | ||
43 : | |||
44 : | fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = ( | ||
45 : | bind := SOME dim; | ||
46 : | pl := Ty.DIM dv :: !pl) | ||
47 : | | bindDimVar _ = raise Fail "rebinding dimension variable" | ||
48 : | |||
49 : | fun undo pl = let | ||
50 : | fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE | ||
51 : | | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE | ||
52 : | | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE | ||
53 : | | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE | ||
54 : | in | ||
55 : | List.map undo1 (!pl) | ||
56 : | end | ||
57 : | |||
58 : | jhr | 81 | (* FIXME: what about the bounds? *) |
59 : | jhr | 96 | fun matchDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2) |
60 : | jhr | 81 | of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2) |
61 : | jhr | 85 | | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let |
62 : | jhr | 81 | val k' = k+i |
63 : | in | ||
64 : | if k' < 0 then false | ||
65 : | jhr | 85 | else (bindDiffVar(pl, dv, Ty.DiffConst k'); true) |
66 : | jhr | 81 | end |
67 : | jhr | 85 | | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let |
68 : | jhr | 81 | val k' = k+i |
69 : | in | ||
70 : | if k' < 0 then false | ||
71 : | jhr | 85 | else (bindDiffVar(pl, dv, Ty.DiffConst k'); true) |
72 : | jhr | 81 | end |
73 : | | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *) | ||
74 : | (* end case *)) | ||
75 : | |||
76 : | jhr | 96 | fun matchDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2) |
77 : | jhr | 81 | of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2) |
78 : | jhr | 85 | | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true) |
79 : | | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true) | ||
80 : | jhr | 81 | (* end case *)) |
81 : | |||
82 : | jhr | 96 | fun matchShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2) |
83 : | jhr | 81 | of (Ty.Shape dd1, Ty.Shape dd2) => let |
84 : | fun chk ([], []) = true | ||
85 : | jhr | 85 | | chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2) |
86 : | jhr | 81 | | chk _ = false |
87 : | in | ||
88 : | chk (dd1, dd2) | ||
89 : | end | ||
90 : | | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let | ||
91 : | fun chk ([], _) = false | ||
92 : | | chk ([d], revDD) = | ||
93 : | jhr | 85 | matchDim(pl, d, d2) andalso matchShape(pl, Ty.Shape(List.rev revDD), shape) |
94 : | jhr | 81 | | chk (d::dd, revDD) = chk(dd, d::revDD) |
95 : | in | ||
96 : | chk (dd, []) | ||
97 : | end | ||
98 : | jhr | 85 | | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true) |
99 : | jhr | 81 | | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) => |
100 : | jhr | 85 | matchDim(pl, d1, d2) andalso matchShape(pl, shape1, shape2) |
101 : | | (shape1, shape2) => matchShape(pl, shape2, shape1) | ||
102 : | jhr | 81 | (* end case *)) |
103 : | |||
104 : | (* QUESTION: do we need an occurs check? *) | ||
105 : | jhr | 85 | fun unifyType (pl, ty1, ty2) = let |
106 : | jhr | 81 | fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) = |
107 : | if Stamp.same(id1, id2) | ||
108 : | then () | ||
109 : | jhr | 85 | else bindTyVar (pl, tv1, Ty.T_Var tv2) |
110 : | jhr | 81 | fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true) |
111 : | jhr | 85 | | match (Ty.T_Var tv1, ty2) = (bindTyVar(pl, tv1, ty2); true) |
112 : | | match (ty1, Ty.T_Var tv2) = (bindTyVar(pl, tv2, ty2); true) | ||
113 : | jhr | 80 | | match (Ty.T_Bool, Ty.T_Bool) = true |
114 : | | match (Ty.T_Int, Ty.T_Int) = true | ||
115 : | | match (Ty.T_String, Ty.T_String) = true | ||
116 : | jhr | 85 | | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (pl, k1, k2) |
117 : | | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (pl, s1, s2) | ||
118 : | jhr | 80 | | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) = |
119 : | jhr | 85 | matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2) |
120 : | jhr | 80 | | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) = |
121 : | jhr | 85 | matchDiff (pl, k1, k2) andalso matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2) |
122 : | jhr | 81 | | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) = |
123 : | jhr | 85 | ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22) |
124 : | jhr | 80 | | match _ = false |
125 : | in | ||
126 : | jhr | 96 | match (TU.pruneHead ty1, TU.pruneHead ty2) |
127 : | jhr | 80 | end |
128 : | |||
129 : | jhr | 85 | fun matchTypes (tys1, tys2) = let |
130 : | val pl = ref[] | ||
131 : | in | ||
132 : | ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2) | ||
133 : | end | ||
134 : | jhr | 81 | |
135 : | jhr | 85 | fun matchType (ty1, ty2) = unifyType (ref[], ty1, ty2) |
136 : | |||
137 : | (* try to match types; if we fail, all meta-variable bindings are undone *) | ||
138 : | fun tryMatchType (ty1, ty2) = let | ||
139 : | val pl = ref[] | ||
140 : | in | ||
141 : | unifyType(pl, ty1, ty2) orelse (undo pl; false) | ||
142 : | end | ||
143 : | |||
144 : | (* try to match types; if we fail, all meta-variable bindings are undone *) | ||
145 : | fun tryMatchTypes (tys1, tys2) = let | ||
146 : | val pl = ref[] | ||
147 : | in | ||
148 : | ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2) | ||
149 : | orelse (undo pl; false) | ||
150 : | end | ||
151 : | |||
152 : | jhr | 96 | (* QUESTION: perhaps this function belongs in the TypeUtil module? *) |
153 : | jhr | 81 | (* instantiate a type scheme, returning the argument meta variables and the resulting type. |
154 : | * Note that we assume that the scheme is closed. | ||
155 : | *) | ||
156 : | fun instantiate ([], ty) = ([], ty) | ||
157 : | | instantiate (mvs, ty) = let | ||
158 : | fun instantiateVar (mv, (mvs, env)) = let | ||
159 : | val mv' = MV.copy mv | ||
160 : | in | ||
161 : | (mv'::mvs, MV.Map.insert(env, mv, mv')) | ||
162 : | end | ||
163 : | val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs | ||
164 : | jhr | 82 | fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k) |
165 : | of SOME(Ty.DIFF k) => Ty.DiffVar(k, i) | ||
166 : | | _ => raise Fail "impossible" | ||
167 : | (* end case *)) | ||
168 : | | iDiff diff = diff | ||
169 : | fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv) | ||
170 : | of SOME(Ty.DIM dv) => Ty.DimVar dv | ||
171 : | | _ => raise Fail "impossible" | ||
172 : | (* end case *)) | ||
173 : | | iDim dim = dim | ||
174 : | fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv) | ||
175 : | of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv | ||
176 : | | _ => raise Fail "impossible" | ||
177 : | (* end case *)) | ||
178 : | | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim) | ||
179 : | | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims) | ||
180 : | jhr | 81 | fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv) |
181 : | of SOME(Ty.TYPE tv) => Ty.T_Var tv | ||
182 : | | _ => raise Fail "impossible" | ||
183 : | (* end case *)) | ||
184 : | jhr | 82 | | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k) |
185 : | | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape) | ||
186 : | | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape} | ||
187 : | | ity (Ty.T_Field{diff, dim, shape}) = | ||
188 : | Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape} | ||
189 : | | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng) | ||
190 : | | ity ty = ty | ||
191 : | jhr | 81 | in |
192 : | jhr | 82 | (mvs, ity ty) |
193 : | jhr | 81 | end |
194 : | |||
195 : | jhr | 80 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |