Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/typechecker/util.sml
ViewVC logotype

Annotation of /trunk/src/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 85 - (view) (download)

1 : jhr 80 (* util.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Utilities for typechecking
7 :     *)
8 :    
9 : jhr 85 structure Util : sig
10 : jhr 80
11 : jhr 85 val prune : Types.ty -> Types.ty
12 :    
13 :     val matchType : Types.ty * Types.ty -> bool
14 :     val matchTypes : Types.ty list * Types.ty list -> bool
15 :    
16 :     val tryMatchType : Types.ty * Types.ty -> bool
17 :     val tryMatchTypes : Types.ty list * Types.ty list -> bool
18 :    
19 :     val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)
20 :    
21 :     end = struct
22 :    
23 : jhr 80 structure Ty = Types
24 : jhr 81 structure MV = MetaVar
25 : jhr 80
26 :     (* prune out instantiated meta variables from a type *)
27 :     fun prune ty = let
28 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
29 :     of NONE => ty
30 :     | SOME ty => prune' ty
31 :     (* end case *))
32 :     | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
33 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
34 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
35 :     dim = pruneDim dim,
36 :     shape = pruneShape shape
37 :     }
38 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
39 :     diff = pruneDiff diff,
40 :     dim = pruneDim dim,
41 :     shape = pruneShape shape
42 :     }
43 : jhr 81 | prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2)
44 : jhr 80 | prune' ty = ty
45 :     in
46 :     prune' ty
47 :     end
48 :    
49 : jhr 81 and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
50 :     case pruneDiff diff
51 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
52 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
53 :     (* end case *))
54 :     | pruneDiff diff = diff
55 :    
56 :     and pruneDim dim = (case dim
57 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
58 :     | dim => dim
59 :     (* end case *))
60 :    
61 :     and pruneShape shape = (case shape
62 :     of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
63 :     | Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim)
64 :     | _ => shape
65 :     (* end case *))
66 :    
67 : jhr 85 (* a patch list tracks the meta variables that have been updated so that we can undo
68 :     * the effects of unification when just testing for a possible type match.
69 :     *)
70 :    
71 :     fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = (
72 :     bind := SOME ty;
73 :     pl := Ty.TYPE tv :: !pl)
74 :     | bindTyVar _ = raise Fail "rebinding type variable"
75 :    
76 :     fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = (
77 :     bind := SOME diff;
78 :     pl := Ty.DIFF dv :: !pl)
79 :     | bindDiffVar _ = raise Fail "rebinding differentiation variable"
80 :    
81 :     fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = (
82 :     bind := SOME shape;
83 :     pl := Ty.SHAPE sv :: !pl)
84 :     | bindShapeVar _ = raise Fail "rebinding shape variable"
85 :    
86 :     fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = (
87 :     bind := SOME dim;
88 :     pl := Ty.DIM dv :: !pl)
89 :     | bindDimVar _ = raise Fail "rebinding dimension variable"
90 :    
91 :     fun undo pl = let
92 :     fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE
93 :     | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE
94 :     | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE
95 :     | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE
96 :     in
97 :     List.map undo1 (!pl)
98 :     end
99 :    
100 : jhr 81 (* FIXME: what about the bounds? *)
101 : jhr 85 fun matchDiff (pl, diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)
102 : jhr 81 of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
103 : jhr 85 | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
104 : jhr 81 val k' = k+i
105 :     in
106 :     if k' < 0 then false
107 : jhr 85 else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
108 : jhr 81 end
109 : jhr 85 | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let
110 : jhr 81 val k' = k+i
111 :     in
112 :     if k' < 0 then false
113 : jhr 85 else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
114 : jhr 81 end
115 :     | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
116 :     (* end case *))
117 :    
118 : jhr 85 fun matchDim (pl, dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)
119 : jhr 81 of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
120 : jhr 85 | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
121 :     | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
122 : jhr 81 (* end case *))
123 :    
124 : jhr 85 fun matchShape (pl, shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)
125 : jhr 81 of (Ty.Shape dd1, Ty.Shape dd2) => let
126 :     fun chk ([], []) = true
127 : jhr 85 | chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2)
128 : jhr 81 | chk _ = false
129 :     in
130 :     chk (dd1, dd2)
131 :     end
132 :     | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
133 :     fun chk ([], _) = false
134 :     | chk ([d], revDD) =
135 : jhr 85 matchDim(pl, d, d2) andalso matchShape(pl, Ty.Shape(List.rev revDD), shape)
136 : jhr 81 | chk (d::dd, revDD) = chk(dd, d::revDD)
137 :     in
138 :     chk (dd, [])
139 :     end
140 : jhr 85 | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
141 : jhr 81 | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
142 : jhr 85 matchDim(pl, d1, d2) andalso matchShape(pl, shape1, shape2)
143 :     | (shape1, shape2) => matchShape(pl, shape2, shape1)
144 : jhr 81 (* end case *))
145 :    
146 :     (* QUESTION: do we need an occurs check? *)
147 : jhr 85 fun unifyType (pl, ty1, ty2) = let
148 : jhr 81 fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =
149 :     if Stamp.same(id1, id2)
150 :     then ()
151 : jhr 85 else bindTyVar (pl, tv1, Ty.T_Var tv2)
152 : jhr 81 fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)
153 : jhr 85 | match (Ty.T_Var tv1, ty2) = (bindTyVar(pl, tv1, ty2); true)
154 :     | match (ty1, Ty.T_Var tv2) = (bindTyVar(pl, tv2, ty2); true)
155 : jhr 80 | match (Ty.T_Bool, Ty.T_Bool) = true
156 :     | match (Ty.T_Int, Ty.T_Int) = true
157 :     | match (Ty.T_String, Ty.T_String) = true
158 : jhr 85 | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (pl, k1, k2)
159 :     | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (pl, s1, s2)
160 : jhr 80 | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
161 : jhr 85 matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)
162 : jhr 80 | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
163 : jhr 85 matchDiff (pl, k1, k2) andalso matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)
164 : jhr 81 | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
165 : jhr 85 ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
166 : jhr 80 | match _ = false
167 :     in
168 :     match (prune ty1, prune ty2)
169 :     end
170 :    
171 : jhr 85 fun matchTypes (tys1, tys2) = let
172 :     val pl = ref[]
173 :     in
174 :     ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
175 :     end
176 : jhr 81
177 : jhr 85 fun matchType (ty1, ty2) = unifyType (ref[], ty1, ty2)
178 :    
179 :     (* try to match types; if we fail, all meta-variable bindings are undone *)
180 :     fun tryMatchType (ty1, ty2) = let
181 :     val pl = ref[]
182 :     in
183 :     unifyType(pl, ty1, ty2) orelse (undo pl; false)
184 :     end
185 :    
186 :     (* try to match types; if we fail, all meta-variable bindings are undone *)
187 :     fun tryMatchTypes (tys1, tys2) = let
188 :     val pl = ref[]
189 :     in
190 :     ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
191 :     orelse (undo pl; false)
192 :     end
193 :    
194 : jhr 81 (* instantiate a type scheme, returning the argument meta variables and the resulting type.
195 :     * Note that we assume that the scheme is closed.
196 :     *)
197 :     fun instantiate ([], ty) = ([], ty)
198 :     | instantiate (mvs, ty) = let
199 :     fun instantiateVar (mv, (mvs, env)) = let
200 :     val mv' = MV.copy mv
201 :     in
202 :     (mv'::mvs, MV.Map.insert(env, mv, mv'))
203 :     end
204 :     val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs
205 : jhr 82 fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
206 :     of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
207 :     | _ => raise Fail "impossible"
208 :     (* end case *))
209 :     | iDiff diff = diff
210 :     fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
211 :     of SOME(Ty.DIM dv) => Ty.DimVar dv
212 :     | _ => raise Fail "impossible"
213 :     (* end case *))
214 :     | iDim dim = dim
215 :     fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
216 :     of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
217 :     | _ => raise Fail "impossible"
218 :     (* end case *))
219 :     | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
220 :     | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
221 : jhr 81 fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
222 :     of SOME(Ty.TYPE tv) => Ty.T_Var tv
223 :     | _ => raise Fail "impossible"
224 :     (* end case *))
225 : jhr 82 | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
226 :     | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
227 :     | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
228 :     | ity (Ty.T_Field{diff, dim, shape}) =
229 :     Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
230 :     | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
231 :     | ity ty = ty
232 : jhr 81 in
233 : jhr 82 (mvs, ity ty)
234 : jhr 81 end
235 :    
236 : jhr 80 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0