(* shape.sml
*
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
* All rights reserved.
*
* A tree representation of the shape of a tensor. The height of the tree
* corresponds to the order of the tensor plus one. I.e., a 0-order tensor
* is represented by a leaf, a 1-order tensor will be ND(_, [LF _, ..., LF _]),
* etc.
*)
structure Shape =
struct
datatype ('nd, 'lf) shape
= LF of 'lf
| ND of ('nd * ('nd, 'lf) shape list)
(* creates a shape with the given depth and width at each level *)
fun create (depth, width, ndAttr, f, lfAttr, init) = let
fun mk (d, i, arg) = if (d < depth)
then ND(ndAttr arg, List.tabulate(width, fn j => mk(d+1, j, f(j, arg))))
else LF(lfAttr arg)
in
mk (0, 0, init)
end
fun map (nd, lf) t = let
fun mapf (LF x) = LF(lf x)
| mapf (ND(i, kids)) = ND(nd i, List.map mapf kids)
in
mapf t
end
fun foldr f init t = let
fun fold (LF x, acc) = f(x, acc)
| fold (ND(_, kids), acc) = List.foldr fold acc kids
in
fold t
end
end