[pypy-commit] pypy indexing-by-array: add broadcast_backwards to allow indexing by a lower-shaped array

mattip noreply at buildbot.pypy.org
Mon Mar 18 21:21:33 CET 2013


Author: Matti Picus <matti.picus at gmail.com>
Branch: indexing-by-array
Changeset: r62404:7f18238a3962
Date: 2013-03-18 12:47 -0700
http://bitbucket.org/pypy/pypy/changeset/7f18238a3962/

Log:	add broadcast_backwards to allow indexing by a lower-shaped array

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -294,12 +294,12 @@
         self.backstrides = backstrides
         self.storage = storage
 
-    def create_iter(self, shape=None):
+    def create_iter(self, shape=None, backward_broadcast=False):
         if shape is None or shape == self.get_shape():
             return iter.ConcreteArrayIterator(self)
         r = calculate_broadcast_strides(self.get_strides(),
                                         self.get_backstrides(),
-                                        self.get_shape(), shape)
+                                        self.get_shape(), shape, backward_broadcast)
         return iter.MultiDimViewIterator(self, self.dtype, 0, r[0], r[1], shape)
 
     def fill(self, box):
@@ -362,11 +362,12 @@
     def fill(self, box):
         loop.fill(self, box.convert_to(self.dtype))
 
-    def create_iter(self, shape=None):
+    def create_iter(self, shape=None, backward_broadcast=False):
         if shape is not None and shape != self.get_shape():
             r = calculate_broadcast_strides(self.get_strides(),
                                             self.get_backstrides(),
-                                            self.get_shape(), shape)
+                                            self.get_shape(), shape,
+                                            backward_broadcast)
             return iter.MultiDimViewIterator(self.parent, self.dtype,
                                              self.start, r[0], r[1], shape)
         if len(self.get_shape()) == 1:
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -224,9 +224,10 @@
         s.append('])')
         return s.build()
 
-    def create_iter(self, shape=None):
+    def create_iter(self, shape=None, backward_broadcast=False):
         assert isinstance(self.implementation, BaseArrayImplementation)
-        return self.implementation.create_iter(shape)
+        return self.implementation.create_iter(shape,
+                                   backward_broadcast=backward_broadcast)
 
     def create_axis_iter(self, shape, dim, cum):
         return self.implementation.create_axis_iter(shape, dim, cum)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -300,14 +300,13 @@
 
 def getitem_filter(res, arr, index):
     res_iter = res.create_iter()
-    index_iter = index.create_iter(arr.get_shape())
+    index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
     arr_iter = arr.create_iter()
     shapelen = len(arr.get_shape())
     arr_dtype = arr.get_dtype()
     index_dtype = index.get_dtype()
     # XXX length of shape of index as well?
     while not index_iter.done():
-        print 'res,arr,index', res_iter.offset, arr_iter.offset, index_iter.offset, index_iter.getitem_bool()
         getitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -40,7 +40,7 @@
     rshape += shape[s:]
     return rshape, rstart, rstrides, rbackstrides
 
-def calculate_broadcast_strides(strides, backstrides, orig_shape, res_shape):
+def calculate_broadcast_strides(strides, backstrides, orig_shape, res_shape, backwards=False):
     rstrides = []
     rbackstrides = []
     for i in range(len(orig_shape)):
@@ -50,8 +50,12 @@
         else:
             rstrides.append(strides[i])
             rbackstrides.append(backstrides[i])
-    rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
-    rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
+    if backwards:
+        rstrides = rstrides + [0] * (len(res_shape) - len(orig_shape))  
+        rbackstrides = rbackstrides + [0] * (len(res_shape) - len(orig_shape)) 
+    else:
+        rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
+        rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
     return rstrides, rbackstrides
 
 def is_single_elem(space, w_elem, is_rec_type):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1612,9 +1612,7 @@
         raises(ValueError, "b[array([[True, False], [True, False]])]")
         a = array([[1,2,3],[4,5,6],[7,8,9]],int)
         c = array([True,False,True],bool)
-        print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
         b = a[c]
-        print 'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy'
         assert (a[c] == [[1, 2, 3], [7, 8, 9]]).all()
 
     def test_bool_array_index_setitem(self):


More information about the pypy-commit mailing list