[pypy-commit] pypy default: support keepdims arg for array reduce operations

bdkearns noreply at buildbot.pypy.org
Wed Jan 29 23:34:19 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r69004:9371233f1468
Date: 2014-01-29 17:26 -0500
http://bitbucket.org/pypy/pypy/changeset/9371233f1468/

Log:	support keepdims arg for array reduce operations

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -961,7 +961,8 @@
 
     def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False,
                            cumulative=False):
-        def impl(self, space, w_axis=None, w_dtype=None, w_out=None):
+        @unwrap_spec(keepdims=bool)
+        def impl(self, space, w_axis=None, w_dtype=None, w_out=None, keepdims=False):
             if space.is_none(w_out):
                 out = None
             elif not isinstance(w_out, W_NDimArray):
@@ -971,7 +972,7 @@
                 out = w_out
             return getattr(interp_ufuncs.get(space), ufunc_name).reduce(
                 space, self, promote_to_largest, w_axis,
-                False, out, w_dtype, cumulative=cumulative)
+                keepdims, out, w_dtype, cumulative=cumulative)
         return func_with_new_name(impl, "reduce_%s_impl_%d_%d" % (ufunc_name,
                     promote_to_largest, cumulative))
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -252,6 +252,11 @@
         if out:
             out.set_scalar_value(res)
             return out
+        if keepdims:
+            shape = [1] * len(obj_shape)
+            out = W_NDimArray.from_shape(space, [1] * len(obj_shape), dtype, w_instance=obj)
+            out.implementation.setitem(0, res)
+            return out
         return res
 
     def descr_outer(self, space, __args__):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1399,6 +1399,8 @@
         from numpypy import arange, array
         a = arange(15).reshape(5, 3)
         assert a.sum() == 105
+        assert a.sum(keepdims=True) == 105
+        assert a.sum(keepdims=True).shape == (1, 1)
         assert a.max() == 14
         assert array([]).sum() == 0.0
         assert array([]).reshape(0, 2).sum() == 0.
@@ -1431,6 +1433,8 @@
         from numpypy import array, dtype
         a = array(range(1, 6))
         assert a.prod() == 120.0
+        assert a.prod(keepdims=True) == 120.0
+        assert a.prod(keepdims=True).shape == (1,)
         assert a[:4].prod() == 24.0
         for dt in ['bool', 'int8', 'uint8', 'int16', 'uint16']:
             a = array([True, False], dtype=dt)
@@ -1445,6 +1449,8 @@
         from numpypy import array, zeros
         a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
         assert a.max() == 5.7
+        assert a.max(keepdims=True) == 5.7
+        assert a.max(keepdims=True).shape == (1,)
         b = array([])
         raises(ValueError, "b.max()")
         assert list(zeros((0, 2)).max(axis=1)) == []
@@ -1458,6 +1464,8 @@
         from numpypy import array, zeros
         a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
         assert a.min() == -3.0
+        assert a.min(keepdims=True) == -3.0
+        assert a.min(keepdims=True).shape == (1,)
         b = array([])
         raises(ValueError, "b.min()")
         assert list(zeros((0, 2)).min(axis=1)) == []
@@ -1508,6 +1516,8 @@
         assert a.all() == False
         a[0] = 3.0
         assert a.all() == True
+        assert a.all(keepdims=True) == True
+        assert a.all(keepdims=True).shape == (1,)
         b = array([])
         assert b.all() == True
 
@@ -1515,6 +1525,8 @@
         from numpypy import array, zeros
         a = array(range(5))
         assert a.any() == True
+        assert a.any(keepdims=True) == True
+        assert a.any(keepdims=True).shape == (1,)
         b = zeros(5)
         assert b.any() == False
         c = array([])


More information about the pypy-commit mailing list