[pypy-commit] pypy fix_transpose_for_list_v3: Fixed ndarray.transpose when argument is a list or an array

Sergey Matyunin pypy.commits at gmail.com
Tue Mar 22 15:56:17 EDT 2016


Author: Sergey Matyunin <sbmatyunin at gmail.com>
Branch: fix_transpose_for_list_v3
Changeset: r83277:ef93194a1339
Date: 2016-03-22 18:08 +0100
http://bitbucket.org/pypy/pypy/changeset/ef93194a1339/

Log:	Fixed ndarray.transpose when argument is a list or an array

diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -502,29 +502,34 @@
         return W_NDimArray(self.implementation.transpose(self, axes))
 
     def descr_transpose(self, space, args_w):
-        if len(args_w) == 1 and space.isinstance_w(args_w[0], space.w_tuple):
-            args_w = space.fixedview(args_w[0])
-        if (len(args_w) == 0 or
-                len(args_w) == 1 and space.is_none(args_w[0])):
+        if len(args_w) == 0 or len(args_w) == 1 and space.is_none(args_w[0]):
             return self.descr_get_transpose(space)
         else:
-            if len(args_w) != self.ndims():
-                raise oefmt(space.w_ValueError, "axes don't match array")
-            axes = []
-            axes_seen = [False] * self.ndims()
-            for w_arg in args_w:
-                try:
-                    axis = support.index_w(space, w_arg)
-                except OperationError:
-                    raise oefmt(space.w_TypeError, "an integer is required")
-                if axis < 0 or axis >= self.ndims():
-                    raise oefmt(space.w_ValueError, "invalid axis for this array")
-                if axes_seen[axis] is True:
-                    raise oefmt(space.w_ValueError, "repeated axis in transpose")
-                axes.append(axis)
-                axes_seen[axis] = True
-            return self.descr_get_transpose(space, axes)
+            if len(args_w) > 1:
+                axes = args_w
+            else:  # Iterable in the only argument (len(arg_w) == 1 and arg_w[0] is not None)
+                axes = space.fixedview(args_w[0])
 
+        axes = self._checked_axes(axes, space)
+        return self.descr_get_transpose(space, axes)
+
+    def _checked_axes(self, axes_raw, space):
+        if len(axes_raw) != self.ndims():
+            raise oefmt(space.w_ValueError, "axes don't match array")
+        axes = []
+        axes_seen = [False] * self.ndims()
+        for elem in axes_raw:
+            try:
+                axis = support.index_w(space, elem)
+            except OperationError:
+                raise oefmt(space.w_TypeError, "an integer is required")
+            if axis < 0 or axis >= self.ndims():
+                raise oefmt(space.w_ValueError, "invalid axis for this array")
+            if axes_seen[axis] is True:
+                raise oefmt(space.w_ValueError, "repeated axis in transpose")
+            axes.append(axis)
+            axes_seen[axis] = True
+        return axes
 
     @unwrap_spec(axis1=int, axis2=int)
     def descr_swapaxes(self, space, axis1, axis2):
diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -2960,6 +2960,36 @@
         assert (a.transpose() == b).all()
         assert (a.transpose(None) == b).all()
 
+    def test_transpose_arg_tuple(self):
+        import numpy as np
+        a = np.arange(24).reshape(2, 3, 4)
+        transpose_args = a.transpose(1, 2, 0)
+
+        transpose_test = a.transpose((1, 2, 0))
+
+        assert transpose_test.shape == (3, 4, 2)
+        assert (transpose_args == transpose_test).all()
+
+    def test_transpose_arg_list(self):
+        import numpy as np
+        a = np.arange(24).reshape(2, 3, 4)
+        transpose_args = a.transpose(1, 2, 0)
+
+        transpose_test = a.transpose([1, 2, 0])
+
+        assert transpose_test.shape == (3, 4, 2)
+        assert (transpose_args == transpose_test).all()
+
+    def test_transpose_arg_array(self):
+        import numpy as np
+        a = np.arange(24).reshape(2, 3, 4)
+        transpose_args = a.transpose(1, 2, 0)
+
+        transpose_test = a.transpose(np.array([1, 2, 0]))
+
+        assert transpose_test.shape == (3, 4, 2)
+        assert (transpose_args == transpose_test).all()
+
     def test_transpose_error(self):
         import numpy as np
         a = np.arange(24).reshape(2, 3, 4)
@@ -2968,6 +2998,11 @@
         raises(ValueError, a.transpose, 1, 0, 1)
         raises(TypeError, a.transpose, 1, 0, '2')
 
+    def test_transpose_unexpected_argument(self):
+        import numpy as np
+        a = np.array([[1, 2], [3, 4], [5, 6]])
+        raises(TypeError, 'a.transpose(axes=(1,2,0))')
+
     def test_flatiter(self):
         from numpy import array, flatiter, arange, zeros
         a = array([[10, 30], [40, 60]])


More information about the pypy-commit mailing list