(* shape.sml
*
* This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* COPYRIGHT (c) 2015 The University of Chicago
* All rights reserved.
*
* A tree representation of the shape of a tensor (or loop nest). The height
* of the tree corresponds to the order of the tensor (or nesting depth) plus
* one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor
* will be ND(_, [LF _, ..., LF _]), etc.
*)
structure Shape : sig
datatype ('nd, 'lf) shape
= LF of 'lf
| ND of ('nd * ('nd, 'lf) shape list)
(* create (h, w, labelNd, f, labelLf, root)
* creates a tree of height h (>= 1), with each interior node having
* w (>= 1) children.
*)
val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape
(* createFromShape (shape, labelNd, f, labelLf, root)
* creates a shape tree from the given tensor shape (i.e., int list). The height of the
* tree will be length(shape)+1 and the number of children at each level is defined by
* the corresponding element of shape.
*)
val createFromShape : int list * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape
(* map functions over the nodes and leaves of a shape *)
val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape
(* right-to-left traversal of the tree *)
val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b
(* apply a node function and a leaf function to the tree in a pre-order traversal *)
val appPreOrder : ('nd -> unit) * ('lf -> unit) -> ('nd, 'lf) shape -> unit
end = struct
datatype ('nd, 'lf) shape
= LF of 'lf
| ND of ('nd * ('nd, 'lf) shape list)
(* creates a shape with the given height and width at each level *)
fun create (height, width, ndAttr, f, lfAttr, init) = let
fun mk (0, arg) = LF(lfAttr arg)
| mk (d, arg) = ND(ndAttr arg, List.tabulate(width, fn j => mk(d-1, f(j, arg))))
in
if (height < 0) orelse (width < 1)
then raise Size
else mk (height, init)
end
fun createFromShape (shape, ndAttr, f, lfAttr, init) = let
fun mk ([], arg) = LF(lfAttr arg)
| mk (d::dd, arg) = ND(ndAttr arg, List.tabulate(d, fn j => mk(dd, f(j, arg))))
in
mk (shape, init)
end
fun map (nd, lf) t = let
fun mapf (LF x) = LF(lf x)
| mapf (ND(i, kids)) = ND(nd i, List.map mapf kids)
in
mapf t
end
fun foldr f init t = let
fun fold (LF x, acc) = f(x, acc)
| fold (ND(_, kids), acc) = List.foldr fold acc kids
in
fold (t, init)
end
fun appPreOrder (ndFn, lfFn) = let
fun app (ND(attr, kids)) = (ndFn attr; List.app app kids)
| app (LF attr) = lfFn attr
in
app
end
end