Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/basis/basis.sml
ViewVC logotype

Diff of /trunk/src/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 65, Thu May 13 21:04:35 2010 UTC revision 78, Mon May 24 22:31:49 2010 UTC
# Line 6  Line 6 
6   * Type definitions for Basis functions.   * Type definitions for Basis functions.
7   *)   *)
8    
9  structure Basis =  structure Basis : sig
10    struct  
11        val env : Env.env
12    
13      end = struct
14      local      local
15        structure N = BasisNames        structure N = BasisNames
16        structure Ty = Types        structure Ty = Types
17        structure TV = TypeVar        structure MV = MetaVar
18    
19        fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])        fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])
20        infix -->        infix -->
21    
22        val N2 = Ty.NatConst 2        val N2 = Ty.DimConst 2
23        val N3 = Ty.NatConst 3        val N3 = Ty.DimConst 3
24    
25      (* short names for kinds *)      (* short names for kinds *)
26        val NK = Ty.TK_NAT        val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
27        val SK = Ty.TK_SHAPE        val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar
28        val TK = Ty.TK_TYPE        val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
29          val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
30    
31        fun ty t = ([], t)        fun ty t = ([], t)
32        fun all (kinds, mkTy) = let        fun all (kinds, mkTy) = let
33              val tvs = List.map (fn k => TV.new k) kinds              val tvs = List.map (fn mk => mk()) kinds
34              in              in
35                (tvs, mkTy tvs)                (tvs, mkTy tvs)
36              end              end
37        fun allNK mkTy = let        fun allNK mkTy = let
38              val tv = TV.new NK              val tv = MV.newDimVar()
39              in              in
40                ([tv], mkTy tv)                ([Ty.DIM tv], mkTy tv)
41              end              end
42    
43      fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}      fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
# Line 41  Line 45 
45    
46      in      in
47    
48      val basis = [    (* overloaded operators *)
49          (* operators *)      val overloads = [
           (N.op_at, all([NK, NK, SK],  
             fn [k, d, dd] => let  
                 val k = Ty.NatVar k  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd  
                 end)),  
50  (*  (*
51      val op_add = Atom.atom "+"      val op_add = Atom.atom "+"
52      val op_sub = Atom.atom "-"      val op_sub = Atom.atom "-"
# Line 63  Line 59 
59      val op_gte = Atom.atom ">="      val op_gte = Atom.atom ">="
60      val op_gt = Atom.atom ">"      val op_gt = Atom.atom ">"
61  *)  *)
62            (N.op_at, all([NK, NK, SK],          ]
63              fn [k, d, dd] => let  
64                  val k = Ty.NatVar k    (* non-overloaded operators, etc. *)
65                  val d = Ty.NatVar d      val basis = [
66            (* operators *)
67              (N.op_at, all([DK, NK, SK],
68                fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
69                    val k = Ty.DiffVar(k, 0)
70                    val d = Ty.DimVar d
71                  val dd = Ty.ShapeVar dd                  val dd = Ty.ShapeVar dd
72                  in                  in
73                    [field(k, d, dd), tensor[d]]                    [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
74                      --> field(Ty.NatExp(k, ~1), d, Ty.ShapeExt(dd, d))                  end)),
75              (N.op_D, all([DK, NK, SK],
76                fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
77                    val k0 = Ty.DiffVar(k, 0)
78                    val km1 = Ty.DiffVar(k, ~1)
79                    val d = Ty.DimVar d
80                    val dd = Ty.ShapeVar dd
81                    in
82                      [field(k0, d, dd), tensor[d]]
83                        --> field(km1, d, Ty.ShapeExt(dd, d))
84                  end)),                  end)),
 (*  
     val op_orelse = Atom.atom "||"  
     val op_andalso = Atom.atom "&&"  
 *)  
85            (N.op_norm, all([SK],            (N.op_norm, all([SK],
86              fn [dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),              fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),
87          (* functions *)          (* functions *)
88            (N.fn_CL,     ty([tensor[N3, N3]] --> Ty.vec3Ty)),            (N.fn_CL,     ty([tensor[N3, N3]] --> Ty.vec3Ty)),
89            (N.fn_convolve, all([NK, NK, SK],            (N.fn_convolve, all([DK, NK, SK],
90              fn [k, d, dd] => let              fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
91                  val k = Ty.NatVar k                  val k = Ty.DiffVar(k, 0)
92                  val d = Ty.NatVar d                  val d = Ty.DimVar d
93                  val dd = Ty.ShapeVar dd                  val dd = Ty.ShapeVar dd
94                  in                  in
95                    [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]                    [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
96                      --> field(k, d, dd)                      --> field(k, d, dd)
97                  end)),                  end)),
98            (N.fn_cos,    ty([Ty.realTy] --> Ty.realTy)),            (N.fn_cos,    ty([Ty.realTy] --> Ty.realTy)),
99            (N.fn_dot,    allNK(fn tv => [tensor[Ty.NatVar tv]]            (N.fn_dot,    allNK(fn tv => [tensor[Ty.DimVar tv]]
100                            --> tensor[Ty.NatVar tv])),                            --> tensor[Ty.DimVar tv])),
101  (*            (N.fn_inside, all([DK, NK, SK],
102      val fn_inside = Atom.atom "inside"              fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
103  *)                  val k = Ty.DiffVar(k, 0)
104                    val d = Ty.DimVar d
105                    val dd = Ty.ShapeVar dd
106                    in
107                      [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
108                        --> Ty.T_Bool
109                    end)),
110            (N.fn_load,   all([NK, SK],            (N.fn_load,   all([NK, SK],
111              fn [d, dd] => let              fn [Ty.DIM d, Ty.SHAPE dd] => let
112                  val d = Ty.NatVar d                  val d = Ty.DimVar d
113                  val dd = Ty.ShapeVar dd                  val dd = Ty.ShapeVar dd
114                  in                  in
115                    [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}                    [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
# Line 108  Line 120 
120  *)  *)
121            (N.fn_sin,    ty([Ty.realTy] --> Ty.realTy)),            (N.fn_sin,    ty([Ty.realTy] --> Ty.realTy)),
122          (* kernels *)          (* kernels *)
123            (N.kn_bspln3, ty(Ty.T_Kernel(Ty.NatConst 2))),            (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2))),
124            (N.kn_tent,   ty(Ty.T_Kernel(Ty.NatConst 0)))            (N.kn_tent,   ty(Ty.T_Kernel(Ty.DiffConst 0)))
125          ]          ]
126    
127        (* seed the basis environment *)
128          val env = let
129                fun ins ((name, ty), env) = let
130                      val x = Var.newPoly (name, AST.BasisVar, ty)
131                      in
132                        Env.insertGlobal (env, name, x)
133                      end
134                in
135                  List.foldl ins (Env.new()) basis
136                end
137    
138      end (* local *)      end (* local *)
139    end    end

Legend:
Removed from v.65  
changed lines
  Added in v.78

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0