Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /trunk/src/compiler/ast/type-util.sml
ViewVC logotype

View of /trunk/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2356 - (download) (annotate)
Sun Apr 7 14:45:25 2013 UTC (6 years, 6 months ago) by jhr
File size: 9267 byte(s)
  Merging in bug fixes and language enhancements from the vis12 branch (via staging).
  Features include type promotion, the curl and colon operator, transpose, and functions.
(* type-util.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure TypeUtil : sig

  (* constructor for building a tensor type of known order, but unknown
   * dimensions.
   *)
    val mkTensorTy : int -> Types.ty

  (* constructor for building a sequence type of unknown size *)
    val mkSequenceTy : Types.ty -> Types.ty

  (* function to compute the slice of a tensor type based on a boolean
   * mask.  The value true in the mask means that the corresponding
   * dimension is being indexed, while false means that it is being
   * copied.
   *)
    val slice : Types.ty * bool list -> Types.ty

  (* returns true if the type is a value type (bool, int, string, or tensor) *)
    val isValueType : Types.ty -> bool

  (* return the range (return type) of a function type *)
    val rngOf : Types.ty -> Types.ty

  (* prune out instantiated meta variables from a type.  We also normalize
   * tensor shapes (i.e., remove 1s).
   *)
    val prune : Types.ty -> Types.ty
    val pruneDiff : Types.diff -> Types.diff
    val pruneShape : Types.shape -> Types.shape
    val pruneDim : Types.dim -> Types.dim

  (* prune the head of a type *)
    val pruneHead : Types.ty -> Types.ty

  (* resolve meta variables to their instantiations (or else variable) *)
    val resolve : Types.ty_var -> Types.ty
    val resolveDiff : Types.diff_var -> Types.diff
    val resolveShape : Types.shape_var -> Types.shape
    val resolveDim : Types.dim_var -> Types.dim
    val resolveVar : Types.meta_var -> Types.var_bind

  (* equality testing *)
    val sameDim : Types.dim * Types.dim -> bool

  (* string representations of types, etc *)
    val toString : Types.ty -> string
    val diffToString : Types.diff -> string
    val shapeToString : Types.shape -> string
    val dimToString : Types.dim -> string

  (* convert to fully resolved monomorphic forms *)
    val monoDim : Types.dim -> int
    val monoShape : Types.shape -> int list
    val monoDiff : Types.diff -> int

  end = struct

    structure Ty = Types
    structure MV = MetaVar

  (* constructor for building a tensor type of known order, but unknown
   * dimensions.
   *)
    fun mkTensorTy order =
	  Ty.T_Tensor(
	    Ty.Shape(List.tabulate(order, fn _ => Ty.DimVar(MetaVar.newDimVar()))))

    fun mkSequenceTy ty =
          Ty.T_Sequence(ty, Ty.DimVar(MetaVar.newDimVar()))

  (* prune out instantiated meta variables from a type.  We also normalize
   * tensor dimensions (i.e., remove 1s).
   *)
    fun prune ty = (case ty
	   of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
		 of NONE => ty
		  | SOME ty => prune ty
		(* end case *))
	    | Ty.T_Sequence(ty, dim) => Ty.T_Sequence(prune ty, pruneDim dim)
	    | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
	    | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
	    | (Ty.T_Image{dim, shape}) => Ty.T_Image{
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
		  diff = pruneDiff diff,
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
	    | ty => ty
	  (* end case *))

    and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
	  case pruneDiff diff
	   of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
	    | Ty.DiffConst i' => Ty.DiffConst(i+i')
	  (* end case *))
      | pruneDiff diff = diff

    and pruneDim dim = (case dim
	   of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
	    | dim => dim
	  (* end case *))

    and filterDim dim = (case pruneDim dim
	   of Ty.DimConst 1 => NONE
	    | dim => SOME dim
	  (* end case *))

    and pruneShape shape = (case shape
	   of Ty.Shape dd => Ty.Shape(List.mapPartial filterDim dd)
	    | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
	    | Ty.ShapeExt(shape, dim) => (case filterDim dim
		 of SOME dim => Ty.shapeExt(pruneShape shape, dim)
		  | NONE => pruneShape shape
		(* end case *))
	    | _ => shape
	  (* end case *))

  (* resolve meta variables to their instantiations (or else variable) *)
    fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
	   of NONE => Ty.T_Var tv
	    | SOME ty => prune ty
	  (* end case *))

    fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
	   of NONE => Ty.DiffVar(dv, 0)
	    | SOME diff => pruneDiff diff
	  (* end case *))

    fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
	   of NONE => Ty.ShapeVar sv
	    | SOME shape => pruneShape shape
	  (* end case *))

    fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
	   of NONE => Ty.DimVar dv
	    | SOME dim => pruneDim dim
	  (* end case *))

    fun resolveVar (Ty.TYPE tv) = Ty.TYPE(resolve tv)
      | resolveVar (Ty.DIFF dv) = Ty.DIFF(resolveDiff dv)
      | resolveVar (Ty.SHAPE sv) = Ty.SHAPE(resolveShape sv)
      | resolveVar (Ty.DIM d) = Ty.DIM(resolveDim d)

  (* prune the head of a type *)
    fun pruneHead ty = let
	  fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
		 of NONE => ty
		  | SOME ty => prune' ty
		(* end case *))
	    | prune' (Ty.T_Sequence(ty, dim)) = Ty.T_Sequence(ty, pruneDim dim)
	    | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
	    | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
	    | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
		  diff = pruneDiff diff,
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | prune' ty = ty
	  in
	    prune' ty
	  end

  (* returns true if the type is a value type (bool, int, string, or tensor) *)
    fun isValueType ty = (case prune ty
	   of Ty.T_Bool => true
	    | Ty.T_Int => true
	    | Ty.T_String => true
	    | Ty.T_Sequence _ => true
	    | Ty.T_Tensor _ => true
	    | _ => false
	  (* end case *))

  (* equality testing *)
    fun sameDim (Ty.DimConst d1, Ty.DimConst d2) = (d1 = d2)
      | sameDim (Ty.DimVar v1, Ty.DimVar v2) = MetaVar.sameDimVar(v1, v2)
      | sameDim _ = false

    fun listToString fmt sep items = String.concatWith sep (List.map fmt items)

    fun diffToString diff = (case pruneDiff diff
	   of Ty.DiffConst n => Int.toString n
	    | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
	    | Ty.DiffVar(dv, i) => if i < 0
		then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
		else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
	  (* end case *))

    fun shapeToString shape = (case pruneShape shape
	   of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
	    | Ty.ShapeVar sv => MV.shapeVarToString sv
	    | Ty.ShapeExt(shape, d) => let
		fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
		  | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
		  | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
		in
		  toS shape ^ dimToString d
		end
	  (* end case *))

    and dimToString dim = (case pruneDim dim
	   of Ty.DimConst n => Int.toString n
	    | Ty.DimVar v => MV.dimVarToString v
	  (* end case *))

    fun toString ty = (case pruneHead ty
	   of Ty.T_Var tv => MV.tyVarToString tv
	    | Ty.T_Bool => "bool"
	    | Ty.T_Int => "int"
	    | Ty.T_String => "string"
	    | Ty.T_Sequence(ty, dim) => concat[toString ty, "{", dimToString dim, "}"]
	    | Ty.T_Kernel n => "kernel#" ^ diffToString n
	    | Ty.T_Tensor(Ty.Shape[]) => "real"
	    | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
	    | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
	    | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
	    | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
	    | Ty.T_Image{dim, shape} => concat[
		  "image(", dimToString dim, ")", shapeToString shape
		]
	    | Ty.T_Field{diff, dim, shape} => concat[
		  "field#", diffToString diff, "(", dimToString dim,
		  ")", shapeToString shape
		]
	    | Ty.T_Fun(tys1, ty2) => let
		fun tysToString [] = "()"
		  | tysToString [ty] = toString ty
		  | tysToString tys = String.concat[
			"(", listToString toString " * " tys, ")"
		      ]
		in
		  String.concat[tysToString tys1, " -> ", toString ty2]
		end
	  (* end case *))

  (* return the range (return type) of a function type *)
    fun rngOf (Ty.T_Fun(_, ty)) = ty
      | rngOf ty = raise Fail(concat["TypeUtil.rngOf(", toString ty, ")"])

    fun slice (Ty.T_Tensor(Ty.Shape l), mask) = let
	  fun f (d, true, dd) = dd
	    | f (d, false, dd) = d::dd
	  in
	    Ty.T_Tensor(Ty.Shape(ListPair.foldr f [] (l, mask)))
	  end
      | slice (ty, _) = raise Fail(concat["slice(", toString ty, ", _)"])

  (* convert to fully resolved monomorphic forms *)
    fun monoDim dim = (case pruneDim dim
	   of Ty.DimConst d => d
	    | dim => raise Fail(concat["dim ", dimToString dim, " is not constant"])
	  (* end case *))

    fun monoShape shp = (case pruneShape shp
	   of Ty.Shape shp => List.map monoDim shp
	    | shp => raise Fail(concat["shape ", shapeToString shp, " is not constant"])
	  (* end case *))

    fun monoDiff diff = (case pruneDiff diff
	   of Ty.DiffConst k => k
	    | diff => raise Fail(concat["diff ", diffToString diff, " is not constant"])
	  (* end case *))

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0