Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/lib/include/diderot/kdtree-inst.hxx
ViewVC logotype

View of /branches/vis15/src/lib/include/diderot/kdtree-inst.hxx

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4516 - (download) (as text) (annotate)
Mon Sep 5 16:19:47 2016 UTC (3 years ago) by jhr
File size: 8024 byte(s)
  Working on merge: strand arrays and spatial queries
/*! \file kdtree-inst.hxx
 *
 * \author John Reppy
 */

/*
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 */

#ifndef _DIDEROT_KDTREE_INST_HXX_
#define _DIDEROT_KDTREE_INST_HXX_

#ifndef _DIDEROT_KDTREE_HXX_
# error kdtree-inst.hxx should not be directly included
#endif

#include <stack>

namespace diderot {

    namespace __details {

      // generic test for two points being within a given radius^2
	template <const uint32_t D, typename REAL>
	bool within_sphere (const REAL pos1[D], const REAL pos2[D], REAL radius2)
	{
	    float sum = REAL(0);
	    for (uint32_t i = 0;  i < D;  i++) {
		REAL d = pos1[i] - pos2[i];
		sum += (d * d);
	    }
	    return sum <= radius2;
	}

      // generic test for two points being within a specified rectangular radius?
	template <const uint32_t D, typename REAL>
	bool within_box (const REAL pos1[D], const REAL pos2[D], REAL radius)
	{
	    for (uint32_t i = 0;  i < D;  i++) {
		if (std::abs(pos1 - pos2) > radius) {
		    return false;
		}
	    }
	    return true;
	}

    } // namespace __details

    template <const uint32_t D, typename REAL, typename SA>
    kdtree<D,REAL,SA>::kdtree (const SA *strands)
	: _nStrands(strands->num_alive()), _poolSz(strands->num_alive()),
	  _parts(nullptr), _pool(nullptr), _strands(strands)
    {
// FIXME: for "grid" programs, we don't need the extra space!
	uint32_t nStrands = strands->num_alive();
	this->_partsSz = nStrands + (nStrands >> 2);
	delete[] this->_parts;
	this->_parts = new uint32_t[this->_partsSz];
	for (int i = 0;  i < nStrands;  i++) {
	    this->_parts[i] = i;
	}

	uint32_t reqNumNodes = 2 * ((nStrands + MIN_LEAF_SIZE - 1) / MIN_LEAF_SIZE) - 1;
	this->_pool = new node[reqNumNodes];
	this->_poolSz = reqNumNodes;

    }

    template <const uint32_t D, typename REAL, typename SA>
    kdtree<D,REAL,SA>::~kdtree ()
    {
	delete[] _parts;
	delete[] _pool;
    }

    template <const uint32_t D, typename REAL, typename SA>
    void kdtree<D,REAL,SA>::swap_parts (uint32_t i, uint32_t j)
    {
	uint32_t tmp = this->_parts[i];
	this->_parts[i] = this->_parts[j];
	this->_parts[j] = tmp;
    }

  // partition _parts[lo..hi] such that _parts[lo..pivotIx] <= X < _parts[pivotIx+1..hi], where
  // X is the initial value of _strands->strands(_parts[pivotIx])->pos()[axis].
  //
    template <const uint32_t D, typename REAL, typename SA>
    uint32_t kdtree<D,REAL,SA>::partition (uint32_t axis, uint32_t lo, uint32_t hi, uint32_t pivotIx)
    {
	uint32_t inIdx = this->_strands->in_state_index();
	REAL X = this->strand(pivotIx)->pos(inIdx)[axis];

      // move pivot element to end
	this->swap_parts(pivotIx, hi);

	uint32_t ix = lo;
	for (uint32_t jx = lo;  jx < hi-1;  jx++) {
	    if (this->strand(jx)->pos(inIdx)[axis] <= X) {
		this->swap_parts (ix, jx);
		ix++;
	    }
	}

	this->swap_parts (ix, pivotIx);
	return ix;
    }

  // partition _parts[lo..hi] into _parts[lo..m] and _parts[m+1..hi] such that the strand
  // with id _parts[m] has the median position on the specified axis.
  // We use the "Quick Select" method (https://en.wikipedia.org/wiki/Quickselect)
  //
    template <const uint32_t D, typename REAL, typename SA>
    uint32_t kdtree<D,REAL,SA>::median (uint32_t axis, uint32_t lo, uint32_t hi)
    {
	assert (hi - lo >= STRANDS_PER_LEAF);
      // the mid-point of the interval [lo..hi]
	uint32_t mid = (hi + lo) >> 1;

      // partition until we are within +/- 25% of the mid-point
	uint32_t tol = ((mid - lo) >> 2);

	while (true) {
	    uint32_t pivotIx = this->partition (axis, lo, hi, mid);
	    if (std::abs(static_cast<int>(pivotIx) - static_cast<int>(mid)) <= tol) {
		return pivotIx;
	    }
	    else if (pivotIx < mid) {
		lo = pivotIx + 1;
	    }
	    else {
		hi = pivotIx - 1;
	    }
	}

    }

    template <const uint32_t D, typename REAL, typename SA>
    uint32_t kdtree<D,REAL,SA>::builder (uint32_t axis, uint32_t lo, uint32_t hi)
    {
	assert (lo <= hi);

      // allocate the node
	uint32_t nd = this->_nextNode++;
	assert (nd < this->_poolSz);

	uint32_t n = hi - lo + 1;
	if (n <= kdtree<D,REAL,SA>::STRANDS_PER_LEAF){
	  // allocate a leaf
	    this->_pool[nd]._lc = 0;
	    this->_pool[nd]._u._leaf._first = lo;
	    this->_pool[nd]._u._leaf._last = hi;
//std::cout << "allocate leaf for " << lo << ".." << hi << " (" << n << ")\n";
	    assert (MIN_LEAF_SIZE <= n);
	}
	else {
	    uint32_t mid = this->median (axis, lo, hi);
//std::cout << "allocate node for " << lo << ".." << mid << ".." << hi << " (" << (mid-lo) << ":" << (hi-mid) << ")\n";
	    // INV: strands indexed by _parts[lo..mid-1] are <= strand[_parts[mid]]
	    // and strand[_parts[mid]] < strands indexed by _parts[mid+1..hi]
	    this->_pool[nd]._u._nd._id = this->_parts[mid];
	    this->_pool[nd]._u._nd._axis = axis;
	    axis = (axis + 1) % D;
	    this->_pool[nd]._lc = builder (axis, lo, mid-1);
	    this->_pool[nd]._rc = builder (axis, mid+1, hi);
	}

	return nd;
    }

    template <const uint32_t D, typename REAL, typename SA>
    void kdtree<D,REAL,SA>::rebuild ()
    {
	uint32_t nStrands = this->_strands->num_alive();
	if (this->_nStrands > nStrands) {
	  // # of strands has shrunk from last call to rebuild
	    for (uint32_t i = 0;  i < nStrands;  i++) {
		this->_parts[i] = i;
	    }
	}
	else if (this->_nStrands < nStrands) {
	  // # of strands has grown from last call to rebuild
	    if (this->_partsSz < nStrands) {
	      // need to reallocate the _parts array
		this->_partsSz = nStrands + (nStrands >> 2);
		delete[] this->_parts;
		this->_parts = new uint32_t[nStrands + (nStrands >> 2)];
	    }
	    for (int i = 0;  i < nStrands;  i++) {
		this->_parts[i] = i;
	    }

	  // a conservative bound on the number of nodes is 2*(ceil(nStrands / MIN_LEAF_SIZE)) - 1
	    uint32_t reqNumNodes = 2 * ((nStrands + MIN_LEAF_SIZE - 1) / MIN_LEAF_SIZE) - 1;;
	    if (this->_poolSz < reqNumNodes) {
	      // grow the pool of nodes
		delete[] this->_pool;
		this->_pool = new node[reqNumNodes];
		this->_poolSz = reqNumNodes;
	    }
	}

	this->_nextNode = 0;
	uint32_t root = this->builder (0, 0, nStrands-1);
	assert (root == 0);

    }

    template <const uint32_t D, typename REAL, typename SA>
    dynseq<typename kdtree<D,REAL,SA>::index_t> kdtree<D,REAL,SA>::sphere_query (
	const kdtree<D,REAL,SA>::strand_t *self, const REAL pos[D], REAL radius)
    {
	dynseq<index_t> result;

      // return empty sequence on
	if (radius < 0.0) {
	    return result;
	}

      // stack of nodes for which we must still visit
	std::stack<const node *> stk;

	REAL radius2 = radius * radius;
	const node *nd = this->root();
	uint32_t inIdx = this->_strands->in_state_index();
	do {
	    if (nd->isLeaf()) {
	      // check strands in leaf to see if they are within the sphere
		for (uint32_t i = nd->_u._leaf._first;  i <= nd->_u._leaf._last;  i++) {
		    uint32_t id = this->_parts[i];
		    const strand_t *strand = this->strand(id);
		    if ((self != strand)
                    && __details::within_sphere<D,REAL>(strand->pos(inIdx), pos, radius2)) {
		      // add the strand to the result list
			result.append (id);
		    }
		}
		if (! stk.empty()) {
		  // continue searching
		    nd = stk.top ();
		    stk.pop ();
		}
		else {
		    nd = nullptr;
		}
	    }
	    else {
		uint32_t axis = nd->axis();
		const strand_t *strand = this->strand(nd);
		REAL sPos = strand->pos(inIdx)[axis];
		if (pos[axis] < sPos - radius) {
		    nd = this->left(nd);
		}
		else if (sPos + radius < pos[axis]) {
		    nd = this->right(nd);
		}
		else {
		    if ((self != strand)
                    && __details::within_sphere<D,REAL>(strand->pos(inIdx), pos, radius2)) {
		      // add the strand to the result list
			result.append (nd->_u._nd._id);
		    }
		    stk.push(this->right(nd));  // visit right child later
		    nd = this->left(nd);
		}
	    }
	} while (nd != nullptr);

	return result;
    }

} // namespace diderot

#endif // !_DIDEROT_KDTREE_INST_HXX_

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0