Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/typechecker/unify.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/typechecker/unify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3423 - (view) (download)

1 : jhr 3402 (* unify.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure Unify : sig
10 :    
11 :     (* when matching two types (ty1 and ty2), there are three possible outcomes:
12 :     * EQ -- types are equal
13 :     * COERCE -- ty2 can be coerced to match ty1 (e.g., int -> float, fixed seq -> dynamic seq)
14 :     * FAIL -- types do not match
15 :     *)
16 :     datatype match = EQ | COERCE | FAIL
17 :    
18 :     val matchType : Types.ty * Types.ty -> match
19 :    
20 :     val tryMatchType : Types.ty * Types.ty -> match
21 :    
22 :     (* attempt to match a list of parameter types with a list of typed arguments. Return
23 :     * the arguments with any required coercions, or NONE on failure.
24 :     *)
25 :     val matchArgs : Types.ty list * AST.expr list * Types.ty list -> AST.expr list option
26 :     val tryMatchArgs : Types.ty list * AST.expr list * Types.ty list -> AST.expr list option
27 :    
28 :     val equalType : Types.ty * Types.ty -> bool
29 :     val equalTypes : Types.ty list * Types.ty list -> bool
30 :    
31 :     val tryEqualType : Types.ty * Types.ty -> bool
32 :     val tryEqualTypes : Types.ty list * Types.ty list -> bool
33 :    
34 :    
35 :     val equalDim : Types.dim * Types.dim -> bool
36 :    
37 :     end = struct
38 :    
39 :     structure Ty = Types
40 :     structure MV = MetaVar
41 :     structure TU = TypeUtil
42 :    
43 :     datatype match = EQ | COERCE | FAIL
44 :    
45 :     (* a patch list tracks the meta variables that have been updated so that we can undo
46 :     * the effects of unification when just testing for a possible type match.
47 :     *)
48 :    
49 :     fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = (
50 :     bind := SOME ty;
51 :     pl := Ty.TYPE tv :: !pl)
52 :     | bindTyVar _ = raise Fail "rebinding type variable"
53 :    
54 :     fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = (
55 :     bind := SOME diff;
56 :     pl := Ty.DIFF dv :: !pl)
57 :     | bindDiffVar _ = raise Fail "rebinding differentiation variable"
58 :    
59 :     fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = (
60 :     bind := SOME shape;
61 :     pl := Ty.SHAPE sv :: !pl)
62 :     | bindShapeVar _ = raise Fail "rebinding shape variable"
63 :    
64 :     fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = (
65 :     bind := SOME dim;
66 :     pl := Ty.DIM dv :: !pl)
67 :     | bindDimVar _ = raise Fail "rebinding dimension variable"
68 :    
69 :     fun undo pl = let
70 :     fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE
71 :     | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE
72 :     | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE
73 :     | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE
74 :     in
75 :     List.map undo1 (!pl)
76 :     end
77 :    
78 :     (* FIXME: what about the bounds? *)
79 :     fun equalDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
80 :     of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
81 :     | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
82 :     val k' = k+i
83 :     in
84 :     if k' < 0 then false
85 :     else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
86 :     end
87 :     | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let
88 :     val k' = k+i
89 :     in
90 :     if k' < 0 then false
91 :     else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
92 :     end
93 : jhr 3423 | (Ty.DiffVar(dv1, 0), diff2 as Ty.DiffVar(dv2, i2)) => (
94 :     bindDiffVar(pl, dv1, diff2); true)
95 :     | (diff1 as Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, 0)) => (
96 :     bindDiffVar(pl, dv2, diff1); true)
97 :     | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) =>
98 :     raise Fail(concat[
99 :     "equalDiff(", TU.diffToString diff1, ", ",
100 :     TU.diffToString diff2, ") unimplemented"
101 :     ])
102 : jhr 3402 (* end case *))
103 :    
104 :     (* match two differentiation constants where the first is allowed to be less than the second *)
105 : jhr 3423 fun matchDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
106 : jhr 3402 of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 <= k2)
107 : jhr 3423 | (Ty.DiffVar(dv1, k1), Ty.DiffVar(dv2, k2)) =>
108 :     MV.sameDiffVar(dv1, dv2) andalso (k1 <= k2)
109 :     | _ => equalDiff (pl, diff1, diff2) (* force equality *)
110 : jhr 3402 (* end case *))
111 :    
112 :     fun equalDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2)
113 :     of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
114 :     | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
115 :     | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
116 :     (* end case *))
117 :    
118 :     fun equalShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2)
119 :     of (Ty.Shape dd1, Ty.Shape dd2) => let
120 :     fun chk ([], []) = true
121 :     | chk (d1::dd1, d2::dd2) = equalDim(pl, d1, d2) andalso chk (dd1, dd2)
122 :     | chk _ = false
123 :     in
124 :     chk (dd1, dd2)
125 :     end
126 :     | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
127 :     fun chk ([], _) = false
128 :     | chk ([d], revDD) =
129 :     equalDim(pl, d, d2) andalso equalShape(pl, Ty.Shape(List.rev revDD), shape)
130 :     | chk (d::dd, revDD) = chk(dd, d::revDD)
131 :     in
132 :     chk (dd, [])
133 :     end
134 :     | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
135 :     | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
136 :     equalDim(pl, d1, d2) andalso equalShape(pl, shape1, shape2)
137 :     | (shape1, shape2) => equalShape(pl, shape2, shape1)
138 :     (* end case *))
139 :    
140 :     (* QUESTION: do we need an occurs check? *)
141 :     fun unifyType (pl, ty1, ty2) = let
142 :     fun matchVar (tv1 as Ty.TV{id=id1, bind=b1}, tv2 as Ty.TV{id=id2, bind=b2}) =
143 :     if Stamp.same(id1, id2)
144 :     then true
145 :     else (case (!b1, !b2)
146 :     of (SOME ty1, SOME ty2) => match(ty1, ty2)
147 :     | (SOME ty1, NONE) => (bindTyVar (pl, tv2, ty1); true)
148 :     | (NONE, SOME ty2) => (bindTyVar (pl, tv1, ty2); true)
149 :     | (NONE, NONE) => (bindTyVar (pl, tv1, Ty.T_Var tv2); true)
150 :     (* end case *))
151 :     and matchVarTy (tv as Ty.TV{bind, ...}, ty) = (case !bind
152 :     of NONE => (bindTyVar(pl, tv, ty); true)
153 :     | SOME ty' => match(ty', ty)
154 :     (* end case *))
155 :     and match (Ty.T_Var tv1, Ty.T_Var tv2) = matchVar(tv1, tv2)
156 :     | match (Ty.T_Var tv1, ty2) = matchVarTy(tv1, ty2)
157 :     | match (ty1, Ty.T_Var tv2) = matchVarTy(tv2, ty1)
158 :     | match (Ty.T_Bool, Ty.T_Bool) = true
159 :     | match (Ty.T_Int, Ty.T_Int) = true
160 :     | match (Ty.T_String, Ty.T_String) = true
161 :     | match (Ty.T_Sequence(ty1, NONE), Ty.T_Sequence(ty2, NONE)) = match(ty1, ty2)
162 :     | match (Ty.T_Sequence(ty1, SOME d1), Ty.T_Sequence(ty2, SOME d2)) =
163 :     equalDim(pl, d1, d2) andalso match(ty1, ty2)
164 :     | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = equalDiff (pl, k1, k2)
165 :     | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = equalShape (pl, s1, s2)
166 :     | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
167 :     equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
168 :     | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
169 :     equalDiff (pl, k1, k2) andalso equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
170 :     | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
171 :     ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
172 : jhr 3417 | match (Ty.T_Error, _) = true
173 :     | match (_, Ty.T_Error) = true
174 : jhr 3402 | match _ = false
175 :     in
176 :     match (TU.pruneHead ty1, TU.pruneHead ty2)
177 :     end
178 :    
179 : jhr 3423 (* try to unify ty1 and ty2, where we allow ty2 to be coerced to ty1 *)
180 : jhr 3402 fun unifyTypeWithCoercion (pl, ty1, ty2) = (case (TU.pruneHead ty1, TU.pruneHead ty2)
181 :     of (Ty.T_Tensor shp, Ty.T_Int) =>
182 :     if equalShape (pl, Ty.Shape[], shp) then COERCE else FAIL
183 :     | (Ty.T_Sequence(ty1, NONE), Ty.T_Sequence(ty2, SOME _)) =>
184 :     if unifyType(pl, ty1, ty2) then COERCE else FAIL
185 :     | (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =>
186 :     if unifyType(pl, ty1, ty2)
187 :     then EQ
188 : jhr 3423 else if matchDiff (pl, k1, k2) andalso equalDim(pl, d1, d2)
189 : jhr 3402 andalso equalShape(pl, s1, s2)
190 :     then COERCE
191 :     else FAIL
192 :     | (ty1, ty2) => if unifyType(pl, ty1, ty2) then EQ else FAIL
193 :     (* end case *))
194 :     (* +DEBUG *
195 :     val unifyTypeWithCoercion = fn (pl, ty1, ty2) => let
196 :     val res = unifyTypeWithCoercion (pl, ty1, ty2)
197 :     val res' = (case res of EQ => "EQ" | COERCE => "COERCE" | FAIL => "FAIL")
198 :     in
199 :     print(concat["unifyTypeWithCoercion (_, ", TU.toString ty1, ", ", TU.toString ty2, ") = ", res', "\n"]);
200 :     res
201 :     end
202 :     * -DEBUG *)
203 :    
204 :     fun equalTypes (tys1, tys2) = let
205 :     val pl = ref[]
206 :     in
207 :     ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
208 :     end
209 :    
210 :     fun equalType (ty1, ty2) = unifyType (ref[], ty1, ty2)
211 :    
212 :     (* try to match types; if we fail, all meta-variable bindings are undone *)
213 :     fun tryEqualType (ty1, ty2) = let
214 :     val pl = ref[]
215 :     in
216 :     unifyType(pl, ty1, ty2) orelse (undo pl; false)
217 :     end
218 :    
219 :     (* try to unify two types to equality; if we fail, all meta-variable bindings are undone *)
220 :     fun tryEqualTypes (tys1, tys2) = let
221 :     val pl = ref[]
222 :     in
223 :     ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
224 :     orelse (undo pl; false)
225 :     end
226 :    
227 :     fun matchType (ty1, ty2) = unifyTypeWithCoercion (ref[], ty1, ty2)
228 :    
229 :     (* try to unify two type lists to equality; if we fail, all meta-variable bindings are undone *)
230 :     fun tryMatchType (ty1, ty2) = let
231 :     val pl = ref[]
232 :     in
233 :     case unifyTypeWithCoercion (pl, ty1, ty2)
234 :     of FAIL => (undo pl; FAIL)
235 :     | result => result
236 :     (* end case *)
237 :     end
238 :    
239 :     (* attempt to match a list of parameter types with a list of typed arguments. Return
240 :     * the arguments with any required coercions, or NONE on failure.
241 :     *)
242 :     local
243 :     fun matchArgs' (pl, paramTys, args, argTys) = let
244 :     fun matchArgTys ([], [], [], args') = SOME(List.rev args')
245 :     | matchArgTys (ty1::tys1, arg::args, ty2::tys2, args') = (
246 :     case unifyTypeWithCoercion (pl, ty1, ty2)
247 :     of EQ => matchArgTys (tys1, args, tys2, arg::args')
248 :     | COERCE => matchArgTys (tys1, args, tys2, AST.E_Coerce{srcTy=ty2, dstTy=ty1, e=arg}::args')
249 :     | _ => (undo pl; NONE)
250 :     (* end case *))
251 :     | matchArgTys _ = NONE
252 :     in
253 :     matchArgTys (paramTys, args, argTys, [])
254 :     end
255 :     in
256 :     fun matchArgs (paramTys, args, argTys) = matchArgs' (ref[], paramTys, args, argTys)
257 :     fun tryMatchArgs (paramTys, args, argTys) = let
258 :     val pl = ref[]
259 :     in
260 :     case matchArgs' (ref[], paramTys, args, argTys)
261 :     of NONE => (undo pl; NONE)
262 :     | someResult => someResult
263 :     (* end case *)
264 :     end
265 :     end
266 :    
267 :     (* rebind equalDim without patch-list argument *)
268 :     val equalDim = fn (d1, d2) => equalDim(ref [], d1, d2)
269 :    
270 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0