Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/basis/basis-vars.sml
ViewVC logotype

View of /branches/charisee/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 179 - (download) (annotate)
Tue Jul 27 20:43:23 2010 UTC (9 years, 2 months ago) by jhr
Original Path: trunk/src/compiler/basis/basis-vars.sml
File size: 7963 byte(s)
  Simplified AST code now uses default input values when no input is given.
(* basis-vars.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
 * All rights reserved.
 *
 * This module defines the AST variables for the built in operators and functions.
 *)

structure BasisVars =
  struct
    local
      structure N = BasisNames
      structure Ty = Types
      structure MV = MetaVar

      fun --> (tys1, ty) = Ty.T_Fun(tys1, ty)
      infix -->

      val N2 = Ty.DimConst 2
      val N3 = Ty.DimConst 3

    (* short names for kinds *)
      val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
      fun DK () = Ty.DIFF(MV.newDiffVar 0)
      val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
      val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar

      fun ty t = ([], t)
      fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
	    val tvs = List.map (fn mk => mk()) kinds
	    in
	      (tvs, mkTy tvs)
	    end
      fun allNK mkTy = let
	    val tv = MV.newDimVar()
	    in
	      ([Ty.DIM tv], mkTy tv)
	    end

      fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
      fun tensor ds = Ty.T_Tensor(Ty.Shape ds)

      fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
      fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
    in

(* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
 * two fields with different differentiation levels to be added.
 *)

  (* overloaded operators; the naming convention is to use the operator name followed
   * by the argument type signature, where
   *	i  -- int
   *	b  -- bool
   *	r  -- real (tensor[])
   *	t  -- tensor[shape]
   *)

    val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
	    val t = Ty.T_Tensor(Ty.ShapeVar dd)
	    in
	      [t, t] --> t
	    end))

    val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
	    val t = Ty.T_Tensor(Ty.ShapeVar dd)
	    in
	      [t, t] --> t
	    end))

  (* note that we assume that operators are tested in the order defined here, so that mul_rr
   * takes precedence over mul_rt and mul_tr!
   *)
    val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
    val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
	    val t = Ty.T_Tensor(Ty.ShapeVar dd)
	    in
	      [Ty.realTy, t] --> t
	    end))
    val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
	    val t = Ty.T_Tensor(Ty.ShapeVar dd)
	    in
	      [t, Ty.realTy] --> t
	    end))

    val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
    val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
	    val t = Ty.T_Tensor(Ty.ShapeVar dd)
	    in
	      [t, Ty.realTy] --> t
	    end))

    val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)

    val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
    val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
    val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
    val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
    val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)


    val neg_i = monoVar(N.op_neg, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val neg_t = polyVar(N.op_neg, all([SK],
	  fn [Ty.SHAPE dd] => let
	      val t = Ty.T_Tensor(Ty.ShapeVar dd)
	      in
		[t] --> t
	      end))
    val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
	  fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
	      val k = Ty.DiffVar(k, 0)
	      val d = Ty.DimVar d
	      val dd = Ty.ShapeVar dd
	      in
		[field(k, d, dd)] --> field(k, d, dd)
	      end))


  (***** non-overloaded operators, etc. *****)

    val op_at = polyVar (N.op_at, all([DK, NK, SK],
	  fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
	      val k = Ty.DiffVar(k, 0)
	      val d = Ty.DimVar d
	      val dd = Ty.ShapeVar dd
	      in
		[field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
	      end))

    val op_D = polyVar (N.op_D, all([DK, NK, SK],
	  fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
	      val k0 = Ty.DiffVar(k, 0)
	      val km1 = Ty.DiffVar(k, ~1)
	      val d = Ty.DimVar d
	      val dd = Ty.ShapeVar dd
	      in
		[field(k0, d, dd)]
		  --> field(km1, d, Ty.ShapeExt(dd, d))
	      end))

    val op_norm = polyVar (N.op_norm, all([SK],
	  fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))

    val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)

    val op_subscript = polyVar (N.op_subscript, all([SK, NK],
	  fn [Ty.SHAPE dd, Ty.DIM d] => let
	      val dd = Ty.ShapeVar dd
	      val d = Ty.DimVar d
	      in
		[Ty.T_Tensor(Ty.ShapeExt(dd, d)), Ty.T_Int]
		  --> Ty.T_Tensor dd
	      end))

  (* functions *)
    val fn_CL = polyVar (N.fn_CL, ty([tensor[N3, N3]] --> Ty.realTy))

    val fn_convolve = polyVar (N.fn_convolve, all([DK, NK, SK],
	    fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
		val k = Ty.DiffVar(k, 0)
		val d = Ty.DimVar d
		val dd = Ty.ShapeVar dd
		in
		  [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
		    --> field(k, d, dd)
		end))

    val fn_cos = monoVar (N.fn_cos, [Ty.realTy] --> Ty.realTy)

    val fn_dot = polyVar (N.fn_dot, allNK(fn tv => let
	  val t = tensor[Ty.DimVar tv]
	  in
	    [t, t] --> Ty.realTy
	  end))

    val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
	    fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
	    	val k = Ty.DiffVar(k, 0)
		val d = Ty.DimVar d
		val dd = Ty.ShapeVar dd
		in
		  [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
		    --> Ty.T_Bool
		end))

    val fn_load = polyVar (N.fn_load, all([NK, SK],
	    fn [Ty.DIM d, Ty.SHAPE dd] => let
		val d = Ty.DimVar d
		val dd = Ty.ShapeVar dd
		in
		  [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
		end))

    val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
    val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)

    val fn_modulate = polyVar (N.fn_modulate, all([NK],
	    fn [Ty.DIM d] => let
		val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
		in
		  [t, t] --> t
		end))

    val fn_pow = monoVar (N.fn_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy)

    val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK],
	    fn [Ty.DIM d] => let
		val d = Ty.DimVar d
		in
		  [tensor[d,d]] --> tensor[d]
		end))

    val fn_sin = monoVar (N.fn_sin, [Ty.realTy] --> Ty.realTy)

  (* kernels *)
(* FIXME: we should really get the continuity info from the kernels themselves *)
    val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2))
    val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4))
    val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 2))
    val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0))

  (* internal variables *)
    val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy)	(* integer to real conversion *)
    val input = polyVar (Atom.atom "$input", all([TK],
	      fn [Ty.TYPE tv] => [Ty.T_String, Ty.T_Var tv] --> Ty.T_Var tv))
    end (* local *)
  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0