[Scipy-svn] r4777 - in branches/spatial/scipy/spatial: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Oct 5 01:52:04 EDT 2008


Author: peridot
Date: 2008-10-05 00:51:47 -0500 (Sun, 05 Oct 2008)
New Revision: 4777

Modified:
   branches/spatial/scipy/spatial/kdtree.py
   branches/spatial/scipy/spatial/tests/test_kdtree.py
Log:
Added all-neighbors code, both array-of-points and between two kd-trees. Not optimized.


Modified: branches/spatial/scipy/spatial/kdtree.py
===================================================================
--- branches/spatial/scipy/spatial/kdtree.py	2008-10-05 05:02:32 UTC (rev 4776)
+++ branches/spatial/scipy/spatial/kdtree.py	2008-10-05 05:51:47 UTC (rev 4777)
@@ -4,6 +4,14 @@
 from heapq import heappush, heappop
 
 def distance_p(x,y,p=2):
+    """Compute the pth power of the L**p distance between x and y
+    
+    For efficiency, this function computes the L**p distance but does
+    not extract the pth root. If p is 1 or infinity, this is equal to
+    the actual L**p distance.
+    """
+    x = np.asarray(x)
+    y = np.asarray(y)
     if p==np.inf:
         return np.amax(np.abs(y-x),axis=-1)
     elif p==1:
@@ -11,12 +19,59 @@
     else:
         return np.sum(np.abs(y-x)**p,axis=-1)
 def distance(x,y,p=2):
+    """Compute the L**p distance between x and y"""
+    x = np.asarray(x)
+    y = np.asarray(y)
     if p==np.inf or p==1:
         return distance_p(x,y,p)
     else:
         return distance_p(x,y,p)**(1./p)
 
+class Rectangle(object):
+    """Hyperrectangle class.
 
+    Represents a Cartesian product of intervals.
+    """
+    def __init__(self, maxes, mins):
+        """Construct a hyperrectangle."""
+        self.maxes = np.maximum(maxes,mins).astype(np.float)
+        self.mins = np.minimum(maxes,mins).astype(np.float)
+        self.k, = self.maxes.shape
+
+    def __repr__(self):
+        return "<Rectangle %s>" % zip(self.mins, self.maxes)
+
+    def volume(self):
+        """Total volume."""
+        return np.prod(self.maxes-self.mins)
+    
+    def split(self, d, split):
+        """Produce two hyperrectangles by splitting along axis d."""
+        mid = np.copy(self.maxes)
+        mid[d] = split
+        less = Rectangle(self.mins, mid)
+        mid = np.copy(self.mins)
+        mid[d] = split
+        greater = Rectangle(mid, self.maxes)
+        return less, greater
+
+    def min_distance_point(self, x, p=2.):
+        """Compute the minimum distance between x and a point in the hyperrectangle."""
+        return distance(0, np.maximum(0,np.maximum(self.mins-x,x-self.maxes)),p)
+
+    def max_distance_point(self, x, p=2.):
+        """Compute the maximum distance between x and a point in the hyperrectangle."""
+        return distance(0, np.maximum(self.maxes-x,x-self.mins),p)
+
+    def min_distance_rectangle(self, other, p=2.):
+        """Compute the minimum distance between points in the two hyperrectangles."""
+        return distance(0, np.maximum(0,np.maximum(self.mins-other.maxes,other.mins-self.maxes)),p)
+
+    def max_distance_rectangle(self, other, p=2.):
+        """Compute the maximum distance between points in the two hyperrectangles."""
+        return distance(0, np.maximum(self.maxes-other.mins,other.maxes-self.mins),p)
+
+
 class KDTree(object):
     """kd-tree for quick nearest-neighbor lookup
 
@@ -42,6 +97,11 @@
     For large dimensions (20 is already large) do not expect this to run 
     significantly faster than brute force. High-dimensional nearest-neighbor
     queries are a substantial open problem in computer science.
+
+    The tree also supports all-neighbors queries, both with arrays of points
+    and with other kd-trees. These do use a reasonably efficient algorithm,
+    but the kd-tree is not necessarily the best data structure for this
+    sort of calculation.
     """
 
     def __init__(self, data, leafsize=10):
@@ -302,4 +362,135 @@
                 raise ValueError("Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None")
 
 
+    def __query_ball_point(self, x, r, p=2., eps=0):
+        R = Rectangle(self.maxes, self.mins)
 
+        def traverse_checking(node, rect):
+            if rect.min_distance_point(x,p)>=r/(1.+eps):
+                return []
+            elif rect.max_distance_point(x,p)<r*(1.+eps):
+                return traverse_no_checking(node)
+            elif isinstance(node, KDTree.leafnode):
+                d = self.data[node.idx]
+                return node.idx[distance(d,x,p)<=r].tolist()
+            else:
+                less, greater = rect.split(node.split_dim, node.split)
+                return traverse_checking(node.less, less)+traverse_checking(node.greater, greater)
+        def traverse_no_checking(node):
+            if isinstance(node, KDTree.leafnode):
+                
+                return node.idx.tolist()
+            else:
+                return traverse_no_checking(node.less)+traverse_no_checking(node.greater)
+
+        return traverse_checking(self.tree, R)
+
+    def query_ball_point(self, x, r, p=2., eps=0):
+        """Find all points within r of x
+
+        Parameters
+        ==========
+
+        x : array_like, shape tuple + (self.k,)
+            The point or points to search for neighbors of
+        r : positive float
+            The radius of points to return
+        p : float 1<=p<=infinity
+            Which Minkowski p-norm to use
+        eps : nonnegative float
+            Approximate search. Branches of the tree are not explored
+            if their nearest points are further than r/(1+eps), and branches
+            are added in bulk if their furthest points are nearer than r*(1+eps).
+
+        Returns
+        =======
+
+        results : list or array of lists
+            If x is a single point, returns a list of the indices of the neighbors
+            of x. If x is an array of points, returns an object array of shape tuple
+            containing lists of neighbors.
+
+
+        Note: if you have many points whose neighbors you want to find, you may save
+        substantial amounts of time by putting them in a KDTree and using query_ball_tree.
+        """
+        x = np.asarray(x)
+        if x.shape[-1]!=self.k:
+            raise ValueError("Searching for a %d-dimensional point in a %d-dimensional KDTree" % (x.shape[-1],self.k))
+        if len(x.shape)==1:
+            return self.__query_ball_point(x,r,p,eps)
+        else:
+            retshape = x.shape[:-1]
+            result = np.empty(retshape,dtype=np.object)
+            for c in np.ndindex(retshape):
+                result[c] = self.__query_ball_point(x[c], r, p=p, eps=eps)
+            return result
+
+    def query_ball_tree(self, other, r, p=2., eps=0):
+        """Find all pairs of points whose distance is at most r
+
+        Parameters
+        ==========
+
+        other : KDTree
+            The tree containing points to search against
+        r : positive float
+            The maximum distance
+        p : float 1<=p<=infinity
+            Which Minkowski norm to use
+        eps : nonnegative float
+            Approximate search. Branches of the tree are not explored
+            if their nearest points are further than r/(1+eps), and branches
+            are added in bulk if their furthest points are nearer than r*(1+eps).
+        
+        Returns
+        =======
+
+        results : list of lists
+            For each element self.data[i] of this tree, results[i] is a list of the
+            indices of its neighbors in other.data.
+        """
+        results = [[] for i in range(self.n)]
+        def traverse_checking(node1, rect1, node2, rect2):
+            if rect1.min_distance_rectangle(rect2, p)>r/(1.+eps):
+                return
+            elif rect1.max_distance_rectangle(rect2, p)<r*(1.+eps):
+                traverse_no_checking(node1, node2)
+            elif isinstance(node1, KDTree.leafnode):
+                if isinstance(node2, KDTree.leafnode):
+                    d = other.data[node2.idx]
+                    for i in node1.idx:
+                        results[i] += node2.idx[distance(d,self.data[i],p)<=r].tolist()
+                else:
+                    less, greater = rect2.split(node2.split_dim, node2.split)
+                    traverse_checking(node1,rect1,node2.less,less)
+                    traverse_checking(node1,rect1,node2.greater,greater)
+            elif isinstance(node2, KDTree.leafnode):
+                less, greater = rect1.split(node1.split_dim, node1.split)
+                traverse_checking(node1.less,less,node2,rect2)
+                traverse_checking(node1.greater,greater,node2,rect2)
+            else:
+                less1, greater1 = rect1.split(node1.split_dim, node1.split)
+                less2, greater2 = rect2.split(node2.split_dim, node2.split)
+                traverse_checking(node1.less,less1,node2.less,less2)
+                traverse_checking(node1.less,less1,node2.greater,greater2)
+                traverse_checking(node1.greater,greater1,node2.less,less2)
+                traverse_checking(node1.greater,greater1,node2.greater,greater2)
+
+        def traverse_no_checking(node1, node2):
+            if isinstance(node1, KDTree.leafnode):
+                if isinstance(node2, KDTree.leafnode):
+                    for i in node1.idx:
+                        results[i] += node2.idx.tolist()
+                else:
+                    traverse_no_checking(node1, node2.less)
+                    traverse_no_checking(node1, node2.greater)
+            else:
+                traverse_no_checking(node1.less, node2)
+                traverse_no_checking(node1.greater, node2)
+
+        traverse_checking(self.tree, Rectangle(self.maxes, self.mins),
+                          other.tree, Rectangle(other.maxes, other.mins))
+        return results
+
+        

Modified: branches/spatial/scipy/spatial/tests/test_kdtree.py
===================================================================
--- branches/spatial/scipy/spatial/tests/test_kdtree.py	2008-10-05 05:02:32 UTC (rev 4776)
+++ branches/spatial/scipy/spatial/tests/test_kdtree.py	2008-10-05 05:51:47 UTC (rev 4777)
@@ -3,7 +3,7 @@
 from numpy.testing import *
 
 import numpy as np
-from scipy.spatial import KDTree, distance
+from scipy.spatial import KDTree, distance, Rectangle
 
 class ConsistencyTests:
     def test_nearest(self):
@@ -126,7 +126,7 @@
         self.m = 4
 
 
-class CheckVectorization(NumpyTestCase):
+class test_vectorization:
     def setUp(self):
         self.data = np.array([[0,0,0],
                               [0,0,1],
@@ -176,4 +176,138 @@
         assert isinstance(d[0,0],list)
         assert isinstance(i[0,0],list)
 
+class ball_consistency:
 
+    def test_in_ball(self):
+        l = self.T.query_ball_point(self.x, self.d, p=self.p, eps=self.eps)
+        for i in l:
+            assert distance(self.data[i],self.x,self.p)<=self.d*(1.+self.eps)
+
+    def test_found_all(self):
+        c = np.ones(self.T.n,dtype=np.bool)
+        l = self.T.query_ball_point(self.x, self.d, p=self.p, eps=self.eps)
+        c[l] = False
+        assert np.all(distance(self.data[c],self.x,self.p)>=self.d/(1.+self.eps))
+
+class test_random_ball(ball_consistency):
+
+    def setUp(self):
+        n = 1000
+        k = 4
+        self.data = np.random.randn(n,k)
+        self.T = KDTree(self.data)
+        self.x = np.random.randn(k)
+        self.p = 2.
+        self.eps = 0
+        self.d = 0.2
+
+class test_random_ball_approx(test_random_ball):
+
+    def setUp(self):
+        test_random_ball.setUp(self)
+        self.eps = 0.1
+
+class test_random_ball_far(test_random_ball):
+
+    def setUp(self):
+        test_random_ball.setUp(self)
+        self.d = 2.
+
+class test_random_ball_l1(test_random_ball):
+
+    def setUp(self):
+        test_random_ball.setUp(self)
+        self.p = 1
+
+class test_random_ball_linf(test_random_ball):
+
+    def setUp(self):
+        test_random_ball.setUp(self)
+        self.p = np.inf
+
+def test_random_ball_vectorized():
+
+    n = 20
+    k = 5
+    T = KDTree(np.random.randn(n,k))
+    
+    r = T.query_ball_point(np.random.randn(2,3,k),1)
+    assert_equal(r.shape,(2,3))
+    assert isinstance(r[0,0],list)
+
+class two_trees_consistency:
+
+    def test_all_in_ball(self):
+        r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
+        for i, l in enumerate(r):
+            for j in l:
+                assert distance(self.data1[i],self.data2[j],self.p)<=self.d*(1.+self.eps)
+    def test_found_all(self):
+        r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
+        for i, l in enumerate(r):
+            c = np.ones(self.T2.n,dtype=np.bool)
+            c[l] = False
+            assert np.all(distance(self.data2[c],self.data1[i],self.p)>=self.d/(1.+self.eps))
+
+class test_two_random_trees(two_trees_consistency):
+
+    def setUp(self):
+        n = 100
+        k = 4
+        self.data1 = np.random.randn(n,k)
+        self.T1 = KDTree(self.data1,leafsize=2)
+        self.data2 = np.random.randn(n,k)
+        self.T2 = KDTree(self.data2,leafsize=2)
+        self.p = 2.
+        self.eps = 0
+        self.d = 0.2
+
+class test_two_random_trees_far(test_two_random_trees):
+
+    def setUp(self):
+        test_two_random_trees.setUp(self)
+        self.d = 2
+
+class test_two_random_trees_linf(test_two_random_trees):
+
+    def setUp(self):
+        test_two_random_trees.setUp(self)
+        self.p = np.inf
+
+
+class test_rectangle:
+
+    def setUp(self):
+        self.rect = Rectangle([0,0],[1,1])
+
+    def test_min_inside(self):
+        assert_almost_equal(self.rect.min_distance_point([0.5,0.5]),0)
+    def test_min_one_side(self):
+        assert_almost_equal(self.rect.min_distance_point([0.5,1.5]),0.5)
+    def test_min_two_sides(self):
+        assert_almost_equal(self.rect.min_distance_point([2,2]),np.sqrt(2))
+    def test_max_inside(self):
+        assert_almost_equal(self.rect.max_distance_point([0.5,0.5]),1/np.sqrt(2))
+    def test_max_one_side(self):
+        assert_almost_equal(self.rect.max_distance_point([0.5,1.5]),np.hypot(0.5,1.5))
+    def test_max_two_sides(self):
+        assert_almost_equal(self.rect.max_distance_point([2,2]),2*np.sqrt(2))
+
+    def test_split(self):
+        less, greater = self.rect.split(0,0.1)
+        assert_array_equal(less.maxes,[0.1,1])
+        assert_array_equal(less.mins,[0,0])
+        assert_array_equal(greater.maxes,[1,1])
+        assert_array_equal(greater.mins,[0.1,0])
+
+
+def test_distance_l2():
+    assert_almost_equal(distance([0,0],[1,1],2),np.sqrt(2))
+def test_distance_l1():
+    assert_almost_equal(distance([0,0],[1,1],1),2)
+def test_distance_linf():
+    assert_almost_equal(distance([0,0],[1,1],np.inf),1)
+def test_distance_vectorization():
+    x = np.random.randn(10,1,3)
+    y = np.random.randn(1,7,3)
+    assert_equal(distance(x,y).shape,(10,7))




More information about the Scipy-svn mailing list