[pypy-commit] pypy default: merge numpy-fancy indexing branch. this branch implements arr[arr-of-ints]

fijal noreply at buildbot.pypy.org
Wed Sep 19 16:46:01 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: 
Changeset: r57389:fff8bdfee215
Date: 2012-09-19 16:45 +0200
http://bitbucket.org/pypy/pypy/changeset/fff8bdfee215/

Log:	merge numpy-fancy indexing branch. this branch implements arr[arr-
	of-ints]

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,7 +1,8 @@
 
 from pypy.module.micronumpy.arrayimpl import base
 from pypy.module.micronumpy import support, loop
-from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
+from pypy.module.micronumpy.base import convert_to_array, W_NDimArray,\
+     ArrayArgumentException
 from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
      calculate_broadcast_strides, calculate_dot_strides
 from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, RecordChunk
@@ -223,6 +224,26 @@
             item += idx * self.strides[i]
         return item
 
+    @jit.unroll_safe
+    def _lookup_by_unwrapped_index(self, space, lst):
+        item = self.start
+        assert len(lst) == len(self.shape)
+        for i, idx in enumerate(lst):
+            if idx < 0:
+                idx = self.shape[i] + idx
+            if idx < 0 or idx >= self.shape[i]:
+                raise operationerrfmt(space.w_IndexError,
+                      "index (%d) out of range (0<=index<%d", i, self.shape[i],
+                )
+            item += idx * self.strides[i]
+        return item
+
+    def getitem_index(self, space, index):
+        return self.getitem(self._lookup_by_unwrapped_index(space, index))
+
+    def setitem_index(self, space, index, value):
+        self.setitem(self._lookup_by_unwrapped_index(space, index), value)
+
     def _single_item_index(self, space, w_idx):
         """ Return an index of single item if possible, otherwise raises
         IndexError
@@ -231,10 +252,16 @@
             space.isinstance_w(w_idx, space.w_slice) or
             space.is_w(w_idx, space.w_None)):
             raise IndexError
+        if isinstance(w_idx, W_NDimArray):
+            raise ArrayArgumentException
         shape_len = len(self.shape)
         if shape_len == 0:
             raise OperationError(space.w_IndexError, space.wrap(
                 "0-d arrays can't be indexed"))
+        view_w = None
+        if (space.isinstance_w(w_idx, space.w_list) or
+            isinstance(w_idx, W_NDimArray)):
+            raise ArrayArgumentException
         if space.isinstance_w(w_idx, space.w_tuple):
             view_w = space.fixedview(w_idx)
             if len(view_w) < shape_len:
@@ -249,6 +276,11 @@
                     raise IndexError # but it's still not a single item
                 raise OperationError(space.w_IndexError,
                                      space.wrap("invalid index"))
+            # check for arrays
+            for w_item in view_w:
+                if (isinstance(w_item, W_NDimArray) or
+                    space.isinstance_w(w_item, space.w_list)):
+                    raise ArrayArgumentException
             return self._lookup_by_index(space, view_w)
         if shape_len > 1:
             raise IndexError
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -58,10 +58,17 @@
         raise OperationError(space.w_IndexError,
                              space.wrap("scalars cannot be indexed"))
 
+    def getitem_index(self, space, idx):
+        raise OperationError(space.w_IndexError,
+                             space.wrap("scalars cannot be indexed"))
+
     def descr_setitem(self, space, w_idx, w_val):
         raise OperationError(space.w_IndexError,
                              space.wrap("scalars cannot be indexed"))
         
+    def setitem_index(self, space, idx, w_val):
+        raise OperationError(space.w_IndexError,
+                             space.wrap("scalars cannot be indexed"))
     def set_shape(self, space, new_shape):
         if not new_shape:
             return self
diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -3,6 +3,9 @@
 from pypy.tool.pairtype import extendabletype
 from pypy.module.micronumpy.support import calc_strides
 
+class ArrayArgumentException(Exception):
+    pass
+
 class W_NDimArray(Wrappable):
     __metaclass__ = extendabletype
 
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -2,10 +2,11 @@
 from pypy.interpreter.error import operationerrfmt, OperationError
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec
-from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array,\
+     ArrayArgumentException
 from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes
 from pypy.module.micronumpy.strides import find_shape_and_elems,\
-     get_shape_from_iterable, to_coords
+     get_shape_from_iterable, to_coords, shape_agreement
 from pypy.module.micronumpy.interp_flatiter import W_FlatIterator
 from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
@@ -63,30 +64,101 @@
 
     def getitem_filter(self, space, arr):
         if arr.get_size() > self.get_size():
-            raise OperationError(space.w_IndexError,
+            raise OperationError(space.w_ValueError,
                                  space.wrap("index out of range for array"))
         size = loop.count_all_true(arr)
         res = W_NDimArray.from_shape([size], self.get_dtype())
         return loop.getitem_filter(res, self, arr)
 
     def setitem_filter(self, space, idx, val):
+        if idx.get_size() > self.get_size():
+            raise OperationError(space.w_ValueError,
+                                 space.wrap("index out of range for array"))
         loop.setitem_filter(self, idx, val)
 
+    def _prepare_array_index(self, space, w_index):
+        if isinstance(w_index, W_NDimArray):
+            return [], w_index.get_shape(), w_index.get_shape(), [w_index]
+        w_lst = space.listview(w_index)
+        for w_item in w_lst:
+            if not space.isinstance_w(w_item, space.w_int):
+                break
+        else:
+            arr = convert_to_array(space, w_index)
+            return [], arr.get_shape(), arr.get_shape(), [arr]
+        shape = None
+        indexes_w = [None] * len(w_lst)
+        res_shape = []
+        arr_index_in_shape = False
+        prefix = []
+        for i, w_item in enumerate(w_lst):
+            if (isinstance(w_item, W_NDimArray) or
+                space.isinstance_w(w_item, space.w_list)):
+                w_item = convert_to_array(space, w_item)
+                if shape is None:
+                    shape = w_item.get_shape()
+                else:
+                    shape = shape_agreement(space, shape, w_item)
+                indexes_w[i] = w_item
+                if not arr_index_in_shape:
+                    res_shape.append(-1)
+                    arr_index_in_shape = True
+            else:
+                if space.isinstance_w(w_item, space.w_slice):
+                    _, _, _, lgt = space.decode_index4(w_item, self.get_shape()[i])
+                    if not arr_index_in_shape:
+                        prefix.append(w_item)
+                    res_shape.append(lgt)
+                indexes_w[i] = w_item
+        real_shape = []
+        for i in res_shape:
+            if i == -1:
+                real_shape += shape
+            else:
+                real_shape.append(i)
+        return prefix, real_shape[:], shape, indexes_w
+
+    def getitem_array_int(self, space, w_index):
+        prefix, res_shape, iter_shape, indexes = \
+                self._prepare_array_index(space, w_index)
+        shape = res_shape + self.get_shape()[len(indexes):]
+        res = W_NDimArray.from_shape(shape, self.get_dtype(), self.get_order())
+        return loop.getitem_array_int(space, self, res, iter_shape, indexes,
+                                      prefix)
+
+    def setitem_array_int(self, space, w_index, w_value):
+        val_arr = convert_to_array(space, w_value)
+        prefix, _, iter_shape, indexes = \
+                self._prepare_array_index(space, w_index)
+        return loop.setitem_array_int(space, self, iter_shape, indexes, val_arr,
+                                      prefix)
+
     def descr_getitem(self, space, w_idx):
-        if (isinstance(w_idx, W_NDimArray) and w_idx.get_shape() == self.get_shape() and
+        if (isinstance(w_idx, W_NDimArray) and
             w_idx.get_dtype().is_bool_type()):
             return self.getitem_filter(space, w_idx)
         try:
             return self.implementation.descr_getitem(space, w_idx)
+        except ArrayArgumentException:
+            return self.getitem_array_int(space, w_idx)
         except OperationError:
             raise OperationError(space.w_IndexError, space.wrap("wrong index"))
 
+    def getitem(self, space, index_list):
+        return self.implementation.getitem_index(space, index_list)
+
+    def setitem(self, space, index_list, w_value):
+        self.implementation.setitem_index(space, index_list, w_value)
+
     def descr_setitem(self, space, w_idx, w_value):
-        if (isinstance(w_idx, W_NDimArray) and w_idx.get_shape() == self.get_shape() and
+        if (isinstance(w_idx, W_NDimArray) and
             w_idx.get_dtype().is_bool_type()):
             return self.setitem_filter(space, w_idx,
                                        convert_to_array(space, w_value))
-        self.implementation.descr_setitem(space, w_idx, w_value)
+        try:
+            self.implementation.descr_setitem(space, w_idx, w_value)
+        except ArrayArgumentException:
+            self.setitem_array_int(space, w_idx, w_value)
 
     def descr_len(self, space):
         shape = self.get_shape()
@@ -265,9 +337,8 @@
             if self.is_scalar():
                 return self.get_scalar_value().item(space)
             if self.get_size() == 1:
-                w_obj = self.descr_getitem(space,
-                                           space.newtuple([space.wrap(0) for i
-                                      in range(len(self.get_shape()))]))
+                w_obj = self.getitem(space,
+                                     [0] * len(self.get_shape()))
                 assert isinstance(w_obj, interp_boxes.W_GenericBox)
                 return w_obj.item(space)
             raise OperationError(space.w_IndexError,
@@ -277,8 +348,7 @@
                 raise OperationError(space.w_IndexError,
                                      space.wrap("index out of bounds"))
             i = self.to_coords(space, w_arg)
-            item = self.descr_getitem(space, space.newtuple([space.wrap(x)
-                                             for x in i]))
+            item = self.getitem(space, i)
             assert isinstance(item, interp_boxes.W_GenericBox)
             return item.item(space)
         raise OperationError(space.w_NotImplementedError, space.wrap(
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,9 +4,10 @@
 """
 
 from pypy.rlib.objectmodel import specialize
+from pypy.rlib.rstring import StringBuilder
+from pypy.rlib import jit
+from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
-from pypy.rlib.rstring import StringBuilder
-from pypy.rpython.lltypesystem import lltype, rffi
 
 def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     if out is None:
@@ -244,3 +245,67 @@
             builder.append(res_str_casted[i])
         iter.next()
     return builder.build()
+
+class PureShapeIterator(object):
+    def __init__(self, shape, idx_w):
+        self.shape = shape
+        self.shapelen = len(shape)
+        self.indexes = [0] * len(shape)
+        self._done = False
+        self.idx_w = [None] * len(idx_w)
+        for i, w_idx in enumerate(idx_w):
+            if isinstance(w_idx, W_NDimArray):
+                self.idx_w[i] = w_idx.create_iter(shape)
+
+    def done(self):
+        return self._done
+
+    @jit.unroll_safe
+    def next(self):
+        for w_idx in self.idx_w:
+            if w_idx is not None:
+                w_idx.next()
+        for i in range(self.shapelen - 1, -1, -1):
+            if self.indexes[i] < self.shape[i] - 1:
+                self.indexes[i] += 1
+                break
+            else:
+                self.indexes[i] = 0
+        else:
+            self._done = True
+
+    def get_index(self, space):
+        return [space.wrap(i) for i in self.indexes]
+
+def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w):
+    iter = PureShapeIterator(iter_shape, indexes_w)
+    while not iter.done():
+        # prepare the index
+        index_w = [None] * len(indexes_w)
+        for i in range(len(indexes_w)):
+            if iter.idx_w[i] is not None:
+                index_w[i] = iter.idx_w[i].getitem()
+            else:
+                index_w[i] = indexes_w[i]
+        res.descr_setitem(space, space.newtuple(prefix_w +
+                                                iter.get_index(space)),
+                          arr.descr_getitem(space, space.newtuple(index_w)))
+        iter.next()
+    return res
+
+def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr,
+                      prefix_w):
+    iter = PureShapeIterator(iter_shape, indexes_w)
+    while not iter.done():
+        # prepare the index
+        index_w = [None] * len(indexes_w)
+        for i in range(len(indexes_w)):
+            if iter.idx_w[i] is not None:
+                index_w[i] = iter.idx_w[i].getitem()
+            else:
+                index_w[i] = indexes_w[i]
+        w_idx = space.newtuple(prefix_w + iter.get_index(space))
+        arr.descr_setitem(space, space.newtuple(index_w),
+                          val_arr.descr_getitem(space, w_idx))
+        iter.next()
+
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1517,7 +1517,38 @@
         a = array([1.0,-1.0])
         a[a<0] = -a[a<0]
         assert (a == [1, 1]).all()
-                        
+
+    def test_int_array_index(self):
+        from numpypy import array, arange
+        b = arange(10)[array([3, 2, 1, 5])]
+        assert (b == [3, 2, 1, 5]).all()
+        raises(IndexError, "arange(10)[array([10])]")
+        assert (arange(10)[[-5, -3]] == [5, 7]).all()
+        raises(IndexError, "arange(10)[[-11]]")
+
+    def test_int_array_index_setitem(self):
+        from numpypy import array, arange, zeros
+        a = arange(10)
+        a[[3, 2, 1, 5]] = zeros(4, dtype=int)
+        assert (a == [0, 0, 0, 0, 4, 0, 6, 7, 8, 9]).all()
+        a[[-9, -8]] = [1, 1]
+        assert (a == [0, 1, 1, 0, 4, 0, 6, 7, 8, 9]).all()
+        raises(IndexError, "arange(10)[array([10])] = 3")
+        raises(IndexError, "arange(10)[[-11]] = 3")
+
+    def test_bool_array_index(self):
+        from numpypy import arange, array
+        b = arange(10)
+        assert (b[array([True, False, True])] == [0, 2]).all()
+        raises(ValueError, "array([1, 2])[array([True, True, True])]")
+
+    def test_bool_array_index_setitem(self):
+        from numpypy import arange, array
+        b = arange(5)
+        b[array([True, False, True])] = [20, 21]
+        assert (b == [20, 1, 21, 3, 4]).all() 
+        raises(ValueError, "array([1, 2])[array([True, False, True])] = [1, 2, 3]")
+
 class AppTestMultiDim(BaseNumpyAppTest):
     def test_init(self):
         import _numpypy
@@ -1943,7 +1974,7 @@
         assert (a.compress([1, 0, 13.5]) == [0, 2]).all()
         a = arange(10).reshape(2, 5)
         assert (a.compress([True, False, True]) == [0, 2]).all()
-        raises(IndexError, "a.compress([1] * 100)")
+        raises((IndexError, ValueError), "a.compress([1] * 100)")
 
     def test_item(self):
         from _numpypy import array
@@ -1961,7 +1992,27 @@
         assert (a + a).item(1) == 4
         raises(IndexError, "array(5).item(1)")
         assert array([1]).item() == 1
- 
+
+    def test_int_array_index(self):
+        from _numpypy import array
+        a = array([[1, 2], [3, 4], [5, 6]])
+        assert (a[slice(0, 3), [0, 0]] == [[1, 1], [3, 3], [5, 5]]).all()
+        assert (a[array([0, 2]), slice(0, 2)] == [[1, 2], [5, 6]]).all()
+        b = a[array([0, 0])]
+        assert (b == [[1, 2], [1, 2]]).all()
+        assert (a[[[0, 1], [0, 0]]] == array([1, 3])).all()
+        assert (a[array([0, 2])] == [[1, 2], [5, 6]]).all()
+        assert (a[array([0, 2]), 1] == [2, 6]).all()
+        assert (a[array([0, 2]), array([1])] == [2, 6]).all()
+
+    def test_int_array_index_setitem(self):
+        from _numpypy import array
+        a = array([[1, 2], [3, 4], [5, 6]])
+        a[slice(0, 3), [0, 0]] = [[0, 0], [0, 0], [0, 0]]
+        assert (a == [[0, 2], [0, 4], [0, 6]]).all()
+        a = array([[1, 2], [3, 4], [5, 6]])
+        a[array([0, 2]), slice(0, 2)] = [[10, 11], [12, 13]]
+        assert (a == [[10, 11], [3, 4], [12, 13]]).all()
 
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):


More information about the pypy-commit mailing list