Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/typechecker/util.sml
ViewVC logotype

Diff of /trunk/src/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 84, Wed May 26 18:51:51 2010 UTC revision 85, Wed May 26 19:51:10 2010 UTC
# Line 6  Line 6 
6   * Utilities for typechecking   * Utilities for typechecking
7   *)   *)
8    
9  structure Util =  structure Util : sig
10    struct  
11        val prune : Types.ty -> Types.ty
12    
13        val matchType : Types.ty * Types.ty -> bool
14        val matchTypes : Types.ty list * Types.ty list -> bool
15    
16        val tryMatchType : Types.ty * Types.ty -> bool
17        val tryMatchTypes : Types.ty list * Types.ty list -> bool
18    
19        val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)
20    
21      end = struct
22    
23      structure Ty = Types      structure Ty = Types
24      structure MV = MetaVar      structure MV = MetaVar
# Line 53  Line 64 
64              | _ => shape              | _ => shape
65            (* end case *))            (* end case *))
66    
67      (* a patch list tracks the meta variables that have been updated so that we can undo
68       * the effects of unification when just testing for a possible type match.
69       *)
70    
71        fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = (
72              bind := SOME ty;
73              pl := Ty.TYPE tv :: !pl)
74          | bindTyVar _ = raise Fail "rebinding type variable"
75    
76        fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = (
77              bind := SOME diff;
78              pl := Ty.DIFF dv :: !pl)
79          | bindDiffVar _ = raise Fail "rebinding differentiation variable"
80    
81        fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = (
82              bind := SOME shape;
83              pl := Ty.SHAPE sv :: !pl)
84          | bindShapeVar _ = raise Fail "rebinding shape variable"
85    
86        fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = (
87              bind := SOME dim;
88              pl := Ty.DIM dv :: !pl)
89          | bindDimVar _ = raise Fail "rebinding dimension variable"
90    
91        fun undo pl = let
92              fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE
93                | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE
94                | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE
95                | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE
96              in
97                List.map undo1 (!pl)
98              end
99    
100  (* FIXME: what about the bounds? *)  (* FIXME: what about the bounds? *)
101      fun matchDiff (diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)      fun matchDiff (pl, diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)
102             of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)             of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
103              | (Ty.DiffConst k, Ty.DiffVar(Ty.DfV{bind, bound, ...}, i)) => let              | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
104                  val k' = k+i                  val k' = k+i
105                  in                  in
106                    if k' < 0 then false                    if k' < 0 then false
107                    else (bind := SOME(Ty.DiffConst k'); true)                    else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
108                  end                  end
109              | (Ty.DiffVar(Ty.DfV{bind, bound, ...}, i), Ty.DiffConst k) => let              | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let
110                  val k' = k+i                  val k' = k+i
111                  in                  in
112                    if k' < 0 then false                    if k' < 0 then false
113                    else (bind := SOME(Ty.DiffConst k'); true)                    else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
114                  end                  end
115              | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)              | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
116            (* end case *))            (* end case *))
117    
118      fun matchDim (dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)      fun matchDim (pl, dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)
119             of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)             of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
120              | (Ty.DimVar(Ty.DV{bind, ...}), dim2) => (bind := SOME dim2; true)              | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
121              | (dim1, Ty.DimVar(Ty.DV{bind, ...})) => (bind := SOME dim1; true)              | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
122            (* end case *))            (* end case *))
123    
124      fun matchShape (shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)      fun matchShape (pl, shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)
125             of (Ty.Shape dd1, Ty.Shape dd2) => let             of (Ty.Shape dd1, Ty.Shape dd2) => let
126                  fun chk ([], []) = true                  fun chk ([], []) = true
127                    | chk (d1::dd1, d2::dd2) = matchDim(d1, d2) andalso chk (dd1, dd2)                    | chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2)
128                    | chk _ = false                    | chk _ = false
129                  in                  in
130                    chk (dd1, dd2)                    chk (dd1, dd2)
# Line 88  Line 132 
132              | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let              | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
133                  fun chk ([], _) = false                  fun chk ([], _) = false
134                    | chk ([d], revDD) =                    | chk ([d], revDD) =
135                        matchDim(d, d2) andalso matchShape(Ty.Shape(List.rev revDD), shape)                        matchDim(pl, d, d2) andalso matchShape(pl, Ty.Shape(List.rev revDD), shape)
136                    | chk (d::dd, revDD) = chk(dd, d::revDD)                    | chk (d::dd, revDD) = chk(dd, d::revDD)
137                  in                  in
138                    chk (dd, [])                    chk (dd, [])
139                  end                  end
140              | (Ty.ShapeVar(Ty.SV{bind, ...}), shape) => (bind := SOME shape; true)              | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
141              | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>              | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
142                  matchDim(d1, d2) andalso matchShape(shape1, shape2)                  matchDim(pl, d1, d2) andalso matchShape(pl, shape1, shape2)
143              | (shape1, shape2) => matchShape(shape2, shape1)              | (shape1, shape2) => matchShape(pl, shape2, shape1)
144          (* end case *))          (* end case *))
145    
146  (* QUESTION: do we need an occurs check? *)  (* QUESTION: do we need an occurs check? *)
147      fun matchType (ty1, ty2) = let      fun unifyType (pl, ty1, ty2) = let
           fun setBind (Ty.TV{bind=ref(SOME _), ...}, _) = raise Fail "prune fail"  
             | setBind (Ty.TV{bind, ...}, ty) = bind := SOME ty  
148            fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =            fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =
149                  if Stamp.same(id1, id2)                  if Stamp.same(id1, id2)
150                    then ()                    then ()
151                    else setBind (tv1, Ty.T_Var tv2)                    else bindTyVar (pl, tv1, Ty.T_Var tv2)
152            fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)            fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)
153              | match (Ty.T_Var tv1, ty2) = (setBind(tv1, ty2); true)              | match (Ty.T_Var tv1, ty2) = (bindTyVar(pl, tv1, ty2); true)
154              | match (ty1, Ty.T_Var tv2) = (setBind(tv2, ty2); true)              | match (ty1, Ty.T_Var tv2) = (bindTyVar(pl, tv2, ty2); true)
155              | match (Ty.T_Bool, Ty.T_Bool) = true              | match (Ty.T_Bool, Ty.T_Bool) = true
156              | match (Ty.T_Int, Ty.T_Int) = true              | match (Ty.T_Int, Ty.T_Int) = true
157              | match (Ty.T_String, Ty.T_String) = true              | match (Ty.T_String, Ty.T_String) = true
158              | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (k1, k2)              | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (pl, k1, k2)
159              | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (s1, s2)              | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (pl, s1, s2)
160              | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =              | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
161                  matchDim (d1, d2) andalso matchShape(s1, s2)                  matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)
162              | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =              | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
163                  matchDiff (k1, k2) andalso matchDim (d1, d2) andalso matchShape(s1, s2)                  matchDiff (pl, k1, k2) andalso matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)
164              | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =              | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
165                  matchTypes (tys11, tys21) andalso match (ty12, ty22)                  ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
166              | match _ = false              | match _ = false
167            in            in
168              match (prune ty1, prune ty2)              match (prune ty1, prune ty2)
169            end            end
170    
171      and matchTypes (tys1, tys2) = ListPair.allEq matchType (tys1, tys2)      fun matchTypes (tys1, tys2) = let
172              val pl = ref[]
173              in
174                ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
175              end
176    
177        fun matchType (ty1, ty2) = unifyType (ref[], ty1, ty2)
178    
179      (* try to match types; if we fail, all meta-variable bindings are undone *)
180        fun tryMatchType (ty1, ty2) = let
181              val pl = ref[]
182              in
183                unifyType(pl, ty1, ty2) orelse (undo pl; false)
184              end
185    
186      (* try to match types; if we fail, all meta-variable bindings are undone *)
187        fun tryMatchTypes (tys1, tys2) = let
188              val pl = ref[]
189              in
190                ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
191                orelse (undo pl; false)
192              end
193    
194    (* instantiate a type scheme, returning the argument meta variables and the resulting type.    (* instantiate a type scheme, returning the argument meta variables and the resulting type.
195     * Note that we assume that the scheme is closed.     * Note that we assume that the scheme is closed.

Legend:
Removed from v.84  
changed lines
  Added in v.85

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0