SCM Repository
Annotation of /trunk/src/compiler/fields/shape.sml
Parent Directory
|
Revision Log
Revision 1116 - (view) (download)
1 : | jhr | 342 | (* shape.sml |
2 : | * | ||
3 : | jhr | 435 | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu) |
4 : | jhr | 342 | * All rights reserved. |
5 : | * | ||
6 : | jhr | 349 | * A tree representation of the shape of a tensor (or loop nest). The height |
7 : | * of the tree corresponds to the order of the tensor (or nesting depth) plus | ||
8 : | * one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor | ||
9 : | * will be ND(_, [LF _, ..., LF _]), etc. | ||
10 : | jhr | 1116 | * |
11 : | * | ||
12 : | jhr | 342 | *) |
13 : | |||
14 : | jhr | 349 | structure Shape : sig |
15 : | jhr | 342 | |
16 : | datatype ('nd, 'lf) shape | ||
17 : | = LF of 'lf | ||
18 : | | ND of ('nd * ('nd, 'lf) shape list) | ||
19 : | |||
20 : | jhr | 1116 | (* create (h, w, labelNd, f, labelLf, root) |
21 : | * creates a tree of height h (>= 1), with each interior node having | ||
22 : | * w (>= 1) children. | ||
23 : | jhr | 349 | *) |
24 : | val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape | ||
25 : | |||
26 : | jhr | 1116 | (* map functions over the nodes and leaves of a shape *) |
27 : | jhr | 349 | val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape |
28 : | |||
29 : | jhr | 353 | (* right-to-left traversal of the tree *) |
30 : | jhr | 349 | val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b |
31 : | |||
32 : | jhr | 353 | val appPreOrder : ('nd -> unit) * ('lf -> unit) -> ('nd, 'lf) shape -> unit |
33 : | |||
34 : | jhr | 349 | end = struct |
35 : | |||
36 : | datatype ('nd, 'lf) shape | ||
37 : | = LF of 'lf | ||
38 : | | ND of ('nd * ('nd, 'lf) shape list) | ||
39 : | |||
40 : | jhr | 1116 | (* creates a shape with the given height and width at each level *) |
41 : | fun create (height, width, ndAttr, f, lfAttr, init) = let | ||
42 : | fun mk (0, arg) = LF(lfAttr arg) | ||
43 : | | mk (d, arg) = ND(ndAttr arg, List.tabulate(width, fn j => mk(d-1, f(j, arg)))) | ||
44 : | jhr | 342 | in |
45 : | jhr | 1116 | if (height < 0) orelse (width < 1) |
46 : | then raise Size | ||
47 : | else mk (height, init) | ||
48 : | jhr | 342 | end |
49 : | |||
50 : | fun map (nd, lf) t = let | ||
51 : | fun mapf (LF x) = LF(lf x) | ||
52 : | | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids) | ||
53 : | in | ||
54 : | mapf t | ||
55 : | end | ||
56 : | |||
57 : | fun foldr f init t = let | ||
58 : | fun fold (LF x, acc) = f(x, acc) | ||
59 : | | fold (ND(_, kids), acc) = List.foldr fold acc kids | ||
60 : | in | ||
61 : | jhr | 349 | fold (t, init) |
62 : | jhr | 342 | end |
63 : | |||
64 : | jhr | 353 | fun appPreOrder (ndFn, lfFn) = let |
65 : | fun app (ND(attr, kids)) = (ndFn attr; List.app app kids) | ||
66 : | | app (LF attr) = lfFn attr | ||
67 : | in | ||
68 : | app | ||
69 : | end | ||
70 : | |||
71 : | jhr | 342 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |