Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/basis/basis.sml
ViewVC logotype

Diff of /trunk/src/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 65, Thu May 13 21:04:35 2010 UTC revision 91, Thu May 27 15:16:36 2010 UTC
# Line 6  Line 6 
6   * Type definitions for Basis functions.   * Type definitions for Basis functions.
7   *)   *)
8    
9  structure Basis =  structure Basis : sig
   struct  
     local  
       structure N = BasisNames  
       structure Ty = Types  
       structure TV = TypeVar  
10    
11        fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])      val env : Env.env
       infix -->  
12    
13        val N2 = Ty.NatConst 2    (* find an operator by name; this returns a singleton list for regular operators (including
14        val N3 = Ty.NatConst 3     * type-index operators) and a list of variables for overloaded operators.
15       *)
16      (* short names for kinds *)      val findOp : Atom.atom -> AST.var list
       val NK = Ty.TK_NAT  
       val SK = Ty.TK_SHAPE  
       val TK = Ty.TK_TYPE  
   
       fun ty t = ([], t)  
       fun all (kinds, mkTy) = let  
             val tvs = List.map (fn k => TV.new k) kinds  
             in  
               (tvs, mkTy tvs)  
             end  
       fun allNK mkTy = let  
             val tv = TV.new NK  
             in  
               ([tv], mkTy tv)  
             end  
17    
18      fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}    end = struct
     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)  
19    
20      in      structure N = BasisNames
21        structure BV = BasisVars
22        structure ATbl = AtomTable
23    
24      (* non-overloaded operators, etc. *)
25      val basis = [      val basis = [
26          (* operators *)          (* non-overloaded operators *)
27            (N.op_at, all([NK, NK, SK],            BV.op_at,
28              fn [k, d, dd] => let            BV.op_D,
29                  val k = Ty.NatVar k            BV.op_norm,
30                  val d = Ty.NatVar d            BV.op_not,
31                  val dd = Ty.ShapeVar dd            BV.op_subscript,
                 in  
                   [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd  
                 end)),  
 (*  
     val op_add = Atom.atom "+"  
     val op_sub = Atom.atom "-"  
     val op_mul = Atom.atom "*"  
     val op_div = Atom.atom "/"  
     val op_lt = Atom.atom "<"  
     val op_lte = Atom.atom "<="  
     val op_eql = Atom.atom "=="  
     val op_neq = Atom.atom "!="  
     val op_gte = Atom.atom ">="  
     val op_gt = Atom.atom ">"  
 *)  
           (N.op_at, all([NK, NK, SK],  
             fn [k, d, dd] => let  
                 val k = Ty.NatVar k  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [field(k, d, dd), tensor[d]]  
                     --> field(Ty.NatExp(k, ~1), d, Ty.ShapeExt(dd, d))  
                 end)),  
 (*  
     val op_orelse = Atom.atom "||"  
     val op_andalso = Atom.atom "&&"  
 *)  
           (N.op_norm, all([SK],  
             fn [dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),  
32          (* functions *)          (* functions *)
33            (N.fn_CL,     ty([tensor[N3, N3]] --> Ty.vec3Ty)),            BV.fn_CL,
34            (N.fn_convolve, all([NK, NK, SK],            BV.fn_convolve,
35              fn [k, d, dd] => let            BV.fn_cos,
36                  val k = Ty.NatVar k            BV.fn_dot,
37                  val d = Ty.NatVar d            BV.fn_inside,
38                  val dd = Ty.ShapeVar dd            BV.fn_load,
39                  in            BV.fn_modulate,
40                    [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]            BV.fn_pow,
41                      --> field(k, d, dd)            BV.fn_principleEvec,
42                  end)),            BV.fn_sin,
           (N.fn_cos,    ty([Ty.realTy] --> Ty.realTy)),  
           (N.fn_dot,    allNK(fn tv => [tensor[Ty.NatVar tv]]  
                           --> tensor[Ty.NatVar tv])),  
 (*  
     val fn_inside = Atom.atom "inside"  
 *)  
           (N.fn_load,   all([NK, SK],  
             fn [d, dd] => let  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}  
                 end)),  
           (N.fn_pow,    ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),  
 (*  
     val fn_principleEvec = Atom.atom "principleEvec"  
 *)  
           (N.fn_sin,    ty([Ty.realTy] --> Ty.realTy)),  
43          (* kernels *)          (* kernels *)
44            (N.kn_bspln3, ty(Ty.T_Kernel(Ty.NatConst 2))),            BV.kn_bspln3,
45            (N.kn_tent,   ty(Ty.T_Kernel(Ty.NatConst 0)))            BV.kn_tent
46            ]
47    
48        (* seed the basis environment *)
49          val env = let
50                fun ins (x, env) = Env.insertGlobal(env, Atom.atom(Var.nameOf x), x)
51                in
52                  List.foldl ins (Env.new()) basis
53                end
54    
55      (* overloaded operators *)
56        val overloads = [
57              (N.op_add, [BV.add_ii, BV.add_tt]),
58              (N.op_sub, [BV.sub_ii, BV.sub_tt]),
59              (N.op_mul, [BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr]),
60              (N.op_div, [BV.div_ii, BV.div_rr, BV.div_tr]),
61              (N.op_lt, [BV.lt_ii, BV.lt_rr]),
62              (N.op_lte, [BV.lte_ii, BV.lte_rr]),
63              (N.op_equ, [BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr]),
64              (N.op_neq, [BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr]),
65              (N.op_gte, [BV.gte_ii, BV.gte_rr]),
66              (N.op_gt, [BV.gt_ii, BV.gt_rr]),
67              (N.op_neg, [BV.neg_i, BV.neg_t, BV.neg_f])
68          ]          ]
69    
70        local
71          val find = let
72                val tbl = ATbl.mkTable(64, Fail "op table")
73                in
74                  List.app (ATbl.insert tbl) overloads;
75                  ATbl.find tbl
76                end
77        in
78        fun findOp name = (case Env.findVar(env, name)
79               of SOME x => [x]
80                | NONE => (case find name
81                     of SOME xs => xs
82                      | NONE => raise Fail("unknown operator "^Atom.toString name)
83                    (* end case *))
84              (* end case *))
85      end (* local *)      end (* local *)
86    
87    end    end

Legend:
Removed from v.65  
changed lines
  Added in v.91

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0