3 |
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) |
* COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) |
4 |
* All rights reserved. |
* All rights reserved. |
5 |
* |
* |
6 |
* A tree representation of the shape of a tensor. The height of the tree |
* A tree representation of the shape of a tensor (or loop nest). The height |
7 |
* corresponds to the order of the tensor plus one. I.e., a 0-order tensor |
* of the tree corresponds to the order of the tensor (or nesting depth) plus |
8 |
* is represented by a leaf, a 1-order tensor will be ND(_, [LF _, ..., LF _]), |
* one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor |
9 |
* etc. |
* will be ND(_, [LF _, ..., LF _]), etc. |
10 |
*) |
*) |
11 |
|
|
12 |
structure Shape = |
structure Shape : sig |
13 |
struct |
|
14 |
|
datatype ('nd, 'lf) shape |
15 |
|
= LF of 'lf |
16 |
|
| ND of ('nd * ('nd, 'lf) shape list) |
17 |
|
|
18 |
|
(* create (depth, wid, labelNd, f, labelLf, root) |
19 |
|
*) |
20 |
|
val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape |
21 |
|
|
22 |
|
val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape |
23 |
|
|
24 |
|
val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b |
25 |
|
|
26 |
|
end = struct |
27 |
|
|
28 |
datatype ('nd, 'lf) shape |
datatype ('nd, 'lf) shape |
29 |
= LF of 'lf |
= LF of 'lf |
49 |
fun fold (LF x, acc) = f(x, acc) |
fun fold (LF x, acc) = f(x, acc) |
50 |
| fold (ND(_, kids), acc) = List.foldr fold acc kids |
| fold (ND(_, kids), acc) = List.foldr fold acc kids |
51 |
in |
in |
52 |
fold t |
fold (t, init) |
53 |
end |
end |
54 |
|
|
55 |
end |
end |