[pypy-commit] pypy numpy-generic-item: generic.swapaxes, improve test

yuyichao noreply at buildbot.pypy.org
Tue Oct 21 23:10:55 CEST 2014


Author: Yichao Yu <yyc1992 at gmail.com>
Branch: numpy-generic-item
Changeset: r74054:0703a6e2723e
Date: 2014-09-23 12:29 -0400
http://bitbucket.org/pypy/pypy/changeset/0703a6e2723e/

Log:	generic.swapaxes, improve test

diff --git a/pypy/module/micronumpy/boxes.py b/pypy/module/micronumpy/boxes.py
--- a/pypy/module/micronumpy/boxes.py
+++ b/pypy/module/micronumpy/boxes.py
@@ -416,6 +416,10 @@
             self.w_flags = W_FlagsObject(self)
         return self.w_flags
 
+    @unwrap_spec(axis1=int, axis2=int)
+    def descr_swapaxes(self, space, axis1, axis2):
+        return self.item(space)
+
 class W_BoolBox(W_GenericBox, PrimitiveBox):
     descr__new__, _get_dtype, descr_reduce = new_dtype_getter(NPY.BOOL)
 
@@ -669,6 +673,7 @@
     tostring = interp2app(W_GenericBox.descr_tostring),
     tobytes = interp2app(W_GenericBox.descr_tostring),
     reshape = interp2app(W_GenericBox.descr_reshape),
+    swapaxes = interp2app(W_GenericBox.descr_swapaxes),
 
     dtype = GetSetProperty(W_GenericBox.descr_get_dtype),
     size = GetSetProperty(W_GenericBox.descr_get_size),
diff --git a/pypy/module/micronumpy/test/test_scalar.py b/pypy/module/micronumpy/test/test_scalar.py
--- a/pypy/module/micronumpy/test/test_scalar.py
+++ b/pypy/module/micronumpy/test/test_scalar.py
@@ -301,13 +301,14 @@
     def test_item_tolist(self):
         from numpypy import int8, int16, int32, int64, float32, float64
         from numpypy import complex64, complex128
-        for t in int8, int16, int32, int64:
-            val = t(17)
-            assert val == 17
-            assert val.item() == 17
-            assert val.tolist() == 17
-            assert type(val.item()) == int
-            assert type(val.tolist()) == int
+
+        def _do_test(np_type, py_type, orig_val, exp_val):
+            val = np_type(orig_val)
+            assert val == orig_val
+            assert val.item() == exp_val
+            assert val.tolist() == exp_val
+            assert type(val.item()) == py_type
+            assert type(val.tolist()) == py_type
             val.item(0)
             val.item(())
             val.item((0,))
@@ -316,65 +317,54 @@
             raises(TypeError, val.item, '')
             raises(IndexError, val.item, 2)
 
+        for t in int8, int16, int32, int64:
+            _do_test(t, int, 17, 17)
+
         for t in float32, float64:
-            val = t(17)
-            assert val == 17
-            assert val.item() == 17
-            assert val.tolist() == 17
-            assert type(val.item()) == float
-            assert type(val.tolist()) == float
-            val.item(0)
-            val.item(())
-            val.item((0,))
-            raises(ValueError, val.item, 0, 1)
-            raises(ValueError, val.item, 0, '')
-            raises(TypeError, val.item, '')
-            raises(IndexError, val.item, 2)
+            _do_test(t, float, 17, 17)
 
         for t in complex64, complex128:
-            val = t(17j)
-            assert val == 17j
-            assert val.item() == 17j
-            assert val.tolist() == 17j
-            assert type(val.item()) == complex
-            assert type(val.tolist()) == complex
-            val.item(0)
-            val.item(())
-            val.item((0,))
-            raises(ValueError, val.item, 0, 1)
-            raises(ValueError, val.item, 0, '')
-            raises(TypeError, val.item, '')
-            raises(IndexError, val.item, 2)
+            _do_test(t, complex, 17j, 17j)
 
     def test_transpose(self):
         from numpypy import int8, int16, int32, int64, float32, float64
         from numpypy import complex64, complex128
-        for t in int8, int16, int32, int64:
-            val = t(17)
-            assert val == 17
-            assert val.transpose() == 17
-            assert type(val.transpose()) == int
+
+        def _do_test(np_type, py_type, orig_val, exp_val):
+            val = np_type(orig_val)
+            assert val == orig_val
+            assert val.transpose() == exp_val
+            assert type(val.transpose()) == py_type
             val.transpose(())
             raises(ValueError, val.transpose, 0, 1)
             raises(TypeError, val.transpose, 0, '')
             raises(ValueError, val.transpose, 0)
 
+        for t in int8, int16, int32, int64:
+            _do_test(t, int, 17, 17)
+
         for t in float32, float64:
-            val = t(17)
-            assert val == 17
-            assert val.transpose() == 17
-            assert type(val.transpose()) == float
-            val.transpose(())
-            raises(ValueError, val.transpose, 0, 1)
-            raises(TypeError, val.transpose, 0, '')
-            raises(ValueError, val.transpose, 0)
+            _do_test(t, float, 17, 17)
 
         for t in complex64, complex128:
-            val = t(17j)
-            assert val == 17j
-            assert val.transpose() == 17j
-            assert type(val.transpose()) == complex
-            val.transpose(())
-            raises(ValueError, val.transpose, 0, 1)
-            raises(TypeError, val.transpose, 0, '')
-            raises(ValueError, val.transpose, 0)
+            _do_test(t, complex, 17j, 17j)
+
+    def test_swapaxes(self):
+        from numpypy import int8, int16, int32, int64, float32, float64
+        from numpypy import complex64, complex128
+
+        def _do_test(np_type, py_type, orig_val, exp_val):
+            val = np_type(orig_val)
+            assert val == orig_val
+            assert val.swapaxes(10, 20) == exp_val
+            assert type(val.swapaxes(0, 1)) == py_type
+            raises(TypeError, val.swapaxes, 0, ())
+
+        for t in int8, int16, int32, int64:
+            _do_test(t, int, 17, 17)
+
+        for t in float32, float64:
+            _do_test(t, float, 17, 17)
+
+        for t in complex64, complex128:
+            _do_test(t, complex, 17j, 17j)


More information about the pypy-commit mailing list