Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/basis/basis.sml
ViewVC logotype

View of /branches/vis15/src/compiler/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3830 - (download) (annotate)
Thu May 5 22:13:46 2016 UTC (3 years, 2 months ago) by jhr
File size: 7924 byte(s)
  Working on merge: getting clamp and lerp sorted
(* basis.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *
 * Defining the Diderot Basis environment.
 *)

structure Basis : sig

    val env : unit -> GlobalEnv.t

  (* operations that are allowed in constant expressions *)
    val allowedInConstExp : AST.var -> bool

  (* reduction operators *)
    val isReductionOp : AST.var -> bool

  (* global sets of strands *)
    val isStrandSet : AST.var -> bool

  end = struct

    structure N = BasisNames
    structure BV = BasisVars
    structure ATbl = AtomTable
    structure GEnv = GlobalEnv

  (* non-overloaded operators, etc. *)
    val basisFunctions = [
        (* non-overloaded operators *)
          BV.op_D,
          BV.op_Dotimes,
          BV.op_Ddot,
          BV.op_not,
        (* functions *)
          BV.image_border,
          BV.fn_inside,
          BV.fn_length,
          BV.image_mirror,
          BV.fn_modulate,
(* unimplemented
          BV.fn_principleEvec,
*)
          BV.fn_size,
          BV.image_wrap,
        (* reductions *)
          BV.red_all,
          BV.red_exists,
          BV.red_max,
          BV.red_mean,
          BV.red_min,
          BV.red_product,
          BV.red_sum,
          BV.red_variance,
        (* Math functions that have not yet been lifted to work on fields *)
          BV.fn_atan2_rr,
          BV.fn_ceil_r,
          BV.fn_floor_r,
          BV.fn_fmod_rr,
          BV.fn_erf_r,
          BV.fn_erfc_r,
          BV.fn_log_r,
          BV.fn_log10_r,
          BV.fn_log2_r,
          BV.fn_pow_rr
        ]

    val basisVars = [
        (* kernels *)
          BV.kn_bspln3,
          BV.kn_bspln5,
          BV.kn_c4hexic,
          BV.kn_ctmr,
          BV.kn_tent
        ]

  (* overloaded operators and functions *)
    val overloads = [
        (* overloaded operators *)
          (N.op_at, [BV.at_Td, BV.at_dT, BV.at_dd]),
          (N.op_lte, [BV.lte_ii, BV.lte_rr]),
          (N.op_equ, [BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr]),
          (N.op_neq, [BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr]),
          (N.op_gte, [BV.gte_ii, BV.gte_rr]),
          (N.op_gt, [BV.gt_ii, BV.gt_rr]),
          (N.op_add, [BV.add_ii, BV.add_tt, BV.add_ff, BV.add_ft, BV.add_tf]),
          (N.op_sub, [BV.sub_ii, BV.sub_tt, BV.sub_ff, BV.sub_ft, BV.sub_tf]),
          (N.op_mul, [
              BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr, BV.mul_rf, BV.mul_fr,
              BV.mul_ss, BV.mul_sf, BV.mul_fs, BV.mul_st, BV.mul_ts
            ]),
          (N.op_div, [BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr, BV.div_fr, BV.div_ss, BV.div_fs]),
          (N.op_pow, [BV.pow_ri, BV.pow_rr, BV.pow_si]),
          (N.op_curl, [BV.curl2D, BV.curl3D]),
          (N.op_convolve, [BV.convolve_vk, BV.convolve_kv]),
          (N.op_lt, [BV.lt_ii, BV.lt_rr]),
          (N.op_neg, [BV.neg_i, BV.neg_t, BV.neg_f]),
          (N.op_cross, [BV.op_cross2_tt, BV.op_cross3_tt, BV.op_cross2_ff, BV.op_cross3_ff]),
          (N.op_norm, [BV.op_norm_t, BV.op_norm_f]),
        (* overloaded functions *)
          (N.fn_abs, [BV.fn_abs_i, BV.fn_abs_r]),
          (N.fn_acos, [BV.fn_acos_r, BV.fn_acos_s]),
          (N.fn_asin, [BV.fn_asin_r, BV.fn_asin_s]),
          (N.fn_atan, [BV.fn_atan_r, BV.fn_atan_s]),
          (N.fn_clamp, [BV.clamp_trr, BV.clamp_ttt, BV.image_clamp]),
          (N.fn_cos, [BV.fn_cos_r, BV.fn_cos_s]),
          (N.fn_det, [BV.fn_det2_t, BV.fn_det3_t, BV.fn_det2_f, BV.fn_det3_f]),
          (N.fn_dist, [BV.dist2_t, BV.dist3_t]),
          (N.fn_evals, [BV.evals2x2, BV.evals3x3]),
          (N.fn_evecs, [BV.evecs2x2, BV.evecs3x3]),
          (N.fn_exp, [BV.fn_exp_r, BV.fn_exp_s]),
          (N.fn_lerp, [BV.lerp5, BV.lerp3]),
          (N.fn_max, [BV.fn_max_i, BV.fn_max_r, BV.red_max]),
          (N.fn_min, [BV.fn_min_i, BV.fn_min_r, BV.red_min]),
          (N.fn_normalize, [BV.fn_normalize_t, BV.fn_normalize_f]),
          (N.fn_sin, [BV.fn_sin_r, BV.fn_sin_s]),
          (N.fn_sphere, [BV.fn_sphere_im, BV.fn_sphere1_r, BV.fn_sphere2_t, BV.fn_sphere3_t]),
          (N.fn_sqrt, [BV.fn_sqrt_r, BV.fn_sqrt_s]),
          (N.fn_tan, [BV.fn_tan_r, BV.fn_tan_s]),
          (N.fn_trace, [BV.fn_trace_t, BV.fn_trace_f]),
          (N.fn_transpose, [BV.fn_transpose_t, BV.fn_transpose_f]),
        (* assignment operators are bound to the corresponding binary operator *)
          (N.asgn_add, [BV.add_ii, BV.add_tt, BV.add_ff, BV.add_ft]),
          (N.asgn_sub, [BV.sub_ii, BV.sub_tt, BV.sub_ff, BV.sub_ft]),
          (N.asgn_mul, [BV.mul_ii, BV.mul_rr, BV.mul_tr, BV.mul_fr]),
          (N.asgn_div, [BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr]),
          (N.asgn_mod, [BV.op_mod])
        ]

  (* seed the basis environment *)
    fun env () = let
          val gEnv = GEnv.new()
          fun insF x = GEnv.insertFunc(gEnv, Atom.atom(Var.nameOf x), GEnv.PrimFun[x])
          fun insV x = GEnv.insertVar(gEnv, Atom.atom(Var.nameOf x), x)
          fun insOvld (f, fns) = GEnv.insertFunc(gEnv, f, GEnv.PrimFun fns)
          in
            List.app insF basisFunctions;
            List.app insV basisVars;
            List.app insOvld overloads;
            gEnv
          end

  (* operations that are allowed in constant expressions; we basically allow any operations
   * on integers, booleans, or tensors.  Operations on fields, images, sequences, or kernels
   * are not allowed.
   *)
   local
      val allowed = List.foldl Var.Set.add' Var.Set.empty [
              BV.op_mod,
              BV.op_cross2_tt, BV.op_cross3_tt,
              BV.op_outer_tt,
              BV.op_norm_t,
              BV.op_not,
              BV.fn_abs_i, BV.fn_abs_r,
              BV.fn_max_i, BV.fn_max_r,
              BV.fn_min_i, BV.fn_min_r,
              BV.fn_modulate,
              BV.fn_normalize_t,
(* unimplemented
              BV.fn_principleEvec,
*)
              BV.fn_trace_t,
              BV.fn_transpose_t,
              BV.lte_ii, BV.lte_rr,
              BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr,
              BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr,
              BV.gte_ii, BV.gte_rr,
              BV.lt_ii, BV.lt_rr,
              BV.gt_ii, BV.gt_rr,
              BV.add_ii, BV.add_tt,
              BV.sub_ii, BV.sub_tt,
              BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr,
              BV.div_ii, BV.div_rr, BV.div_tr, BV.div_tr,
              BV.pow_ri, BV.pow_rr,
              BV.neg_i, BV.neg_t,
              BV.clamp_trr, BV.clamp_ttt,
              BV.lerp5, BV.lerp3,
              BV.fn_acos_r,
              BV.fn_asin_r,
              BV.fn_atan_r,
              BV.fn_atan2_rr,
              BV.fn_ceil_r,
              BV.fn_cos_r,
              BV.fn_erf_r,
              BV.fn_erfc_r,
              BV.fn_exp_r,
              BV.fn_floor_r,
              BV.fn_fmod_rr,
              BV.fn_log_r,
              BV.fn_log10_r,
              BV.fn_log2_r,
              BV.fn_sin_r,
              BV.fn_sqrt_r,
              BV.fn_tan_r
            ]
    in
    fun allowedInConstExp x = Var.Set.member (allowed, x)
    end (* local *)

  (* the reduction operators *)
    local
      val redOps = List.foldl Var.Set.add' Var.Set.empty [
              BV.red_all,
              BV.red_exists,
              BV.red_max,
              BV.red_mean,
              BV.red_min,
              BV.red_product,
              BV.red_sum,
              BV.red_variance
            ]
    in
    fun isReductionOp x = Var.Set.member (redOps, x)
    end (* local *)

  (* the sets of strands are only allowed in global initialization and update blocks *)
    local
      val strandSets = List.foldl Var.Set.add' Var.Set.empty [
              BV.set_active,
              BV.set_all,
              BV.set_stable
            ]
    in
    fun isStrandSet x = Var.Set.member (strandSets, x)
    end (* end local *)

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0