Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /trunk/src/compiler/fields/shape.sml
ViewVC logotype

View of /trunk/src/compiler/fields/shape.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2356 - (download) (annotate)
Sun Apr 7 14:45:25 2013 UTC (6 years, 6 months ago) by jhr
File size: 2752 byte(s)
  Merging in bug fixes and language enhancements from the vis12 branch (via staging).
  Features include type promotion, the curl and colon operator, transpose, and functions.
(* shape.sml
 *
 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *
 * A tree representation of the shape of a tensor (or loop nest).  The height
 * of the tree corresponds to the order of the tensor (or nesting depth) plus
 * one.  I.e., a 0-order tensor is represented by a leaf, a 1-order tensor
 * will be ND(_, [LF _, ..., LF _]), etc.
 *)

structure Shape : sig

    datatype ('nd, 'lf) shape
      = LF of 'lf
      | ND of ('nd * ('nd, 'lf) shape list)

  (* create (h, w, labelNd, f, labelLf, root)
   * creates a tree of height h (>= 1), with each interior node having
   * w (>= 1) children.
   *)
    val create : int * int * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape

  (* createFromShape (shape, labelNd, f, labelLf, root)
   * creates a shape tree from the given tensor shape (i.e., int list).  The height of the
   * tree will be length(shape)+1 and the number of children at each level is defined by
   * the corresponding element of shape.
   *)
    val createFromShape : int list * ('a -> 'nd) * (int * 'a -> 'a) * ('a -> 'lf) * 'a -> ('nd, 'lf) shape

  (* map functions over the nodes and leaves of a shape *)
    val map : ('a -> 'b) * ('c -> 'd) -> ('a,'c) shape -> ('b,'d) shape

  (* right-to-left traversal of the tree *)
    val foldr : ('a * 'b -> 'b) -> 'b -> ('c,'a) shape -> 'b

  (* apply a node function and a leaf function to the tree in a pre-order traversal *)
    val appPreOrder : ('nd -> unit) * ('lf -> unit) -> ('nd, 'lf) shape -> unit

  end = struct

    datatype ('nd, 'lf) shape
      = LF of 'lf
      | ND of ('nd * ('nd, 'lf) shape list)

  (* creates a shape with the given height and width at each level *)
    fun create (height, width, ndAttr, f, lfAttr, init) = let
	  fun mk (0, arg) = LF(lfAttr arg)
	    | mk (d, arg) = ND(ndAttr arg, List.tabulate(width, fn j => mk(d-1, f(j, arg))))
	  in
	    if (height < 0) orelse (width < 1)
	      then raise Size
	      else mk (height, init)
	  end

    fun createFromShape (shape, ndAttr, f, lfAttr, init) = let
          fun mk ([], arg) = LF(lfAttr arg)
            | mk (d::dd, arg) = ND(ndAttr arg, List.tabulate(d, fn j => mk(dd, f(j, arg))))
          in
            mk (shape, init)
          end

    fun map (nd, lf) t = let
	  fun mapf (LF x) = LF(lf x)
	    | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids)
	  in
	    mapf t
	  end

    fun foldr f init t = let
	  fun fold (LF x, acc) = f(x, acc)
	    | fold (ND(_, kids), acc) = List.foldr fold acc kids
	  in
	    fold (t, init)
	  end

    fun appPreOrder (ndFn, lfFn) = let
	  fun app (ND(attr, kids)) = (ndFn attr; List.app app kids)
	    | app (LF attr) = lfFn attr
	  in
	    app
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0