Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /trunk/src/typechecker/util.sml
ViewVC logotype

View of /trunk/src/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 82 - (download) (annotate)
Wed May 26 18:20:49 2010 UTC (9 years, 1 month ago) by jhr
File size: 6286 byte(s)
  Working on typechecker
(* util.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
 * All rights reserved.
 *
 * Utilities for typechecking
 *)

structure Util =
  struct

    structure Ty = Types
    structure MV = MetaVar

  (* prune out instantiated meta variables from a type *)
    fun prune ty = let
	  fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
		 of NONE => ty
		  | SOME ty => prune' ty
		(* end case *))
	    | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
	    | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
	    | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
		  diff = pruneDiff diff,
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2)
	    | prune' ty = ty
	  in
	    prune' ty
	  end

    and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
	  case pruneDiff diff
	   of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
	    | Ty.DiffConst i' => Ty.DiffConst(i+i')
	  (* end case *))
      | pruneDiff diff = diff

    and pruneDim dim = (case dim
	   of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
	    | dim => dim
	  (* end case *))

    and pruneShape shape = (case shape
	   of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
	    | Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim)
	    | _ => shape
	  (* end case *))

(* FIXME: what about the bounds? *)
    fun matchDiff (diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)
	   of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
	    | (Ty.DiffConst k, Ty.DiffVar(Ty.DfV{bind, bound, ...}, i)) => let
		val k' = k+i
		in
		  if k' < 0 then false
		  else (bind := SOME(Ty.DiffConst k'); true)
		end
	    | (Ty.DiffVar(Ty.DfV{bind, bound, ...}, i), Ty.DiffConst k) => let
		val k' = k+i
		in
		  if k' < 0 then false
		  else (bind := SOME(Ty.DiffConst k'); true)
		end
	    | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
	  (* end case *))

    fun matchDim (dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)
	   of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
	    | (Ty.DimVar(Ty.DV{bind, ...}), dim2) => (bind := SOME dim2; true)
	    | (dim1, Ty.DimVar(Ty.DV{bind, ...})) => (bind := SOME dim1; true)
	  (* end case *))

    fun matchShape (shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)
	   of (Ty.Shape dd1, Ty.Shape dd2) => let
		fun chk ([], []) = true
		  | chk (d1::dd1, d2::dd2) = matchDim(d1, d2) andalso chk (dd1, dd2)
		  | chk _ = false
		in
		  chk (dd1, dd2)
		end
	    | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
		fun chk ([], _) = false
		  | chk ([d], revDD) =
		      matchDim(d, d2) andalso matchShape(Ty.Shape(List.rev revDD), shape)
		  | chk (d::dd, revDD) = chk(dd, d::revDD)
		in
		  chk (dd, [])
		end
	    | (Ty.ShapeVar(Ty.SV{bind, ...}), shape) => (bind := SOME shape; true)
	    | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
		matchDim(d1, d2) andalso matchShape(shape1, shape2)
	    | (shape1, shape2) => matchShape(shape2, shape1)
	(* end case *))

(* QUESTION: do we need an occurs check? *)
    fun matchType (ty1, ty2) = let
	  fun setBind (Ty.TV{bind=ref(SOME _), ...}, _) = raise Fail "prune fail"
	    | setBind (Ty.TV{bind, ...}, ty) = bind := SOME ty
	  fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =
		if Stamp.same(id1, id2)
		  then ()
		  else setBind (tv1, Ty.T_Var tv2)
	  fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)
	    | match (Ty.T_Var tv1, ty2) = (setBind(tv1, ty2); true)
	    | match (ty1, Ty.T_Var tv2) = (setBind(tv2, ty2); true)
	    | match (Ty.T_Bool, Ty.T_Bool) = true
	    | match (Ty.T_Int, Ty.T_Int) = true
	    | match (Ty.T_String, Ty.T_String) = true
	    | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (k1, k2)
	    | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (s1, s2)
	    | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
		matchDim (d1, d2) andalso matchShape(s1, s2)
	    | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
		matchDiff (k1, k2) andalso matchDim (d1, d2) andalso matchShape(s1, s2)
	    | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
		matchTypes (tys11, tys21) andalso match (ty12, ty22)
	    | match _ = false
	  in
	    match (prune ty1, prune ty2)
	  end

    and matchTypes (tys1, tys2) = ListPair.allEq matchType (tys1, tys2)

  (* instantiate a type scheme, returning the argument meta variables and the resulting type.
   * Note that we assume that the scheme is closed.
   *)
    fun instantiate ([], ty) = ([], ty)
      | instantiate (mvs, ty) = let
	  fun instantiateVar (mv, (mvs, env)) = let
		val mv' = MV.copy mv
		in
		  (mv'::mvs, MV.Map.insert(env, mv, mv'))
		end
	  val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs
	  fun iDiff (Ty.DiffVar(k, i)) = (case MV.Map.find(env, Ty.DIFF k)
		 of SOME(Ty.DIFF k) => Ty.DiffVar(k, i)
		  | _ => raise Fail "impossible"
		(* end case *))
	    | iDiff diff = diff
	  fun iDim (Ty.DimVar dv) = (case MV.Map.find(env, Ty.DIM dv)
		 of SOME(Ty.DIM dv) => Ty.DimVar dv
		  | _ => raise Fail "impossible"
		(* end case *))
	    | iDim dim = dim
	  fun iShape (Ty.ShapeVar sv) = (case MV.Map.find(env, Ty.SHAPE sv)
		 of SOME(Ty.SHAPE sv) => Ty.ShapeVar sv
		  | _ => raise Fail "impossible"
		(* end case *))
	    | iShape (Ty.ShapeExt(shape, dim)) = Ty.ShapeExt(iShape shape, iDim dim)
	    | iShape (Ty.Shape dims) = Ty.Shape(List.map iDim dims)
	  fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
		 of SOME(Ty.TYPE tv) => Ty.T_Var tv
		  | _ => raise Fail "impossible"
		(* end case *))
	    | ity (Ty.T_Kernel k) = Ty.T_Kernel(iDiff k)
	    | ity (Ty.T_Tensor shape) = Ty.T_Tensor(iShape shape)
	    | ity (Ty.T_Image{dim, shape}) = Ty.T_Image{dim=iDim dim, shape=iShape shape}
	    | ity (Ty.T_Field{diff, dim, shape}) =
		Ty.T_Field{diff=iDiff diff, dim=iDim dim, shape=iShape shape}
	    | ity (Ty.T_Fun(dom, rng)) = Ty.T_Fun(List.map ity dom, ity rng)
	    | ity ty = ty
	  in
	    (mvs, ity ty)
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0