Skip to content

Commit

Permalink
Merge pull request #18 from mastodon-sc/improve-kd-tree-performance
Browse files Browse the repository at this point in the history
Improve kd tree performance during initialization
  • Loading branch information
tinevez authored Oct 17, 2024
2 parents fcd9dfb + 1e0d1dc commit a79b576
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 120 deletions.
130 changes: 10 additions & 120 deletions src/main/java/org/mastodon/kdtree/KDTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Collection;
import java.util.function.IntToDoubleFunction;

import org.mastodon.RefPool;
import org.mastodon.collection.RefRefMap;
import org.mastodon.collection.ref.RefRefHashMap;
import org.mastodon.util.KthElement;
import org.mastodon.pool.DoubleMappedElement;
import org.mastodon.pool.DoubleMappedElementArray;
import org.mastodon.pool.MappedElement;
Expand Down Expand Up @@ -294,7 +296,14 @@ private int makeNode( final int i, final int j, final int d, final KDTreeNode< O
if ( j > i )
{
final int k = i + ( j - i ) / 2;
kthElement( i, j, k, d, n1, n2, n3 );
final IntToDoubleFunction rankMethod = index -> {
getObject( index, n1 );
return n1.getDoublePosition( d );
};
final KthElement.Swap swapMethod = ( l, m ) -> {
getMemPool().swap( l, m );
};
KthElement.kthElement( i, j, k, rankMethod, swapMethod );

final int dChild = ( d + 1 == n ) ? 0 : d + 1;
final int left = makeNode( i, k - 1, dChild, n1, n2, n3 );
Expand All @@ -319,125 +328,6 @@ else if ( j == i )
}
}

/**
* Partition a sublist of KDTreeNodes such that the k-th smallest value is
* at position {@code k}, elements before the k-th are smaller or equal and
* elements after the k-th are larger or equal. Elements are compared by
* their coordinate in the specified dimension.s
*
* Note, that is is assumed that the {@link KDTreeNode}s are stored with
* consecutive indices in the pool.
*
* @param i
* index of first element of the sublist
* @param j
* index of last element of the sublist
* @param k
* index for k-th smallest value. i &lt;= k &lt;= j.
* @param compare_d
* dimension by which to compare.
* @param pivot
* temporary {@link KDTreeNode} reference.
* @param ti
* temporary {@link KDTreeNode} reference.
* @param tj
* temporary {@link KDTreeNode} reference.
*/
private void kthElement( int i, int j, final int k, final int compare_d, final KDTreeNode< O, T > pivot, final KDTreeNode< O, T > ti, final KDTreeNode< O, T > tj )
{
while ( true )
{
final int pivotpos = partitionSubList( i, j, compare_d, pivot, ti, tj );
if ( pivotpos > k )
{
// partition lower half
j = pivotpos - 1;
}
else if ( pivotpos < k )
{
// partition upper half
i = pivotpos + 1;
}
else
break;
}
}

/**
* Partition a sublist of KDTreeNodes by their coordinate in the specified
* dimension.
*
* The element at index {@code j} is taken as the pivot value. The elements
* {@code [i,j]} are reordered, such that all elements before the pivot are
* smaller and all elements after the pivot are equal or larger than the
* pivot. The index of the pivot element is returned.
*
* Note, that is is assumed that the {@link KDTreeNode}s are stored with
* consecutive indices in the pool.
*
* @param i
* index of first element of the sublist
* @param j
* index of last element of the sublist
* @param compare_d
* dimension by which to order the sublist
* @param pivot
* temporary {@link KDTreeNode} reference.
* @param ti
* temporary {@link KDTreeNode} reference.
* @param tj
* temporary {@link KDTreeNode} reference.
* @return index of pivot element
*/
private int partitionSubList( int i, int j, final int compare_d, final KDTreeNode< O, T > pivot, final KDTreeNode< O, T > ti, final KDTreeNode< O, T > tj )
{
final int pivotIndex = j;
getObject( j--, pivot );
final double pivotPosition = pivot.getPosition( compare_d );

A: while ( true )
{
// move i forward while < pivot (and not at j)
while ( i <= j )
{
getObject( i, ti );
if ( ti.getPosition( compare_d ) >= pivotPosition )
break;
++i;
}
// now [i] is the place where the next value < pivot is to be
// inserted

if ( i > j )
break;

// move j backward while >= pivot (and not at i)
while ( true )
{
getObject( j, tj );
if ( tj.getPosition( compare_d ) < pivotPosition )
{
// swap [j] with [i]
getMemPool().swap( i++, j-- );
break;
}
else if ( j == i )
{
break A;
}
--j;
}
}

// we are done. put the pivot element here.
// check whether the element at iLastIndex is <
if ( i != pivotIndex )
{
getMemPool().swap( i, pivotIndex );
}
return i;
}

@Override
public int numDimensions()
{
Expand Down
94 changes: 94 additions & 0 deletions src/main/java/org/mastodon/util/KthElement.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package org.mastodon.util;

import java.util.function.IntToDoubleFunction;

/**
* Class for partially sorting a list. The class is used in the KDTree.
*/
public class KthElement
{
private KthElement()
{
// prevent instantiation
}

/**
* Partially sort a sublist such that the element that would be at postition
* {@code k} in a sorted list is at position {@code k}, elements before the k-th
* are smaller or equal and elements after the k-th are larger or equal.
*
* @param i index of first element of the sublist
* @param j index of last element of the sublist
* @param k index for k-th smallest value. i &lt;= k &lt;= j.
* @param rankMethod method that returns i-th rank of the i-th value in the list
* @param swapMethod method that swaps entry i and j in the list
*/
public static void kthElement( int i, int j, final int k, final IntToDoubleFunction rankMethod, final Swap swapMethod )
{
while ( i < j )
{
final int pivotpos = partitionSubList( i, j, rankMethod, swapMethod );
if ( k < pivotpos )
{
// partition lower half
j = pivotpos - 1;
}
else //if ( k >= pivotpos )
{
// partition upper half
i = pivotpos;
}
}
}

/**
* Partition a sublist.
*
* The method does not swap entries for a correctly sorted list.
*
* @param left index of first element of the sublist
* @param right index of last element of the sublist
* @param rankMethod method that returns i-th rank of the i-th value in the list
* @param swapMethod method that swaps entry i and j in the list
* @return the index of the first element of the right partition
*/
static int partitionSubList( final int left, final int right, final IntToDoubleFunction rankMethod, final Swap swapMethod )
{
final double pivot = rankMethod.applyAsDouble( ( left + right ) / 2 );
int i = left;
int j = right;

while ( true )
{
double ivalue = rankMethod.applyAsDouble( i );
while ( ivalue < pivot )
{
++i;
ivalue = rankMethod.applyAsDouble( i );
}

double jvalue = rankMethod.applyAsDouble( j );
while ( pivot < jvalue )
{
--j;
jvalue = rankMethod.applyAsDouble( j );
}

if ( i <= j )
{
if ( ivalue > jvalue ) // this avoids unnecessary swaps in case of ivalue = jvalue = pivot
swapMethod.swap( i, j );
++i;
--j;
}

if ( i > j )
return i;
}
}

public interface Swap
{
void swap( int i, int j );
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.mastodon.kdtree;

import net.imglib2.util.StopWatch;

import org.mastodon.collection.RefList;
import org.mastodon.collection.ref.RefArrayList;

/**
* Measure how long it takes to initialize a KDTree with 1_000_000 points
* if the points are not randomly distributed, but lie on a circle. This
* is a difficult case scenario for KDTree initialization.
*/
public class KDTreeInitializationBenchmark
{
public static void main( final String... args )
{
final int count = 1_000_000;

final RealPointPool vertexPool = new RealPointPool( 3, count );
final RefList< RealPoint > positions = pointsInACircle( vertexPool, count );

final StopWatch watch = StopWatch.createAndStart();
for ( int i = 0; i < 10; i++ )
KDTree.kdtree( positions, vertexPool );
System.out.println( watch );
}

private static RefList< RealPoint > pointsInACircle( final RealPointPool vertexPool, final int count )
{
final RefList< RealPoint > positions = new RefArrayList<>( vertexPool );
final RealPoint point = positions.createRef();
for ( int i = 0; i < count; i++ )
{
final double angle = 2 * Math.PI * i / count;
point.init( Math.sin( angle ), Math.cos( angle ), 1 );
positions.add( point );
}
return positions;
}
}
Loading

0 comments on commit a79b576

Please sign in to comment.