Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/basis/basis.sml
ViewVC logotype

Diff of /trunk/src/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 68, Tue May 18 16:51:30 2010 UTC revision 91, Thu May 27 15:16:36 2010 UTC
# Line 6  Line 6 
6   * Type definitions for Basis functions.   * Type definitions for Basis functions.
7   *)   *)
8    
9  structure Basis =  structure Basis : sig
   struct  
     local  
       structure N = BasisNames  
       structure Ty = Types  
       structure TV = TypeVar  
10    
11        fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])      val env : Env.env
       infix -->  
12    
13        val N2 = Ty.NatConst 2    (* find an operator by name; this returns a singleton list for regular operators (including
14        val N3 = Ty.NatConst 3     * type-index operators) and a list of variables for overloaded operators.
15       *)
16        val findOp : Atom.atom -> AST.var list
17    
18      (* short names for kinds *)    end = struct
       val NK = Ty.TK_NAT  
       val SK = Ty.TK_SHAPE  
       val TK = Ty.TK_TYPE  
   
       fun ty t = ([], t)  
       fun all (kinds, mkTy) = let  
             val tvs = List.map (fn k => TV.new k) kinds  
             in  
               (tvs, mkTy tvs)  
             end  
       fun allNK mkTy = let  
             val tv = TV.new NK  
             in  
               ([tv], mkTy tv)  
             end  
19    
20      fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}      structure N = BasisNames
21      fun tensor ds = Ty.T_Tensor(Ty.Shape ds)      structure BV = BasisVars
22        structure ATbl = AtomTable
23    
24      (* non-overloaded operators, etc. *)
25        val basis = [
26            (* non-overloaded operators *)
27              BV.op_at,
28              BV.op_D,
29              BV.op_norm,
30              BV.op_not,
31              BV.op_subscript,
32            (* functions *)
33              BV.fn_CL,
34              BV.fn_convolve,
35              BV.fn_cos,
36              BV.fn_dot,
37              BV.fn_inside,
38              BV.fn_load,
39              BV.fn_modulate,
40              BV.fn_pow,
41              BV.fn_principleEvec,
42              BV.fn_sin,
43            (* kernels *)
44              BV.kn_bspln3,
45              BV.kn_tent
46            ]
47    
48        (* seed the basis environment *)
49          val env = let
50                fun ins (x, env) = Env.insertGlobal(env, Atom.atom(Var.nameOf x), x)
51      in      in
52                  List.foldl ins (Env.new()) basis
53                end
54    
55    (* overloaded operators *)    (* overloaded operators *)
56      val overloads = [      val overloads = [
57    (*            (N.op_add, [BV.add_ii, BV.add_tt]),
58      val op_add = Atom.atom "+"            (N.op_sub, [BV.sub_ii, BV.sub_tt]),
59      val op_sub = Atom.atom "-"            (N.op_mul, [BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr]),
60      val op_mul = Atom.atom "*"            (N.op_div, [BV.div_ii, BV.div_rr, BV.div_tr]),
61      val op_div = Atom.atom "/"            (N.op_lt, [BV.lt_ii, BV.lt_rr]),
62      val op_lt = Atom.atom "<"            (N.op_lte, [BV.lte_ii, BV.lte_rr]),
63      val op_lte = Atom.atom "<="            (N.op_equ, [BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr]),
64      val op_eql = Atom.atom "=="            (N.op_neq, [BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr]),
65      val op_neq = Atom.atom "!="            (N.op_gte, [BV.gte_ii, BV.gte_rr]),
66      val op_gte = Atom.atom ">="            (N.op_gt, [BV.gt_ii, BV.gt_rr]),
67      val op_gt = Atom.atom ">"            (N.op_neg, [BV.neg_i, BV.neg_t, BV.neg_f])
 *)  
68          ]          ]
69    
70    (* non-overloaded operators, etc. *)      local
71      val basis = [        val find = let
72          (* operators *)              val tbl = ATbl.mkTable(64, Fail "op table")
           (N.op_at, all([NK, NK, SK],  
             fn [k, d, dd] => let  
                 val k = Ty.NatVar k  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd  
                 end)),  
           (N.op_at, all([NK, NK, SK],  
             fn [k, d, dd] => let  
                 val k = Ty.NatVar k  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [field(k, d, dd), tensor[d]]  
                     --> field(Ty.NatExp(k, ~1), d, Ty.ShapeExt(dd, d))  
                 end)),  
           (N.op_norm, all([SK],  
             fn [dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),  
         (* functions *)  
           (N.fn_CL,     ty([tensor[N3, N3]] --> Ty.vec3Ty)),  
           (N.fn_convolve, all([NK, NK, SK],  
             fn [k, d, dd] => let  
                 val k = Ty.NatVar k  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]  
                     --> field(k, d, dd)  
                 end)),  
           (N.fn_cos,    ty([Ty.realTy] --> Ty.realTy)),  
           (N.fn_dot,    allNK(fn tv => [tensor[Ty.NatVar tv]]  
                           --> tensor[Ty.NatVar tv])),  
           (N.fn_inside, all([NK, NK, SK],  
             fn [k, d, dd] => let  
                 val k = Ty.NatVar k  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
73                  in                  in
74                    [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]                List.app (ATbl.insert tbl) overloads;
75                      --> Ty.T_Bool                ATbl.find tbl
76                  end)),              end
           (N.fn_load,   all([NK, SK],  
             fn [d, dd] => let  
                 val d = Ty.NatVar d  
                 val dd = Ty.ShapeVar dd  
77                  in                  in
78                    [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}      fun findOp name = (case Env.findVar(env, name)
79                  end)),             of SOME x => [x]
80            (N.fn_pow,    ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),              | NONE => (case find name
81  (*                   of SOME xs => xs
82      val fn_principleEvec = Atom.atom "principleEvec"                    | NONE => raise Fail("unknown operator "^Atom.toString name)
83  *)                  (* end case *))
84            (N.fn_sin,    ty([Ty.realTy] --> Ty.realTy)),            (* end case *))
         (* kernels *)  
           (N.kn_bspln3, ty(Ty.T_Kernel(Ty.NatConst 2))),  
           (N.kn_tent,   ty(Ty.T_Kernel(Ty.NatConst 0)))  
         ]  
   
85      end (* local *)      end (* local *)
86    
87    end    end

Legend:
Removed from v.68  
changed lines
  Added in v.91

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0