Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/basis/basis.sml
ViewVC logotype

Diff of /trunk/src/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 47, Tue Apr 13 14:57:27 2010 UTC revision 78, Mon May 24 22:31:49 2010 UTC
# Line 6  Line 6 
6   * Type definitions for Basis functions.   * Type definitions for Basis functions.
7   *)   *)
8    
9  structure Basis =  structure Basis : sig
   struct  
10    
11        val env : Env.env
12    
13      end = struct
14        local
15          structure N = BasisNames
16          structure Ty = Types
17          structure MV = MetaVar
18    
19          fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])
20          infix -->
21    
22          val N2 = Ty.DimConst 2
23          val N3 = Ty.DimConst 3
24    
25        (* short names for kinds *)
26          val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
27          val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar
28          val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
29          val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
30    
31          fun ty t = ([], t)
32          fun all (kinds, mkTy) = let
33                val tvs = List.map (fn mk => mk()) kinds
34                in
35                  (tvs, mkTy tvs)
36                end
37          fun allNK mkTy = let
38                val tv = MV.newDimVar()
39                in
40                  ([Ty.DIM tv], mkTy tv)
41                end
42    
43        fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
44        fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
45    
46        in
47    
48      (* overloaded operators *)
49        val overloads = [
50      (*
51        val op_add = Atom.atom "+"
52        val op_sub = Atom.atom "-"
53        val op_mul = Atom.atom "*"
54        val op_div = Atom.atom "/"
55        val op_lt = Atom.atom "<"
56        val op_lte = Atom.atom "<="
57        val op_eql = Atom.atom "=="
58        val op_neq = Atom.atom "!="
59        val op_gte = Atom.atom ">="
60        val op_gt = Atom.atom ">"
61    *)
62            ]
63    
64      (* non-overloaded operators, etc. *)
65        val basis = [
66            (* operators *)
67              (N.op_at, all([DK, NK, SK],
68                fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
69                    val k = Ty.DiffVar(k, 0)
70                    val d = Ty.DimVar d
71                    val dd = Ty.ShapeVar dd
72                    in
73                      [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
74                    end)),
75              (N.op_D, all([DK, NK, SK],
76                fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
77                    val k0 = Ty.DiffVar(k, 0)
78                    val km1 = Ty.DiffVar(k, ~1)
79                    val d = Ty.DimVar d
80                    val dd = Ty.ShapeVar dd
81                    in
82                      [field(k0, d, dd), tensor[d]]
83                        --> field(km1, d, Ty.ShapeExt(dd, d))
84                    end)),
85              (N.op_norm, all([SK],
86                fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),
87            (* functions *)
88              (N.fn_CL,     ty([tensor[N3, N3]] --> Ty.vec3Ty)),
89              (N.fn_convolve, all([DK, NK, SK],
90                fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
91                    val k = Ty.DiffVar(k, 0)
92                    val d = Ty.DimVar d
93                    val dd = Ty.ShapeVar dd
94                    in
95                      [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
96                        --> field(k, d, dd)
97                    end)),
98              (N.fn_cos,    ty([Ty.realTy] --> Ty.realTy)),
99              (N.fn_dot,    allNK(fn tv => [tensor[Ty.DimVar tv]]
100                              --> tensor[Ty.DimVar tv])),
101              (N.fn_inside, all([DK, NK, SK],
102                fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
103                    val k = Ty.DiffVar(k, 0)
104                    val d = Ty.DimVar d
105                    val dd = Ty.ShapeVar dd
106                    in
107                      [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
108                        --> Ty.T_Bool
109                    end)),
110              (N.fn_load,   all([NK, SK],
111                fn [Ty.DIM d, Ty.SHAPE dd] => let
112                    val d = Ty.DimVar d
113                    val dd = Ty.ShapeVar dd
114                    in
115                      [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
116                    end)),
117              (N.fn_pow,    ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),
118    (*
119        val fn_principleEvec = Atom.atom "principleEvec"
120    *)
121              (N.fn_sin,    ty([Ty.realTy] --> Ty.realTy)),
122            (* kernels *)
123              (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2))),
124              (N.kn_tent,   ty(Ty.T_Kernel(Ty.DiffConst 0)))
125            ]
126    
127        (* seed the basis environment *)
128          val env = let
129                fun ins ((name, ty), env) = let
130                      val x = Var.newPoly (name, AST.BasisVar, ty)
131                      in
132                        Env.insertGlobal (env, name, x)
133                      end
134                in
135                  List.foldl ins (Env.new()) basis
136                end
137    
138        end (* local *)
139    end    end

Legend:
Removed from v.47  
changed lines
  Added in v.78

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0