[pypy-commit] pypy default: move unnecessary additions to base iterator from nditer branches to subclass

bdkearns noreply at buildbot.pypy.org
Thu Dec 4 03:03:52 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r74803:73e3f0512a73
Date: 2014-12-03 21:02 -0500
http://bitbucket.org/pypy/pypy/changeset/73e3f0512a73/

Log:	move unnecessary additions to base iterator from nditer branches to
	subclass

diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -41,16 +41,6 @@
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.flagsobj import _update_contiguous_flags
 
-class OpFlag(object):
-    def __init__(self):
-        self.rw = ''
-        self.broadcast = True
-        self.force_contig = False
-        self.force_align = False
-        self.native_byte_order = False
-        self.tmp_copy = ''
-        self.allocate = False
-
 
 class PureShapeIter(object):
     def __init__(self, shape, idx_w):
@@ -99,14 +89,12 @@
 class ArrayIter(object):
     _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]', 'factors[*]',
-                          'slice_shape', 'slice_stride', 'slice_backstride',
-                          'track_index', 'operand_type', 'slice_operand_type']
+                          'track_index']
 
     track_index = True
 
     @jit.unroll_safe
-    def __init__(self, array, size, shape, strides, backstrides, op_flags=OpFlag()):
-        from pypy.module.micronumpy import concrete
+    def __init__(self, array, size, shape, strides, backstrides):
         assert len(shape) == len(strides) == len(backstrides)
         _update_contiguous_flags(array)
         self.contiguous = (array.flags & NPY.ARRAY_C_CONTIGUOUS and
@@ -118,12 +106,6 @@
         self.shape_m1 = [s - 1 for s in shape]
         self.strides = strides
         self.backstrides = backstrides
-        self.slice_shape = 1
-        self.slice_stride = -1
-        if strides:
-            self.slice_stride = strides[-1]
-        self.slice_backstride = 1
-        self.slice_operand_type = concrete.SliceArray
 
         ndim = len(shape)
         factors = [0] * ndim
@@ -133,10 +115,6 @@
             else:
                 factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
         self.factors = factors
-        if op_flags.rw == 'r':
-            self.operand_type = concrete.ConcreteNonWritableArrayWithBase
-        else:
-            self.operand_type = concrete.ConcreteArrayWithBase
 
     @jit.unroll_safe
     def reset(self, state=None):
@@ -220,12 +198,6 @@
         assert state.iterator is self
         self.array.setitem(state.offset, elem)
 
-    def getoperand(self, st, base):
-        impl = self.operand_type
-        res = impl([], self.array.dtype, self.array.order, [], [],
-                   self.array.storage, base)
-        res.start = st.offset
-        return res
 
 def AxisIter(array, shape, axis, cumulative):
     strides = array.get_strides()
@@ -249,42 +221,3 @@
         size /= shape[axis]
     shape[axis] = backstrides[axis] = 0
     return ArrayIter(array, size, shape, array.strides, backstrides)
-
-class SliceIter(ArrayIter):
-    '''
-    used with external loops, getitem and setitem return a SliceArray
-    view into the original array
-    '''
-    _immutable_fields_ = ['base', 'slice_shape[*]', 'slice_stride[*]', 'slice_backstride[*]']
-
-    def __init__(self, array, size, shape, strides, backstrides, slice_shape,
-                 slice_stride, slice_backstride, op_flags, base):
-        from pypy.module.micronumpy import concrete
-        ArrayIter.__init__(self, array, size, shape, strides, backstrides, op_flags)
-        self.slice_shape = slice_shape
-        self.slice_stride = slice_stride
-        self.slice_backstride = slice_backstride
-        self.base = base
-        if op_flags.rw == 'r':
-            self.slice_operand_type = concrete.NonWritableSliceArray
-        else:
-            self.slice_operand_type = concrete.SliceArray
-
-    def getitem(self, state):
-        # XXX cannot be called - must return a boxed value
-        assert False
-
-    def getitem_bool(self, state):
-        # XXX cannot be called - must return a boxed value
-        assert False
-
-    def setitem(self, state, elem):
-        # XXX cannot be called - must return a boxed value
-        assert False
-
-    def getoperand(self, state, base):
-        assert state.iterator is self
-        impl = self.slice_operand_type
-        arr = impl(state.offset, [self.slice_stride], [self.slice_backstride],
-                   [self.slice_shape], self.array, self.base)
-        return arr
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -6,7 +6,7 @@
 from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIter, OpFlag
+from pypy.module.micronumpy.iterators import ArrayIter
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                             shape_agreement, shape_agreement_multiple)
 
@@ -36,6 +36,16 @@
     return ret
 
 
+class OpFlag(object):
+    def __init__(self):
+        self.rw = ''
+        self.broadcast = True
+        self.force_contig = False
+        self.force_align = False
+        self.native_byte_order = False
+        self.tmp_copy = ''
+        self.allocate = False
+
 def parse_op_flag(space, lst):
     op_flag = OpFlag()
     for w_item in lst:
@@ -142,11 +152,75 @@
         raise NotImplementedError('not implemented yet')
 
 
-def get_iter(space, order, arr, shape, dtype, op_flags):
+class OperandIter(ArrayIter):
+    _immutable_fields_ = ['slice_shape', 'slice_stride', 'slice_backstride',
+                          'operand_type', 'base']
+
+    def getitem(self, state):
+        # XXX cannot be called - must return a boxed value
+        assert False
+
+    def getitem_bool(self, state):
+        # XXX cannot be called - must return a boxed value
+        assert False
+
+    def setitem(self, state, elem):
+        # XXX cannot be called - must return a boxed value
+        assert False
+
+
+class ConcreteIter(OperandIter):
+    def __init__(self, array, size, shape, strides, backstrides,
+                 op_flags, base):
+        OperandIter.__init__(self, array, size, shape, strides, backstrides)
+        self.slice_shape = 1
+        self.slice_stride = -1
+        if strides:
+            self.slice_stride = strides[-1]
+        self.slice_backstride = 1
+        if op_flags.rw == 'r':
+            self.operand_type = concrete.ConcreteNonWritableArrayWithBase
+        else:
+            self.operand_type = concrete.ConcreteArrayWithBase
+        self.base = base
+
+    def getoperand(self, state):
+        assert state.iterator is self
+        impl = self.operand_type
+        #assert issubclass(impl, concrete.ConcreteArrayWithBase)
+        res = impl([], self.array.dtype, self.array.order, [], [],
+                   self.array.storage, self.base)
+        res.start = state.offset
+        return res
+
+
+class SliceIter(OperandIter):
+    def __init__(self, array, size, shape, strides, backstrides, slice_shape,
+                 slice_stride, slice_backstride, op_flags, base):
+        OperandIter.__init__(self, array, size, shape, strides, backstrides)
+        self.slice_shape = slice_shape
+        self.slice_stride = slice_stride
+        self.slice_backstride = slice_backstride
+        if op_flags.rw == 'r':
+            self.operand_type = concrete.NonWritableSliceArray
+        else:
+            self.operand_type = concrete.SliceArray
+        self.base = base
+
+    def getoperand(self, state):
+        assert state.iterator is self
+        impl = self.operand_type
+        #assert issubclass(impl, concrete.SliceArray)
+        arr = impl(state.offset, [self.slice_stride], [self.slice_backstride],
+                   [self.slice_shape], self.array, self.base)
+        return arr
+
+
+def get_iter(space, order, arr, shape, dtype, op_flags, base):
     imp = arr.implementation
     backward = is_backward(imp, order)
     if arr.is_scalar():
-        return ArrayIter(imp, 1, [], [], [], op_flags=op_flags)
+        return ConcreteIter(imp, 1, [], [], [], op_flags, base)
     if (imp.strides[0] < imp.strides[-1] and not backward) or \
        (imp.strides[0] > imp.strides[-1] and backward):
         # flip the strides. Is this always true for multidimension?
@@ -161,7 +235,7 @@
         backstrides = imp.backstrides
     r = calculate_broadcast_strides(strides, backstrides, imp.shape,
                                     shape, backward)
-    return ArrayIter(imp, imp.get_size(), shape, r[0], r[1], op_flags=op_flags)
+    return ConcreteIter(imp, imp.get_size(), shape, r[0], r[1], op_flags, base)
 
 def calculate_ndim(op_in, oa_ndim):
     if oa_ndim >=0:
@@ -398,7 +472,7 @@
         self.iters = []
         for i in range(len(self.seq)):
             it = get_iter(space, self.order, self.seq[i], self.shape,
-                          self.dtypes[i], self.op_flags[i])
+                          self.dtypes[i], self.op_flags[i], self)
             it.contiguous = False
             self.iters.append((it, it.reset()))
 
@@ -437,7 +511,7 @@
         return space.wrap(self)
 
     def getitem(self, it, st):
-        res = it.getoperand(st, self)
+        res = it.getoperand(st)
         return W_NDimArray(res)
 
     def descr_getitem(self, space, w_idx):


More information about the pypy-commit mailing list