Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/compiler/fields/shape.sml
ViewVC logotype

Annotation of /trunk/src/compiler/fields/shape.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3349 - (view) (download)

1 : jhr 342 (* shape.sml
2 :     *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : jhr 342 * All rights reserved.
7 :     *
8 : jhr 349 * A tree representation of the shape of a tensor (or loop nest). The height
9 :     * of the tree corresponds to the order of the tensor (or nesting depth) plus
10 :     * one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor
11 :     * will be ND(_, [LF _, ..., LF _]), etc.
12 : jhr 342 *)
13 :    
14 : jhr 349 structure Shape : sig
15 : jhr 342
16 :     datatype ('nd, 'lf) shape
17 :     = LF of 'lf
18 :     | ND of ('nd * ('nd, 'lf) shape list)
19 :    
20 : jhr 1116 (* create (h, w, labelNd, f, labelLf, root)
21 :     * creates a tree of height h (>= 1), with each interior node having
22 :     * w (>= 1) children.
23 : jhr 349 *)
24 :     val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape
25 :    
26 : jhr 2356 (* createFromShape (shape, labelNd, f, labelLf, root)
27 :     * creates a shape tree from the given tensor shape (i.e., int list). The height of the
28 :     * tree will be length(shape)+1 and the number of children at each level is defined by
29 :     * the corresponding element of shape.
30 :     *)
31 :     val createFromShape : int list * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape
32 :    
33 : jhr 1116 (* map functions over the nodes and leaves of a shape *)
34 : jhr 349 val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape
35 :    
36 : jhr 353 (* right-to-left traversal of the tree *)
37 : jhr 349 val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b
38 :    
39 : jhr 2356 (* apply a node function and a leaf function to the tree in a pre-order traversal *)
40 : jhr 353 val appPreOrder : ('nd -> unit) * ('lf -> unit) -> ('nd, 'lf) shape -> unit
41 :    
42 : jhr 349 end = struct
43 :    
44 :     datatype ('nd, 'lf) shape
45 :     = LF of 'lf
46 :     | ND of ('nd * ('nd, 'lf) shape list)
47 :    
48 : jhr 1116 (* creates a shape with the given height and width at each level *)
49 :     fun create (height, width, ndAttr, f, lfAttr, init) = let
50 :     fun mk (0, arg) = LF(lfAttr arg)
51 :     | mk (d, arg) = ND(ndAttr arg, List.tabulate(width, fn j => mk(d-1, f(j, arg))))
52 : jhr 342 in
53 : jhr 1116 if (height < 0) orelse (width < 1)
54 :     then raise Size
55 :     else mk (height, init)
56 : jhr 342 end
57 :    
58 : jhr 2356 fun createFromShape (shape, ndAttr, f, lfAttr, init) = let
59 :     fun mk ([], arg) = LF(lfAttr arg)
60 :     | mk (d::dd, arg) = ND(ndAttr arg, List.tabulate(d, fn j => mk(dd, f(j, arg))))
61 :     in
62 :     mk (shape, init)
63 :     end
64 :    
65 : jhr 342 fun map (nd, lf) t = let
66 :     fun mapf (LF x) = LF(lf x)
67 :     | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids)
68 :     in
69 :     mapf t
70 :     end
71 :    
72 :     fun foldr f init t = let
73 :     fun fold (LF x, acc) = f(x, acc)
74 :     | fold (ND(_, kids), acc) = List.foldr fold acc kids
75 :     in
76 : jhr 349 fold (t, init)
77 : jhr 342 end
78 :    
79 : jhr 353 fun appPreOrder (ndFn, lfFn) = let
80 :     fun app (ND(attr, kids)) = (ndFn attr; List.app app kids)
81 :     | app (LF attr) = lfFn attr
82 :     in
83 :     app
84 :     end
85 :    
86 : jhr 342 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0