Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/compiler/fields/shape.sml
ViewVC logotype

Annotation of /trunk/src/compiler/fields/shape.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 349 - (view) (download)

1 : jhr 342 (* shape.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 : jhr 349 * A tree representation of the shape of a tensor (or loop nest). The height
7 :     * of the tree corresponds to the order of the tensor (or nesting depth) plus
8 :     * one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor
9 :     * will be ND(_, [LF _, ..., LF _]), etc.
10 : jhr 342 *)
11 :    
12 : jhr 349 structure Shape : sig
13 : jhr 342
14 :     datatype ('nd, 'lf) shape
15 :     = LF of 'lf
16 :     | ND of ('nd * ('nd, 'lf) shape list)
17 :    
18 : jhr 349 (* create (depth, wid, labelNd, f, labelLf, root)
19 :     *)
20 :     val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape
21 :    
22 :     val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape
23 :    
24 :     val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b
25 :    
26 :     end = struct
27 :    
28 :     datatype ('nd, 'lf) shape
29 :     = LF of 'lf
30 :     | ND of ('nd * ('nd, 'lf) shape list)
31 :    
32 : jhr 342 (* creates a shape with the given depth and width at each level *)
33 :     fun create (depth, width, ndAttr, f, lfAttr, init) = let
34 :     fun mk (d, i, arg) = if (d < depth)
35 :     then ND(ndAttr arg, List.tabulate(width, fn j => mk(d+1, j, f(j, arg))))
36 :     else LF(lfAttr arg)
37 :     in
38 :     mk (0, 0, init)
39 :     end
40 :    
41 :     fun map (nd, lf) t = let
42 :     fun mapf (LF x) = LF(lf x)
43 :     | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids)
44 :     in
45 :     mapf t
46 :     end
47 :    
48 :     fun foldr f init t = let
49 :     fun fold (LF x, acc) = f(x, acc)
50 :     | fold (ND(_, kids), acc) = List.foldr fold acc kids
51 :     in
52 : jhr 349 fold (t, init)
53 : jhr 342 end
54 :    
55 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0