Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/lib/include/diderot/kdtree-inst.hxx
ViewVC logotype

View of /branches/vis15/src/lib/include/diderot/kdtree-inst.hxx

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4372 - (download) (as text) (annotate)
Sat Aug 6 12:04:26 2016 UTC (3 years, 1 month ago) by jhr
File size: 7612 byte(s)
  Working on merge: spatial query support
/*! \file kdtree-inst.hxx
 *
 * \author John Reppy
 */

/*
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 */

#ifndef _DIDEROT_KDTREE_INST_HXX_
#define _DIDEROT_KDTREE_INST_HXX_

#ifndef _DIDEROT_KDTREE_HXX_
# error kdtree-inst.hxx should not be directly included
#endif

#include <stack>

namespace diderot {

    namespace __details {

      // generic test for two points being within a given radius^2
	template <const uint32_t D, typename REAL>
	bool within_sphere (const REAL pos1[D], const REAL pos2[D], REAL radius2)
	{
	    float sum = REAL(0);
	    for (uint32_t i = 0;  i < D;  i++) {
		REAL d = pos1[i] - pos2[i];
		sum += (d * d);
	    }
	    return sum <= radius2;
	}

      // generic test for two points being within a specified retangular radius?
	template <const uint32_t D, typename REAL>
	bool within_box (const REAL pos1[D], const REAL pos2[D], REAL radius)
	{
	    for (uint32_t i = 0;  i < D;  i++) {
		if (std::abs(pos1 - pos2) > radius) {
		    return false;
		}
	    }
	    return true;
	}

    } // namespace __details

    template <const uint32_t D, typename REAL, typename S>
    kdtree<D,REAL,S>::kdtree (uint32_t nStrands)
	: _nStrands(nStrands), _poolSz(nStrands),
	  _parts(nullptr), _pool(nullptr), _strands(nullptr)
    {
// FIXME: for "grid" programs, we don't need the extra space!
	this->_partsSz = nStrands + (nStrands >> 2);
	delete[] this->_parts;
	this->_parts = new uint32_t[nStrands + (nStrands >> 2)];
	for (int i = 0;  i < nStrands;  i++) {
	    this->_parts[i] = i;
	}

	uint32_t reqNumNodes =
	    2 * ((nStrands + kdtree<D,REAL,S>::STRANDS_PER_LEAF - 1) / STRANDS_PER_LEAF) - 1;
	if (reqNumNodes < this->_poolSz) {
	    delete[] this->_pool;
	    this->_pool = new node[reqNumNodes];
	    this->_poolSz = reqNumNodes;
	}

    }

    template <const uint32_t D, typename REAL, typename S>
    kdtree<D,REAL,S>::~kdtree ()
    {
	delete[] _parts;
	delete[] _pool;
    }

    template <const uint32_t D, typename REAL, typename S>
    void kdtree<D,REAL,S>::swap_parts (uint32_t i, uint32_t j)
    {
	uint32_t tmp = this->_parts[i];
	this->_parts[i] = this->_parts[j];
	this->_parts[j] = tmp;
    }

  // partition _parts[lo..hi] such that _parts[lo..pivotIx] <= X < _parts[pivotIx+1..hi], where
  // X is the initial value of _strands[_parts[pivotIx]].pos()[axis].
  //
    template <const uint32_t D, typename REAL, typename S>
    uint32_t kdtree<D,REAL,S>::partition (uint32_t axis, uint32_t lo, uint32_t hi, uint32_t pivotIx)
    {
	REAL X = this->strand(pivotIx)->pos()[axis];

      // move pivot element to end
	this->swap_parts(pivotIx, hi);

	uint32_t ix = lo;
	for (uint32_t jx = lo;  jx < hi-1;  jx++) {
	    if (this->strand(jx)->pos()[axis] <= X) {
		this->swap_parts (ix, jx);
		ix++;
	    }
	}

	this->swap_parts (ix, pivotIx);
	return ix;
    }

  // partition _parts[lo..hi] into _parts[lo..m] and _parts[m+1..hi] such that the strand
  // with id _parts[m] has the median position on the specified axis.
  // We use the "Quick Select" method (https://en.wikipedia.org/wiki/Quickselect)
  //
    template <const uint32_t D, typename REAL, typename S>
    uint32_t kdtree<D,REAL,S>::median (uint32_t axis, uint32_t lo, uint32_t hi)
    {
	assert (hi - lo >= STRANDS_PER_LEAF);
      // the mid-point of the interval [lo..hi]
	uint32_t mid = (hi + lo) >> 1;

      // partition until we are within +/- 25% of the mid-point
	uint32_t tol = ((mid - lo) >> 2);

	while (true) {
	    uint32_t pivotIx = this->partition (axis, lo, hi, mid);
	    if (std::abs(static_cast<int>(pivotIx) - static_cast<int>(mid)) <= tol) {
		return pivotIx;
	    }
	    else if (pivotIx < mid) {
		lo = pivotIx + 1;
	    }
	    else {
		hi = pivotIx - 1;
	    }
	}

    }

    template <const uint32_t D, typename REAL, typename S>
    uint32_t kdtree<D,REAL,S>::builder (uint32_t axis, uint32_t lo, uint32_t hi)
    {
	assert (lo <= hi);

      // allcate the node
	uint32_t nd = this->_nextNode++;

	uint32_t n = hi - lo + 1;
	if (n <= kdtree<D,REAL,S>::STRANDS_PER_LEAF){
	  // allocate a leaf
	    this->_pool[nd]._lc = 0;
	    this->_pool[nd]._u._leaf._first = lo;
	    this->_pool[nd]._u._leaf._last = hi;
	}
	else {
	    uint32_t mid = this->median (axis, lo, hi);
	    // INV: strands indexed by _parts[lo..mid-1] are <= strand[_parts[mid]]
	    // and strand[_parts[mid]] < strands indexed by _parts[mid+1..hi]
	    this->_pool[nd]._u._nd._id = this->_parts[mid];
	    this->_pool[nd]._u._nd._axis = axis;
	    axis = (axis + 1) % D;
	    this->_pool[nd]._lc = builder (axis, lo, mid-1);
	    this->_pool[nd]._rc = builder (axis, mid+1, hi);
	}

	return nd;
    }

    template <const uint32_t D, typename REAL, typename S>
    void kdtree<D,REAL,S>::rebuild (uint32_t nStrands, const S *strands)
    {
	this->_strands = strands;

	if (this->_nStrands > nStrands) {
	  // # of strands has shrunk from last call to rebuild
	    for (uint32_t i = 0;  i < nStrands;  i++) {
		this->_parts[i] = i;
	    }
	}
	else if (this->_nStrands < nStrands) {
	  // # of strands has grown from last call to rebuild
	    if (this->_partsSz < nStrands) {
	      // need to reallocate the _parts array
		this->_partsSz = nStrands + (nStrands >> 2);
		delete[] this->_parts;
		this->_parts = new uint32_t[nStrands + (nStrands >> 2)];
	    }
	    for (int i = 0;  i < nStrands;  i++) {
		this->_parts[i] = i;
	    }

	  // a conservative bound on the number of nodes is 2*(ceil(nStrands / STRANDS_PER_LEAF)) - 1
	    uint32_t reqNumNodes =
		2 * ((nStrands + kdtree<D,REAL,S>::STRANDS_PER_LEAF - 1) / STRANDS_PER_LEAF) - 1;
	    if (reqNumNodes < this->_poolSz) {
		delete[] this->_pool;
		this->_pool = new node[reqNumNodes];
		this->_poolSz = reqNumNodes;
	    }
	}

	this->_nextNode = 0;
	uint32_t root = this->builder (0, 0, nStrands-1);
	assert (root == 0);

    }

// FIXME: need to filter out the strand that is doing the query!!
    template <const uint32_t D, typename REAL, typename S>
    dynseq<uint32_t> kdtree<D,REAL,S>::sphere_query (const REAL pos[D], REAL radius)
    {
	dynseq<uint32_t> result;

      // return empty sequence on
	if (radius < 0.0) {
	    return result;
	}

      // stack of nodes for which we must still visit
	std::stack<const node *> stk;

	REAL radius2 = radius * radius;
	const node *nd = this->root();
	do {
	    if (nd->isLeaf()) {
	      // check strands in leaf to see if they are within the sphere
		for (uint32_t i = nd->_u._leaf._first;  i <= nd->_u._leaf._last;  i++) {
		    uint32_t id = this->_parts[i];
		    if (__details::within_sphere<D,REAL>(this->_strands[id].pos(), pos, radius2)) {
		      // add the strand to the result list
			result.append (id);
		    }
		}
		if (! stk.empty()) {
		  // continue searching
		    nd = stk.top ();
		    stk.pop ();
		}
		else {
		    nd = nullptr;
		}
	    }
	    else {
		uint32_t axis = nd->axis();
		REAL sPos = this->strand(nd)->pos()[axis];
		if (pos[axis] < sPos - radius) {
		    nd = this->left(nd);
		}
		else if (sPos + radius < pos[axis]) {
		    nd = this->right(nd);
		}
		else {
		    if (__details::within_sphere<D,REAL>(this->strand(nd)->pos(), pos, radius2)) {
		      // add the strand to the result list
			result.append (nd->_u._nd._id);
std::cout << "result.append(" << nd->_u._nd._id << "); length = " << result.length() << std::endl;
		    }
		    stk.push(this->right(nd));  // visit right child later
		    nd = this->left(nd);
		}
	    }
	} while (nd != nullptr);

	return result;
    }

} // namespace diderot

#endif // !_DIDEROT_KDTREE_INST_HXX_

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0