Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/basis/basis.sml
ViewVC logotype

Annotation of /trunk/src/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 78 - (view) (download)

1 : jhr 47 (* basis.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Type definitions for Basis functions.
7 :     *)
8 :    
9 : jhr 78 structure Basis : sig
10 :    
11 :     val env : Env.env
12 :    
13 :     end = struct
14 : jhr 63 local
15 :     structure N = BasisNames
16 :     structure Ty = Types
17 : jhr 75 structure MV = MetaVar
18 : jhr 47
19 : jhr 63 fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])
20 :     infix -->
21 :    
22 : jhr 75 val N2 = Ty.DimConst 2
23 :     val N3 = Ty.DimConst 3
24 : jhr 63
25 :     (* short names for kinds *)
26 : jhr 75 val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
27 :     val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar
28 :     val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
29 :     val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar
30 : jhr 63
31 :     fun ty t = ([], t)
32 :     fun all (kinds, mkTy) = let
33 : jhr 75 val tvs = List.map (fn mk => mk()) kinds
34 : jhr 63 in
35 :     (tvs, mkTy tvs)
36 :     end
37 :     fun allNK mkTy = let
38 : jhr 75 val tv = MV.newDimVar()
39 : jhr 63 in
40 : jhr 75 ([Ty.DIM tv], mkTy tv)
41 : jhr 63 end
42 :    
43 :     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
44 :     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)
45 :    
46 :     in
47 :    
48 : jhr 68 (* overloaded operators *)
49 :     val overloads = [
50 :     (*
51 : jhr 63 val op_add = Atom.atom "+"
52 :     val op_sub = Atom.atom "-"
53 :     val op_mul = Atom.atom "*"
54 :     val op_div = Atom.atom "/"
55 :     val op_lt = Atom.atom "<"
56 :     val op_lte = Atom.atom "<="
57 :     val op_eql = Atom.atom "=="
58 :     val op_neq = Atom.atom "!="
59 :     val op_gte = Atom.atom ">="
60 :     val op_gt = Atom.atom ">"
61 :     *)
62 : jhr 68 ]
63 :    
64 :     (* non-overloaded operators, etc. *)
65 :     val basis = [
66 :     (* operators *)
67 : jhr 75 (N.op_at, all([DK, NK, SK],
68 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
69 :     val k = Ty.DiffVar(k, 0)
70 :     val d = Ty.DimVar d
71 : jhr 63 val dd = Ty.ShapeVar dd
72 :     in
73 : jhr 68 [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
74 :     end)),
75 : jhr 78 (N.op_D, all([DK, NK, SK],
76 : jhr 75 fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
77 :     val k0 = Ty.DiffVar(k, 0)
78 :     val km1 = Ty.DiffVar(k, ~1)
79 :     val d = Ty.DimVar d
80 : jhr 68 val dd = Ty.ShapeVar dd
81 :     in
82 : jhr 75 [field(k0, d, dd), tensor[d]]
83 :     --> field(km1, d, Ty.ShapeExt(dd, d))
84 : jhr 63 end)),
85 :     (N.op_norm, all([SK],
86 : jhr 75 fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),
87 : jhr 63 (* functions *)
88 :     (N.fn_CL, ty([tensor[N3, N3]] --> Ty.vec3Ty)),
89 : jhr 75 (N.fn_convolve, all([DK, NK, SK],
90 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
91 :     val k = Ty.DiffVar(k, 0)
92 :     val d = Ty.DimVar d
93 : jhr 63 val dd = Ty.ShapeVar dd
94 :     in
95 :     [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
96 :     --> field(k, d, dd)
97 :     end)),
98 :     (N.fn_cos, ty([Ty.realTy] --> Ty.realTy)),
99 : jhr 75 (N.fn_dot, allNK(fn tv => [tensor[Ty.DimVar tv]]
100 :     --> tensor[Ty.DimVar tv])),
101 :     (N.fn_inside, all([DK, NK, SK],
102 :     fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
103 :     val k = Ty.DiffVar(k, 0)
104 :     val d = Ty.DimVar d
105 : jhr 68 val dd = Ty.ShapeVar dd
106 :     in
107 :     [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
108 :     --> Ty.T_Bool
109 :     end)),
110 : jhr 63 (N.fn_load, all([NK, SK],
111 : jhr 75 fn [Ty.DIM d, Ty.SHAPE dd] => let
112 :     val d = Ty.DimVar d
113 : jhr 63 val dd = Ty.ShapeVar dd
114 :     in
115 :     [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
116 :     end)),
117 :     (N.fn_pow, ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),
118 :     (*
119 :     val fn_principleEvec = Atom.atom "principleEvec"
120 :     *)
121 :     (N.fn_sin, ty([Ty.realTy] --> Ty.realTy)),
122 :     (* kernels *)
123 : jhr 75 (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2))),
124 :     (N.kn_tent, ty(Ty.T_Kernel(Ty.DiffConst 0)))
125 : jhr 63 ]
126 :    
127 : jhr 78 (* seed the basis environment *)
128 :     val env = let
129 :     fun ins ((name, ty), env) = let
130 :     val x = Var.newPoly (name, AST.BasisVar, ty)
131 :     in
132 :     Env.insertGlobal (env, name, x)
133 :     end
134 :     in
135 :     List.foldl ins (Env.new()) basis
136 :     end
137 :    
138 : jhr 63 end (* local *)
139 : jhr 47 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0