[pypy-commit] pypy numpy_broadcast_nd: W_Broadcast (micronumpy) is rewritten using W_FlatIterator for implementation of iters attribute. W_FlatIterator gets optional arguments in constructor.

Sergey Matyunin pypy.commits at gmail.com
Sat Apr 30 16:13:26 EDT 2016


Author: Sergey Matyunin <sbmatyunin at gmail.com>
Branch: numpy_broadcast_nd
Changeset: r84065:c0d40603d40b
Date: 2016-04-24 13:36 +0200
http://bitbucket.org/pypy/pypy/changeset/c0d40603d40b/

Log:	W_Broadcast (micronumpy) is rewritten using W_FlatIterator for
	implementation of iters attribute. W_FlatIterator gets optional
	arguments in constructor.

diff --git a/pypy/module/micronumpy/broadcast.py b/pypy/module/micronumpy/broadcast.py
--- a/pypy/module/micronumpy/broadcast.py
+++ b/pypy/module/micronumpy/broadcast.py
@@ -1,12 +1,12 @@
 import pypy.module.micronumpy.constants as NPY
-from nditer import ConcreteIter, parse_op_flag, parse_op_arg
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.gateway import interp2app
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import support
-from pypy.module.micronumpy.base import W_NDimArray, convert_to_array, W_NumpyObject
+from pypy.module.micronumpy.base import convert_to_array, W_NumpyObject
+from pypy.module.micronumpy.flatiter import W_FlatIterator
 from rpython.rlib import jit
-from strides import calculate_broadcast_strides, shape_agreement_multiple
+from strides import shape_agreement_multiple
 
 def descr_new_broadcast(space, w_subtype, __args__):
     return W_Broadcast(space, __args__.arguments_w)
@@ -26,45 +26,21 @@
         self.seq = [convert_to_array(space, w_elem)
                     for w_elem in args]
 
-        self.op_flags = parse_op_arg(space, 'op_flags', space.w_None,
-                                     len(self.seq), parse_op_flag)
-
         self.shape = shape_agreement_multiple(space, self.seq, shape=None)
         self.order = NPY.CORDER
 
-        self.iters = []
+        self.list_iter_state = []
         self.index = 0
 
         try:
             self.size = support.product_check(self.shape)
         except OverflowError as e:
             raise oefmt(space.w_ValueError, "broadcast dimensions too large.")
-        for i in range(len(self.seq)):
-            it = self.get_iter(space, i)
-            it.contiguous = False
-            self.iters.append((it, it.reset()))
+
+        self.list_iter_state = [W_FlatIterator(arr, self.shape, arr.get_order() != self.order)
+                                for arr in self.seq]
 
         self.done = False
-        pass
-
-    def get_iter(self, space, i):
-        arr = self.seq[i]
-        imp = arr.implementation
-        if arr.is_scalar():
-            return ConcreteIter(imp, 1, [], [], [], self.op_flags[i], self)
-        shape = self.shape
-
-        backward = imp.order != self.order
-
-        r = calculate_broadcast_strides(imp.strides, imp.backstrides, imp.shape,
-                                        shape, backward)
-
-        iter_shape = shape
-        if len(shape) != len(r[0]):
-            # shape can be shorter when using an external loop, just return a view
-            iter_shape = imp.shape
-        return ConcreteIter(imp, imp.get_size(), iter_shape, r[0], r[1],
-                            self.op_flags[i], self)
 
     def descr_iter(self, space):
         return space.wrap(self)
@@ -79,28 +55,26 @@
         return space.wrap(self.index)
 
     def descr_get_numiter(self, space):
-        return space.wrap(len(self.iters))
+        return space.wrap(len(self.list_iter_state))
 
     def descr_get_number_of_dimensions(self, space):
         return space.wrap(len(self.shape))
 
+    def descr_get_iters(self, space):
+        return space.newtuple(self.list_iter_state)
+
     @jit.unroll_safe
     def descr_next(self, space):
         if self.index >= self.size:
             self.done = True
             raise OperationError(space.w_StopIteration, space.w_None)
         self.index += 1
-        res = []
-        for i, (it, st) in enumerate(self.iters):
-            res.append(self._get_item(it, st))
-            self.iters[i] = (it, it.next(st))
+        res = [it.descr_next(space) for it in self.list_iter_state]
+
         if len(res) < 2:
             return res[0]
         return space.newtuple(res)
 
-    def _get_item(self, it, st):
-        return W_NDimArray(it.getoperand(st))
-
 
 W_Broadcast.typedef = TypeDef("numpy.broadcast",
                               __new__=interp2app(descr_new_broadcast),
@@ -111,4 +85,5 @@
                               index=GetSetProperty(W_Broadcast.descr_get_index),
                               numiter=GetSetProperty(W_Broadcast.descr_get_numiter),
                               nd=GetSetProperty(W_Broadcast.descr_get_number_of_dimensions),
+                              iters=GetSetProperty(W_Broadcast.descr_get_iters),
                               )
diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -33,9 +33,9 @@
 
 
 class W_FlatIterator(W_NDimArray):
-    def __init__(self, arr):
+    def __init__(self, arr, shape=None, backward_broadcast=False):
         self.base = arr
-        self.iter, self.state = arr.create_iter()
+        self.iter, self.state = arr.create_iter(shape=shape, backward_broadcast=backward_broadcast)
         # this is needed to support W_NDimArray interface
         self.implementation = FakeArrayImplementation(self.base)
 
diff --git a/pypy/module/micronumpy/test/test_broadcast.py b/pypy/module/micronumpy/test/test_broadcast.py
--- a/pypy/module/micronumpy/test/test_broadcast.py
+++ b/pypy/module/micronumpy/test/test_broadcast.py
@@ -102,3 +102,24 @@
 
         assert hasattr(b, 'nd')
         assert b.nd == 3
+
+    def test_broadcast_iters(self):
+        import numpy as np
+        x = np.array([[[1, 2]]])
+        y = np.array([[3], [4], [5]])
+
+        b = np.broadcast(x, y)
+        iters = b.iters
+
+        # iters has right shape
+        assert len(iters) == 2
+        assert isinstance(iters, tuple)
+
+        step_in_y = iters[1].next()
+        step_in_broadcast = b.next()
+        step2_in_y = iters[1].next()
+
+        # iters should not interfere with iteration in broadcast
+        assert step_in_y == y[0, 0]  # == 3
+        assert step_in_broadcast == (1, 3)
+        assert step2_in_y == y[1, 0]  # == 4


More information about the pypy-commit mailing list