Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/basis/basis-vars.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3407, Wed Nov 11 18:53:18 2015 UTC revision 3511, Fri Dec 18 17:43:38 2015 UTC
# Line 217  Line 217 
217                [f, s] --> f                [f, s] --> f
218              end))              end))
219    
220    (* vector distance function *)    (* power; we distinguish between integer and real exponents to allow x^2 to be compiled
221      local     * as x*x.  The power operation of fields is restricted by the typechecker to constant
222        val vec2Ty = let     * integer arguments.
             val t = tensor[N2]  
             in  
               [t, t] --> Ty.realTy  
             end  
       val vec3Ty = let  
             val t = tensor[N3]  
             in  
               [t, t] --> Ty.realTy  
             end  
     in  
     val dist2_t  = monoVar (N.fn_dist, vec2Ty)  
     val dist3_t  = monoVar (N.fn_dist, vec3Ty)  
     end (* local *)  
   
   (* exponentiation; we distinguish between integer and real exponents to allow x^2 to be compiled  
    * as x*x.  
223     *)     *)
224      val exp_ri = monoVar(N.op_exp, [Ty.realTy, Ty.T_Int] --> Ty.realTy)      val pow_ri = monoVar(N.op_pow, [Ty.realTy, Ty.T_Int] --> Ty.realTy)
225      val exp_rr = monoVar(N.op_exp, [Ty.realTy, Ty.realTy] --> Ty.realTy)      val pow_rr = monoVar(N.op_pow, [Ty.realTy, Ty.realTy] --> Ty.realTy)
226        val pow_si = polyVar (N.op_pow, all([DK, NK], fn [Ty.DIFF k, Ty.DIM d] => let
227              val k = Ty.DiffVar(k, 0)
228              val d = Ty.DimVar d
229              val fld = field(k, d, Ty.Shape[])
230              in
231                [fld, Ty.T_Int] --> fld
232              end))
233    
234      val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let      val convolve_vk = polyVar (N.op_convolve, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
235              val k = Ty.DiffVar(k, 0)              val k = Ty.DiffVar(k, 0)
# Line 291  Line 282 
282      val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)      val neq_rr = monoVar(N.op_neq, [Ty.realTy, Ty.realTy] --> Ty.T_Bool)
283    
284      val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)      val neg_i = monoVar(N.op_neg, [Ty.T_Int] --> Ty.T_Int)
285      val neg_t = polyVar(N.op_neg, all([SK],      val neg_t = polyVar(N.op_neg, all([SK], fn [Ty.SHAPE dd] => let
           fn [Ty.SHAPE dd] => let  
286                val t = Ty.T_Tensor(Ty.ShapeVar dd)                val t = Ty.T_Tensor(Ty.ShapeVar dd)
287                in                in
288                  [t] --> t                  [t] --> t
289                end))                end))
290      val neg_f = polyVar(N.op_neg, all([DK, NK, SK],      val neg_f = polyVar(N.op_neg, all([DK, NK, SK], fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
           fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let  
291                val k = Ty.DiffVar(k, 0)                val k = Ty.DiffVar(k, 0)
292                val d = Ty.DimVar d                val d = Ty.DimVar d
293                val dd = Ty.ShapeVar dd                val dd = Ty.ShapeVar dd
# Line 314  Line 303 
303              [t, t, t] --> t              [t, t, t] --> t
304            end))            end))
305    
306      val lerp3 = polyVar(N.fn_lerp, all([SK],      val lerp3 = polyVar(N.fn_lerp, all([SK], fn [Ty.SHAPE dd] => let
           fn [Ty.SHAPE dd] => let  
307                val t = Ty.T_Tensor(Ty.ShapeVar dd)                val t = Ty.T_Tensor(Ty.ShapeVar dd)
308                in                in
309                  [t, t, Ty.realTy] --> t                  [t, t, Ty.realTy] --> t
310                end))                end))
311      val lerp5 = polyVar(N.fn_lerp, all([SK],      val lerp5 = polyVar(N.fn_lerp, all([SK], fn [Ty.SHAPE dd] => let
           fn [Ty.SHAPE dd] => let  
312                val t = Ty.T_Tensor(Ty.ShapeVar dd)                val t = Ty.T_Tensor(Ty.ShapeVar dd)
313                in                in
314                  [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t                  [t, t, Ty.realTy, Ty.realTy, Ty.realTy] --> t
# Line 401  Line 388 
388                [f1] --> f2                [f1] --> f2
389              end))              end))
390    
391      (* boolean operators; 'and' and 'or' are used to implement reductions *)
392        val op_and = monoVar (Atom.atom "$and", [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
393        val op_or = monoVar (Atom.atom "$or", [Ty.T_Bool, Ty.T_Bool] --> Ty.T_Bool)
394      val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)      val op_not = monoVar (N.op_not, [Ty.T_Bool] --> Ty.T_Bool)
395    
396    (* cross product *)    (* cross product *)
# Line 578  Line 568 
568      val fn_length = polyVar (N.fn_length, all([TK],      val fn_length = polyVar (N.fn_length, all([TK],
569              fn [Ty.TYPE tv] => [dynSeq(Ty.T_Var tv)] --> Ty.T_Int))              fn [Ty.TYPE tv] => [dynSeq(Ty.T_Var tv)] --> Ty.T_Int))
570    
571      val fn_max = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)      val fn_abs_i = monoVar (N.fn_abs, [Ty.T_Int] --> Ty.T_Int)
572      val fn_min = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)      val fn_abs_r = monoVar (N.fn_abs, [Ty.realTy] --> Ty.realTy)
573        val fn_max_i = monoVar (N.fn_max, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
574        val fn_max_r = monoVar (N.fn_max, [Ty.realTy, Ty.realTy] --> Ty.realTy)
575        val fn_min_i = monoVar (N.fn_min, [Ty.T_Int, Ty.T_Int] --> Ty.T_Int)
576        val fn_min_r = monoVar (N.fn_min, [Ty.realTy, Ty.realTy] --> Ty.realTy)
577    
578      val fn_modulate = polyVar (N.fn_modulate, all([NK],      val fn_modulate = polyVar (N.fn_modulate, all([NK],
579            fn [Ty.DIM d] => let            fn [Ty.DIM d] => let
# Line 588  Line 582 
582                [t, t] --> t                [t, t] --> t
583              end))              end))
584    
585      val fn_normalize_t = polyVar (N.fn_normalize, all([NK],      val fn_normalize_t = polyVar (N.fn_normalize, all([SK],
586            fn [Ty.DIM d] => let            fn [Ty.SHAPE dd] => let
587              val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])              val t = Ty.T_Tensor(Ty.ShapeVar dd)
588              in              in
589                [t] --> t                [t] --> t
590              end))              end))
# Line 617  Line 611 
611      end (* local *)      end (* local *)
612    
613      local      local
614          fun mkOuter2 [Ty.DIM d1, Ty.DIM d2, Ty.DIM d3] = let
615                val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1])
616                val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d2, Ty.DimVar d3])
617                val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2, Ty.DimVar d3])
618                in
619                  [vt1, vt2] --> mt
620                end
621        in
622          val op_outer_tm = polyVar (N.op_outer, all([NK, NK, NK], mkOuter2))
623        end (* local *)
624    
625        local
626          fun mkOuter3 [Ty.DIM d1, Ty.DIM d2, Ty.DIM d3] = let
627                val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1,Ty.DimVar d2])
628                val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d3])
629                val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2,Ty.DimVar d3])
630                in
631                  [vt1, vt2] --> mt
632                end
633        in
634          val op_outer_mt = polyVar (N.op_outer, all([NK, NK,NK], mkOuter3))
635        end (* local *)
636    
637        local
638        fun mkOuterField [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] = let        fun mkOuterField [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] = let
639              val k0 = Ty.DiffVar(k, 0)              val k0 = Ty.DiffVar(k, 0)
640              val d' = Ty.DimVar d              val d' = Ty.DimVar d
# Line 687  Line 705 
705                [f] --> s                [f] --> s
706              end))              end))
707    
708    (* sqrt *)    (* lifted unary math functions; these have both real and scalar-field forms *)
709      val fn_sqrt_t = monoVar (N.fn_sqrt, [Ty.realTy] --> Ty.realTy)      local
710      val fn_sqrt_f = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let        fun fn_r name = monoVar (name, [Ty.realTy] --> Ty.realTy)
711            val k' = Ty.DiffVar(k, 0)        fun fn_s name = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let
           val d' = Ty.DimVar d  
           val f = field(k', d', Ty.Shape[])  
           in  
             [f] --> f  
           end))  
   
   (* cosine *)  
     val fn_cos_t = monoVar (N.fn_cos, [Ty.realTy] --> Ty.realTy)  
     val fn_cos_f = polyVar (N.fn_cos, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
             val k' = Ty.DiffVar(k, 0)  
             val d' = Ty.DimVar d  
             val f = field(k', d', Ty.Shape[])  
             in  
               [f] --> f  
             end))  
   
   (* arc cosine *)  
     val fn_acos_t = monoVar (N.fn_acos, [Ty.realTy] --> Ty.realTy)  
     val fn_acos_f = polyVar (N.fn_acos, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
             val k' = Ty.DiffVar(k, 0)  
             val d' = Ty.DimVar d  
             val f = field(k', d', Ty.Shape[])  
             in  
               [f] --> f  
             end))  
   
   (* sine *)  
     val fn_sin_t = monoVar (N.fn_sin, [Ty.realTy] --> Ty.realTy)  
     val fn_sin_f = polyVar (N.fn_sin, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
712              val k' = Ty.DiffVar(k, 0)              val k' = Ty.DiffVar(k, 0)
713              val d' = Ty.DimVar d              val d' = Ty.DimVar d
714              val f = field(k', d', Ty.Shape[])              val f = field(k', d', Ty.Shape[])
715              in              in
716                [f] --> f                [f] --> f
717              end))              end))
   
   (* arc sine *)  
     val fn_asin_t = monoVar (N.fn_asin, [Ty.realTy] --> Ty.realTy)  
     val fn_asin_f = polyVar (N.fn_asin, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
             val k' = Ty.DiffVar(k, 0)  
             val d' = Ty.DimVar d  
             val f = field(k', d', Ty.Shape[])  
718              in              in
719                [f] --> f      val fn_sqrt_r = fn_r N.fn_sqrt
720              end))      val fn_sqrt_s = fn_s N.fn_sqrt
721        val fn_cos_r  = fn_r N.fn_cos
722        val fn_cos_s  = fn_s N.fn_cos
723        val fn_acos_r = fn_r N.fn_acos
724        val fn_acos_s = fn_s N.fn_acos
725        val fn_sin_r  = fn_r N.fn_sin
726        val fn_sin_s  = fn_s N.fn_sin
727        val fn_asin_r = fn_r N.fn_asin
728        val fn_asin_s = fn_s N.fn_asin
729        val fn_tan_r  = fn_r N.fn_tan
730        val fn_tan_s  = fn_s N.fn_tan
731        val fn_atan_r = fn_r N.fn_atan
732        val fn_atan_s = fn_s N.fn_atan
733        val fn_exp_r  = fn_r N.fn_exp
734        val fn_exp_s  = fn_s N.fn_exp
735        end (* local *)
736    
737    (* Math functions that have not yet been lifted to work on fields *)    (* Math functions that have not yet been lifted to work on fields *)
738      local      local
739        fun mk (name, n) =        fun mk (name, n) =
740              monoVar(Atom.atom name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)              monoVar(name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)
741      in      in
742      val fn_atan_t = mk("atan", 1)      val fn_atan2_rr = mk(N.fn_atan2, 2)
743      val fn_atan2_tt = mk("atan2", 2)      val fn_ceil_r   = mk(N.fn_ceil, 1)
744      val fn_ceil_t = mk("ceil", 1)      val fn_floor_r  = mk(N.fn_floor, 1)
745      val fn_floor_t = mk("floor", 1)      val fn_fmod_rr  = mk(N.fn_fmod, 2)
746      val fn_fmod_tt = mk("fmod", 2)      val fn_erf_r    = mk(N.fn_erf, 1)
747      val fn_exp_t = mk("exp", 1)      val fn_erfc_r   = mk(N.fn_erfc, 1)
748      val fn_erf_t = mk("erf", 1)      val fn_log_r    = mk(N.fn_log, 1)
749      val fn_erfc_t = mk("erfc", 1)      val fn_log10_r  = mk(N.fn_log10, 1)
750      val fn_log_t = mk("log", 1)      val fn_log2_r   = mk(N.fn_log2, 1)
751      val fn_log10_t = mk("log10", 1)      val fn_pow_rr   = mk(N.fn_pow, 2)
     val fn_log2_t = mk("log2", 1)  
     val fn_pow_tt = mk("pow", 2)  (* also used to implement ^ operator *)  
     val fn_tan_t = mk("tan", 1)  
752      end (* local *)      end (* local *)
753    
754    (* Query functions *)    (* Query functions *)
# Line 778  Line 772 
772      val fn_sphere3_t = polyVar (N.fn_sphere, all([TK], vec3Ty))      val fn_sphere3_t = polyVar (N.fn_sphere, all([TK], vec3Ty))
773      end (* local *)      end (* local *)
774    
775      (* vector distance function *)
776        local
777          val vec2Ty = let
778                val t = tensor[N2]
779                in
780                  [t, t] --> Ty.realTy
781                end
782          val vec3Ty = let
783                val t = tensor[N3]
784                in
785                  [t, t] --> Ty.realTy
786                end
787        in
788        val dist2_t  = monoVar (N.fn_dist, vec2Ty)
789        val dist3_t  = monoVar (N.fn_dist, vec3Ty)
790        end (* local *)
791    
792      (* Sets of strands *)
793        local
794          fun mkSetFn name = polyVar (name, all([TK], fn [Ty.TYPE tv] => [] --> dynSeq(Ty.T_Var tv)))
795        in
796        val set_active = mkSetFn N.set_active
797        val set_all    = mkSetFn N.set_all
798        val set_stable = mkSetFn N.set_stable
799        end
800    
801    (* reduction operators *)    (* reduction operators *)
802      local      local
803        fun reduction (name, elemTy) =        fun reduction (name, elemTy) =
# Line 785  Line 805 
805      in      in
806      val red_all         = reduction (N.fn_all, Ty.T_Bool)      val red_all         = reduction (N.fn_all, Ty.T_Bool)
807      val red_exists      = reduction (N.fn_exists, Ty.T_Bool)      val red_exists      = reduction (N.fn_exists, Ty.T_Bool)
808    (* FIXME: allow max on integers *)
809      val red_max         = reduction (N.fn_max, Ty.realTy)      val red_max         = reduction (N.fn_max, Ty.realTy)
810      val red_mean        = reduction (N.fn_mean, Ty.realTy)      val red_mean        = reduction (N.fn_mean, Ty.realTy)
811    (* FIXME: allow min on integers *)
812      val red_min         = reduction (N.fn_min, Ty.realTy)      val red_min         = reduction (N.fn_min, Ty.realTy)
813      val red_product     = reduction (N.fn_product, Ty.realTy)      val red_product     = reduction (N.fn_product, Ty.realTy)
814    (* FIXME: allow sum on int and tensor types *)
815      val red_sum         = reduction (N.fn_sum, Ty.realTy)      val red_sum         = reduction (N.fn_sum, Ty.realTy)
816      val red_variance    = reduction (N.fn_variance, Ty.realTy)      val red_variance    = reduction (N.fn_variance, Ty.realTy)
817      end (* local *)      end (* local *)

Legend:
Removed from v.3407  
changed lines
  Added in v.3511

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0