[pypy-commit] pypy default: move unnecessary additions to base iterator from nditer branches to subclass
bdkearns
noreply at buildbot.pypy.org
Thu Dec 4 03:03:52 CET 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r74803:73e3f0512a73
Date: 2014-12-03 21:02 -0500
http://bitbucket.org/pypy/pypy/changeset/73e3f0512a73/
Log: move unnecessary additions to base iterator from nditer branches to
subclass
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -41,16 +41,6 @@
from pypy.module.micronumpy.base import W_NDimArray
from pypy.module.micronumpy.flagsobj import _update_contiguous_flags
-class OpFlag(object):
- def __init__(self):
- self.rw = ''
- self.broadcast = True
- self.force_contig = False
- self.force_align = False
- self.native_byte_order = False
- self.tmp_copy = ''
- self.allocate = False
-
class PureShapeIter(object):
def __init__(self, shape, idx_w):
@@ -99,14 +89,12 @@
class ArrayIter(object):
_immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]',
'strides[*]', 'backstrides[*]', 'factors[*]',
- 'slice_shape', 'slice_stride', 'slice_backstride',
- 'track_index', 'operand_type', 'slice_operand_type']
+ 'track_index']
track_index = True
@jit.unroll_safe
- def __init__(self, array, size, shape, strides, backstrides, op_flags=OpFlag()):
- from pypy.module.micronumpy import concrete
+ def __init__(self, array, size, shape, strides, backstrides):
assert len(shape) == len(strides) == len(backstrides)
_update_contiguous_flags(array)
self.contiguous = (array.flags & NPY.ARRAY_C_CONTIGUOUS and
@@ -118,12 +106,6 @@
self.shape_m1 = [s - 1 for s in shape]
self.strides = strides
self.backstrides = backstrides
- self.slice_shape = 1
- self.slice_stride = -1
- if strides:
- self.slice_stride = strides[-1]
- self.slice_backstride = 1
- self.slice_operand_type = concrete.SliceArray
ndim = len(shape)
factors = [0] * ndim
@@ -133,10 +115,6 @@
else:
factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
self.factors = factors
- if op_flags.rw == 'r':
- self.operand_type = concrete.ConcreteNonWritableArrayWithBase
- else:
- self.operand_type = concrete.ConcreteArrayWithBase
@jit.unroll_safe
def reset(self, state=None):
@@ -220,12 +198,6 @@
assert state.iterator is self
self.array.setitem(state.offset, elem)
- def getoperand(self, st, base):
- impl = self.operand_type
- res = impl([], self.array.dtype, self.array.order, [], [],
- self.array.storage, base)
- res.start = st.offset
- return res
def AxisIter(array, shape, axis, cumulative):
strides = array.get_strides()
@@ -249,42 +221,3 @@
size /= shape[axis]
shape[axis] = backstrides[axis] = 0
return ArrayIter(array, size, shape, array.strides, backstrides)
-
-class SliceIter(ArrayIter):
- '''
- used with external loops, getitem and setitem return a SliceArray
- view into the original array
- '''
- _immutable_fields_ = ['base', 'slice_shape[*]', 'slice_stride[*]', 'slice_backstride[*]']
-
- def __init__(self, array, size, shape, strides, backstrides, slice_shape,
- slice_stride, slice_backstride, op_flags, base):
- from pypy.module.micronumpy import concrete
- ArrayIter.__init__(self, array, size, shape, strides, backstrides, op_flags)
- self.slice_shape = slice_shape
- self.slice_stride = slice_stride
- self.slice_backstride = slice_backstride
- self.base = base
- if op_flags.rw == 'r':
- self.slice_operand_type = concrete.NonWritableSliceArray
- else:
- self.slice_operand_type = concrete.SliceArray
-
- def getitem(self, state):
- # XXX cannot be called - must return a boxed value
- assert False
-
- def getitem_bool(self, state):
- # XXX cannot be called - must return a boxed value
- assert False
-
- def setitem(self, state, elem):
- # XXX cannot be called - must return a boxed value
- assert False
-
- def getoperand(self, state, base):
- assert state.iterator is self
- impl = self.slice_operand_type
- arr = impl(state.offset, [self.slice_stride], [self.slice_backstride],
- [self.slice_shape], self.array, self.base)
- return arr
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -6,7 +6,7 @@
from pypy.module.micronumpy import ufuncs, support, concrete
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIter, OpFlag
+from pypy.module.micronumpy.iterators import ArrayIter
from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
shape_agreement, shape_agreement_multiple)
@@ -36,6 +36,16 @@
return ret
+class OpFlag(object):
+ def __init__(self):
+ self.rw = ''
+ self.broadcast = True
+ self.force_contig = False
+ self.force_align = False
+ self.native_byte_order = False
+ self.tmp_copy = ''
+ self.allocate = False
+
def parse_op_flag(space, lst):
op_flag = OpFlag()
for w_item in lst:
@@ -142,11 +152,75 @@
raise NotImplementedError('not implemented yet')
-def get_iter(space, order, arr, shape, dtype, op_flags):
+class OperandIter(ArrayIter):
+ _immutable_fields_ = ['slice_shape', 'slice_stride', 'slice_backstride',
+ 'operand_type', 'base']
+
+ def getitem(self, state):
+ # XXX cannot be called - must return a boxed value
+ assert False
+
+ def getitem_bool(self, state):
+ # XXX cannot be called - must return a boxed value
+ assert False
+
+ def setitem(self, state, elem):
+ # XXX cannot be called - must return a boxed value
+ assert False
+
+
+class ConcreteIter(OperandIter):
+ def __init__(self, array, size, shape, strides, backstrides,
+ op_flags, base):
+ OperandIter.__init__(self, array, size, shape, strides, backstrides)
+ self.slice_shape = 1
+ self.slice_stride = -1
+ if strides:
+ self.slice_stride = strides[-1]
+ self.slice_backstride = 1
+ if op_flags.rw == 'r':
+ self.operand_type = concrete.ConcreteNonWritableArrayWithBase
+ else:
+ self.operand_type = concrete.ConcreteArrayWithBase
+ self.base = base
+
+ def getoperand(self, state):
+ assert state.iterator is self
+ impl = self.operand_type
+ #assert issubclass(impl, concrete.ConcreteArrayWithBase)
+ res = impl([], self.array.dtype, self.array.order, [], [],
+ self.array.storage, self.base)
+ res.start = state.offset
+ return res
+
+
+class SliceIter(OperandIter):
+ def __init__(self, array, size, shape, strides, backstrides, slice_shape,
+ slice_stride, slice_backstride, op_flags, base):
+ OperandIter.__init__(self, array, size, shape, strides, backstrides)
+ self.slice_shape = slice_shape
+ self.slice_stride = slice_stride
+ self.slice_backstride = slice_backstride
+ if op_flags.rw == 'r':
+ self.operand_type = concrete.NonWritableSliceArray
+ else:
+ self.operand_type = concrete.SliceArray
+ self.base = base
+
+ def getoperand(self, state):
+ assert state.iterator is self
+ impl = self.operand_type
+ #assert issubclass(impl, concrete.SliceArray)
+ arr = impl(state.offset, [self.slice_stride], [self.slice_backstride],
+ [self.slice_shape], self.array, self.base)
+ return arr
+
+
+def get_iter(space, order, arr, shape, dtype, op_flags, base):
imp = arr.implementation
backward = is_backward(imp, order)
if arr.is_scalar():
- return ArrayIter(imp, 1, [], [], [], op_flags=op_flags)
+ return ConcreteIter(imp, 1, [], [], [], op_flags, base)
if (imp.strides[0] < imp.strides[-1] and not backward) or \
(imp.strides[0] > imp.strides[-1] and backward):
# flip the strides. Is this always true for multidimension?
@@ -161,7 +235,7 @@
backstrides = imp.backstrides
r = calculate_broadcast_strides(strides, backstrides, imp.shape,
shape, backward)
- return ArrayIter(imp, imp.get_size(), shape, r[0], r[1], op_flags=op_flags)
+ return ConcreteIter(imp, imp.get_size(), shape, r[0], r[1], op_flags, base)
def calculate_ndim(op_in, oa_ndim):
if oa_ndim >=0:
@@ -398,7 +472,7 @@
self.iters = []
for i in range(len(self.seq)):
it = get_iter(space, self.order, self.seq[i], self.shape,
- self.dtypes[i], self.op_flags[i])
+ self.dtypes[i], self.op_flags[i], self)
it.contiguous = False
self.iters.append((it, it.reset()))
@@ -437,7 +511,7 @@
return space.wrap(self)
def getitem(self, it, st):
- res = it.getoperand(st, self)
+ res = it.getoperand(st)
return W_NDimArray(res)
def descr_getitem(self, space, w_idx):
More information about the pypy-commit
mailing list