[pypy-commit] pypy default: implement and test ndarray.trace()
bdkearns
noreply at buildbot.pypy.org
Wed Oct 16 01:22:35 CEST 2013
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r67400:3fd593fe30bf
Date: 2013-10-15 19:09 -0400
http://bitbucket.org/pypy/pypy/changeset/3fd593fe30bf/
Log: implement and test ndarray.trace()
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -550,6 +550,12 @@
return interp_arrayops.diagonal(space, self.implementation, offset,
axis1, axis2)
+ @unwrap_spec(offset=int, axis1=int, axis2=int)
+ def descr_trace(self, space, offset=0, axis1=0, axis2=1,
+ w_dtype=None, w_out=None):
+ diag = self.descr_diagonal(space, offset, axis1, axis2)
+ return diag.descr_sum(space, w_axis=space.wrap(-1), w_dtype=w_dtype, w_out=w_out)
+
def descr_dump(self, space, w_file):
raise OperationError(space.w_NotImplementedError, space.wrap(
"dump not implemented yet"))
@@ -653,11 +659,6 @@
raise OperationError(space.w_NotImplementedError, space.wrap(
"tofile not implemented yet"))
- def descr_trace(self, space, w_offset=0, w_axis1=0, w_axis2=1,
- w_dtype=None, w_out=None):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "trace not implemented yet"))
-
def descr_view(self, space, w_dtype=None, w_type=None) :
if not w_type and w_dtype:
try:
@@ -1153,6 +1154,7 @@
round = interp2app(W_NDimArray.descr_round),
data = GetSetProperty(W_NDimArray.descr_get_data),
diagonal = interp2app(W_NDimArray.descr_diagonal),
+ trace = interp2app(W_NDimArray.descr_trace),
view = interp2app(W_NDimArray.descr_view),
ctypes = GetSetProperty(W_NDimArray.descr_get_ctypes), # XXX unimplemented
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1465,6 +1465,14 @@
assert a[3].imag == -10
assert a[2].imag == -5
+ def test_trace(self):
+ import numpypy as np
+ assert np.trace(np.eye(3)) == 3.0
+ a = np.arange(8).reshape((2,2,2))
+ assert np.array_equal(np.trace(a), [6, 8])
+ a = np.arange(24).reshape((2,2,2,3))
+ assert np.trace(a).shape == (2, 3)
+
def test_view(self):
from numpypy import array, int8, int16, dtype
x = array((1, 2), dtype=int8)
More information about the pypy-commit
mailing list