Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/basis/basis-vars.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/basis/basis-vars.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3477, Thu Dec 3 15:08:09 2015 UTC revision 3478, Thu Dec 3 17:26:24 2015 UTC
# Line 591  Line 591 
591                [t, t] --> t                [t, t] --> t
592              end))              end))
593    
594      val fn_normalize_t = polyVar (N.fn_normalize, all([NK],      val fn_normalize_t = polyVar (N.fn_normalize, all([SK],
595            fn [Ty.DIM d] => let            fn [Ty.SHAPE dd] => let
596              val t = Ty.T_Tensor(Ty.Shape[Ty.DimVar d])              val t = Ty.T_Tensor(Ty.ShapeVar dd)
597              in              in
598                [t] --> t                [t] --> t
599              end))              end))
# Line 620  Line 620 
620      end (* local *)      end (* local *)
621    
622      local      local
623          fun mkOuter2 [Ty.DIM d1, Ty.DIM d2, Ty.DIM d3] = let
624                val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1])
625                val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d2, Ty.DimVar d3])
626                val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2, Ty.DimVar d3])
627                in
628                  [vt1, vt2] --> mt
629                end
630        in
631          val op_outer_tm = polyVar (N.op_outer, all([NK, NK, NK], mkOuter2))
632        end (* local *)
633    
634        local
635          fun mkOuter3 [Ty.DIM d1, Ty.DIM d2, Ty.DIM d3] = let
636                val vt1 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1,Ty.DimVar d2])
637                val vt2 = Ty.T_Tensor(Ty.Shape[Ty.DimVar d3])
638                val mt = Ty.T_Tensor(Ty.Shape[Ty.DimVar d1, Ty.DimVar d2,Ty.DimVar d3])
639                in
640                  [vt1, vt2] --> mt
641                end
642        in
643          val op_outer_mt = polyVar (N.op_outer, all([NK, NK,NK], mkOuter3))
644        end (* local *)
645    
646        local
647        fun mkOuterField [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] = let        fun mkOuterField [Ty.DIFF k, Ty.DIM d,Ty.DIM a, Ty.DIM b] = let
648              val k0 = Ty.DiffVar(k, 0)              val k0 = Ty.DiffVar(k, 0)
649              val d' = Ty.DimVar d              val d' = Ty.DimVar d
# Line 690  Line 714 
714                [f] --> s                [f] --> s
715              end))              end))
716    
717    (* sqrt *)    (* lifted unary math functions *)
718      val fn_sqrt_t = monoVar (N.fn_sqrt, [Ty.realTy] --> Ty.realTy)      local
719      val fn_sqrt_f = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let        fun fn_t name = monoVar (name, [Ty.realTy] --> Ty.realTy)
720            val k' = Ty.DiffVar(k, 0)        fun fn_f name = polyVar (N.fn_sqrt, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let
           val d' = Ty.DimVar d  
           val f = field(k', d', Ty.Shape[])  
           in  
             [f] --> f  
           end))  
   
   (* cosine *)  
     val fn_cos_t = monoVar (N.fn_cos, [Ty.realTy] --> Ty.realTy)  
     val fn_cos_f = polyVar (N.fn_cos, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
             val k' = Ty.DiffVar(k, 0)  
             val d' = Ty.DimVar d  
             val f = field(k', d', Ty.Shape[])  
             in  
               [f] --> f  
             end))  
   
   (* arc cosine *)  
     val fn_acos_t = monoVar (N.fn_acos, [Ty.realTy] --> Ty.realTy)  
     val fn_acos_f = polyVar (N.fn_acos, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
             val k' = Ty.DiffVar(k, 0)  
             val d' = Ty.DimVar d  
             val f = field(k', d', Ty.Shape[])  
             in  
               [f] --> f  
             end))  
   
   (* sine *)  
     val fn_sin_t = monoVar (N.fn_sin, [Ty.realTy] --> Ty.realTy)  
     val fn_sin_f = polyVar (N.fn_sin, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
721              val k' = Ty.DiffVar(k, 0)              val k' = Ty.DiffVar(k, 0)
722              val d' = Ty.DimVar d              val d' = Ty.DimVar d
723              val f = field(k', d', Ty.Shape[])              val f = field(k', d', Ty.Shape[])
724              in              in
725                [f] --> f                [f] --> f
726              end))              end))
   
   (* arc sine *)  
     val fn_asin_t = monoVar (N.fn_asin, [Ty.realTy] --> Ty.realTy)  
     val fn_asin_f = polyVar (N.fn_asin, all([DK,NK], fn [Ty.DIFF k, Ty.DIM d] => let  
             val k' = Ty.DiffVar(k, 0)  
             val d' = Ty.DimVar d  
             val f = field(k', d', Ty.Shape[])  
727              in              in
728                [f] --> f      val fn_sqrt_t = fn_t N.fn_sqrt
729              end))      val fn_sqrt_f = fn_f N.fn_sqrt
730        val fn_cos_t = fn_t N.fn_cos
731        val fn_cos_f = fn_f N.fn_cos
732        val fn_acos_t = fn_t N.fn_acos
733        val fn_acos_f = fn_f N.fn_acos
734        val fn_sin_t = fn_t N.fn_sin
735        val fn_sin_f = fn_f N.fn_sin
736        val fn_asin_t = fn_t N.fn_asin
737        val fn_asin_f = fn_f N.fn_asin
738        val fn_tan_t = fn_t N.fn_tan
739        val fn_tan_f = fn_f N.fn_tan
740        val fn_atan_t = fn_t N.fn_atan
741        val fn_atan_f = fn_f N.fn_atan
742        end (* local *)
743    
744    (* Math functions that have not yet been lifted to work on fields *)    (* Math functions that have not yet been lifted to work on fields *)
745      local      local
746        fun mk (name, n) =        fun mk (name, n) =
747              monoVar(Atom.atom name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)              monoVar(Atom.atom name, List.tabulate(n, fn _ => Ty.realTy) --> Ty.realTy)
748      in      in
     val fn_atan_t = mk("atan", 1)  
749      val fn_atan2_tt = mk("atan2", 2)      val fn_atan2_tt = mk("atan2", 2)
750      val fn_ceil_t = mk("ceil", 1)      val fn_ceil_t = mk("ceil", 1)
751      val fn_floor_t = mk("floor", 1)      val fn_floor_t = mk("floor", 1)
# Line 757  Line 757 
757      val fn_log10_t = mk("log10", 1)      val fn_log10_t = mk("log10", 1)
758      val fn_log2_t = mk("log2", 1)      val fn_log2_t = mk("log2", 1)
759      val fn_pow_tt = mk("pow", 2)  (* also used to implement ^ operator *)      val fn_pow_tt = mk("pow", 2)  (* also used to implement ^ operator *)
     val fn_tan_t = mk("tan", 1)  
760      end (* local *)      end (* local *)
761    
762    (* Query functions *)    (* Query functions *)

Legend:
Removed from v.3477  
changed lines
  Added in v.3478

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0