Sat Apr 2 19:49:07 2011 UTC (8 years, 6 months ago) by jhr
File size: 3223 byte(s)
Better printing of partials
(* partials.sml
*
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* A symbolic representation of partial derivative operators.
*)

structure Partials :> sig

eqtype axis				(* abstract representation of axis *)
val axisToString : axis -> string	(* string representation of axis name *)
val axis : int -> axis		(* return ith axis name (0 base) *)
val index : axis -> int		(* axis index *)

(* representation of a partial derivative operator, where the length of the
* list defines the dimension of the space and each element specifies the
* number of levels of differentiation along the corresponding axis.
* For example, the operator d/dx in 2D space is represented by D[1, 0]
* and the operator d/dxdz in 3D is represented by D[1, 0, 1].
*)
datatype partial = D of int list

val partialToString : partial -> string

(* the identity for the given dimension *)
val ident : int -> partial

(* derivative for the given dimension and axis *)
val del : int -> axis -> partial

(* multiply two partial derivative operators *)
val multiply : (partial * partial) -> partial

(* the partial derivative operator for a given dimension.  Specifically,
*
*    partial d index
*
* returns the partial derivative in the tensor of partial derivatives that
* is generated by differentiating k times (k = length index).  For example,
*
*	partial 2 [] = ID
*	partial 2  = d/dx
*	partial 2  = d/dy
*    partial 2 [0, 0] = d/dx^2
*    partial 2 [0, 1] = d/dxdy
*    partial 2 [1, 0] = d/dxdy
*    partial 2 [1, 1] = d/dy^2
*)
val partial : int -> axis list -> partial

end = struct

type axis = int

val maxAxis = 3	(* up to 4 dimensions supported *)

fun axisToString i = String.substring("xyzw", i, 1)

fun axis i = if (i < 0) orelse (maxAxis < i) then raise Size else i

fun index (i : axis) = i

datatype partial = D of int list

fun partialToString (D l) = let
fun f (a, []) = []
| f (a, 0::r) = f(a+1, r)
| f (a, 1::r) = "d" :: axisToString(axis a) :: f(a+1, r)
| f (a, i::r) = "d" :: axisToString(axis a) :: Int.toString i :: f(a+1, r)
val n = List.foldl (op +) 0 l
in
case f (0, l)
of [] => "ID"
| l => String.concat("d" :: Int.toString n :: "/" :: l)
(* end case *)
end

fun uncheckedIdent d = D(List.tabulate(d, fn _ => 0))
fun uncheckedDel d axis = D(List.tabulate(d, fn i => if i = axis then 1 else 0))

fun ident d = if (d < 0) orelse (maxAxis < d)
then raise Size
else uncheckedIdent d

fun del d = if (d < 0) orelse (maxAxis < d)
then raise Size
else fn axis => if (d < axis)
then raise Subscript
else uncheckedDel d axis

(* multiply two partial derivative operators *)
fun multiply (D l1, D l2) = D(ListPair.mapEq (fn (i, j) => (i+j)) (l1, l2))

fun partial dim = let
val del = uncheckedDel dim
fun loop ([], dd) = dd
| loop (a::r, dd) = loop (r, multiply(del a, dd))
in
if (dim < 0) orelse (maxAxis < dim)
then raise Size
else fn index => loop (index, uncheckedIdent dim)
end

end

