Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /trunk/src/compiler/ast/type-util.sml
ViewVC logotype

View of /trunk/src/compiler/ast/type-util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 241 - (download) (annotate)
Fri Aug 6 14:07:20 2010 UTC (9 years ago) by jhr
File size: 6024 byte(s)
  Bug fix: add missing case to pruneShape
(* type-util.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
 * All rights reserved.
 *)

structure TypeUtil : sig

  (* returns true if the type is a value type (bool, int, string, or tensor) *)
    val isValueType : Types.ty -> bool

  (* prune out instantiated meta variables *)
    val prune : Types.ty -> Types.ty
    val pruneDiff : Types.diff -> Types.diff
    val pruneShape : Types.shape -> Types.shape
    val pruneDim : Types.dim -> Types.dim

  (* prune the head of a type *)
    val pruneHead : Types.ty -> Types.ty

  (* resolve meta variables to their instantiations (or else variable) *)
    val resolve : Types.ty_var -> Types.ty
    val resolveDiff : Types.diff_var -> Types.diff
    val resolveShape : Types.shape_var -> Types.shape
    val resolveDim : Types.dim_var -> Types.dim

  (* string representations of types, etc *)
    val toString : Types.ty -> string
    val diffToString : Types.diff -> string
    val shapeToString : Types.shape -> string
    val dimToString : Types.dim -> string

  end = struct

    structure Ty = Types
    structure MV = MetaVar

  (* prune out instantiated meta variables from a type *)
    fun prune ty = (case ty
	   of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind
		 of NONE => ty
		  | SOME ty => prune ty
		(* end case *))
	    | (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff)
	    | (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape)
	    | (Ty.T_Image{dim, shape}) => Ty.T_Image{
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{
		  diff = pruneDiff diff,
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2)
	    | ty => ty
	  (* end case *))

    and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
	  case pruneDiff diff
	   of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
	    | Ty.DiffConst i' => Ty.DiffConst(i+i')
	  (* end case *))
      | pruneDiff diff = diff

    and pruneDim dim = (case dim
	   of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
	    | dim => dim
	  (* end case *))

    and pruneShape shape = (case shape
	   of Ty.Shape dd => Ty.Shape(List.map pruneDim dd)
	    | Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
	    | Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim)
	    | _ => shape
	  (* end case *))

  (* resolve meta variables to their instantiations (or else variable) *)
    fun resolve (tv as Ty.TV{bind, ...}) = (case !bind
	   of NONE => Ty.T_Var tv
	    | SOME ty => prune ty
	  (* end case *))

    fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind
	   of NONE => Ty.DiffVar(dv, 0)
	    | SOME diff => pruneDiff diff
	  (* end case *))

    fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind
	   of NONE => Ty.ShapeVar sv
	    | SOME shape => pruneShape shape
	  (* end case *))

    fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind
	   of NONE => Ty.DimVar dv
	    | SOME dim => pruneDim dim
	  (* end case *))

  (* prune the head of a type *)
    fun pruneHead ty = let
	  fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
		 of NONE => ty
		  | SOME ty => prune' ty
		(* end case *))
	    | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
	    | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
	    | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
		  diff = pruneDiff diff,
		  dim = pruneDim dim,
		  shape = pruneShape shape
		}
	    | prune' ty = ty
	  in
	    prune' ty
	  end

    fun listToString fmt sep items = String.concatWith sep (List.map fmt items)

    fun diffToString diff = (case pruneDiff diff
	   of Ty.DiffConst n => Int.toString n
	    | Ty.DiffVar(dv, 0) => MV.diffVarToString dv
	    | Ty.DiffVar(dv, i) => if i < 0
		then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"]
		else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"]
	  (* end case *))

    fun shapeToString shape = (case pruneShape shape
	   of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"]
	    | Ty.ShapeVar sv => MV.shapeVarToString sv
	    | Ty.ShapeExt(shape, d) => let
		fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ ","
		  | toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";"
		  | toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","]
		in
		  toS shape ^ dimToString d
		end
	  (* end case *))

    and dimToString dim = (case pruneDim dim
	   of Ty.DimConst n => Int.toString n
	    | Ty.DimVar v => MV.dimVarToString v
	  (* end case *))

  (* returns true if the type is a value type (bool, int, string, or tensor) *)
    fun isValueType ty = (case prune ty
	   of Ty.T_Bool => true
	    | Ty.T_Int => true
	    | Ty.T_String => true
	    | Ty.T_Tensor _ => true
	    | _ => false
	  (* end case *))

    fun toString ty = (case pruneHead ty
	   of Ty.T_Var tv => MV.tyVarToString tv
	    | Ty.T_Bool => "bool"
	    | Ty.T_Int => "int"
	    | Ty.T_String => "string"
	    | Ty.T_Kernel n => "kernel#" ^ diffToString n
	    | Ty.T_Tensor(Ty.Shape[]) => "real"
	    | Ty.T_Tensor(Ty.Shape[Ty.DimConst 2]) => "vec2"
	    | Ty.T_Tensor(Ty.Shape[Ty.DimConst 3]) => "vec3"
	    | Ty.T_Tensor(Ty.Shape[Ty.DimConst 4]) => "vec4"
	    | Ty.T_Tensor shape => "tensor" ^ shapeToString shape
	    | Ty.T_Image{dim, shape} => concat[
		  "image(", dimToString dim, ")", shapeToString shape
		]
	    | Ty.T_Field{diff, dim, shape} => concat[
		  "field#", diffToString diff, "(", dimToString dim,
		  ")", shapeToString shape
		]
	    | Ty.T_Fun(tys1, ty2) => let
		fun tysToString [] = "()"
		  | tysToString [ty] = toString ty
		  | tysToString tys = String.concat[
			"(", listToString toString " * " tys, ")"
		      ]
		in
		  String.concat[tysToString tys1, " -> ", toString ty2]
		end
	  (* end case *))

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0