Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

# SCM Repository

[diderot] Annotation of /trunk/src/compiler/fields/shape.sml
 [diderot] / trunk / src / compiler / fields / shape.sml

# Annotation of /trunk/src/compiler/fields/shape.sml

Revision 3349 - (view) (download)

 1 : jhr 342 (* shape.sml 2 : * 3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * 5 : * COPYRIGHT (c) 2015 The University of Chicago 6 : jhr 342 * All rights reserved. 7 : * 8 : jhr 349 * A tree representation of the shape of a tensor (or loop nest). The height 9 : * of the tree corresponds to the order of the tensor (or nesting depth) plus 10 : * one. I.e., a 0-order tensor is represented by a leaf, a 1-order tensor 11 : * will be ND(_, [LF _, ..., LF _]), etc. 12 : jhr 342 *) 13 : 14 : jhr 349 structure Shape : sig 15 : jhr 342 16 : datatype ('nd, 'lf) shape 17 : = LF of 'lf 18 : | ND of ('nd * ('nd, 'lf) shape list) 19 : 20 : jhr 1116 (* create (h, w, labelNd, f, labelLf, root) 21 : * creates a tree of height h (>= 1), with each interior node having 22 : * w (>= 1) children. 23 : jhr 349 *) 24 : val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape 25 : 26 : jhr 2356 (* createFromShape (shape, labelNd, f, labelLf, root) 27 : * creates a shape tree from the given tensor shape (i.e., int list). The height of the 28 : * tree will be length(shape)+1 and the number of children at each level is defined by 29 : * the corresponding element of shape. 30 : *) 31 : val createFromShape : int list * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape 32 : 33 : jhr 1116 (* map functions over the nodes and leaves of a shape *) 34 : jhr 349 val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape 35 : 36 : jhr 353 (* right-to-left traversal of the tree *) 37 : jhr 349 val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b 38 : 39 : jhr 2356 (* apply a node function and a leaf function to the tree in a pre-order traversal *) 40 : jhr 353 val appPreOrder : ('nd -> unit) * ('lf -> unit) -> ('nd, 'lf) shape -> unit 41 : 42 : jhr 349 end = struct 43 : 44 : datatype ('nd, 'lf) shape 45 : = LF of 'lf 46 : | ND of ('nd * ('nd, 'lf) shape list) 47 : 48 : jhr 1116 (* creates a shape with the given height and width at each level *) 49 : fun create (height, width, ndAttr, f, lfAttr, init) = let 50 : fun mk (0, arg) = LF(lfAttr arg) 51 : | mk (d, arg) = ND(ndAttr arg, List.tabulate(width, fn j => mk(d-1, f(j, arg)))) 52 : jhr 342 in 53 : jhr 1116 if (height < 0) orelse (width < 1) 54 : then raise Size 55 : else mk (height, init) 56 : jhr 342 end 57 : 58 : jhr 2356 fun createFromShape (shape, ndAttr, f, lfAttr, init) = let 59 : fun mk ([], arg) = LF(lfAttr arg) 60 : | mk (d::dd, arg) = ND(ndAttr arg, List.tabulate(d, fn j => mk(dd, f(j, arg)))) 61 : in 62 : mk (shape, init) 63 : end 64 : 65 : jhr 342 fun map (nd, lf) t = let 66 : fun mapf (LF x) = LF(lf x) 67 : | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids) 68 : in 69 : mapf t 70 : end 71 : 72 : fun foldr f init t = let 73 : fun fold (LF x, acc) = f(x, acc) 74 : | fold (ND(_, kids), acc) = List.foldr fold acc kids 75 : in 76 : jhr 349 fold (t, init) 77 : jhr 342 end 78 : 79 : jhr 353 fun appPreOrder (ndFn, lfFn) = let 80 : fun app (ND(attr, kids)) = (ndFn attr; List.app app kids) 81 : | app (LF attr) = lfFn attr 82 : in 83 : app 84 : end 85 : 86 : jhr 342 end

 root@smlnj-gforge.cs.uchicago.edu ViewVC Help Powered by ViewVC 1.0.0