[pypy-commit] pypy indexing-by-array: fix nd array indexing with a lower dimensional array

mattip noreply at buildbot.pypy.org
Mon Mar 18 19:50:50 CET 2013


Author: Matti Picus <matti.picus at gmail.com>
Branch: indexing-by-array
Changeset: r62401:1f15361645b3
Date: 2013-03-18 11:35 -0700
http://bitbucket.org/pypy/pypy/changeset/1f15361645b3/

Log:	fix nd array indexing with a lower dimensional array

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -79,7 +79,8 @@
             raise OperationError(space.w_ValueError,
                                  space.wrap("index out of range for array"))
         size = loop.count_all_true(arr)
-        res = W_NDimArray.from_shape([size], self.get_dtype())
+        res_shape = [size] + self.get_shape()[1:]
+        res = W_NDimArray.from_shape(res_shape, self.get_dtype())
         return loop.getitem_filter(res, self, arr)
 
     def setitem_filter(self, space, idx, val):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -300,13 +300,14 @@
 
 def getitem_filter(res, arr, index):
     res_iter = res.create_iter()
-    index_iter = index.create_iter()
+    index_iter = index.create_iter(arr.get_shape())
     arr_iter = arr.create_iter()
     shapelen = len(arr.get_shape())
     arr_dtype = arr.get_dtype()
     index_dtype = index.get_dtype()
     # XXX length of shape of index as well?
     while not index_iter.done():
+        print 'res,arr,index', res_iter.get_index(), arr.get_index(), index.get_index()
         getitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1595,7 +1595,7 @@
         assert (zeros(1)[[]] == []).all()
 
     def test_int_array_index_setitem(self):
-        from numpypy import array, arange, zeros
+        from numpypy import arange, zeros
         a = arange(10)
         a[[3, 2, 1, 5]] = zeros(4, dtype=int)
         assert (a == [0, 0, 0, 0, 4, 0, 6, 7, 8, 9]).all()
@@ -1610,6 +1610,10 @@
         assert (b[array([True, False, True])] == [0, 2]).all()
         raises(ValueError, "array([1, 2])[array([True, True, True])]")
         raises(ValueError, "b[array([[True, False], [True, False]])]")
+        a = array([[1,2,3],[4,5,6],[7,8,9]],int)
+        c = array([True,False,True],bool)
+        b = a[c]
+        assert (a[c] == [[1, 2, 3], [7, 8, 9]]).all()
 
     def test_bool_array_index_setitem(self):
         from numpypy import arange, array


More information about the pypy-commit mailing list