[pypy-commit] pypy numpypy-nditer: Refactor the way nditer iterates

rguillebert noreply at buildbot.pypy.org
Thu Jun 6 17:22:33 CEST 2013


Author: Romain Guillebert <romain.py at gmail.com>
Branch: numpypy-nditer
Changeset: r64813:8c3a4fc396d3
Date: 2013-06-06 17:21 +0200
http://bitbucket.org/pypy/pypy/changeset/8c3a4fc396d3/

Log:	Refactor the way nditer iterates

diff --git a/pypy/module/micronumpy/interp_nditer.py b/pypy/module/micronumpy/interp_nditer.py
--- a/pypy/module/micronumpy/interp_nditer.py
+++ b/pypy/module/micronumpy/interp_nditer.py
@@ -5,10 +5,41 @@
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                              shape_agreement_multiple)
-from pypy.module.micronumpy.iter import MultiDimViewIterator
+from pypy.module.micronumpy.iter import MultiDimViewIterator, SliceIterator
 from pypy.module.micronumpy import support
 from pypy.module.micronumpy.arrayimpl.concrete import SliceArray
 
+class AbstractIterator(object):
+    def done(self):
+        raise NotImplementedError("Abstract Class")
+
+    def next(self):
+        raise NotImplementedError("Abstract Class")
+
+    def getitem(self, array):
+        raise NotImplementedError("Abstract Class")
+
+class IteratorMixin(object):
+    _mixin_ = True
+    def __init__(self, it, op_flags):
+        self.it = it
+        self.op_flags = op_flags
+
+    def done(self):
+        return self.it.done()
+
+    def next(self):
+        self.it.next()
+
+    def getitem(self, space, array):
+        return self.op_flags.get_it_item(space, array, self.it)
+
+class BoxIterator(IteratorMixin):
+    pass
+
+class SliceIterator(IteratorMixin):
+    pass
+
 def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
     ret = []
     if space.is_w(w_op_flags, space.w_None):
@@ -53,6 +84,13 @@
     #it.dtype.setitem(res, 0, it.getitem())
     return W_NDimArray(res)
 
+def get_readonly_slice(space, array, it):
+    #XXX Not readonly
+    return W_NDimArray(it.getslice())
+
+def get_readwrite_slice(space, array, it):
+    return W_NDimArray(it.getslice())
+
 def parse_op_flag(space, lst):
     op_flag = OpFlag()
     for w_item in lst:
@@ -191,11 +229,11 @@
         self.iters=[]
         self.shape = iter_shape = shape_agreement_multiple(space, self.seq)
         if self.external_loop:
-            xxx find longest contiguous shape
+            #XXX find longest contiguous shape
             iter_shape = iter_shape[1:]
         for i in range(len(self.seq)):
-            self.iters.append(get_iter(space, self.order,
-                            self.seq[i].implementation, iter_shape))
+            self.iters.append(BoxIterator(get_iter(space, self.order,
+                            self.seq[i].implementation, iter_shape), self.op_flags[i]))
 
     def descr_iter(self, space):
         return space.wrap(self)
@@ -220,8 +258,7 @@
             raise OperationError(space.w_StopIteration, space.w_None)
         res = []
         for i in range(len(self.iters)):
-            res.append(self.op_flags[i].get_it_item(space, self.seq[i],
-                                                    self.iters[i]))
+            res.append(self.iters[i].getitem(space, self.seq[i]))
             self.iters[i].next()
         if len(res) <2:
             return res[0]
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -32,13 +32,13 @@
 shape dimension
   which is back 25 and forward 1,
   which is x.strides[1] * (x.shape[1] - 1) + x.strides[0]
-so if we precalculate the overflow backstride as 
+so if we precalculate the overflow backstride as
 [x.strides[i] * (x.shape[i] - 1) for i in range(len(x.shape))]
 we can go faster.
 All the calculations happen in next()
 
 next_skip_x() tries to do the iteration for a number of steps at once,
-but then we cannot gaurentee that we only overflow one single shape 
+but then we cannot gaurentee that we only overflow one single shape
 dimension, perhaps we could overflow times in one big step.
 """
 
@@ -266,6 +266,30 @@
     def reset(self):
         self.offset %= self.size
 
+class SliceIterator(object):
+    def __init__(self, arr, stride, backstride, shape, dtype=None):
+        self.step = 0
+        self.arr = arr
+        self.stride = stride
+        self.backstride = backstride
+        self.shape = shape
+        if dtype is None:
+            dtype = arr.implementation.dtype
+        self.dtype = dtype
+        self._done = False
+
+    def done():
+        return self._done
+
+    def next():
+        self.step += self.arr.implementation.dtype.get_size()
+        if self.step == self.backstride - self.implementation.dtype.get_size():
+            self._done = True
+
+    def getslice(self):
+        from pypy.module.micronumpy.arrayimpl.concrete import SliceArray
+        return SliceArray(self.step, [self.stride], [self.backstride], self.shape, self.arr.implementation, self.arr, self.dtype)
+
 class AxisIterator(base.BaseArrayIterator):
     def __init__(self, array, shape, dim, cumultative):
         self.shape = shape
@@ -288,7 +312,7 @@
         self.dim = dim
         self.array = array
         self.dtype = array.dtype
-        
+
     def setitem(self, elem):
         self.dtype.setitem(self.array, self.offset, elem)
 


More information about the pypy-commit mailing list