Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/compiler/typechecker/util.sml
ViewVC logotype

Diff of /trunk/src/compiler/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2355, Sun Apr 7 11:35:08 2013 UTC revision 2356, Sun Apr 7 14:45:25 2013 UTC
# Line 8  Line 8 
8    
9  structure Util : sig  structure Util : sig
10    
11      val matchType : Types.ty * Types.ty -> bool    (* when matching two types (ty1 and ty2), there are three possible outcomes:
12      val matchTypes : Types.ty list * Types.ty list -> bool     *   EQ       -- types are equal
13       *   COERCE   -- ty2 can be coerced to match ty1 (e.g., int -> float)
14       *   FAIL     -- types do not match
15       *)
16        datatype match = EQ | COERCE | FAIL
17    
18        val matchType : Types.ty * Types.ty -> match
19    
20        val tryMatchType : Types.ty * Types.ty -> match
21    
22      (* attempt to match a list of parameter types with a list of typed arguments.  Return
23       * the arguments with any required coercions, or NONE on failure.
24       *)
25        val matchArgs : Types.ty list * AST.expr list * Types.ty list -> AST.expr list option
26        val tryMatchArgs : Types.ty list * AST.expr list * Types.ty list -> AST.expr list option
27    
28        val equalType : Types.ty * Types.ty -> bool
29        val equalTypes : Types.ty list * Types.ty list -> bool
30    
31      val tryMatchType : Types.ty * Types.ty -> bool      val tryEqualType : Types.ty * Types.ty -> bool
32      val tryMatchTypes : Types.ty list * Types.ty list -> bool      val tryEqualTypes : Types.ty list * Types.ty list -> bool
33    
34      val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)      val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)
35    
36      val matchDim : Types.dim * Types.dim -> bool      val equalDim : Types.dim * Types.dim -> bool
37    
38    end = struct    end = struct
39    
# Line 24  Line 41 
41      structure MV = MetaVar      structure MV = MetaVar
42      structure TU = TypeUtil      structure TU = TypeUtil
43    
44        datatype match = EQ | COERCE | FAIL
45    
46    (* a patch list tracks the meta variables that have been updated so that we can undo    (* a patch list tracks the meta variables that have been updated so that we can undo
47     * the effects of unification when just testing for a possible type match.     * the effects of unification when just testing for a possible type match.
48     *)     *)
# Line 58  Line 77 
77            end            end
78    
79  (* FIXME: what about the bounds? *)  (* FIXME: what about the bounds? *)
80      fun matchDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)      fun equalDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
81             of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)             of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
82              | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let              | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
83                  val k' = k+i                  val k' = k+i
# Line 75  Line 94 
94              | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)              | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
95            (* end case *))            (* end case *))
96    
97      fun matchDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2)    (* match two differentiation constants where the first is allowed to be less than the second *)
98        fun matchDiff (diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
99               of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 <= k2)
100                | _ => raise Fail "unimplemented" (* FIXME *)
101              (* end case *))
102    
103        fun equalDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2)
104             of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)             of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
105              | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)              | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
106              | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)              | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
107            (* end case *))            (* end case *))
108    
109      fun matchShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2)      fun equalShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2)
110             of (Ty.Shape dd1, Ty.Shape dd2) => let             of (Ty.Shape dd1, Ty.Shape dd2) => let
111                  fun chk ([], []) = true                  fun chk ([], []) = true
112                    | chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2)                    | chk (d1::dd1, d2::dd2) = equalDim(pl, d1, d2) andalso chk (dd1, dd2)
113                    | chk _ = false                    | chk _ = false
114                  in                  in
115                    chk (dd1, dd2)                    chk (dd1, dd2)
# Line 92  Line 117 
117              | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let              | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
118                  fun chk ([], _) = false                  fun chk ([], _) = false
119                    | chk ([d], revDD) =                    | chk ([d], revDD) =
120                        matchDim(pl, d, d2) andalso matchShape(pl, Ty.Shape(List.rev revDD), shape)                        equalDim(pl, d, d2) andalso equalShape(pl, Ty.Shape(List.rev revDD), shape)
121                    | chk (d::dd, revDD) = chk(dd, d::revDD)                    | chk (d::dd, revDD) = chk(dd, d::revDD)
122                  in                  in
123                    chk (dd, [])                    chk (dd, [])
124                  end                  end
125              | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)              | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
126              | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>              | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
127                  matchDim(pl, d1, d2) andalso matchShape(pl, shape1, shape2)                  equalDim(pl, d1, d2) andalso equalShape(pl, shape1, shape2)
128              | (shape1, shape2) => matchShape(pl, shape2, shape1)              | (shape1, shape2) => equalShape(pl, shape2, shape1)
129          (* end case *))          (* end case *))
130    
131  (* QUESTION: do we need an occurs check? *)  (* QUESTION: do we need an occurs check? *)
# Line 125  Line 150 
150              | match (Ty.T_Int, Ty.T_Int) = true              | match (Ty.T_Int, Ty.T_Int) = true
151              | match (Ty.T_String, Ty.T_String) = true              | match (Ty.T_String, Ty.T_String) = true
152              | match (Ty.T_Sequence(ty1, d1), Ty.T_Sequence(ty2, d2)) =              | match (Ty.T_Sequence(ty1, d1), Ty.T_Sequence(ty2, d2)) =
153                  matchDim(pl, d1, d2) andalso match(ty1, ty2)                  equalDim(pl, d1, d2) andalso match(ty1, ty2)
154              | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (pl, k1, k2)              | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = equalDiff (pl, k1, k2)
155              | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (pl, s1, s2)              | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = equalShape (pl, s1, s2)
156              | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =              | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
157                  matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)                  equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
158              | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =              | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
159                  matchDiff (pl, k1, k2) andalso matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2)                  equalDiff (pl, k1, k2) andalso equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
160              | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =              | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
161                  ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)                  ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
162              | match _ = false              | match _ = false
# Line 139  Line 164 
164              match (TU.pruneHead ty1, TU.pruneHead ty2)              match (TU.pruneHead ty1, TU.pruneHead ty2)
165            end            end
166    
167      fun matchTypes (tys1, tys2) = let      fun unifyTypeWithCoercion (pl, ty1, ty2) = (case (TU.pruneHead ty1, TU.pruneHead ty2)
168               of (Ty.T_Tensor shp, Ty.T_Int) =>
169                    if equalShape (pl, Ty.Shape[], shp) then COERCE else FAIL
170                | (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =>
171                    if unifyType(pl, ty1, ty2)
172                      then EQ
173                    else if matchDiff (k1, k2) andalso equalDim(pl, d1, d2)
174                    andalso equalShape(pl, s1, s2)
175                      then COERCE
176                      else FAIL
177                | (ty1, ty2) => if unifyType(pl, ty1, ty2) then EQ else FAIL
178              (* end case *))
179    (* +DEBUG *
180    val unifyTypeWithCoercion = fn (pl, ty1, ty2) => let
181      val res = unifyTypeWithCoercion (pl, ty1, ty2)
182      val res' = (case res of EQ => "EQ" | COERCE => "COERCE" | FAIL => "FAIL")
183      in
184        print(concat["unifyTypeWithCoercion (_, ", TU.toString ty1, ", ", TU.toString ty2, ") = ", res', "\n"]);
185        res
186      end
187    * -DEBUG *)
188    
189        fun equalTypes (tys1, tys2) = let
190            val pl = ref[]            val pl = ref[]
191            in            in
192              ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)              ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
193            end            end
194    
195      fun matchType (ty1, ty2) = unifyType (ref[], ty1, ty2)      fun equalType (ty1, ty2) = unifyType (ref[], ty1, ty2)
196    
197    (* try to match types; if we fail, all meta-variable bindings are undone *)    (* try to match types; if we fail, all meta-variable bindings are undone *)
198      fun tryMatchType (ty1, ty2) = let      fun tryEqualType (ty1, ty2) = let
199            val pl = ref[]            val pl = ref[]
200            in            in
201              unifyType(pl, ty1, ty2) orelse (undo pl; false)              unifyType(pl, ty1, ty2) orelse (undo pl; false)
202            end            end
203    
204    (* try to match types; if we fail, all meta-variable bindings are undone *)    (* try to unify two types to equality; if we fail, all meta-variable bindings are undone *)
205      fun tryMatchTypes (tys1, tys2) = let      fun tryEqualTypes (tys1, tys2) = let
206            val pl = ref[]            val pl = ref[]
207            in            in
208              ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)              ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
209              orelse (undo pl; false)              orelse (undo pl; false)
210            end            end
211    
212    (* rebind matchDim without patch-list argument *)      fun matchType (ty1, ty2) = unifyTypeWithCoercion (ref[], ty1, ty2)
213      val matchDim = fn (d1, d2) => matchDim(ref [], d1, d2)  
214      (* try to unify two type lists to equality; if we fail, all meta-variable bindings are undone *)
215        fun tryMatchType (ty1, ty2) = let
216              val pl = ref[]
217              in
218                case unifyTypeWithCoercion (pl, ty1, ty2)
219                 of FAIL => (undo pl; FAIL)
220                  | result => result
221                (* end case *)
222              end
223    
224      (* attempt to match a list of parameter types with a list of typed arguments.  Return
225       * the arguments with any required coercions, or NONE on failure.
226       *)
227        local
228          fun matchArgs' (pl, paramTys, args, argTys) = let
229                fun matchArgTys ([], [], [], args') = SOME(List.rev args')
230                  | matchArgTys (ty1::tys1, arg::args, ty2::tys2, args') = (
231                      case unifyTypeWithCoercion (pl, ty1, ty2)
232                       of EQ => matchArgTys (tys1, args, tys2, arg::args')
233                        | COERCE => matchArgTys (tys1, args, tys2, AST.E_Coerce{srcTy=ty2, dstTy=ty1, e=arg}::args')
234                        | _ => (undo pl; NONE)
235                      (* end case *))
236                    | matchArgTys _ = NONE
237                in
238                  matchArgTys (paramTys, args, argTys, [])
239                end
240        in
241        fun matchArgs (paramTys, args, argTys) = matchArgs' (ref[], paramTys, args, argTys)
242        fun tryMatchArgs (paramTys, args, argTys) = let
243              val pl = ref[]
244              in
245                case matchArgs' (ref[], paramTys, args, argTys)
246                 of NONE => (undo pl; NONE)
247                  | someResult => someResult
248                (* end case *)
249              end
250        end
251    
252      (* rebind equalDim without patch-list argument *)
253        val equalDim = fn (d1, d2) => equalDim(ref [], d1, d2)
254    
255  (* QUESTION: perhaps this function belongs in the TypeUtil module? *)  (* QUESTION: perhaps this function belongs in the TypeUtil module? *)
256    (* instantiate a type scheme, returning the argument meta variables and the resulting type.    (* instantiate a type scheme, returning the argument meta variables and the resulting type.
# Line 197  Line 284 
284                   of SOME(Ty.TYPE tv) => Ty.T_Var tv                   of SOME(Ty.TYPE tv) => Ty.T_Var tv
285                    | _ => raise Fail "impossible"                    | _ => raise Fail "impossible"
286                  (* end case *))                  (* end case *))
287                | ity Ty.T_Bool = Ty.T_Bool
288                | ity Ty.T_Int = Ty.T_Int
289                | ity Ty.T_String = Ty.T_String
290              | ity (Ty.T_Sequence(ty, d)) = Ty.T_Sequence(ity ty, iDim d)              | ity (Ty.T_Sequence(ty, d)) = Ty.T_Sequence(ity ty, iDim d)
291              | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)              | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
292              | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)              | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
# Line 204  Line 294 
294              | ity (Ty.T_Field{diff, dim, shape}) =              | ity (Ty.T_Field{diff, dim, shape}) =
295                  Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}                  Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
296              | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)              | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
             | ity ty = ty  
297            in            in
298              (mvs, ity ty)              (mvs, ity ty)
299            end            end

Legend:
Removed from v.2355  
changed lines
  Added in v.2356

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0