[pypy-commit] pypy numppy-flatitter: add setitem, getitem to flatitter

mattip noreply at buildbot.pypy.org
Thu Jan 19 01:05:05 CET 2012


Author: mattip
Branch: numppy-flatitter
Changeset: r51471:60b724406de5
Date: 2012-01-19 01:46 +0200
http://bitbucket.org/pypy/pypy/changeset/60b724406de5/

Log:	add setitem, getitem to flatitter

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1299,13 +1299,18 @@
         size = 1
         for sh in arr.shape:
             size *= sh
-        self.strides = [arr.strides[-1]]
-        self.backstrides = [arr.backstrides[-1]]
-        ViewArray.__init__(self, size, [size], arr.dtype, arr.order,
-                               arr)
+        if arr.strides[-1] < arr.strides[0]:
+            self.strides = [arr.strides[-1]]
+            self.backstrides = [arr.backstrides[-1]]
+        else:
+            self.strides = [arr.strides[0]]
+            self.backstrides = [arr.backstrides[0]]
+        ViewArray.__init__(self, size, [size], arr.dtype, order=arr.order,
+                               parent=arr)
         self.shapelen = len(arr.shape)
         self.iter = OneDimIterator(arr.start, self.strides[0],
                                    self.shape[0])
+        self.base = arr
 
     def descr_next(self, space):
         if self.iter.done():
@@ -1317,9 +1322,42 @@
     def descr_iter(self):
         return self
 
+    def descr_getitem(self, space, w_idx):
+        if not space.isinstance_w(w_idx, space.w_int):
+            raise OperationError(space.w_ValueError, space.wrap(
+                        "non-integer indexing not supported yet"))
+        _i = space.int_w(w_idx)
+        if _i<0:
+            i = self.size + _i
+        else:
+            i = _i
+        if i >= self.size or i < 0:
+            raise operationerrfmt(space.w_IndexError,
+                            "index (%d) out of range (%d<=index<%d", 
+                                _i, -self.size, self.size)
+        result = self.getitem(self.base.start + i * self.strides[0])
+        return result
+
+    def descr_setitem(self, space, w_idx, w_value):
+        if not space.isinstance_w(w_idx, space.w_int):
+            raise OperationError(space.w_ValueError, space.wrap(
+                        "non-integer indexing not supported yet"))
+        _i = space.int_w(w_idx)
+        if _i<0:
+            i = self.size + _i
+        else:
+            i = _i
+        if i >= self.size or i < 0:
+            raise operationerrfmt(space.w_IndexError,
+                            "index (%d) out of range (%d<=index<%d", 
+                                _i, -self.size, self.size)
+        self.setitem(self.base.start + i * self.strides[0], w_value)
+
 W_FlatIterator.typedef = TypeDef(
     'flatiter',
     next = interp2app(W_FlatIterator.descr_next),
     __iter__ = interp2app(W_FlatIterator.descr_iter),
+    __getitem__ = interp2app(W_FlatIterator.descr_getitem),
+    __setitem__ = interp2app(W_FlatIterator.descr_setitem),
 )
 W_FlatIterator.acceptable_as_base_class = False
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -276,6 +276,12 @@
         for i in xrange(5):
             assert a[i] == b[i]
 
+    def test_getitem_nd(self):
+        from _numpypy import arange
+        a = arange(15).reshape(3, 5)
+        assert a[1, 3] == 8
+        assert a.T[1, 2] == 11 
+
     def test_setitem(self):
         from _numpypy import array
         a = array(range(5))
@@ -1286,6 +1292,29 @@
         a = ones((2, 2))
         assert list(((a + a).flat)) == [2, 2, 2, 2]
 
+    def test_flatiter_getitem(self):
+        from _numpypy import arange
+        a = arange(10)
+        assert a.flat[3] == 3
+        assert a[2:].flat[3] == 5
+        assert (a + a).flat[3] == 6
+        assert a[::2].flat[3] == 6
+        assert a.reshape(2,5).flat[3] == 3
+        b = a.flat
+        b.next()
+        b.next()
+        b.next()
+        assert b[3] == 3
+        assert b[-2] == 8
+        raises(IndexError, "b[11]")
+        raises(IndexError, "b[-11]")
+
+    def test_flatiter_transpose(self):
+        from _numpypy import arange
+        a = arange(10)
+        skip('out-of-order transformations do not work yet')
+        assert a.reshape(2,5).T.flat[3] == 6
+
     def test_slice_copy(self):
         from _numpypy import zeros
         a = zeros((10, 10))


More information about the pypy-commit mailing list