[pypy-commit] pypy numpy-fancy-indexing: implement fancy indexing, I'm not 100% sure about complex cases, but it seems to work

fijal noreply at buildbot.pypy.org
Wed Sep 19 16:39:30 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-fancy-indexing
Changeset: r57387:3bc12f08194a
Date: 2012-09-19 16:39 +0200
http://bitbucket.org/pypy/pypy/changeset/3bc12f08194a/

Log:	implement fancy indexing, I'm not 100% sure about complex cases, but
	it seems to work

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -272,13 +272,15 @@
                 for w_item in view_w:
                     if space.is_w(w_item, space.w_None):
                         count -= 1
-                    if (space.isinstance_w(w_item, space.w_list) or
-                        isinstance(w_item, W_NDimArray)):
-                        raise ArrayArgumentException
                 if count == shape_len:
                     raise IndexError # but it's still not a single item
                 raise OperationError(space.w_IndexError,
                                      space.wrap("invalid index"))
+            # check for arrays
+            for w_item in view_w:
+                if (isinstance(w_item, W_NDimArray) or
+                    space.isinstance_w(w_item, space.w_list)):
+                    raise ArrayArgumentException
             return self._lookup_by_index(space, view_w)
         if shape_len > 1:
             raise IndexError
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -6,7 +6,7 @@
      ArrayArgumentException
 from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes
 from pypy.module.micronumpy.strides import find_shape_and_elems,\
-     get_shape_from_iterable, to_coords
+     get_shape_from_iterable, to_coords, shape_agreement
 from pypy.module.micronumpy.interp_flatiter import W_FlatIterator
 from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
@@ -78,27 +78,60 @@
 
     def _prepare_array_index(self, space, w_index):
         if isinstance(w_index, W_NDimArray):
-            return w_index.get_shape(), [w_index]
+            return [], w_index.get_shape(), w_index.get_shape(), [w_index]
         w_lst = space.listview(w_index)
         for w_item in w_lst:
             if not space.isinstance_w(w_item, space.w_int):
                 break
         else:
             arr = convert_to_array(space, w_index)
-            return arr.get_shape(), [arr]
-        xxx # determine shape
-        return w_lst
+            return [], arr.get_shape(), arr.get_shape(), [arr]
+        shape = None
+        indexes_w = [None] * len(w_lst)
+        res_shape = []
+        arr_index_in_shape = False
+        prefix = []
+        for i, w_item in enumerate(w_lst):
+            if (isinstance(w_item, W_NDimArray) or
+                space.isinstance_w(w_item, space.w_list)):
+                w_item = convert_to_array(space, w_item)
+                if shape is None:
+                    shape = w_item.get_shape()
+                else:
+                    shape = shape_agreement(space, shape, w_item)
+                indexes_w[i] = w_item
+                if not arr_index_in_shape:
+                    res_shape.append(None)
+                    arr_index_in_shape = True
+            else:
+                if space.isinstance_w(w_item, space.w_slice):
+                    _, _, _, lgt = space.decode_index4(w_item, self.get_shape()[i])
+                    if not arr_index_in_shape:
+                        prefix.append(w_item)
+                    res_shape.append(lgt)
+                indexes_w[i] = w_item
+        real_shape = []
+        for i in res_shape:
+            if i is None:
+                real_shape += shape
+            else:
+                real_shape.append(i)
+        return prefix, real_shape[:], shape, indexes_w
 
     def getitem_array_int(self, space, w_index):
-        iter_shape, indexes = self._prepare_array_index(space, w_index)
-        shape = iter_shape + self.get_shape()[len(indexes):]
+        prefix, res_shape, iter_shape, indexes = \
+                self._prepare_array_index(space, w_index)
+        shape = res_shape + self.get_shape()[len(indexes):]
         res = W_NDimArray.from_shape(shape, self.get_dtype(), self.get_order())
-        return loop.getitem_array_int(space, self, res, iter_shape, indexes)
+        return loop.getitem_array_int(space, self, res, iter_shape, indexes,
+                                      prefix)
 
     def setitem_array_int(self, space, w_index, w_value):
         val_arr = convert_to_array(space, w_value)
-        iter_shape, indexes = self._prepare_array_index(space, w_index)
-        return loop.setitem_array_int(space, self, iter_shape, indexes, val_arr)
+        prefix, _, iter_shape, indexes = \
+                self._prepare_array_index(space, w_index)
+        return loop.setitem_array_int(space, self, iter_shape, indexes, val_arr,
+                                      prefix)
 
     def descr_getitem(self, space, w_idx):
         if (isinstance(w_idx, W_NDimArray) and
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -275,34 +275,37 @@
             self._done = True
 
     def get_index(self, space):
-        return space.newtuple([space.wrap(i) for i in self.indexes])
+        return [space.wrap(i) for i in self.indexes]
 
-def getitem_array_int(space, arr, res, iter_shape, indexes_w):
+def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w):
     iter = PureShapeIterator(iter_shape, indexes_w)
     while not iter.done():
         # prepare the index
-        index_w = [None] * len(iter_shape)
-        for i in range(len(iter_shape)):
+        index_w = [None] * len(indexes_w)
+        for i in range(len(indexes_w)):
             if iter.idx_w[i] is not None:
                 index_w[i] = iter.idx_w[i].getitem()
             else:
                 index_w[i] = indexes_w[i]
-        res.descr_setitem(space, iter.get_index(space),
+        res.descr_setitem(space, space.newtuple(prefix_w +
+                                                iter.get_index(space)),
                           arr.descr_getitem(space, space.newtuple(index_w)))
         iter.next()
     return res
 
-def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr):
+def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr,
+                      prefix_w):
     iter = PureShapeIterator(iter_shape, indexes_w)
     while not iter.done():
         # prepare the index
-        index_w = [None] * len(iter_shape)
-        for i in range(len(iter_shape)):
+        index_w = [None] * len(indexes_w)
+        for i in range(len(indexes_w)):
             if iter.idx_w[i] is not None:
                 index_w[i] = iter.idx_w[i].getitem()
             else:
                 index_w[i] = indexes_w[i]
+        w_idx = space.newtuple(prefix_w + iter.get_index(space))
         arr.descr_setitem(space, space.newtuple(index_w),
-                          val_arr.descr_getitem(space, iter.get_index(space)))
+                          val_arr.descr_getitem(space, w_idx))
         iter.next()
 
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1995,9 +1995,24 @@
 
     def test_int_array_index(self):
         from _numpypy import array
-        a = array([[1, 2], [3, 4]])
+        a = array([[1, 2], [3, 4], [5, 6]])
+        assert (a[slice(0, 3), [0, 0]] == [[1, 1], [3, 3], [5, 5]]).all()
+        assert (a[array([0, 2]), slice(0, 2)] == [[1, 2], [5, 6]]).all()
         b = a[array([0, 0])]
         assert (b == [[1, 2], [1, 2]]).all()
+        assert (a[[[0, 1], [0, 0]]] == array([1, 3])).all()
+        assert (a[array([0, 2])] == [[1, 2], [5, 6]]).all()
+        assert (a[array([0, 2]), 1] == [2, 6]).all()
+        assert (a[array([0, 2]), array([1])] == [2, 6]).all()
+
+    def test_int_array_index_setitem(self):
+        from _numpypy import array
+        a = array([[1, 2], [3, 4], [5, 6]])
+        a[slice(0, 3), [0, 0]] = [[0, 0], [0, 0], [0, 0]]
+        assert (a == [[0, 2], [0, 4], [0, 6]]).all()
+        a = array([[1, 2], [3, 4], [5, 6]])
+        a[array([0, 2]), slice(0, 2)] = [[10, 11], [12, 13]]
+        assert (a == [[10, 11], [3, 4], [12, 13]]).all()
 
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):


More information about the pypy-commit mailing list