Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/basis/basis.sml
ViewVC logotype

Annotation of /trunk/src/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 75 - (view) (download)

1 : jhr 47 (* basis.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Type definitions for Basis functions.
7 :     *)
8 :    
9 :     structure Basis =
10 :     struct
11 : jhr 63 local
12 :     structure N = BasisNames
13 :     structure Ty = Types
14 : jhr 75 structure MV = MetaVar
15 : jhr 47
16 : jhr 63 fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])
17 :     infix -->
18 :    
19 : jhr 75 val N2 = Ty.DimConst 2
20 :     val N3 = Ty.DimConst 3
21 : jhr 63
22 :     (* short names for kinds *)
23 : jhr 75 val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
24 :     val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar
25 :     val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
26 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
27 : jhr 63
28 :     fun ty t = ([], t)
29 :     fun all (kinds, mkTy) = let
30 : jhr 75 val tvs = List.map (fn mk => mk()) kinds
31 : jhr 63 in
32 :     (tvs, mkTy tvs)
33 :     end
34 :     fun allNK mkTy = let
35 : jhr 75 val tv = MV.newDimVar()
36 : jhr 63 in
37 : jhr 75 ([Ty.DIM tv], mkTy tv)
38 : jhr 63 end
39 :    
40 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
41 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
42 :    
43 :     in
44 :    
45 : jhr 68 (* overloaded operators *)
46 :     val overloads = [
47 :     (*
48 : jhr 63 val op_add = Atom.atom "+"
49 :     val op_sub = Atom.atom "-"
50 :     val op_mul = Atom.atom "*"
51 :     val op_div = Atom.atom "/"
52 :     val op_lt = Atom.atom "<"
53 :     val op_lte = Atom.atom "<="
54 :     val op_eql = Atom.atom "=="
55 :     val op_neq = Atom.atom "!="
56 :     val op_gte = Atom.atom ">="
57 :     val op_gt = Atom.atom ">"
58 :     *)
59 : jhr 68 ]
60 :    
61 :     (* non-overloaded operators, etc. *)
62 :     val basis = [
63 :     (* operators *)
64 : jhr 75 (N.op_at, all([DK, NK, SK],
65 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
66 :     val k = Ty.DiffVar(k, 0)
67 :     val d = Ty.DimVar d
68 : jhr 63 val dd = Ty.ShapeVar dd
69 :     in
70 : jhr 68 [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
71 :     end)),
72 : jhr 75 (N.op_at, all([DK, NK, SK],
73 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
74 :     val k0 = Ty.DiffVar(k, 0)
75 :     val km1 = Ty.DiffVar(k, ~1)
76 :     val d = Ty.DimVar d
77 : jhr 68 val dd = Ty.ShapeVar dd
78 :     in
79 : jhr 75 [field(k0, d, dd), tensor[d]]
80 :     --> field(km1, d, Ty.ShapeExt(dd, d))
81 : jhr 63 end)),
82 :     (N.op_norm, all([SK],
83 : jhr 75 fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),
84 : jhr 63 (* functions *)
85 :     (N.fn_CL, ty([tensor[N3, N3]] --> Ty.vec3Ty)),
86 : jhr 75 (N.fn_convolve, all([DK, NK, SK],
87 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
88 :     val k = Ty.DiffVar(k, 0)
89 :     val d = Ty.DimVar d
90 : jhr 63 val dd = Ty.ShapeVar dd
91 :     in
92 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
93 :     --> field(k, d, dd)
94 :     end)),
95 :     (N.fn_cos, ty([Ty.realTy] --> Ty.realTy)),
96 : jhr 75 (N.fn_dot, allNK(fn tv => [tensor[Ty.DimVar tv]]
97 :     --> tensor[Ty.DimVar tv])),
98 :     (N.fn_inside, all([DK, NK, SK],
99 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
100 :     val k = Ty.DiffVar(k, 0)
101 :     val d = Ty.DimVar d
102 : jhr 68 val dd = Ty.ShapeVar dd
103 :     in
104 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
105 :     --> Ty.T_Bool
106 :     end)),
107 : jhr 63 (N.fn_load, all([NK, SK],
108 : jhr 75 fn [Ty.DIM d, Ty.SHAPE dd] => let
109 :     val d = Ty.DimVar d
110 : jhr 63 val dd = Ty.ShapeVar dd
111 :     in
112 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
113 :     end)),
114 :     (N.fn_pow, ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),
115 :     (*
116 :     val fn_principleEvec = Atom.atom "principleEvec"
117 :     *)
118 :     (N.fn_sin, ty([Ty.realTy] --> Ty.realTy)),
119 :     (* kernels *)
120 : jhr 75 (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2))),
121 :     (N.kn_tent, ty(Ty.T_Kernel(Ty.DiffConst 0)))
122 : jhr 63 ]
123 :    
124 :     end (* local *)
125 : jhr 47 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0