[pypy-commit] pypy default: fix searchsorted with multidim targets

bdkearns noreply at buildbot.pypy.org
Thu Oct 9 21:44:08 CEST 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r73868:188aa764d2e7
Date: 2014-10-09 15:18 -0400
http://bitbucket.org/pypy/pypy/changeset/188aa764d2e7/

Log:	fix searchsorted with multidim targets

diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -738,8 +738,6 @@
         if len(self.get_shape()) > 1:
             raise oefmt(space.w_ValueError, "a must be a 1-d array")
         v = convert_to_array(space, w_v)
-        if len(v.get_shape()) > 1:
-            raise oefmt(space.w_ValueError, "v must be a 1-d array-like")
         ret = W_NDimArray.from_shape(
             space, v.get_shape(), descriptor.get_dtype_cache(space).w_longdtype)
         app_searchsort(space, self, v, space.wrap(side), ret)
diff --git a/pypy/module/micronumpy/selection.py b/pypy/module/micronumpy/selection.py
--- a/pypy/module/micronumpy/selection.py
+++ b/pypy/module/micronumpy/selection.py
@@ -375,9 +375,8 @@
             op = operator.lt
         else:
             op = operator.le
-        if v.size < 2:
-            result[...] = _searchsort(a, op, v)
-        else:
-            for i in range(v.size):
-                result[i] = _searchsort(a, op, v[i])
+        v = v.flat
+        result = result.flat
+        for i in xrange(len(v)):
+            result[i] = _searchsort(a, op, v[i])
 """, filename=__file__).interphook('searchsort')
diff --git a/pypy/module/micronumpy/test/test_selection.py b/pypy/module/micronumpy/test/test_selection.py
--- a/pypy/module/micronumpy/test/test_selection.py
+++ b/pypy/module/micronumpy/test/test_selection.py
@@ -354,25 +354,36 @@
         import numpy as np
         import sys
         a = np.arange(1, 6)
+
         ret = a.searchsorted(3)
         assert ret == 2
         assert isinstance(ret, np.generic)
+
         ret = a.searchsorted(np.array(3))
         assert ret == 2
         assert isinstance(ret, np.generic)
+
         ret = a.searchsorted(np.array([3]))
         assert ret == 2
         assert isinstance(ret, np.ndarray)
+
+        ret = a.searchsorted(np.array([[2, 3]]))
+        assert (ret == [1, 2]).all()
+        assert ret.shape == (1, 2)
+
         ret = a.searchsorted(3, side='right')
         assert ret == 3
         assert isinstance(ret, np.generic)
+
         exc = raises(ValueError, a.searchsorted, 3, side=None)
         assert str(exc.value) == "expected nonempty string for keyword 'side'"
         exc = raises(ValueError, a.searchsorted, 3, side='')
         assert str(exc.value) == "expected nonempty string for keyword 'side'"
         exc = raises(ValueError, a.searchsorted, 3, side=2)
         assert str(exc.value) == "expected nonempty string for keyword 'side'"
+
         ret = a.searchsorted([-10, 10, 2, 3])
         assert (ret == [0, 5, 1, 2]).all()
+
         if '__pypy__' in sys.builtin_module_names:
             raises(NotImplementedError, "a.searchsorted(3, sorter=range(6))")


More information about the pypy-commit mailing list