Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

# SCM Repository

[diderot] View of /trunk/src/basis/basis.sml
 [diderot] / trunk / src / basis / basis.sml

# View of /trunk/src/basis/basis.sml

Revision 78 - (download) (annotate)
Mon May 24 22:31:49 2010 UTC (11 years, 1 month ago) by jhr
File size: 3569 byte(s)
```  Rosking on type checker
```
```(* basis.sml
*
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
* All rights reserved.
*
* Type definitions for Basis functions.
*)

structure Basis : sig

val env : Env.env

end = struct
local
structure N = BasisNames
structure Ty = Types
structure MV = MetaVar

fun --> (tys1, ty) = Ty.T_Fun(tys1, [ty])
infix -->

val N2 = Ty.DimConst 2
val N3 = Ty.DimConst 3

(* short names for kinds *)
val TK : unit -> Ty.meta_var = Ty.TYPE o MV.newTyVar
val DK : unit -> Ty.meta_var = Ty.DIFF o MV.newDiffVar
val SK : unit -> Ty.meta_var = Ty.SHAPE o MV.newShapeVar
val NK : unit -> Ty.meta_var = Ty.DIM o MV.newDimVar

fun ty t = ([], t)
fun all (kinds, mkTy) = let
val tvs = List.map (fn mk => mk()) kinds
in
(tvs, mkTy tvs)
end
fun allNK mkTy = let
val tv = MV.newDimVar()
in
([Ty.DIM tv], mkTy tv)
end

fun field (k, d, dd) = Ty.T_Field{diff=k, dim=d, shape=dd}
fun tensor ds = Ty.T_Tensor(Ty.Shape ds)

in

(* overloaded operators *)
val overloads = [
(*
val op_add = Atom.atom "+"
val op_sub = Atom.atom "-"
val op_mul = Atom.atom "*"
val op_div = Atom.atom "/"
val op_lt = Atom.atom "<"
val op_lte = Atom.atom "<="
val op_eql = Atom.atom "=="
val op_neq = Atom.atom "!="
val op_gte = Atom.atom ">="
val op_gt = Atom.atom ">"
*)
]

(* non-overloaded operators, etc. *)
val basis = [
(* operators *)
(N.op_at, all([DK, NK, SK],
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
val k = Ty.DiffVar(k, 0)
val d = Ty.DimVar d
val dd = Ty.ShapeVar dd
in
[field(k, d, dd), tensor[d]] --> Ty.T_Tensor dd
end)),
(N.op_D, all([DK, NK, SK],
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
val k0 = Ty.DiffVar(k, 0)
val km1 = Ty.DiffVar(k, ~1)
val d = Ty.DimVar d
val dd = Ty.ShapeVar dd
in
[field(k0, d, dd), tensor[d]]
--> field(km1, d, Ty.ShapeExt(dd, d))
end)),
(N.op_norm, all([SK],
fn [Ty.SHAPE dd] => [Ty.T_Tensor(Ty.ShapeVar dd)] --> Ty.realTy)),
(* functions *)
(N.fn_CL,	ty([tensor[N3, N3]] --> Ty.vec3Ty)),
(N.fn_convolve, all([DK, NK, SK],
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
val k = Ty.DiffVar(k, 0)
val d = Ty.DimVar d
val dd = Ty.ShapeVar dd
in
[Ty.T_Kernel k, Ty.T_Image{dim=d, shape=dd}]
--> field(k, d, dd)
end)),
(N.fn_cos,	ty([Ty.realTy] --> Ty.realTy)),
(N.fn_dot,	allNK(fn tv => [tensor[Ty.DimVar tv]]
--> tensor[Ty.DimVar tv])),
(N.fn_inside,	all([DK, NK, SK],
fn [Ty.DIFF k, Ty.DIM d, Ty.SHAPE dd] => let
val k = Ty.DiffVar(k, 0)
val d = Ty.DimVar d
val dd = Ty.ShapeVar dd
in
[Ty.T_Tensor(Ty.Shape[d]), field(k, d, dd)]
--> Ty.T_Bool
end)),
(N.fn_load,	all([NK, SK],
fn [Ty.DIM d, Ty.SHAPE dd] => let
val d = Ty.DimVar d
val dd = Ty.ShapeVar dd
in
[Ty.T_String] --> Ty.T_Image{dim=d, shape=dd}
end)),
(N.fn_pow,	ty([Ty.realTy, Ty.realTy] --> Ty.realTy)),
(*
val fn_principleEvec = Atom.atom "principleEvec"
*)
(N.fn_sin,	ty([Ty.realTy] --> Ty.realTy)),
(* kernels *)
(N.kn_bspln3,	ty(Ty.T_Kernel(Ty.DiffConst 2))),
(N.kn_tent,	ty(Ty.T_Kernel(Ty.DiffConst 0)))
]

(* seed the basis environment *)
val env = let
fun ins ((name, ty), env) = let
val x = Var.newPoly (name, AST.BasisVar, ty)
in
Env.insertGlobal (env, name, x)
end
in
List.foldl ins (Env.new()) basis
end

end (* local *)
end
```

 root@smlnj-gforge.cs.uchicago.edu ViewVC Help Powered by ViewVC 1.0.0