[pypy-commit] pypy default: implement and test ndarray.trace()

bdkearns noreply at buildbot.pypy.org
Wed Oct 16 01:22:35 CEST 2013


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r67400:3fd593fe30bf
Date: 2013-10-15 19:09 -0400
http://bitbucket.org/pypy/pypy/changeset/3fd593fe30bf/

Log:	implement and test ndarray.trace()

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -550,6 +550,12 @@
         return interp_arrayops.diagonal(space, self.implementation, offset,
                                         axis1, axis2)
 
+    @unwrap_spec(offset=int, axis1=int, axis2=int)
+    def descr_trace(self, space, offset=0, axis1=0, axis2=1,
+                    w_dtype=None, w_out=None):
+        diag = self.descr_diagonal(space, offset, axis1, axis2)
+        return diag.descr_sum(space, w_axis=space.wrap(-1), w_dtype=w_dtype, w_out=w_out)
+
     def descr_dump(self, space, w_file):
         raise OperationError(space.w_NotImplementedError, space.wrap(
             "dump not implemented yet"))
@@ -653,11 +659,6 @@
         raise OperationError(space.w_NotImplementedError, space.wrap(
             "tofile not implemented yet"))
 
-    def descr_trace(self, space, w_offset=0, w_axis1=0, w_axis2=1,
-                    w_dtype=None, w_out=None):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "trace not implemented yet"))
-
     def descr_view(self, space, w_dtype=None, w_type=None) :
         if not w_type and w_dtype:
             try:
@@ -1153,6 +1154,7 @@
     round    = interp2app(W_NDimArray.descr_round),
     data     = GetSetProperty(W_NDimArray.descr_get_data),
     diagonal = interp2app(W_NDimArray.descr_diagonal),
+    trace = interp2app(W_NDimArray.descr_trace),
     view = interp2app(W_NDimArray.descr_view),
 
     ctypes = GetSetProperty(W_NDimArray.descr_get_ctypes), # XXX unimplemented
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1465,6 +1465,14 @@
         assert a[3].imag == -10
         assert a[2].imag == -5
 
+    def test_trace(self):
+        import numpypy as np
+        assert np.trace(np.eye(3)) == 3.0
+        a = np.arange(8).reshape((2,2,2))
+        assert np.array_equal(np.trace(a), [6, 8])
+        a = np.arange(24).reshape((2,2,2,3))
+        assert np.trace(a).shape == (2, 3)
+
     def test_view(self):
         from numpypy import array, int8, int16, dtype
         x = array((1, 2), dtype=int8)


More information about the pypy-commit mailing list