[pypy-commit] pypy default: support keepdims arg for array reduce operations
bdkearns
noreply at buildbot.pypy.org
Wed Jan 29 23:34:19 CET 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r69004:9371233f1468
Date: 2014-01-29 17:26 -0500
http://bitbucket.org/pypy/pypy/changeset/9371233f1468/
Log: support keepdims arg for array reduce operations
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -961,7 +961,8 @@
def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False,
cumulative=False):
- def impl(self, space, w_axis=None, w_dtype=None, w_out=None):
+ @unwrap_spec(keepdims=bool)
+ def impl(self, space, w_axis=None, w_dtype=None, w_out=None, keepdims=False):
if space.is_none(w_out):
out = None
elif not isinstance(w_out, W_NDimArray):
@@ -971,7 +972,7 @@
out = w_out
return getattr(interp_ufuncs.get(space), ufunc_name).reduce(
space, self, promote_to_largest, w_axis,
- False, out, w_dtype, cumulative=cumulative)
+ keepdims, out, w_dtype, cumulative=cumulative)
return func_with_new_name(impl, "reduce_%s_impl_%d_%d" % (ufunc_name,
promote_to_largest, cumulative))
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -252,6 +252,11 @@
if out:
out.set_scalar_value(res)
return out
+ if keepdims:
+ shape = [1] * len(obj_shape)
+ out = W_NDimArray.from_shape(space, [1] * len(obj_shape), dtype, w_instance=obj)
+ out.implementation.setitem(0, res)
+ return out
return res
def descr_outer(self, space, __args__):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1399,6 +1399,8 @@
from numpypy import arange, array
a = arange(15).reshape(5, 3)
assert a.sum() == 105
+ assert a.sum(keepdims=True) == 105
+ assert a.sum(keepdims=True).shape == (1, 1)
assert a.max() == 14
assert array([]).sum() == 0.0
assert array([]).reshape(0, 2).sum() == 0.
@@ -1431,6 +1433,8 @@
from numpypy import array, dtype
a = array(range(1, 6))
assert a.prod() == 120.0
+ assert a.prod(keepdims=True) == 120.0
+ assert a.prod(keepdims=True).shape == (1,)
assert a[:4].prod() == 24.0
for dt in ['bool', 'int8', 'uint8', 'int16', 'uint16']:
a = array([True, False], dtype=dt)
@@ -1445,6 +1449,8 @@
from numpypy import array, zeros
a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
assert a.max() == 5.7
+ assert a.max(keepdims=True) == 5.7
+ assert a.max(keepdims=True).shape == (1,)
b = array([])
raises(ValueError, "b.max()")
assert list(zeros((0, 2)).max(axis=1)) == []
@@ -1458,6 +1464,8 @@
from numpypy import array, zeros
a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
assert a.min() == -3.0
+ assert a.min(keepdims=True) == -3.0
+ assert a.min(keepdims=True).shape == (1,)
b = array([])
raises(ValueError, "b.min()")
assert list(zeros((0, 2)).min(axis=1)) == []
@@ -1508,6 +1516,8 @@
assert a.all() == False
a[0] = 3.0
assert a.all() == True
+ assert a.all(keepdims=True) == True
+ assert a.all(keepdims=True).shape == (1,)
b = array([])
assert b.all() == True
@@ -1515,6 +1525,8 @@
from numpypy import array, zeros
a = array(range(5))
assert a.any() == True
+ assert a.any(keepdims=True) == True
+ assert a.any(keepdims=True).shape == (1,)
b = zeros(5)
assert b.any() == False
c = array([])
More information about the pypy-commit
mailing list