[pypy-commit] pypy fix_transpose_for_list_v3: Fixed ndarray.transpose when argument is a list or an array
Sergey Matyunin
pypy.commits at gmail.com
Tue Mar 22 15:56:17 EDT 2016
Author: Sergey Matyunin <sbmatyunin at gmail.com>
Branch: fix_transpose_for_list_v3
Changeset: r83277:ef93194a1339
Date: 2016-03-22 18:08 +0100
http://bitbucket.org/pypy/pypy/changeset/ef93194a1339/
Log: Fixed ndarray.transpose when argument is a list or an array
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -502,29 +502,34 @@
return W_NDimArray(self.implementation.transpose(self, axes))
def descr_transpose(self, space, args_w):
- if len(args_w) == 1 and space.isinstance_w(args_w[0], space.w_tuple):
- args_w = space.fixedview(args_w[0])
- if (len(args_w) == 0 or
- len(args_w) == 1 and space.is_none(args_w[0])):
+ if len(args_w) == 0 or len(args_w) == 1 and space.is_none(args_w[0]):
return self.descr_get_transpose(space)
else:
- if len(args_w) != self.ndims():
- raise oefmt(space.w_ValueError, "axes don't match array")
- axes = []
- axes_seen = [False] * self.ndims()
- for w_arg in args_w:
- try:
- axis = support.index_w(space, w_arg)
- except OperationError:
- raise oefmt(space.w_TypeError, "an integer is required")
- if axis < 0 or axis >= self.ndims():
- raise oefmt(space.w_ValueError, "invalid axis for this array")
- if axes_seen[axis] is True:
- raise oefmt(space.w_ValueError, "repeated axis in transpose")
- axes.append(axis)
- axes_seen[axis] = True
- return self.descr_get_transpose(space, axes)
+ if len(args_w) > 1:
+ axes = args_w
+ else: # Iterable in the only argument (len(arg_w) == 1 and arg_w[0] is not None)
+ axes = space.fixedview(args_w[0])
+ axes = self._checked_axes(axes, space)
+ return self.descr_get_transpose(space, axes)
+
+ def _checked_axes(self, axes_raw, space):
+ if len(axes_raw) != self.ndims():
+ raise oefmt(space.w_ValueError, "axes don't match array")
+ axes = []
+ axes_seen = [False] * self.ndims()
+ for elem in axes_raw:
+ try:
+ axis = support.index_w(space, elem)
+ except OperationError:
+ raise oefmt(space.w_TypeError, "an integer is required")
+ if axis < 0 or axis >= self.ndims():
+ raise oefmt(space.w_ValueError, "invalid axis for this array")
+ if axes_seen[axis] is True:
+ raise oefmt(space.w_ValueError, "repeated axis in transpose")
+ axes.append(axis)
+ axes_seen[axis] = True
+ return axes
@unwrap_spec(axis1=int, axis2=int)
def descr_swapaxes(self, space, axis1, axis2):
diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -2960,6 +2960,36 @@
assert (a.transpose() == b).all()
assert (a.transpose(None) == b).all()
+ def test_transpose_arg_tuple(self):
+ import numpy as np
+ a = np.arange(24).reshape(2, 3, 4)
+ transpose_args = a.transpose(1, 2, 0)
+
+ transpose_test = a.transpose((1, 2, 0))
+
+ assert transpose_test.shape == (3, 4, 2)
+ assert (transpose_args == transpose_test).all()
+
+ def test_transpose_arg_list(self):
+ import numpy as np
+ a = np.arange(24).reshape(2, 3, 4)
+ transpose_args = a.transpose(1, 2, 0)
+
+ transpose_test = a.transpose([1, 2, 0])
+
+ assert transpose_test.shape == (3, 4, 2)
+ assert (transpose_args == transpose_test).all()
+
+ def test_transpose_arg_array(self):
+ import numpy as np
+ a = np.arange(24).reshape(2, 3, 4)
+ transpose_args = a.transpose(1, 2, 0)
+
+ transpose_test = a.transpose(np.array([1, 2, 0]))
+
+ assert transpose_test.shape == (3, 4, 2)
+ assert (transpose_args == transpose_test).all()
+
def test_transpose_error(self):
import numpy as np
a = np.arange(24).reshape(2, 3, 4)
@@ -2968,6 +2998,11 @@
raises(ValueError, a.transpose, 1, 0, 1)
raises(TypeError, a.transpose, 1, 0, '2')
+ def test_transpose_unexpected_argument(self):
+ import numpy as np
+ a = np.array([[1, 2], [3, 4], [5, 6]])
+ raises(TypeError, 'a.transpose(axes=(1,2,0))')
+
def test_flatiter(self):
from numpy import array, flatiter, arange, zeros
a = array([[10, 30], [40, 60]])
More information about the pypy-commit
mailing list