SCM Repository
Annotation of /trunk/src/compiler/fields/shape.sml
Parent Directory
|
Revision Log
Revision 349 - (view) (download)
1 : | jhr | 342 | (* shape.sml |
2 : | * | ||
3 : | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu) | ||
4 : | * All rights reserved. | ||
5 : | * | ||
6 : | jhr | 349 | * A tree representation of the shape of a tensor (or loop nest). The height |
7 : | * of the tree corresponds to the order of the tensor (or nesting depth) plus | ||
8 : | * one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor | ||
9 : | * will be ND(_, [LF _, ..., LF _]), etc. | ||
10 : | jhr | 342 | *) |
11 : | |||
12 : | jhr | 349 | structure Shape : sig |
13 : | jhr | 342 | |
14 : | datatype ('nd, 'lf) shape | ||
15 : | = LF of 'lf | ||
16 : | | ND of ('nd * ('nd, 'lf) shape list) | ||
17 : | |||
18 : | jhr | 349 | (* create (depth, wid, labelNd, f, labelLf, root) |
19 : | *) | ||
20 : | val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape | ||
21 : | |||
22 : | val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape | ||
23 : | |||
24 : | val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b | ||
25 : | |||
26 : | end = struct | ||
27 : | |||
28 : | datatype ('nd, 'lf) shape | ||
29 : | = LF of 'lf | ||
30 : | | ND of ('nd * ('nd, 'lf) shape list) | ||
31 : | |||
32 : | jhr | 342 | (* creates a shape with the given depth and width at each level *) |
33 : | fun create (depth, width, ndAttr, f, lfAttr, init) = let | ||
34 : | fun mk (d, i, arg) = if (d < depth) | ||
35 : | then ND(ndAttr arg, List.tabulate(width, fn j => mk(d+1, j, f(j, arg)))) | ||
36 : | else LF(lfAttr arg) | ||
37 : | in | ||
38 : | mk (0, 0, init) | ||
39 : | end | ||
40 : | |||
41 : | fun map (nd, lf) t = let | ||
42 : | fun mapf (LF x) = LF(lf x) | ||
43 : | | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids) | ||
44 : | in | ||
45 : | mapf t | ||
46 : | end | ||
47 : | |||
48 : | fun foldr f init t = let | ||
49 : | fun fold (LF x, acc) = f(x, acc) | ||
50 : | | fold (ND(_, kids), acc) = List.foldr fold acc kids | ||
51 : | in | ||
52 : | jhr | 349 | fold (t, init) |
53 : | jhr | 342 | end |
54 : | |||
55 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |