Home My Page Projects Code Snippets Project Openings SML/NJ
Summary Activity Forums Tracker Lists Tasks Docs Surveys News SCM Files

SCM Repository

[smlnj] View of /sml/trunk/src/MLRISC/block-placement/weighted-block-placement-fn.sml
ViewVC logotype

View of /sml/trunk/src/MLRISC/block-placement/weighted-block-placement-fn.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1083 - (download) (annotate)
Thu Feb 21 18:52:10 2002 UTC (17 years, 5 months ago) by jhr
File size: 8052 byte(s)
  New file: implementation of Pettis-Hansen block placement.
(* weighted-block-placement-fn.sml
 *
 * COPYRIGHT (c) 2002 Bell Labs, Lucent Technologies
 *
 * This functor implements the bottom-up block-placement algorithm of
 * Pettis and Hansen (PLDI 1990).
 *
 * TODO
 *	remove low-weight nodes to break cycles in chain graph
 *)

functor WeightedBlockPlacementFn (
    structure CFG : CONTROL_FLOW_GRAPH
    structure InsnProps : INSN_PROPERTIES
    sharing CFG.I = InsnProps.I 
  ) : BLOCK_PLACEMENT = struct

    structure CFG=CFG
    structure IP = InsnProps
    structure G = Graph
    structure ITbl = IntHashTable
    structure PQ = LeftPriorityQFn (
      struct
	type priority = Freq.freq
	val compare = Freq.compare
	type item = CFG.edge
	fun priority (_, _, CFG.EDGE{w, ...}) = !w
      end)

  (* sequences with constant-time concatenation *)
    datatype 'a seq
      = ONE of 'a
      | SEQ of ('a seq * 'a seq)

  (* a chain of blocks that should be placed in order *)
    datatype chain = CHAIN of {
	blocks : CFG.node seq,
	hd : block,
	tl : block
      }

    fun head (CHAIN{hd, ...}) = #1 hd
    fun tail (CHAIN{tl, ...}) = #1 tail
    fun id (CHAIN{hd, ...}) = #1 hd	(* use node ID of head to identify chains *)
    fun sameChain (CHAIN{hd=h1, ...}, CHAIN{hd=h2, ...}) = (h1 = h2)

  (* join two chains *)
    fun joinChains (CHAIN{blocks=b1, hd, ...}, CHAIN{bocks=b2, tl, ...}) =
	  CHAIN{blocks=SEQ(b1, b2), hd=hd, tl=tl}

    val unifyChainPtrs = URef.unify joinChains

  (* chain pointers provide a union-find structure for chains *)
    datatype chain_ptr = chain URef.uref

    type block_chain_tbl = chain_ptr ITbl.hash_table

  (* a directed graph representing the placement ordering on chains. An edge
   * from chain c1 to c2 means that we should place c1 before c2.  The graph
   * may be cyclic, so we weight the edges and remove the low-cost edge
   * on any cycle.
   *)
    datatype node = ND of {
	chain : chain,
	mark : bool ref,
	kids : edge list ref
      }
    and edge = E of {
	w : CFG.weight,
	dst : node,
	ign : bool ref		(* if set, then ignore this edge.  We use this *)
				(* flag to break cycles. *)
      }

    fun mkNode c = ND{chain = c, mark = ref false, kids = ref []}
    fun mkEdge (w, dst) = E{w = w, dst = dst, ign = ref false}

  (* given a table that maps block IDs to chain pointers, construct a table that
   * maps block IDs to their chain-placement graph nodes.
   *)
    fun mkChainPlacementGraph (tbl : block_chain_tbl) = let
	  val gTbl = ITbl.mkTable (ITbl.numItems tbl, Fail "graph table")
	  val find = ITbl.find tbl
	  val insert = ITbl.insert tbl
	(* given a block ID and the chain pointer corresponding to the block
	 * add the chain node to the graph table (this may involve creating
	 * the node if it doesn't already exist).
	 *)
	  fun blockToNd (blkId, cptr, nodes) = let
		val chain = URef.!! cptr
		val chainId = id chain
		in
		  case find chainId
		   of NONE => let
			val nd = mkNode chain
			in
			  insert (chainId, nd);
			  if (blkId != chainId)
			    then insert (blkId, nd)
			    else ();
			  nd :: nodes
			end
		    | SOME nd => (insert (blkId, nd); nodes)
		  (* end case *)
		end
	  in
	    (ITbl.foldi blockToNd [] tbl, gTbl)
	  end

    fun blockPlacement (cfg as G.GRAPH graph) = let
	(* a map from block IDs to their chain *)
	  val blkTbl : chain_ptr ITbl.hash_table = let
		val tbl = ITbl.mkTable (#size graph (), Fail "blkTbl")
		val insert = ITbl.insert tbl
		fun ins (b : CFG.node) =
		      insert (#1 b, CHAIN{blocks = BLK b, hd = b, tl = b})
		in
		  #forall_nodes graph ins;
		  tbl
		end
	  val lookupChain = ITbl.lookup
	(* given an edge that connects two blocks, attempt to merge their chains.
	 * We return true if a merge occurred.
	 *)
	  fun join (src, dst, _) = let
		val cptr1 = lookupChain src
		val chain1 = URef.!! cptr1
		val cptr2 = lookupChain dst
		val chain2 = URef.!! cptr2
		in
		  if (tail chain1 = src) andalso (dst = head chain2)
		    then
		    (* the source block is the tail of its chain and the
		     * destination block is the head of its chain, so we can
		     * join the chains.
		     *)
		      ignore (unifyChainPtrs (cptr1, cptr2))
		    else fase (* we cannot join these chains *)
		end
	(* merge chains until all of the edges have been examined; the remaining
	 * edges cannot be fall-through.
	 *)
	  fun loop (pq, edges) = (case PQ.next pq
		 of SOME(edge, pq) => if join edge;
		      then loop (pq, edges)
		      else loop (pq, edge::edges)
		  | NONE => edges
		(* end case *))
	  val edges = loop (PQ.fromList (#nodes graph ()), [])
	(* construct a chain placement graph *)
	  val (chainNodes, grTbl) = mkChainPlacementGraph blkTbl
	  val lookupNd = ITbl.lookup grTbl
	  fun addCFGEdge (src, dst, CFG.EDGE{k, w, ...}) = (case k
(* NOTE: there may be icache benefits to including SWITCH edges. *)
		 of CFG.SWITCH _ => ()
		  | CFG.FLOWSTO => ()
		  | _ => let
		      val ND{chain=c1, kids} = lookupNd src
		      val dstNd as ND{chain=c2, ...} = lookupNd dst
		      in
			if sameChain(c1, c2)
			  then ()
			  else kids := mkEdge (w, dstNd) :: !kids
		      end
		(* end case *))
	  val _ = List.app addCFGEdge edges
(* FIXME: we should remove low-weight nodes to break cycles *)
	(* now we construct an ordering on the chains by doing a DFS on the
	 * chain graph.
	 *)
	  fun dfs (ND{mark = ref true, ...}, l) = l
	    | dfs (ND{mark, chain, kids, ...}, l) = let
		fun addKid (E{ign=ref true, ...}, l) = l
		  | addKid (E{dst, ...}, l) = dfs (dst, l)
		in
		  List.foldl addKid (chain::l) (!kids)
		end
	(* mark the exit node, since it should be last *)
	  val exitChain = let
		val ND{chain, mark, ...} = lookupNd(hd(#exits graph ()))
		in
		  mark := true;
		  chain
		end
	(* start with the entry node *)
	  val chains = dfs (lookupNd(hd(#entries graph ())))
	(* place the rest of the nodes and add the exit node *)
	  val chains = List.foldl dfs chains chainNodes
	  val chains = exitChain :: chains
	(* extract the list of blocks from the chains list; the chains list is
	 * in reverse order.
	 *)
	  fun addChain (CHAIN{blocks, ...}, blks) = let
		fun addSeq (ONE b, blks) = b::blks
		  | addSeq (SEQ(s1, s2), blks) = addSeq(s1, addSeq(s2, blks))
		in
		  addSeq (blocks, blks)
		end
	  val blocks = List.foldl addChain [] chains
	  fun updEdge (CFG.EDGE{w, a, ...}, k) = CFG.EDGE{w=w, a=a, k=k}
	  fun flipJmp (insns as ref(i::r), lab) =
		insns := IP.negateConditional(i, lab) :: r
	  val setEdges = #set_out_edges graph
	(* map a block ID to a label *)
	  fun labelOf blkId = (case #node_info graph blkId
		 of CFG.BLOCK{labels=ref(lab::_), ...} => lab
		  | CFG.BLOCK{labels, ...} = let
		      val lab = Label.anon()
		      in
			labels := [lab];
			lab
		      end
		(* end case *))
	  fun patch (
		(blkId, CFG.BLOCK{kind=CFG.NORMAL, insns, ...}),
		(next as (blkId', _)) :: rest
	      ) = (case #out_edges graph blkId
		 of [(_, dst, e as EDGE{k, w, a})] => (case (dst = blkId', k)
		       of (false, CFG.FALLSTHRU) => (
			    (* rewrite edge as JUMP and add jump insn *)
			    setEdges (blkId, [(blkId, dst, updEdge(e, CFG.JUMP))]);
			    insns := IP.jump(labelOf dst) :: !insns)
			| (true, CFG.JUMP) => (
			    (* rewrite edge as FALLSTHRU and remove jump insn *)
			    setEdges (blkId,
			      [(blkId, dst, updEdge(e, CFG.FALLSTHRU))]);
			    insns := tl(!insns))
			| _ => ()
		      (* end case *))
		  | [(_, dst1, e1 as EDGE{k=CFG.BRANCH b, ...}), (_, dst2, e2)] => (
		      case (dst1 = blkId', b)
		       of (true, true) => (
			    setEdges (blkId, [
				(blkId, dst1, updEdge(e1, CFG.BRANCH false)),
				(blkId, dst2, updEdge(e2, CFG.BRANCH true))
			      ]);
			    flipJmp (insns, labelOf dst2))
			| (false, false) => (
			    setEdges (blkId, [
				(blkId, dst1, updEdge(e1, CFG.BRANCH true)),
				(blkId, dst2, updEdge(e2, CFG.BRANCH false))
			      ]);
			    flipJmp (insns, labelOf dst1))
			| _ => ()
		      (* end case *))
		  | _ => ()
		(* end case *))
	    | patch (_, next::rest) = patch(next, rest)
	    | patch (_, []) = ()
	  in
	    patch (hd blocks, tl blocks);
	    blocks;
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0