[pypy-commit] pypy pypy-pyarray: ndarray.nonzero(): Deal with scalars, add test.
shmuller
noreply at buildbot.pypy.org
Mon Aug 26 22:03:03 CEST 2013
Author: Stefan H. Muller <shmueller2 at gmail.com>
Branch: pypy-pyarray
Changeset: r66349:01ebe4a52002
Date: 2013-08-11 21:09 +0200
http://bitbucket.org/pypy/pypy/changeset/01ebe4a52002/
Log: ndarray.nonzero(): Deal with scalars, add test.
- lib_pypy/numpypy/core/fromnumeric.py: Put original numpy
implementation back for nonzero().
diff --git a/lib_pypy/numpypy/core/fromnumeric.py b/lib_pypy/numpypy/core/fromnumeric.py
--- a/lib_pypy/numpypy/core/fromnumeric.py
+++ b/lib_pypy/numpypy/core/fromnumeric.py
@@ -1133,7 +1133,13 @@
(array([1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))
"""
- raise NotImplementedError('Waiting on interp level method')
+ try:
+ nonzero = a.nonzero
+ except AttributeError:
+ res = _wrapit(a, 'nonzero')
+ else:
+ res = nonzero()
+ return res
def shape(a):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -352,15 +352,21 @@
return self.descr_reshape(space, [space.wrap(-1)])
def descr_nonzero(self, space):
+ s = loop.count_all_true(self)
+ index_type = interp_dtype.get_dtype_cache(space).w_int64dtype
+ box = index_type.itemtype.box
+
+ if self.is_scalar():
+ w_res = W_NDimArray.from_shape(space, [s], index_type)
+ if s == 1:
+ w_res.implementation.setitem(0, box(0))
+ return space.newtuple([w_res])
+
impl = self.implementation
arr_iter = iter.MultiDimViewIterator(impl, impl.dtype, 0,
impl.strides, impl.backstrides, impl.shape)
- index_type = interp_dtype.get_dtype_cache(space).w_int64dtype
- box = index_type.itemtype.box
-
nd = len(impl.shape)
- s = loop.count_all_true(self)
w_res = W_NDimArray.from_shape(space, [s, nd], index_type)
res_iter = w_res.create_iter()
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2302,8 +2302,13 @@
def test_nonzero(self):
from numpypy import array
- a = array([[1, 0, 3], [2, 0, 4]])
- nz = a.nonzero()
+ nz = array(0).nonzero()
+ assert nz[0].size == 0
+
+ nz = array(2).nonzero()
+ assert (nz[0] == array([0])).all()
+
+ nz = array([[1, 0, 3], [2, 0, 4]]).nonzero()
assert (nz[0] == array([0, 0, 1, 1])).all()
assert (nz[1] == array([0, 2, 0, 2])).all()
More information about the pypy-commit
mailing list