Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/basis/basis-vars.sml
 [diderot] / branches / vis15 / src / compiler / basis / basis-vars.sml

# Diff of /branches/vis15/src/compiler/basis/basis-vars.sml

revision 3481, Fri Dec 4 21:59:49 2015 UTC revision 3482, Sat Dec 5 14:43:53 2015 UTC
# Line 234  Line 234
234      val dist3_t  = monoVar (N.fn_dist, vec3Ty)      val dist3_t  = monoVar (N.fn_dist, vec3Ty)
235      end (* local *)      end (* local *)
236
237    (* exponentiation; we distinguish between integer and real exponents to allow x^2 to be compiled    (* power; we distinguish between integer and real exponents to allow x^2 to be compiled
238     * as x*x.     * as x*x.  The power operation of fields is restricted by the typechecker to constant
239       * integer arguments.
240     *)     *)
241      val exp_ri = monoVar(N.op_exp, [Ty.realTy, Ty.T_Int] --> Ty.realTy)      val pow_ri = monoVar(N.op_pow, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
242      val exp_rr = monoVar(N.op_exp, [Ty.realTy, Ty.realTy] --> Ty.realTy)      val pow_rr = monoVar(N.op_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy)
243        val pow_si = polyVar (N.op_pow, all([DK, NK], fn [Ty.DIFF k, Ty.DIM d] => let
244              val k = Ty.DiffVar(k, 0)
245              val d = Ty.DimVar d
246              val fld = field(k, d, Ty.Shape[])
247              in
248                [fld, Ty.T_Int] --> fld
249              end))
250
251      val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let      val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
252              val k = Ty.DiffVar(k, 0)              val k = Ty.DiffVar(k, 0)
# Line 291  Line 299
299      val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)      val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
300
301      val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)      val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)
302      val neg_t = polyVar(N.op_neg, all([SK],      val neg_t = polyVar(N.op_neg, all([SK], fn [Ty.SHAPE dd] => let
fn [Ty.SHAPE dd] => let
303                val t = Ty.T_Tensor(Ty.ShapeVar dd)                val t = Ty.T_Tensor(Ty.ShapeVar dd)
304                in                in
305                  [t] --> t                  [t] --> t
306                end))                end))
307      val neg_f = polyVar(N.op_neg, all([DK, NK, SK],      val neg_f = polyVar(N.op_neg, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
308                val k = Ty.DiffVar(k, 0)                val k = Ty.DiffVar(k, 0)
309                val d = Ty.DimVar d                val d = Ty.DimVar d
310                val dd = Ty.ShapeVar dd                val dd = Ty.ShapeVar dd
# Line 314  Line 320
320              [t, t, t] --> t              [t, t, t] --> t
321            end))            end))
322
323      val lerp3 = polyVar(N.fn_lerp, all([SK],      val lerp3 = polyVar(N.fn_lerp, all([SK], fn [Ty.SHAPE dd] => let
fn [Ty.SHAPE dd] => let
324                val t = Ty.T_Tensor(Ty.ShapeVar dd)                val t = Ty.T_Tensor(Ty.ShapeVar dd)
325                in                in
326                  [t, t, Ty.realTy] --> t                  [t, t, Ty.realTy] --> t
327                end))                end))
328      val lerp5 = polyVar(N.fn_lerp, all([SK],      val lerp5 = polyVar(N.fn_lerp, all([SK], fn [Ty.SHAPE dd] => let
fn [Ty.SHAPE dd] => let
329                val t = Ty.T_Tensor(Ty.ShapeVar dd)                val t = Ty.T_Tensor(Ty.ShapeVar dd)
330                in                in
331                  [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t                  [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
# Line 581  Line 585
585      val fn_length = polyVar (N.fn_length, all([TK],      val fn_length = polyVar (N.fn_length, all([TK],
586              fn [Ty.TYPE tv] => [dynSeq(Ty.T_Var tv)] --> Ty.T_Int))              fn [Ty.TYPE tv] => [dynSeq(Ty.T_Var tv)] --> Ty.T_Int))
587
588      val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)      val fn_abs_i = monoVar (N.fn_abs, [Ty.T_Int] --> Ty.T_Int)
589      val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)      val fn_abs_r = monoVar (N.fn_abs, [Ty.realTy] --> Ty.realTy)
590        val fn_max_i = monoVar (N.fn_max, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
591        val fn_max_r = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
592        val fn_min_i = monoVar (N.fn_min, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
593        val fn_min_r = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
594
595      val fn_modulate = polyVar (N.fn_modulate, all([NK],      val fn_modulate = polyVar (N.fn_modulate, all([NK],
596            fn [Ty.DIM d] => let            fn [Ty.DIM d] => let
# Line 714  Line 722
722                [f] --> s                [f] --> s
723              end))              end))
724
725    (* lifted unary math functions *)    (* lifted unary math functions; these have both real and scalar-field forms *)
726      local      local
727        fun fn_t name = monoVar (name, [Ty.realTy] --> Ty.realTy)        fun fn_r name = monoVar (name, [Ty.realTy] --> Ty.realTy)
728        fun fn_f name = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let        fun fn_s name = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let
729              val k' = Ty.DiffVar(k, 0)              val k' = Ty.DiffVar(k, 0)
730              val d' = Ty.DimVar d              val d' = Ty.DimVar d
731              val f = field(k', d', Ty.Shape[])              val f = field(k', d', Ty.Shape[])
# Line 725  Line 733
733                [f] --> f                [f] --> f
734              end))              end))
735      in      in
736      val fn_sqrt_t = fn_t N.fn_sqrt      val fn_sqrt_r = fn_r N.fn_sqrt
737      val fn_sqrt_f = fn_f N.fn_sqrt      val fn_sqrt_s = fn_s N.fn_sqrt
738      val fn_cos_t = fn_t N.fn_cos      val fn_cos_r  = fn_r N.fn_cos
739      val fn_cos_f = fn_f N.fn_cos      val fn_cos_s  = fn_s N.fn_cos
740      val fn_acos_t = fn_t N.fn_acos      val fn_acos_r = fn_r N.fn_acos
741      val fn_acos_f = fn_f N.fn_acos      val fn_acos_s = fn_s N.fn_acos
742      val fn_sin_t = fn_t N.fn_sin      val fn_sin_r  = fn_r N.fn_sin
743      val fn_sin_f = fn_f N.fn_sin      val fn_sin_s  = fn_s N.fn_sin
744      val fn_asin_t = fn_t N.fn_asin      val fn_asin_r = fn_r N.fn_asin
745      val fn_asin_f = fn_f N.fn_asin      val fn_asin_s = fn_s N.fn_asin
746      val fn_tan_t = fn_t N.fn_tan      val fn_tan_r  = fn_r N.fn_tan
747      val fn_tan_f = fn_f N.fn_tan      val fn_tan_s  = fn_s N.fn_tan
748      val fn_atan_t = fn_t N.fn_atan      val fn_atan_r = fn_r N.fn_atan
749      val fn_atan_f = fn_f N.fn_atan      val fn_atan_s = fn_s N.fn_atan
750        val fn_exp_r  = fn_r N.fn_exp
751        val fn_exp_s  = fn_s N.fn_exp
752      end (* local *)      end (* local *)
753
754    (* Math functions that have not yet been lifted to work on fields *)    (* Math functions that have not yet been lifted to work on fields *)
755      local      local
756        fun mk (name, n) =        fun mk (name, n) =
757              monoVar(Atom.atom name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)              monoVar(name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)
758      in      in
759      val fn_atan2_tt = mk("atan2", 2)      val fn_atan2_rr = mk(N.fn_atan2, 2)
760      val fn_ceil_t = mk("ceil", 1)      val fn_ceil_r   = mk(N.fn_ceil, 1)
761      val fn_floor_t = mk("floor", 1)      val fn_floor_r  = mk(N.fn_floor, 1)
762      val fn_fmod_tt = mk("fmod", 2)      val fn_fmod_rr  = mk(N.fn_fmod, 2)
763      val fn_exp_t = mk("exp", 1)      val fn_erf_r    = mk(N.fn_erf, 1)
764      val fn_erf_t = mk("erf", 1)      val fn_erfc_r   = mk(N.fn_erfc, 1)
765      val fn_erfc_t = mk("erfc", 1)      val fn_log_r    = mk(N.fn_log, 1)
766      val fn_log_t = mk("log", 1)      val fn_log10_r  = mk(N.fn_log10, 1)
767      val fn_log10_t = mk("log10", 1)      val fn_log2_r   = mk(N.fn_log2, 1)
768      val fn_log2_t = mk("log2", 1)      val fn_pow_rr   = mk(N.op_pow, 2)  (* also used to implement ^ operator *)
val fn_pow_tt = mk("pow", 2)  (* also used to implement ^ operator *)
769      end (* local *)      end (* local *)
770
771    (* Query functions *)    (* Query functions *)
# Line 796  Line 805
805      in      in
806      val red_all         = reduction (N.fn_all, Ty.T_Bool)      val red_all         = reduction (N.fn_all, Ty.T_Bool)
807      val red_exists      = reduction (N.fn_exists, Ty.T_Bool)      val red_exists      = reduction (N.fn_exists, Ty.T_Bool)
808    (* FIXME: allow max on integers *)
809      val red_max         = reduction (N.fn_max, Ty.realTy)      val red_max         = reduction (N.fn_max, Ty.realTy)
810      val red_mean        = reduction (N.fn_mean, Ty.realTy)      val red_mean        = reduction (N.fn_mean, Ty.realTy)
811    (* FIXME: allow min on integers *)
812      val red_min         = reduction (N.fn_min, Ty.realTy)      val red_min         = reduction (N.fn_min, Ty.realTy)
813      val red_product     = reduction (N.fn_product, Ty.realTy)      val red_product     = reduction (N.fn_product, Ty.realTy)
814  (* FIXME: allow sum on int and tensor types *)  (* FIXME: allow sum on int and tensor types *)

Legend:
 Removed from v.3481 changed lines Added in v.3482