[pypy-commit] pypy pypy-pyarray: Put split nonzero() between scalar.py, concrete.py and loops.py.

shmuller noreply at buildbot.pypy.org
Mon Aug 26 22:03:04 CEST 2013


Author: Stefan H. Muller <shmueller2 at gmail.com>
Branch: pypy-pyarray
Changeset: r66350:b70301c90922
Date: 2013-08-12 23:49 +0200
http://bitbucket.org/pypy/pypy/changeset/b70301c90922/

Log:	Put split nonzero() between scalar.py, concrete.py and loops.py.

	- Separate implementations for 1D and ND case. Try to reunify in
	next commit.

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -279,6 +279,22 @@
         return W_NDimArray.new_slice(space, self.start, strides,
                                      backstrides, shape, self, orig_arr)
 
+    def nonzero(self, space, index_type):
+        s = loop.count_all_true_concrete(self)
+        box = index_type.itemtype.box
+        nd = len(self.shape)
+        
+        if nd == 1:
+            w_res = W_NDimArray.from_shape(space, [s], index_type)        
+            loop.nonzero_onedim(w_res, self, box)
+            return space.newtuple([w_res])
+        else:
+            w_res = W_NDimArray.from_shape(space, [s, nd], index_type)        
+            loop.nonzero_multidim(w_res, self, box)
+            w_res = w_res.implementation.swapaxes(space, w_res, 0, 1)
+            l_w = [w_res.descr_getitem(space, space.wrap(d)) for d in range(nd)]
+            return space.newtuple(l_w)
+
     def get_storage_as_int(self, space):
         return rffi.cast(lltype.Signed, self.storage) + self.start
 
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -155,6 +155,13 @@
     def swapaxes(self, space, orig_array, axis1, axis2):
         raise Exception("should not be called")
 
+    def nonzero(self, space, index_type):
+        s = self.dtype.itemtype.bool(self.value)
+        w_res = W_NDimArray.from_shape(space, [s], index_type)
+        if s == 1:
+            w_res.implementation.setitem(0, index_type.itemtype.box(0)) 
+        return space.newtuple([w_res])
+
     def fill(self, w_value):
         self.value = w_value
 
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -332,6 +332,10 @@
             return self
         return self.implementation.swapaxes(space, self, axis1, axis2)
 
+    def descr_nonzero(self, space):
+        index_type = interp_dtype.get_dtype_cache(space).w_int64dtype
+        return self.implementation.nonzero(space, index_type)
+        
     def descr_tolist(self, space):
         if len(self.get_shape()) == 0:
             return self.get_scalar_value().item(space)
@@ -351,37 +355,6 @@
                 "order not implemented"))
         return self.descr_reshape(space, [space.wrap(-1)])
 
-    def descr_nonzero(self, space):
-        s = loop.count_all_true(self)
-        index_type = interp_dtype.get_dtype_cache(space).w_int64dtype
-        box = index_type.itemtype.box
-        
-        if self.is_scalar():
-            w_res = W_NDimArray.from_shape(space, [s], index_type)
-            if s == 1:
-                w_res.implementation.setitem(0, box(0))
-            return space.newtuple([w_res])
-
-        impl = self.implementation
-        arr_iter = iter.MultiDimViewIterator(impl, impl.dtype, 0, 
-                impl.strides, impl.backstrides, impl.shape)
-        
-        nd = len(impl.shape)
-        w_res = W_NDimArray.from_shape(space, [s, nd], index_type)        
-        res_iter = w_res.create_iter()
-
-        dims = range(nd)
-        while not arr_iter.done():
-            if arr_iter.getitem_bool():
-                for d in dims:
-                    res_iter.setitem(box(arr_iter.indexes[d]))
-                    res_iter.next()
-            arr_iter.next()
-
-        w_res = w_res.implementation.swapaxes(space, w_res, 0, 1)
-        l_w = [w_res.descr_getitem(space, space.wrap(d)) for d in dims]
-        return space.newtuple(l_w)
-
     def descr_take(self, space, w_obj, w_axis=None, w_out=None):
         # if w_axis is None and w_out is Nont this is an equivalent to
         # fancy indexing
@@ -1101,11 +1074,11 @@
     tolist = interp2app(W_NDimArray.descr_tolist),
     flatten = interp2app(W_NDimArray.descr_flatten),
     ravel = interp2app(W_NDimArray.descr_ravel),
-    nonzero = interp2app(W_NDimArray.descr_nonzero),
     take = interp2app(W_NDimArray.descr_take),
     compress = interp2app(W_NDimArray.descr_compress),
     repeat = interp2app(W_NDimArray.descr_repeat),
     swapaxes = interp2app(W_NDimArray.descr_swapaxes),
+    nonzero = interp2app(W_NDimArray.descr_nonzero),
     flat = GetSetProperty(W_NDimArray.descr_get_flatiter),
     item = interp2app(W_NDimArray.descr_item),
     real = GetSetProperty(W_NDimArray.descr_get_real,
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -9,7 +9,8 @@
 from rpython.rlib import jit
 from rpython.rtyper.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
-from pypy.module.micronumpy.iter import PureShapeIterator
+from pypy.module.micronumpy.iter import PureShapeIterator, OneDimViewIterator, \
+        MultiDimViewIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
 
@@ -323,19 +324,61 @@
                                       greens = ['shapelen', 'dtype'],
                                       reds = 'auto')
 
-def count_all_true(arr):
+def count_all_true_concrete(impl):
     s = 0
-    if arr.is_scalar():
-        return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
-    iter = arr.create_iter()
-    shapelen = len(arr.get_shape())
-    dtype = arr.get_dtype()
+    iter = impl.create_iter()
+    shapelen = len(impl.shape)
+    dtype = impl.dtype
     while not iter.done():
         count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
         s += iter.getitem_bool()
         iter.next()
     return s
 
+def count_all_true(arr):
+    if arr.is_scalar():
+        return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
+    else:
+        return count_all_true_concrete(arr.implementation)
+
+nonzero_driver_onedim = jit.JitDriver(name = 'numpy_nonzero_onedim',
+                                      greens = ['shapelen', 'dtype'],
+                                      reds = 'auto')
+
+def nonzero_onedim(res, arr, box):
+    res_iter = res.create_iter()
+    arr_iter = OneDimViewIterator(arr, arr.dtype, 0, 
+            arr.strides, arr.shape)
+    shapelen = 1
+    dtype = arr.dtype
+    while not arr_iter.done():
+        nonzero_driver_onedim.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        if arr_iter.getitem_bool():
+            res_iter.setitem(box(arr_iter.index))
+            res_iter.next()
+        arr_iter.next()
+    return res
+
+nonzero_driver_multidim = jit.JitDriver(name = 'numpy_nonzero_onedim',
+                                        greens = ['shapelen', 'dims', 'dtype'],
+                                        reds = 'auto')
+
+def nonzero_multidim(res, arr, box):
+    res_iter = res.create_iter()
+    arr_iter = MultiDimViewIterator(arr, arr.dtype, 0, 
+        arr.strides, arr.backstrides, arr.shape)
+    shapelen = len(arr.shape)
+    dtype = arr.dtype
+    dims = range(shapelen)
+    while not arr_iter.done():
+        nonzero_driver_multidim.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
+        if arr_iter.getitem_bool():
+            for d in dims:
+                res_iter.setitem(box(arr_iter.indexes[d]))
+                res_iter.next()
+        arr_iter.next()
+    return res
+
 getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool',
                                       greens = ['shapelen', 'arr_dtype',
                                                 'index_dtype'],
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1348,7 +1348,7 @@
             for i in xrange(5):
                 assert c[i] == func(b[i], 3)
 
-    def test_nonzero(self):
+    def test___nonzero__(self):
         from numpypy import array
         a = array([1, 2])
         raises(ValueError, bool, a)
@@ -2306,11 +2306,14 @@
         assert nz[0].size == 0
 
         nz = array(2).nonzero()
-        assert (nz[0] == array([0])).all()
+        assert (nz[0] == [0]).all()
+
+        nz = array([1, 0, 3]).nonzero()
+        assert (nz[0] == [0, 2]).all()
 
         nz = array([[1, 0, 3], [2, 0, 4]]).nonzero()
-        assert (nz[0] == array([0, 0, 1, 1])).all()
-        assert (nz[1] == array([0, 2, 0, 2])).all()
+        assert (nz[0] == [0, 0, 1, 1]).all()
+        assert (nz[1] == [0, 2, 0, 2]).all()
 
     def test_take(self):
         from numpypy import arange


More information about the pypy-commit mailing list