Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/basis/basis-vars.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3604 - (download) (annotate)
Tue Jan 19 01:24:00 2016 UTC (3 years, 7 months ago) by cchiw
File size: 30894 byte(s)
clean up rules and generic outer product
(* basis-vars.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *
 * This module defines the AST variables for the built in operators and functions.
 *)

structure BasisVars =
  struct
    local
      structure N = BasisNames
      structure Ty = Types
      structure MV = MetaVar

      fun --> (tys1, ty) = Ty.T_Fun(tys1, ty)
      infix -->

      val N2 = Ty.DimConst 2
      val N3 = Ty.DimConst 3

    (* short names for kinds *)
      val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
      fun DK () : Ty.meta_var = Ty.DIFF(MV.newDiffVar 0)
      val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
      val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar

      fun ty t = ([], t)
      fun all (kinds, mkTy : Ty.meta_var list -> Ty.ty) = let
            val tvs = List.map (fn mk => mk()) kinds
            in
              (tvs, mkTy tvs)
            end
      fun allNK mkTy = let
            val tv = MV.newDimVar()
            in
              ([Ty.DIM tv], mkTy tv)
            end

      fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
      fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
      fun matrix d = tensor[d,d]

      fun monoVar (name, ty) = Var.new (name, AST.BasisVar, ty)
      fun polyVar (name, scheme) = Var.newPoly (name, AST.BasisVar, scheme)
    in

(* TODO: I'm not sure how to extend + and - to fields, since the typing rules should allow
 * two fields with different differentiation levels to be added.
 *)

  (* overloaded operators; the naming convention is to use the operator name followed
   * by the argument type signature, where
   *    i  -- int
   *    b  -- bool
   *    r  -- real (tensor[])
   *    t  -- tensor[shape]
   *    f  -- field#k(d)[shape]
   *    s  -- field#k(d)[]
   *)

    val add_ii = monoVar(N.op_add, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val add_tt = polyVar(N.op_add, all([SK], fn [Ty.SHAPE dd] => let
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [t, t] --> t
            end))
    val add_ff = polyVar(N.op_add, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            in
              [f, f] --> f
            end))
    val add_ft = polyVar(N.op_add, all([DK,NK,SK], (* field + scalar *)
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [f, t] --> f
            end))
    val add_tf = polyVar(N.op_add, all([DK,NK,SK], (* scalar + field *)
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [t, f] --> f
            end))

    val sub_ii = monoVar(N.op_sub, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val sub_tt = polyVar(N.op_sub, all([SK], fn [Ty.SHAPE dd] => let
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [t, t] --> t
            end))
    val sub_ff = polyVar(N.op_sub, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            in
              [f, f] --> f
            end))
    val sub_ft = polyVar(N.op_sub, all([DK,NK,SK], (* field - scalar *)
          fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [f, t] --> f
            end))
    val sub_tf = polyVar(N.op_sub, all([DK,NK,SK], (* scalar - field *)
          fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [t, f] --> f
            end))

  (* note that we assume that operators are tested in the order defined here, so that mul_rr
   * takes precedence over mul_rt and mul_tr!
   *)
    val mul_ii = monoVar(N.op_mul, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val mul_rr = monoVar(N.op_mul, [Ty.realTy, Ty.realTy] --> Ty.realTy)
    val mul_rt = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [Ty.realTy, t] --> t
            end))
    val mul_tr = polyVar(N.op_mul, all([SK], fn [Ty.SHAPE dd] => let
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [t, Ty.realTy] --> t
            end))
    val mul_rf = polyVar(N.op_mul, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            in
              [Ty.realTy, t] --> t
            end))
    val mul_st = polyVar(N.op_mul, all([DK,NK,SK],
        fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
        val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
        val t = Ty.T_Tensor(Ty.ShapeVar dd)
        val g = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
        in
            [f, t] --> g
        end))

    val mul_ts = polyVar(N.op_mul, all([DK,NK,SK],
        fn [Ty.DIFF k, Ty.DIM d,Ty.SHAPE dd] => let
        val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
        val t = Ty.T_Tensor(Ty.ShapeVar dd)
        val g = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
        in
            [t,f] --> g
        end))
    val mul_fr = polyVar(N.op_mul, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            in
              [t, Ty.realTy] --> t
            end))
    val mul_ss = polyVar(N.op_mul, all([DK,NK],
          fn [Ty.DIFF k, Ty.DIM d] => let
              val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
              in
                  [t, t] --> t
              end))
    val mul_sf = polyVar(N.op_mul, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
              val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
              val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
              in
                  [a,b] --> b
              end))
    val mul_fs = polyVar(N.op_mul, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
              val a = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
              val b = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
              in
                  [b,a] --> b
              end))

    val div_ii = monoVar(N.op_div, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
    val div_rr = monoVar(N.op_div, [Ty.realTy, Ty.realTy] --> Ty.realTy)
    val div_tr = polyVar(N.op_div, all([SK], fn [Ty.SHAPE dd] => let
            val t = Ty.T_Tensor(Ty.ShapeVar dd)
            in
              [t, Ty.realTy] --> t
            end))
    val div_fr = polyVar(N.op_div, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
            val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            in
              [t, Ty.realTy] --> t
            end))
    val div_ss = polyVar(N.op_mul, all([DK,NK],
            fn [Ty.DIFF k, Ty.DIM d] => let
            val t = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
            in
                [t, t] --> t
            end))
    val div_fs = polyVar(N.op_div, all([DK,DK,NK,SK],
            fn [Ty.DIFF k, Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd] => let
            val f = Ty.T_Field{diff = Ty.DiffVar(k, 0), dim = Ty.DimVar d, shape = Ty.ShapeVar dd}
            val s = Ty.T_Field{diff = Ty.DiffVar(k2, 0), dim = Ty.DimVar d, shape = Ty.Shape []}
            in
                [f,s] --> f
            end))

  (* exponentiation; we distinguish between integer and real exponents to allow x^2 to be compiled
   * as x*x.
   *)
    val exp_ri = monoVar(N.op_exp, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
    val exp_rr = monoVar(N.op_exp, [Ty.realTy, Ty.realTy] --> Ty.realTy)

    val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK],
            fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
                val k = Ty.DiffVar(k, 0)
                val d = Ty.DimVar d
                val dd = Ty.ShapeVar dd
                in
                  [Ty.T_Image{dim=d, shape=dd}, Ty.T_Kernel k]
                    --> field(k, d, dd)
                end))
    val convolve_kv = polyVar (N.op_convolve, all([DK, NK, SK],
            fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
                val k = Ty.DiffVar(k, 0)
                val d = Ty.DimVar d
                val dd = Ty.ShapeVar dd
                in
                  [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
                    --> field(k, d, dd)
                end))

  (* curl on 2d and 3d vector fields *)
    local
      val diff0 = Ty.DiffConst 0
      fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
    in
(* FIXME: we want to be able to require that k > 0, but we don't have a way to do that! *)
    val curl2D = polyVar (N.op_curl, all([DK],
          fn [Ty.DIFF k] => let
              val km1 = Ty.DiffVar(k, ~1)
              in 
                [field' (Ty.DiffVar(k, 0), 2, [2])] --> field' (km1, 2, [])
              end))
    val curl3D = polyVar (N.op_curl, all([DK],
          fn [Ty.DIFF k] =>let
              val km1 = Ty.DiffVar(k, ~1)
              in 
                [field' (Ty.DiffVar(k, 0), 3, [3])] --> field' (km1, 3, [3])
              end))
    end (* local *)

    val lt_ii = monoVar(N.op_lt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val lt_rr = monoVar(N.op_lt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val lte_ii = monoVar(N.op_lte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val lte_rr = monoVar(N.op_lte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val gte_ii = monoVar(N.op_gte, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val gte_rr = monoVar(N.op_gte, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val gt_ii = monoVar(N.op_gt, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val gt_rr = monoVar(N.op_gt, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)

    val equ_bb = monoVar(N.op_equ, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
    val equ_ii = monoVar(N.op_equ, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val equ_ss = monoVar(N.op_equ, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
    val equ_rr = monoVar(N.op_equ, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
    val neq_bb = monoVar(N.op_neq, [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
    val neq_ii = monoVar(N.op_neq, [Ty.T_Int, Ty.T_Int] --> Ty.T_Bool)
    val neq_ss = monoVar(N.op_neq, [Ty.T_String, Ty.T_String] --> Ty.T_Bool)
    val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)

    val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)
    val neg_t = polyVar(N.op_neg, all([SK],
          fn [Ty.SHAPE dd] => let
              val t = Ty.T_Tensor(Ty.ShapeVar dd)
              in
                [t] --> t
              end))
    val neg_f = polyVar(N.op_neg, all([DK, NK, SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
              val k = Ty.DiffVar(k, 0)
              val d = Ty.DimVar d
              val dd = Ty.ShapeVar dd
              in
                [field(k, d, dd)] --> field(k, d, dd)
              end))

  (* clamp is overloaded at scalars and vectors *)
    val clamp_rrr = monoVar(N.fn_clamp, [Ty.realTy, Ty.realTy, Ty.realTy] --> Ty.realTy)
    val clamp_vvv = polyVar (N.fn_clamp, allNK(fn tv => let
          val t = tensor[Ty.DimVar tv]
          in
            [t, t, t] --> t
          end))

    val lerp3 = polyVar(N.fn_lerp, all([SK],
          fn [Ty.SHAPE dd] => let
              val t = Ty.T_Tensor(Ty.ShapeVar dd)
              in
                [t, t, Ty.realTy] --> t
              end))
    val lerp5 = polyVar(N.fn_lerp, all([SK],
          fn [Ty.SHAPE dd] => let
              val t = Ty.T_Tensor(Ty.ShapeVar dd)
              in
                [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
              end))

  (* Eigenvalues/vectors of a matrix; we only support this operation on 2x2 and 3x3 matrices, so
   * we overload the function.
   *)
    local
      fun evals d = monoVar (N.fn_evals, [matrix d] --> Ty.T_Sequence(Ty.realTy, d))
      fun evecs d = monoVar (N.fn_evecs, [matrix d] --> Ty.T_Sequence(tensor[d], d))
    in
    val evals2x2 = evals(Ty.DimConst 2)
    val evecs2x2 = evecs(Ty.DimConst 2)
    val evals3x3 = evals(Ty.DimConst 3)
    val evecs3x3 = evecs(Ty.DimConst 3)
    end

  (***** non-overloaded operators, etc. *****)

  (* C math functions *)
    val mathFns : (MathFuns.name * Var.var) list = let
          fun ty n = List.tabulate(MathFuns.arity n, fn _ => Ty.realTy) --> Ty.realTy
          in
            List.map (fn n => (n, monoVar(MathFuns.toAtom n, ty n))) MathFuns.allFuns
          end

  (* pseudo-operator for probing a field *)
    val op_probe = polyVar (N.op_at, all([DK, NK, SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
              val k = Ty.DiffVar(k, 0)
              val d = Ty.DimVar d
              val dd = Ty.ShapeVar dd
              in
                [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
              end))

  (* differentiation of scalar fields *)
    val op_D = polyVar (N.op_D, all([DK, NK],
          fn [Ty.DIFF k, Ty.DIM d] => let
              val k0 = Ty.DiffVar(k, 0)
              val km1 = Ty.DiffVar(k, ~1)
              val d = Ty.DimVar d
              in
                [field(k0, d, Ty.Shape[])]
                  --> field(km1, d, Ty.Shape[d])
              end))

  (* differentiation of higher-order tensor fields *)
    val op_Dotimes = polyVar (N.op_Dotimes, all([DK, NK, SK, NK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
              val k0 = Ty.DiffVar(k, 0)
              val km1 = Ty.DiffVar(k, ~1)
              val d = Ty.DimVar d
              val d' = Ty.DimVar d'
              val dd = Ty.ShapeVar dd
              in
                [field(k0, d, Ty.ShapeExt(dd, d'))]
                  --> field(km1, d, Ty.ShapeExt(Ty.ShapeExt(dd, d'), d))
              end))

   (* divergence differentiation of higher-order tensor fields *)
    val op_Ddot = polyVar (N.op_Ddot, all([DK, NK, SK, NK],
        fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd, Ty.DIM d'] => let
            val k0 = Ty.DiffVar(k, 0)
            val km1 = Ty.DiffVar(k, ~1)
            val d = Ty.DimVar d
            val d' = Ty.DimVar d'
            val dd' = Ty.ShapeVar dd
            in
              [field(k0, d, Ty.ShapeExt(dd', d'))]
                --> field(k0, d, dd')
            end))

    val op_norm_t = polyVar (N.op_norm, all([SK],
          fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy))
    val op_norm_f = polyVar (N.op_norm, all([DK, NK, SK],
          fn [Ty.DIFF k,Ty.DIM d, Ty.SHAPE dd1] => let
              val k = Ty.DiffVar(k, 0)
              val d = Ty.DimVar d
              val f1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
              val f2 = Ty.T_Field{diff = k, dim = d, shape = Ty.Shape []}
              in
                [f1] --> f2
              end))

    val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)

  (* functions *)
    local
      val crossTy = let
            val t = tensor[N3]
            in
              [t, t] --> t
            end
      val crossTy2 = let
            val t = tensor[N2]
            in
              [t, t] --> Ty.realTy
            end
    in
    val op_cross2_tt = monoVar (N.op_cross, crossTy2)
    val op_cross3_tt = monoVar (N.op_cross, crossTy) 
    end

    val op_cross2_ff  = polyVar (N.op_cross, all([DK],
          fn [Ty.DIFF k] => let
              fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
              val k0 = Ty.DiffVar(k, 0)
              val f = field' (k0, 2, [2])
              val t1 = field' (k0, 2, [])
              in
                [f, f] --> t1
              end))

    val op_cross3_ff  = polyVar (N.op_cross, all([DK],
          fn [Ty.DIFF k] => let
              fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
              val f = field' (Ty.DiffVar(k, 0), 3, [3])
              in
                  [f, f] --> f
              end))

  (* the inner product operator (including dot product) is treated as a special case in the
   * typechecker.  It is not included in the basis environment, but we define its type scheme
   * here.  There is an implicit constraint on its type to have the following scheme:
   *
   *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
   *)
    val op_inner_tt = polyVar (N.op_dot, all([SK, SK, SK],
          fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
              [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
                --> Ty.T_Tensor(Ty.ShapeVar s3)))

    val op_inner_tf = polyVar (N.op_dot, all([DK ,NK, SK, SK, SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
              val k = Ty.DiffVar(k, 0)
              val d = Ty.DimVar d
              val t1 = Ty.T_Tensor(Ty.ShapeVar dd1)
              val t2 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd2}
              val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}     
              in
                [t1, t2] --> t3
              end))

    val op_inner_ft = polyVar (N.op_dot, all([DK, NK, SK, SK, SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
              val k = Ty.DiffVar(k, 0)
              val d = Ty.DimVar d
              val t1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
              val t2 = Ty.T_Tensor(Ty.ShapeVar dd2)
              val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}     
              in
                [t1, t2] --> t3
              end))

    val op_inner_ff = polyVar (N.op_dot, all([DK,DK, NK, SK, SK, SK],
          fn [Ty.DIFF k1,Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
              val k1 = Ty.DiffVar(k1, 0)
              val k2 = Ty.DiffVar(k2, 0)
              val d = Ty.DimVar d
              val t1 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd1}
              val t2 = Ty.T_Field{diff = k2, dim = d, shape = Ty.ShapeVar dd2}
              val t3 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd3}
              in
                [t1, t2] --> t3
              end))

  (* the colon (or double-dot) product operator is treated as a special case in the
   * typechecker.  It is not included in the basis environment, but we define its type
   * schemehere.  There is an implicit constraint on its type to have the following scheme:
   *
   *     ALL[sigma1, d1, d2, sigma2] .
   *       tensor[sigma1, d1, d2] * tensor[d1, d2, sigma2] -> tensor[sigma1, sigma2]
   *)
    val op_colon_tt = polyVar (N.op_colon, all([SK, SK, SK],
          fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
              [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
                --> Ty.T_Tensor(Ty.ShapeVar s3)))
    val op_colon_ff = polyVar (N.op_colon, all([DK, SK,NK,SK,SK],
          fn [Ty.DIFF k,Ty.SHAPE dd1, Ty.DIM d, Ty.SHAPE dd2,Ty.SHAPE dd3] =>let
              val k0 = Ty.DiffVar(k, 0)
              val d' = Ty.DimVar d
              val t1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
              val t2 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd2}
              val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
              in
                [t1,t2] --> t3
              end))

    val op_colon_ft = polyVar (N.op_colon, all([DK, SK,NK,SK,SK],
        fn [Ty.DIFF k,Ty.SHAPE dd1, Ty.DIM d, Ty.SHAPE s2,Ty.SHAPE dd3] =>let
        val k0 = Ty.DiffVar(k, 0)
        val d' = Ty.DimVar d
        val t1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
        val t2 = Ty.T_Tensor(Ty.ShapeVar s2)
        val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
        in
        [t1,t2] --> t3
        end))

    val op_colon_tf = polyVar (N.op_colon, all([DK, SK,NK,SK,SK],
    fn [Ty.DIFF k,Ty.SHAPE s1, Ty.DIM d, Ty.SHAPE dd2,Ty.SHAPE dd3] =>let
        val k0 = Ty.DiffVar(k, 0)
        val d' = Ty.DimVar d
        val t1 = Ty.T_Tensor(Ty.ShapeVar s1)
        val t2 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd2}
        val t3 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd3}
        in
        [t1,t2] --> t3
    end))




  (* load image from nrrd *)
    val fn_image = polyVar (N.fn_image, all([NK, SK],
            fn [Ty.DIM d, Ty.SHAPE dd] => let
                val d = Ty.DimVar d
                val dd = Ty.ShapeVar dd
                in
                  [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
                end))

    val fn_inside = polyVar (N.fn_inside, all([DK, NK, SK],
            fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
                val k = Ty.DiffVar(k, 0)
                val d = Ty.DimVar d
                val dd = Ty.ShapeVar dd
                in
                  [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
                    --> Ty.T_Bool
                end))

    val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
    val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)

    val fn_modulate = polyVar (N.fn_modulate, all([NK],
          fn [Ty.DIM d] => let
              val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])
              in
                [t, t] --> t
              end))

        val fn_normalize_t = polyVar (N.fn_normalize, all([SK],
            fn [Ty.SHAPE dd] => let
                val t = Ty.T_Tensor(Ty.ShapeVar dd)
                in
                    [t] --> t
                end))

val fn_normalize_f = polyVar (N.fn_normalize, all([DK,NK, SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1] => let
              val k0 = Ty.DiffVar(k, 0)
              val d' = Ty.DimVar d
              val f1 = Ty.T_Field{diff = k0, dim = d', shape = Ty.ShapeVar dd1}
              in
                [f1] --> f1
              end))
    (* outer product *)
    val op_outer_tt = polyVar (N.op_outer, all([SK, SK, SK],
        fn [Ty.SHAPE s1, Ty.SHAPE s2, Ty.SHAPE s3] =>
        [Ty.T_Tensor(Ty.ShapeVar s1), Ty.T_Tensor(Ty.ShapeVar s2)]
        --> Ty.T_Tensor(Ty.ShapeVar s3)))
    val op_outer_tf = polyVar (N.op_outer, all([DK,NK,SK,SK,SK],
        fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
        val k = Ty.DiffVar(k, 0)
        val d = Ty.DimVar d
        val t1 = Ty.T_Tensor(Ty.ShapeVar dd1)
        val t2 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd2}
        val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
        in
        [t1, t2] --> t3
        end))
    val op_outer_ft = polyVar (N.op_outer, all([DK, NK, SK, SK, SK],
        fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
        val k = Ty.DiffVar(k, 0)
        val d = Ty.DimVar d
        val t1 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd1}
        val t2 = Ty.T_Tensor(Ty.ShapeVar dd2)
        val t3 = Ty.T_Field{diff = k, dim = d, shape = Ty.ShapeVar dd3}
        in
        [t1, t2] --> t3
        end))
    val op_outer_ff = polyVar (N.op_outer, all([DK, DK, NK, SK, SK, SK],
        fn [Ty.DIFF k1,Ty.DIFF k2, Ty.DIM d, Ty.SHAPE dd1, Ty.SHAPE dd2, Ty.SHAPE dd3] => let
        val k1 = Ty.DiffVar(k1, 0)
        val k2 = Ty.DiffVar(k2, 0)
        val d = Ty.DimVar d
        val t1 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd1}
        val t2 = Ty.T_Field{diff = k2, dim = d, shape = Ty.ShapeVar dd2}
        val t3 = Ty.T_Field{diff = k1, dim = d, shape = Ty.ShapeVar dd3}
        in
        [t1, t2] --> t3
        end))

    val fn_principleEvec = polyVar (N.fn_principleEvec, all([NK],
            fn [Ty.DIM d] => let
                val d = Ty.DimVar d
                in
                  [matrix d] --> tensor[d]
                end))

    val fn_trace_t = polyVar (N.fn_trace, all([NK],
          fn [Ty.DIM d] => [matrix(Ty.DimVar d)] --> Ty.realTy))
    val fn_trace_f = polyVar (N.fn_trace, all([DK,NK,SK],
          fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd1] => let
              val k' = Ty.DiffVar(k, 0)
              val d' = Ty.DimVar d
              val d1 = Ty.ShapeVar dd1
              val f = field(k', d', Ty.ShapeExt(Ty.ShapeExt(d1, d'), d'))
              val h = field(k', d', d1)
              in
                [f] --> h
              end))

    val fn_transpose_t = polyVar (N.fn_transpose, all([NK, NK],
          fn [Ty.DIM d1, Ty.DIM d2] =>
              [tensor[Ty.DimVar d1, Ty.DimVar d2]] --> tensor[Ty.DimVar d2, Ty.DimVar d1]))
    val fn_transpose_f = polyVar (N.fn_transpose, all([DK,NK,NK,NK], 
          fn [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] => let
              val k0 = Ty.DiffVar(k, 0)
              val d' = Ty.DimVar d
              val a' = Ty.DimVar a
              val b' = Ty.DimVar b
              val f = field(k0, d', Ty.Shape[a',b'])
              val h = field(k0, d', Ty.Shape[b',a'])
              in
                [f] --> h
              end))

    (*restrict to 2x2 and 3x3*)
    local
        val detT2 = let
        val t = matrix N2
        in
            [t] --> Ty.realTy
        end

    in

        val fn_det_t2 = monoVar (N.fn_det, detT2)
    end

        local
            val detT3 = let
            val t = matrix N3
        in
            [t] --> Ty.realTy
        end

        in

            val fn_det_t3 = monoVar (N.fn_det, detT3)
        end

        val fn_det_f2  = polyVar (N.fn_det, all([DK],
            fn [Ty.DIFF k] => let
                fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
                val k0 = Ty.DiffVar(k, 0)
                val f = field' (k0, 2, [2,2])
                val s = field' (k0, 2, [])
                in
                    [f] --> s
                end))

        val fn_det_f3  = polyVar (N.fn_det, all([DK],
            fn [Ty.DIFF k] => let
                fun field' (k, d, dd) = field(k, Ty.DimConst d, Ty.Shape(List.map Ty.DimConst dd))
                val k0 = Ty.DiffVar(k, 0)
                val f = field' (k0, 3, [3,3])
                val s = field' (k0, 3, [])
                in
                    [f] --> s
                end))

    val fn_sqrt_f = polyVar (N.fn_sqrt, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
        val k' = Ty.DiffVar(k, 0)
        val d' = Ty.DimVar d
        val f = field(k', d', Ty.Shape[])
        in
            [f] --> f
        end))

    val fn_sqrt_t = polyVar (N.fn_sqrt, all([],
        fn [] => let
        val t= Ty.realTy
        in
        [t] --> t
    end))

    val fn_cos_f = polyVar (N.fn_cos, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
            val k' = Ty.DiffVar(k, 0)
            val d' = Ty.DimVar d
            val f = field(k', d', Ty.Shape[])
        in
            [f] --> f
        end))
    val fn_acos_f = polyVar (N.fn_acos, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
        val k' = Ty.DiffVar(k, 0)
        val d' = Ty.DimVar d
        val f = field(k', d', Ty.Shape[])
        in
        [f] --> f
        end))

    val fn_sin_f = polyVar (N.fn_sin, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
            val k' = Ty.DiffVar(k, 0)
            val d' = Ty.DimVar d
            val f = field(k', d', Ty.Shape[])
            in
            [f] --> f
            end))

    val fn_asin_f = polyVar (N.fn_asin, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
            val k' = Ty.DiffVar(k, 0)
            val d' = Ty.DimVar d
            val f = field(k', d', Ty.Shape[])
            in
            [f] --> f
            end))


    (*Post branch split*)

    val fn_tan_f = polyVar (N.fn_tan, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
            val k' = Ty.DiffVar(k, 0)
            val d' = Ty.DimVar d
            val f = field(k', d', Ty.Shape[])
        in
            [f] --> f
    end))

    val fn_atan_f = polyVar (N.fn_atan, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
            val k' = Ty.DiffVar(k, 0)
            val d' = Ty.DimVar d
            val f = field(k', d', Ty.Shape[])
        in
            [f] --> f
    end))


    val fn_exp_f = polyVar (N.fn_exp, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
        val k' = Ty.DiffVar(k, 0)
        val d' = Ty.DimVar d
        val f = field(k', d', Ty.Shape[])
        in
        [f] --> f
    end))

    val fn_exp_t = monoVar(N.fn_exp, [Ty.realTy] --> Ty.realTy)


    val fn_pow_f = polyVar (N.fn_pow_f, all([DK,NK],
        fn [Ty.DIFF k, Ty.DIM d] => let
        val k' = Ty.DiffVar(k, 0)
        val d' = Ty.DimVar d
        val f = field(k', d', Ty.Shape[])
        in
            [f,Ty.T_Int] --> f
        end))


  (* kernels *)
(* FIXME: we should really get the continuity info from the kernels themselves *)
    val kn_bspln3 = monoVar (N.kn_bspln3, Ty.T_Kernel(Ty.DiffConst 2))
    val kn_bspln5 = monoVar (N.kn_bspln5, Ty.T_Kernel(Ty.DiffConst 4))
    val kn_c4hexic = monoVar (N.kn_c4hexic, Ty.T_Kernel(Ty.DiffConst 4))
    val kn_ctmr = monoVar (N.kn_ctmr, Ty.T_Kernel(Ty.DiffConst 1))
    val kn_tent = monoVar (N.kn_tent, Ty.T_Kernel(Ty.DiffConst 0))
  (* kernels with false claims of differentiability, for pedagogy *)
    val kn_c1tent = monoVar (N.kn_c1tent, Ty.T_Kernel(Ty.DiffConst 1))
    val kn_c2ctmr = monoVar (N.kn_c2ctmr, Ty.T_Kernel(Ty.DiffConst 2))

  (***** internal variables *****)

  (* integer to real conversion *)
    val i2r = monoVar (Atom.atom "$i2r", [Ty.T_Int] --> Ty.realTy)

  (* identity matrix *)
    val identity = polyVar (Atom.atom "$id", allNK (fn dv => [] --> matrix(Ty.DimVar dv)))

  (* zero tensor *)
    val zero = polyVar (Atom.atom "$zero", all ([SK],
            fn [Ty.SHAPE dd] => [] --> Ty.T_Tensor(Ty.ShapeVar dd)))

  (* sequence subscript *)
    val subscript = polyVar (Atom.atom "$sub", all ([TK, NK],
            fn [Ty.TYPE tv, Ty.DIM d] =>
              [Ty.T_Sequence(Ty.T_Var tv, Ty.DimVar d), Ty.T_Int] --> Ty.T_Var tv))
    end (* local *)



  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0