[pypy-commit] pypy default: (rguillebert, ronan, joanna) Progress on fancy indexing with booleans

rguillebert noreply at buildbot.pypy.org
Wed Aug 28 14:43:32 CEST 2013


Author: Romain Guillebert <romain.py at gmail.com>
Branch: 
Changeset: r66377:9515e4524aaa
Date: 2013-08-27 17:56 +0100
http://bitbucket.org/pypy/pypy/changeset/9515e4524aaa/

Log:	(rguillebert, ronan, joanna) Progress on fancy indexing with
	booleans

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -95,6 +95,12 @@
         if idx.get_size() > self.get_size():
             raise OperationError(space.w_ValueError,
                                  space.wrap("index out of range for array"))
+        idx_iter = idx.create_iter(self.get_shape())
+        size = loop.count_all_true_iter(idx_iter, self.get_shape(), idx.get_dtype())
+        if size != val.get_shape()[0]:
+            raise OperationError(space.w_ValueError, space.wrap("NumPy boolean array indexing assignment "
+                                                                "cannot assign %d input values to "
+                                                                "the %d output values where the mask is true" % (val.get_shape()[0],size)))
         loop.setitem_filter(self, idx, val)
 
     def _prepare_array_index(self, space, w_index):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -318,23 +318,27 @@
         lefti.next()
     return result
 
-count_all_true_driver = jit.JitDriver(name = 'numpy_count',
-                                      greens = ['shapelen', 'dtype'],
-                                      reds = 'auto')
 
 def count_all_true(arr):
-    s = 0
     if arr.is_scalar():
         return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
     iter = arr.create_iter()
-    shapelen = len(arr.get_shape())
-    dtype = arr.get_dtype()
+    return count_all_true_iter(iter, arr.get_shape(), arr.get_dtype())
+
+count_all_true_iter_driver = jit.JitDriver(name = 'numpy_count',
+                                      greens = ['shapelen', 'dtype'],
+                                      reds = 'auto')
+def count_all_true_iter(iter, shape, dtype):
+    s = 0
+    shapelen = len(shape)
+    dtype = dtype
     while not iter.done():
-        count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        count_all_true_iter_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
         s += iter.getitem_bool()
         iter.next()
     return s
 
+
 getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool',
                                       greens = ['shapelen', 'arr_dtype',
                                                 'index_dtype'],
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2354,6 +2354,12 @@
     def test_array_indexing_bool_specialcases(self):
         from numpypy import arange, array
         a = arange(6)
+        try:
+            a[a < 3] = [1, 2]
+            assert False, "Should not work"
+        except ValueError:
+            pass
+        a = arange(6)
         a[a > 3] = array([15])
         assert (a == [0, 1, 2, 3, 15, 15]).all()
         a = arange(6).reshape(3, 2)


More information about the pypy-commit mailing list