[pypy-commit] pypy default: merge numpypy-reshape
fijal
noreply at buildbot.pypy.org
Sat Jan 14 20:40:47 CET 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch:
Changeset: r51315:c9f77f542246
Date: 2012-01-14 21:40 +0200
http://bitbucket.org/pypy/pypy/changeset/c9f77f542246/
Log: merge numpypy-reshape
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -157,9 +157,6 @@
# (meaning that the realignment of elements crosses from one step into another)
# return None so that the caller can raise an exception.
def calc_new_strides(new_shape, old_shape, old_strides):
- # Return the proper strides for new_shape, or None if the mapping crosses
- # stepping boundaries
-
# Assumes that prod(old_shape) == prod(new_shape), len(old_shape) > 1, and
# len(new_shape) > 0
steps = []
@@ -167,6 +164,7 @@
oldI = 0
new_strides = []
if old_strides[0] < old_strides[-1]:
+ #Start at old_shape[0], old_stides[0]
for i in range(len(old_shape)):
steps.append(old_strides[i] / last_step)
last_step *= old_shape[i]
@@ -184,10 +182,11 @@
if n_new_elems_used == n_old_elems_to_use:
oldI += 1
if oldI >= len(old_shape):
- break
+ continue
cur_step = steps[oldI]
n_old_elems_to_use *= old_shape[oldI]
else:
+ #Start at old_shape[-1], old_strides[-1]
for i in range(len(old_shape) - 1, -1, -1):
steps.insert(0, old_strides[i] / last_step)
last_step *= old_shape[i]
@@ -207,7 +206,7 @@
if n_new_elems_used == n_old_elems_to_use:
oldI -= 1
if oldI < -len(old_shape):
- break
+ continue
cur_step = steps[oldI]
n_old_elems_to_use *= old_shape[oldI]
return new_strides
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -157,6 +157,8 @@
assert calc_new_strides([2, 3, 4], [8, 3], [1, 16]) is None
assert calc_new_strides([24], [2, 4, 3], [48, 6, 1]) is None
assert calc_new_strides([24], [2, 4, 3], [24, 6, 2]) == [2]
+ assert calc_new_strides([105, 1], [3, 5, 7], [35, 7, 1]) == [1, 1]
+ assert calc_new_strides([1, 105], [3, 5, 7], [35, 7, 1]) == [105, 1]
class AppTestNumArray(BaseNumpyAppTest):
@@ -765,7 +767,6 @@
assert (a[:, 1, :].sum(1) == [70, 315, 560]).all()
raises (ValueError, 'a[:, 1, :].sum(2)')
assert ((a + a).T.sum(2).T == (a + a).sum(0)).all()
- skip("Those are broken on reshape, fix!")
assert (a.reshape(1,-1).sum(0) == range(105)).all()
assert (a.reshape(1,-1).sum(1) == 5460)
@@ -1556,3 +1557,7 @@
a = range(12)
b = reshape(a, (3, 4))
assert b.shape == (3, 4)
+ a = array(range(105)).reshape(3, 5, 7)
+ assert a.reshape(1, -1).shape == (1, 105)
+ assert a.reshape(1, 1, -1).shape == (1, 1, 105)
+ assert a.reshape(-1, 1, 1).shape == (105, 1, 1)
More information about the pypy-commit
mailing list