[pypy-commit] pypy indexing-by-array: add broadcast_backwards to allow indexing by a lower-shaped array
mattip
noreply at buildbot.pypy.org
Mon Mar 18 21:21:33 CET 2013
Author: Matti Picus <matti.picus at gmail.com>
Branch: indexing-by-array
Changeset: r62404:7f18238a3962
Date: 2013-03-18 12:47 -0700
http://bitbucket.org/pypy/pypy/changeset/7f18238a3962/
Log: add broadcast_backwards to allow indexing by a lower-shaped array
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -294,12 +294,12 @@
self.backstrides = backstrides
self.storage = storage
- def create_iter(self, shape=None):
+ def create_iter(self, shape=None, backward_broadcast=False):
if shape is None or shape == self.get_shape():
return iter.ConcreteArrayIterator(self)
r = calculate_broadcast_strides(self.get_strides(),
self.get_backstrides(),
- self.get_shape(), shape)
+ self.get_shape(), shape, backward_broadcast)
return iter.MultiDimViewIterator(self, self.dtype, 0, r[0], r[1], shape)
def fill(self, box):
@@ -362,11 +362,12 @@
def fill(self, box):
loop.fill(self, box.convert_to(self.dtype))
- def create_iter(self, shape=None):
+ def create_iter(self, shape=None, backward_broadcast=False):
if shape is not None and shape != self.get_shape():
r = calculate_broadcast_strides(self.get_strides(),
self.get_backstrides(),
- self.get_shape(), shape)
+ self.get_shape(), shape,
+ backward_broadcast)
return iter.MultiDimViewIterator(self.parent, self.dtype,
self.start, r[0], r[1], shape)
if len(self.get_shape()) == 1:
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -224,9 +224,10 @@
s.append('])')
return s.build()
- def create_iter(self, shape=None):
+ def create_iter(self, shape=None, backward_broadcast=False):
assert isinstance(self.implementation, BaseArrayImplementation)
- return self.implementation.create_iter(shape)
+ return self.implementation.create_iter(shape,
+ backward_broadcast=backward_broadcast)
def create_axis_iter(self, shape, dim, cum):
return self.implementation.create_axis_iter(shape, dim, cum)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -300,14 +300,13 @@
def getitem_filter(res, arr, index):
res_iter = res.create_iter()
- index_iter = index.create_iter(arr.get_shape())
+ index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
arr_iter = arr.create_iter()
shapelen = len(arr.get_shape())
arr_dtype = arr.get_dtype()
index_dtype = index.get_dtype()
# XXX length of shape of index as well?
while not index_iter.done():
- print 'res,arr,index', res_iter.offset, arr_iter.offset, index_iter.offset, index_iter.getitem_bool()
getitem_filter_driver.jit_merge_point(shapelen=shapelen,
index_dtype=index_dtype,
arr_dtype=arr_dtype,
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -40,7 +40,7 @@
rshape += shape[s:]
return rshape, rstart, rstrides, rbackstrides
-def calculate_broadcast_strides(strides, backstrides, orig_shape, res_shape):
+def calculate_broadcast_strides(strides, backstrides, orig_shape, res_shape, backwards=False):
rstrides = []
rbackstrides = []
for i in range(len(orig_shape)):
@@ -50,8 +50,12 @@
else:
rstrides.append(strides[i])
rbackstrides.append(backstrides[i])
- rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
- rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
+ if backwards:
+ rstrides = rstrides + [0] * (len(res_shape) - len(orig_shape))
+ rbackstrides = rbackstrides + [0] * (len(res_shape) - len(orig_shape))
+ else:
+ rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
+ rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
return rstrides, rbackstrides
def is_single_elem(space, w_elem, is_rec_type):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1612,9 +1612,7 @@
raises(ValueError, "b[array([[True, False], [True, False]])]")
a = array([[1,2,3],[4,5,6],[7,8,9]],int)
c = array([True,False,True],bool)
- print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
b = a[c]
- print 'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy'
assert (a[c] == [[1, 2, 3], [7, 8, 9]]).all()
def test_bool_array_index_setitem(self):
More information about the pypy-commit
mailing list