Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/compiler/basis/basis.sml
ViewVC logotype

Diff of /trunk/src/compiler/basis/basis.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 78, Mon May 24 22:31:49 2010 UTC revision 79, Tue May 25 01:55:48 2010 UTC
# Line 10  Line 10 
10    
11      val env : Env.env      val env : Env.env
12    
13    end = struct    (* find an operator by name; this returns a singleton list for regular operators (including
14      local     * type-index operators) and a list of variables for overloaded operators.
15        structure N = BasisNames     *)
16        structure Ty = Types      val findOp : Atom.atom -> AST.var list
       structure MV = MetaVar  
   
       fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])  
       infix -->  
   
       val N2 = Ty.DimConst 2  
       val N3 = Ty.DimConst 3  
   
     (* short names for kinds *)  
       val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar  
       val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar  
       val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar  
       val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar  
   
       fun ty t = ([], t)  
       fun all (kinds, mkTy) = let  
             val tvs = List.map (fn mk => mk()) kinds  
             in  
               (tvs, mkTy tvs)  
             end  
       fun allNK mkTy = let  
             val tv = MV.newDimVar()  
             in  
               ([Ty.DIM tv], mkTy tv)  
             end  
   
     fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}  
     fun tensor ds = Ty.T_Tensor(Ty.Shape ds)  
17    
18      in    end = struct
19    
20    (* overloaded operators *)      structure N = BasisNames
21      val overloads = [      structure BV = BasisVars
22    (*      structure ATbl = AtomTable
     val op_add = Atom.atom "+"  
     val op_sub = Atom.atom "-"  
     val op_mul = Atom.atom "*"  
     val op_div = Atom.atom "/"  
     val op_lt = Atom.atom "<"  
     val op_lte = Atom.atom "<="  
     val op_eql = Atom.atom "=="  
     val op_neq = Atom.atom "!="  
     val op_gte = Atom.atom ">="  
     val op_gt = Atom.atom ">"  
 *)  
         ]  
23    
24    (* non-overloaded operators, etc. *)    (* non-overloaded operators, etc. *)
25      val basis = [      val basis = [
26          (* operators *)          (* non-overloaded operators *)
27            (N.op_at, all([DK, NK, SK],            BV.op_at,
28              fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let            BV.op_D,
29                  val k = Ty.DiffVar(k, 0)            BV.op_norm,
30                  val d = Ty.DimVar d            BV.op_not,
                 val dd = Ty.ShapeVar dd  
                 in  
                   [field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd  
                 end)),  
           (N.op_D, all([DK, NK, SK],  
             fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let  
                 val k0 = Ty.DiffVar(k, 0)  
                 val km1 = Ty.DiffVar(k, ~1)  
                 val d = Ty.DimVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [field(k0, d, dd), tensor[d]]  
                     --> field(km1, d, Ty.ShapeExt(dd, d))  
                 end)),  
           (N.op_norm, all([SK],  
             fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),  
31          (* functions *)          (* functions *)
32            (N.fn_CL,     ty([tensor[N3, N3]] --> Ty.vec3Ty)),            BV.fn_CL,
33            (N.fn_convolve, all([DK, NK, SK],            BV.fn_convolve,
34              fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let            BV.fn_cos,
35                  val k = Ty.DiffVar(k, 0)            BV.fn_dot,
36                  val d = Ty.DimVar d            BV.fn_inside,
37                  val dd = Ty.ShapeVar dd            BV.fn_load,
38                  in            BV.fn_pow,
                   [Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]  
                     --> field(k, d, dd)  
                 end)),  
           (N.fn_cos,    ty([Ty.realTy] --> Ty.realTy)),  
           (N.fn_dot,    allNK(fn tv => [tensor[Ty.DimVar tv]]  
                           --> tensor[Ty.DimVar tv])),  
           (N.fn_inside, all([DK, NK, SK],  
             fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let  
                 val k = Ty.DiffVar(k, 0)  
                 val d = Ty.DimVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]  
                     --> Ty.T_Bool  
                 end)),  
           (N.fn_load,   all([NK, SK],  
             fn [Ty.DIM d, Ty.SHAPE dd] => let  
                 val d = Ty.DimVar d  
                 val dd = Ty.ShapeVar dd  
                 in  
                   [Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}  
                 end)),  
           (N.fn_pow,    ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),  
39  (*  (*
40      val fn_principleEvec = Atom.atom "principleEvec"      val fn_principleEvec = Atom.atom "principleEvec"
41  *)  *)
42            (N.fn_sin,    ty([Ty.realTy] --> Ty.realTy)),            BV.fn_sin,
43          (* kernels *)          (* kernels *)
44            (N.kn_bspln3, ty(Ty.T_Kernel(Ty.DiffConst 2))),            BV.kn_bspln3,
45            (N.kn_tent,   ty(Ty.T_Kernel(Ty.DiffConst 0)))            BV.kn_tent
46          ]          ]
47    
48      (* seed the basis environment *)      (* seed the basis environment *)
49        val env = let        val env = let
50              fun ins ((name, ty), env) = let              fun ins (x, env) = Env.insertGlobal(env, Atom.atom(Var.nameOf x), x)
                   val x = Var.newPoly (name, AST.BasisVar, ty)  
                   in  
                     Env.insertGlobal (env, name, x)  
                   end  
51              in              in
52                List.foldl ins (Env.new()) basis                List.foldl ins (Env.new()) basis
53              end              end
54    
55      (* overloaded operators *)
56        val overloads = [
57              (N.op_add, [BV.add_ii, BV.add_tt]),
58              (N.op_sub, [BV.sub_ii, BV.sub_tt]),
59              (N.op_mul, [BV.mul_ii, BV.mul_rr, BV.mul_rt, BV.mul_tr]),
60              (N.op_div, [BV.div_ii, BV.div_rr, BV.div_tr]),
61              (N.op_lt, [BV.lt_ii, BV.lt_rr]),
62              (N.op_lte, [BV.lte_ii, BV.lte_rr]),
63              (N.op_equ, [BV.equ_bb, BV.equ_ii, BV.equ_ss, BV.equ_rr]),
64              (N.op_neq, [BV.neq_bb, BV.neq_ii, BV.neq_ss, BV.neq_rr]),
65              (N.op_gte, [BV.gte_ii, BV.gte_rr]),
66              (N.op_gt, [BV.gt_ii, BV.gt_rr]),
67              (N.op_neg, [BV.neg_i, BV.neg_t, BV.neg_f])
68            ]
69    
70        local
71          val lookup = let
72                val tbl = ATbl.mkTable(64, Fail "op table")
73                in
74                  List.app (ATbl.insert tbl) overloads;
75                  ATbl.lookup tbl
76                end
77        in
78        fun findOp name = (case Env.findVar(env, name)
79               of SOME x => [x]
80                | NONE => lookup name
81              (* end case *))
82      end (* local *)      end (* local *)
83    
84    end    end

Legend:
Removed from v.78  
changed lines
  Added in v.79

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0