SCM Repository
View of /trunk/src/compiler/typechecker/util.sml
Parent Directory
|
Revision Log
Revision 1113 -
(download)
(annotate)
Thu May 5 04:11:52 2011 UTC (9 years, 8 months ago) by jhr
File size: 7272 byte(s)
Thu May 5 04:11:52 2011 UTC (9 years, 8 months ago) by jhr
File size: 7272 byte(s)
Starting to merge pure-cfg changes back into trunk
(* util.sml * * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. * * Utilities for typechecking *) structure Util : sig val matchType : Types.ty * Types.ty -> bool val matchTypes : Types.ty list * Types.ty list -> bool val tryMatchType : Types.ty * Types.ty -> bool val tryMatchTypes : Types.ty list * Types.ty list -> bool val instantiate : Types.scheme -> (Types.meta_var list * Types.ty) val matchDim : Types.dim * Types.dim -> bool end = struct structure Ty = Types structure MV = MetaVar structure TU = TypeUtil (* a patch list tracks the meta variables that have been updated so that we can undo * the effects of unification when just testing for a possible type match. *) fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = ( bind := SOME ty; pl := Ty.TYPE tv :: !pl) | bindTyVar _ = raise Fail "rebinding type variable" fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = ( bind := SOME diff; pl := Ty.DIFF dv :: !pl) | bindDiffVar _ = raise Fail "rebinding differentiation variable" fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = ( bind := SOME shape; pl := Ty.SHAPE sv :: !pl) | bindShapeVar _ = raise Fail "rebinding shape variable" fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = ( bind := SOME dim; pl := Ty.DIM dv :: !pl) | bindDimVar _ = raise Fail "rebinding dimension variable" fun undo pl = let fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE in List.map undo1 (!pl) end (* FIXME: what about the bounds? *) fun matchDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2) of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2) | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let val k' = k+i in if k' < 0 then false else (bindDiffVar(pl, dv, Ty.DiffConst k'); true) end | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let val k' = k+i in if k' < 0 then false else (bindDiffVar(pl, dv, Ty.DiffConst k'); true) end | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *) (* end case *)) fun matchDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2) of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2) | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true) | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true) (* end case *)) fun matchShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2) of (Ty.Shape dd1, Ty.Shape dd2) => let fun chk ([], []) = true | chk (d1::dd1, d2::dd2) = matchDim(pl, d1, d2) andalso chk (dd1, dd2) | chk _ = false in chk (dd1, dd2) end | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let fun chk ([], _) = false | chk ([d], revDD) = matchDim(pl, d, d2) andalso matchShape(pl, Ty.Shape(List.rev revDD), shape) | chk (d::dd, revDD) = chk(dd, d::revDD) in chk (dd, []) end | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true) | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) => matchDim(pl, d1, d2) andalso matchShape(pl, shape1, shape2) | (shape1, shape2) => matchShape(pl, shape2, shape1) (* end case *)) (* QUESTION: do we need an occurs check? *) fun unifyType (pl, ty1, ty2) = let fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) = if Stamp.same(id1, id2) then () else bindTyVar (pl, tv1, Ty.T_Var tv2) fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true) | match (Ty.T_Var tv1, ty2) = (bindTyVar(pl, tv1, ty2); true) | match (ty1, Ty.T_Var tv2) = (bindTyVar(pl, tv2, ty2); true) | match (Ty.T_Bool, Ty.T_Bool) = true | match (Ty.T_Int, Ty.T_Int) = true | match (Ty.T_String, Ty.T_String) = true | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (pl, k1, k2) | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (pl, s1, s2) | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) = matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2) | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) = matchDiff (pl, k1, k2) andalso matchDim (pl, d1, d2) andalso matchShape(pl, s1, s2) | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) = ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22) | match _ = false in match (TU.pruneHead ty1, TU.pruneHead ty2) end fun matchTypes (tys1, tys2) = let val pl = ref[] in ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2) end fun matchType (ty1, ty2) = unifyType (ref[], ty1, ty2) (* try to match types; if we fail, all meta-variable bindings are undone *) fun tryMatchType (ty1, ty2) = let val pl = ref[] in unifyType(pl, ty1, ty2) orelse (undo pl; false) end (* try to match types; if we fail, all meta-variable bindings are undone *) fun tryMatchTypes (tys1, tys2) = let val pl = ref[] in ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2) orelse (undo pl; false) end (* rebind matchDim without patch-list argument *) val matchDim = fn (d1, d2) => matchDim(ref [], d1, d2) (* QUESTION: perhaps this function belongs in the TypeUtil module? *) (* instantiate a type scheme, returning the argument meta variables and the resulting type. * Note that we assume that the scheme is closed. *) fun instantiate ([], ty) = ([], ty) | instantiate (mvs, ty) = let fun instantiateVar (mv, (mvs, env)) = let val mv' = MV.copy mv in (mv'::mvs, MV.Map.insert(env, mv, mv')) end val (mvs, env) = List.foldr instantiateVar ([], MV.Map.empty) mvs fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k) of SOME(Ty.DIFF k) => Ty.DiffVar(k, i) | _ => raise Fail "impossible" (* end case *)) | iDiff diff = diff fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv) of SOME(Ty.DIM dv) => Ty.DimVar dv | _ => raise Fail "impossible" (* end case *)) | iDim dim = dim fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv) of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv | _ => raise Fail "impossible" (* end case *)) | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim) | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims) fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv) of SOME(Ty.TYPE tv) => Ty.T_Var tv | _ => raise Fail "impossible" (* end case *)) | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k) | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape) | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape} | ity (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape} | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng) | ity ty = ty in (mvs, ity ty) end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |