SCM Repository
Annotation of /trunk/src/basis/basis.sml
Parent Directory
|
Revision Log
Revision 75 - (view) (download)
1 : | jhr | 47 | (* basis.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | * Type definitions for Basis functions. | ||
7 : | *) | ||
8 : | |||
9 : | structure Basis = | ||
10 : | struct | ||
11 : | jhr | 63 | local |
12 : | structure N = BasisNames | ||
13 : | structure Ty = Types | ||
14 : | jhr | 75 | structure MV = MetaVar |
15 : | jhr | 47 | |
16 : | jhr | 63 | fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty]) |
17 : | infix --> | ||
18 : | |||
19 : | jhr | 75 | val N2 = Ty.DimConst 2 |
20 : | val N3 = Ty.DimConst 3 | ||
21 : | jhr | 63 | |
22 : | (* short names for kinds *) | ||
23 : | jhr | 75 | val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar |
24 : | val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar | ||
25 : | val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar | ||
26 : | val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar | ||
27 : | jhr | 63 | |
28 : | fun ty t = ([], t) | ||
29 : | fun all (kinds, mkTy) = let | ||
30 : | jhr | 75 | val tvs = List.map (fn mk => mk()) kinds |
31 : | jhr | 63 | in |
32 : | (tvs, mkTy tvs) | ||
33 : | end | ||
34 : | fun allNK mkTy = let | ||
35 : | jhr | 75 | val tv = MV.newDimVar() |
36 : | jhr | 63 | in |
37 : | jhr | 75 | ([Ty.DIM tv], mkTy tv) |
38 : | jhr | 63 | end |
39 : | |||
40 : | fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd} | ||
41 : | fun tensor ds = Ty.T_Tensor(Ty.Shape ds) | ||
42 : | |||
43 : | in | ||
44 : | |||
45 : | jhr | 68 | (* overloaded operators *) |
46 : | val overloads = [ | ||
47 : | (* | ||
48 : | jhr | 63 | val op_add = Atom.atom "+" |
49 : | val op_sub = Atom.atom "-" | ||
50 : | val op_mul = Atom.atom "*" | ||
51 : | val op_div = Atom.atom "/" | ||
52 : | val op_lt = Atom.atom "<" | ||
53 : | val op_lte = Atom.atom "<=" | ||
54 : | val op_eql = Atom.atom "==" | ||
55 : | val op_neq = Atom.atom "!=" | ||
56 : | val op_gte = Atom.atom ">=" | ||
57 : | val op_gt = Atom.atom ">" | ||
58 : | *) | ||
59 : | jhr | 68 | ] |
60 : | |||
61 : | (* non-overloaded operators, etc. *) | ||
62 : | val basis = [ | ||
63 : | (* operators *) | ||
64 : | jhr | 75 | (N.op_at, all([DK, NK, SK], |
65 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
66 : | val k = Ty.DiffVar(k, 0) | ||
67 : | val d = Ty.DimVar d | ||
68 : | jhr | 63 | val dd = Ty.ShapeVar dd |
69 : | in | ||
70 : | jhr | 68 | [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd |
71 : | end)), | ||
72 : | jhr | 75 | (N.op_at, all([DK, NK, SK], |
73 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
74 : | val k0 = Ty.DiffVar(k, 0) | ||
75 : | val km1 = Ty.DiffVar(k, ~1) | ||
76 : | val d = Ty.DimVar d | ||
77 : | jhr | 68 | val dd = Ty.ShapeVar dd |
78 : | in | ||
79 : | jhr | 75 | [field(k0, d, dd), tensor[d]] |
80 : | --> field(km1, d, Ty.ShapeExt(dd, d)) | ||
81 : | jhr | 63 | end)), |
82 : | (N.op_norm, all([SK], | ||
83 : | jhr | 75 | fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)), |
84 : | jhr | 63 | (* functions *) |
85 : | (N.fn_CL, ty([tensor[N3, N3]] --> Ty.vec3Ty)), | ||
86 : | jhr | 75 | (N.fn_convolve, all([DK, NK, SK], |
87 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
88 : | val k = Ty.DiffVar(k, 0) | ||
89 : | val d = Ty.DimVar d | ||
90 : | jhr | 63 | val dd = Ty.ShapeVar dd |
91 : | in | ||
92 : | [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}] | ||
93 : | --> field(k, d, dd) | ||
94 : | end)), | ||
95 : | (N.fn_cos, ty([Ty.realTy] --> Ty.realTy)), | ||
96 : | jhr | 75 | (N.fn_dot, allNK(fn tv => [tensor[Ty.DimVar tv]] |
97 : | --> tensor[Ty.DimVar tv])), | ||
98 : | (N.fn_inside, all([DK, NK, SK], | ||
99 : | fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let | ||
100 : | val k = Ty.DiffVar(k, 0) | ||
101 : | val d = Ty.DimVar d | ||
102 : | jhr | 68 | val dd = Ty.ShapeVar dd |
103 : | in | ||
104 : | [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)] | ||
105 : | --> Ty.T_Bool | ||
106 : | end)), | ||
107 : | jhr | 63 | (N.fn_load, all([NK, SK], |
108 : | jhr | 75 | fn [Ty.DIM d, Ty.SHAPE dd] => let |
109 : | val d = Ty.DimVar d | ||
110 : | jhr | 63 | val dd = Ty.ShapeVar dd |
111 : | in | ||
112 : | [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd} | ||
113 : | end)), | ||
114 : | (N.fn_pow, ty([Ty.realTy, Ty.realTy] --> Ty.realTy)), | ||
115 : | (* | ||
116 : | val fn_principleEvec = Atom.atom "principleEvec" | ||
117 : | *) | ||
118 : | (N.fn_sin, ty([Ty.realTy] --> Ty.realTy)), | ||
119 : | (* kernels *) | ||
120 : | jhr | 75 | (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2))), |
121 : | (N.kn_tent, ty(Ty.T_Kernel(Ty.DiffConst 0))) | ||
122 : | jhr | 63 | ] |
123 : | |||
124 : | end (* local *) | ||
125 : | jhr | 47 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |