Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis12/src/compiler/typechecker/util.sml
ViewVC logotype

View of /branches/vis12/src/compiler/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1925 - (download) (annotate)
Sat Jun 23 14:16:09 2012 UTC (7 years, 1 month ago) by jhr
File size: 9795 byte(s)
  Added length function on dynamic sequences
(* util.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * Utilities for typechecking
 *
 * TODO:
 *      coercions to lift integers to reals
 *      coercions to lift fixed-sized sequences to dynamic-sized sequences
 *)

structure Util : sig

  (* when matching two types (ty1 and ty2), there are three possible outcomes:
   *   EQ       -- types are equal
   *   COERCE   -- ty2 can be coerced to match ty1 (e.g., int -> float, fixed seq -> dynamic seq)
   *   FAIL     -- types do not match
   *)
    datatype match = EQ | COERCE | FAIL

    val matchType : Types.ty * Types.ty -> match

    val equalType : Types.ty * Types.ty -> bool
    val equalTypes : Types.ty list * Types.ty list -> bool

    val tryMatchType : Types.ty * Types.ty -> bool
    val tryMatchTypes : Types.ty list * Types.ty list -> bool

    val instantiate : Types.scheme -> (Types.meta_var list * Types.ty)

    val equalDim : Types.dim * Types.dim -> bool

  end = struct

    structure Ty = Types
    structure MV = MetaVar
    structure TU = TypeUtil

    datatype match = EQ | COERCE | FAIL

  (* a patch list tracks the meta variables that have been updated so that we can undo
   * the effects of unification when just testing for a possible type match.
   *)

    fun bindTyVar (pl, tv as Ty.TV{bind as ref NONE, ...}, ty) = (
	  bind := SOME ty;
	  pl := Ty.TYPE tv :: !pl)
      | bindTyVar _ = raise Fail "rebinding type variable"

    fun bindDiffVar (pl, dv as Ty.DfV{bind as ref NONE, ...}, diff) = (
	  bind := SOME diff;
	  pl := Ty.DIFF dv :: !pl)
      | bindDiffVar _ = raise Fail "rebinding differentiation variable"

    fun bindShapeVar (pl, sv as Ty.SV{bind as ref NONE, ...}, shape) = (
	  bind := SOME shape;
	  pl := Ty.SHAPE sv :: !pl)
      | bindShapeVar _ = raise Fail "rebinding shape variable"

    fun bindDimVar (pl, dv as Ty.DV{bind as ref NONE, ...}, dim) = (
	  bind := SOME dim;
	  pl := Ty.DIM dv :: !pl)
      | bindDimVar _ = raise Fail "rebinding dimension variable"

    fun undo pl = let
	  fun undo1 (Ty.TYPE(Ty.TV{bind, ...})) = bind := NONE
	    | undo1 (Ty.DIFF(Ty.DfV{bind, ...})) = bind := NONE
	    | undo1 (Ty.SHAPE(Ty.SV{bind, ...})) = bind := NONE
	    | undo1 (Ty.DIM(Ty.DV{bind, ...})) = bind := NONE
	  in
	    List.map undo1 (!pl)
	  end

(* FIXME: what about the bounds? *)
    fun equalDiff (pl, diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
	   of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
	    | (Ty.DiffConst k, Ty.DiffVar(dv, i)) => let
		val k' = k+i
		in
		  if k' < 0 then false
		  else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
		end
	    | (Ty.DiffVar(dv, i), Ty.DiffConst k) => let
		val k' = k+i
		in
		  if k' < 0 then false
		  else (bindDiffVar(pl, dv, Ty.DiffConst k'); true)
		end
	    | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
	  (* end case *))

  (* match two differentiation constants where the first is allowed to be less than the second *)
    fun matchDiff (diff1, diff2) = (case (TU.pruneDiff diff1, TU.pruneDiff diff2)
	   of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 <= k2)
	    | _ => raise Fail "unimplemented" (* FIXME *)
	  (* end case *))

    fun equalDim (pl, dim1, dim2) = (case (TU.pruneDim dim1, TU.pruneDim dim2)
	   of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
	    | (Ty.DimVar dv, dim2) => (bindDimVar(pl, dv, dim2); true)
	    | (dim1, Ty.DimVar dv) => (bindDimVar(pl, dv, dim1); true)
	  (* end case *))

    fun equalShape (pl, shape1, shape2) = (case (TU.pruneShape shape1, TU.pruneShape shape2)
	   of (Ty.Shape dd1, Ty.Shape dd2) => let
		fun chk ([], []) = true
		  | chk (d1::dd1, d2::dd2) = equalDim(pl, d1, d2) andalso chk (dd1, dd2)
		  | chk _ = false
		in
		  chk (dd1, dd2)
		end
	    | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
		fun chk ([], _) = false
		  | chk ([d], revDD) =
		      equalDim(pl, d, d2) andalso equalShape(pl, Ty.Shape(List.rev revDD), shape)
		  | chk (d::dd, revDD) = chk(dd, d::revDD)
		in
		  chk (dd, [])
		end
	    | (Ty.ShapeVar sv, shape) => (bindShapeVar (pl, sv, shape); true)
	    | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
		equalDim(pl, d1, d2) andalso equalShape(pl, shape1, shape2)
	    | (shape1, shape2) => equalShape(pl, shape2, shape1)
	(* end case *))

(* QUESTION: do we need an occurs check? *)
    fun unifyType (pl, ty1, ty2) = let
	  fun matchVar (tv1 as Ty.TV{id=id1, bind=b1}, tv2 as Ty.TV{id=id2, bind=b2}) =
		if Stamp.same(id1, id2)
		  then true
		  else (case (!b1, !b2)
                     of (SOME ty1, SOME ty2) => match(ty1, ty2)
                      | (SOME ty1, NONE) => (bindTyVar (pl, tv2, ty1); true)
                      | (NONE, SOME ty2) => (bindTyVar (pl, tv1, ty2); true)
                      | (NONE, NONE) => (bindTyVar (pl, tv1, Ty.T_Var tv2); true)
                    (* end case *))
          and matchVarTy (tv as Ty.TV{bind, ...}, ty) = (case !bind
                 of NONE => (bindTyVar(pl, tv, ty); true)
                  | SOME ty' => match(ty', ty)
                (* end case *))
	  and match (Ty.T_Var tv1, Ty.T_Var tv2) = matchVar(tv1, tv2)
	    | match (Ty.T_Var tv1, ty2) = matchVarTy(tv1, ty2)
	    | match (ty1, Ty.T_Var tv2) = matchVarTy(tv2, ty1)
	    | match (Ty.T_Bool, Ty.T_Bool) = true
	    | match (Ty.T_Int, Ty.T_Int) = true
	    | match (Ty.T_String, Ty.T_String) = true
            | match (Ty.T_Sequence(ty1, d1), Ty.T_Sequence(ty2, d2)) =
                equalDim(pl, d1, d2) andalso match(ty1, ty2)
            | match (Ty.T_DynSequence ty1, Ty.T_DynSequence ty2) = match(ty1, ty2)
	    | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = equalDiff (pl, k1, k2)
	    | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = equalShape (pl, s1, s2)
	    | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
		equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
	    | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
		equalDiff (pl, k1, k2) andalso equalDim (pl, d1, d2) andalso equalShape(pl, s1, s2)
	    | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
		ListPair.allEq match (tys11, tys21) andalso match (ty12, ty22)
	    | match _ = false
	  in
	    match (TU.pruneHead ty1, TU.pruneHead ty2)
	  end

    fun equalTypes (tys1, tys2) = let
	  val pl = ref[]
	  in
	    ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
	  end

    fun equalType (ty1, ty2) = unifyType (ref[], ty1, ty2)

(* TODO: add removal of differentiation *)
    fun matchType (ty1, ty2) = (case (TU.pruneHead ty1, TU.pruneHead ty2)
           of (Ty.T_Tensor(Ty.Shape[]), Ty.T_Int) => COERCE
            | (Ty.T_DynSequence ty1, Ty.T_Sequence(ty2, _)) =>
                if equalType(ty1, ty2) then COERCE else FAIL
            | (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =>
                if equalType (ty1, ty2)
                  then EQ
                else if matchDiff (k1, k2) andalso equalDim(ref[], d1, d2)
                andalso equalShape(ref[], s1, s2)
                  then COERCE
                  else FAIL
            | (ty1, ty2) => if equalType(ty1, ty2) then EQ else FAIL
          (* end case *))

  (* try to match types; if we fail, all meta-variable bindings are undone *)
    fun tryMatchType (ty1, ty2) = let
	  val pl = ref[]
	  in
	    unifyType(pl, ty1, ty2) orelse (undo pl; false)
	  end

  (* try to match types; if we fail, all meta-variable bindings are undone *)
    fun tryMatchTypes (tys1, tys2) = let
	  val pl = ref[]
	  in
	    ListPair.allEq (fn (ty1, ty2) => unifyType(pl, ty1, ty2)) (tys1, tys2)
	    orelse (undo pl; false)
	  end

  (* rebind equalDim without patch-list argument *)
    val equalDim = fn (d1, d2) => equalDim(ref [], d1, d2)

(* QUESTION: perhaps this function belongs in the TypeUtil module? *)
  (* instantiate a type scheme, returning the argument meta variables and the resulting type.
   * Note that we assume that the scheme is closed.
   *)
    fun instantiate ([], ty) = ([], ty)
      | instantiate (mvs, ty) = let
	  fun instantiateVar (mv, (mvs, env)) = let
		val mv' = MV.copy mv
		in
		  (mv'::mvs, MV.Map.insert(env, mv, mv'))
		end
	  val (mvs, env) = List.foldr instantiateVar ([], MV.Map.empty) mvs
	  fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
		 of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
		  | _ => raise Fail "impossible"
		(* end case *))
	    | iDiff diff = diff
	  fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
		 of SOME(Ty.DIM dv) => Ty.DimVar dv
		  | _ => raise Fail "impossible"
		(* end case *))
	    | iDim dim = dim
	  fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
		 of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
		  | _ => raise Fail "impossible"
		(* end case *))
	    | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
	    | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
	  fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
		 of SOME(Ty.TYPE tv) => Ty.T_Var tv
		  | _ => raise Fail "impossible"
		(* end case *))
            | ity Ty.T_Bool = Ty.T_Bool
            | ity Ty.T_Int = Ty.T_Int
            | ity Ty.T_String = Ty.T_String
            | ity (Ty.T_Sequence(ty, d)) = Ty.T_Sequence(ity ty, iDim d)
	    | ity (Ty.T_DynSequence ty) = Ty.T_DynSequence(ity ty)
	    | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
	    | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
	    | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
	    | ity (Ty.T_Field{diff, dim, shape}) =
		Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
	    | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
	  in
	    (mvs, ity ty)
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0