SCM Repository
View of /trunk/src/compiler/fields/partials.sml
Parent Directory
|
Revision Log
Revision 349 -
(download)
(annotate)
Fri Sep 24 00:24:20 2010 UTC (10 years, 3 months ago) by jhr
File size: 3157 byte(s)
Fri Sep 24 00:24:20 2010 UTC (10 years, 3 months ago) by jhr
File size: 3157 byte(s)
working on HighIL to MidIL translation
(* partials.sml * * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) * All rights reserved. * * A symbolic representation of partial derivative operators. *) structure Partials :> sig eqtype axis (* abstract representation of axis *) val axisToString : axis -> string (* string representation of axis name *) val axis : int -> axis (* return ith axis name (0 base) *) val index : axis -> int (* axis index *) (* representation of a partial derivative operator, where the length of the * list defines the dimension of the space and each element specifies the * number of levels of differentiation along the corresponding axis. * For example, the operator d/dx in 2D space is represented by D[1, 0] * and the operator d/dxdz in 3D is represented by D[1, 0, 1]. *) datatype partial = D of int list val partialToString : partial -> string (* the identity for the given dimension *) val ident : int -> partial (* derivative for the given dimension and axis *) val del : int -> axis -> partial (* multiply two partial derivative operators *) val multiply : (partial * partial) -> partial (* the partial derivative operator for a given dimension. Specifically, * * partial d index * * returns the partial derivative in the tensor of partial derivatives that * is generated by differentiating k times (k = length index). For example, * * partial 2 [] = ID * partial 2 [0] = d/dx * partial 2 [1] = d/dy * partial 2 [0, 0] = d/dx^2 * partial 2 [0, 1] = d/dxdy * partial 2 [1, 0] = d/dxdy * partial 2 [1, 1] = d/dy^2 *) val partial : int -> axis list -> partial end = struct type axis = int val maxAxis = 3 (* up to 4 dimensions supported *) fun axisToString i = String.substring("xyzw", i, 1) fun axis i = if (i < 0) orelse (maxAxis < i) then raise Size else i fun index (i : axis) = i datatype partial = D of int list fun partialToString (D l) = let fun f (a, []) = [] | f (a, 0::r) = f(a+1, r) | f (a, 1::r) = "d" :: axisToString(axis a) :: f(a+1, r) | f (a, i::r) = "d" :: axisToString(axis a) :: Int.toString i :: f(a+1, r) in case f (0, l) of [] => "ID" | l => String.concat("d/" :: l) (* end case *) end fun uncheckedIdent d = D(List.tabulate(d, fn _ => 0)) fun uncheckedDel d axis = D(List.tabulate(d, fn i => if i = axis then 1 else 0)) fun ident d = if (d < 0) orelse (maxAxis < d) then raise Size else uncheckedIdent d fun del d = if (d < 0) orelse (maxAxis < d) then raise Size else fn axis => if (d < axis) then raise Subscript else uncheckedDel d axis (* multiply two partial derivative operators *) fun multiply (D l1, D l2) = D(ListPair.mapEq (fn (i, j) => (i+j)) (l1, l2)) fun partial dim = let val del = uncheckedDel dim fun loop ([], dd) = dd | loop (a::r, dd) = loop (r, multiply(del a, dd)) in if (dim < 0) orelse (maxAxis < dim) then raise Size else fn index => loop (index, uncheckedIdent dim) end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |