[pypy-commit] pypy default: test/fix ufunc reduce with comparison func when dtype specified

bdkearns noreply at buildbot.pypy.org
Fri Dec 5 06:50:14 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r74824:4213885db36d
Date: 2014-12-05 00:23 -0500
http://bitbucket.org/pypy/pypy/changeset/4213885db36d/

Log:	test/fix ufunc reduce with comparison func when dtype specified

diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -785,6 +785,7 @@
         assert exc.value[0] == "'axis' entry is out of bounds"
 
     def test_reduce_1d(self):
+        import numpy as np
         from numpypy import array, add, maximum, less, float16, complex64
 
         assert less.reduce([5, 4, 3, 2, 1])
@@ -799,6 +800,10 @@
         assert type(add.reduce(array([True, False] * 200, dtype='float16'))) is float16
         assert type(add.reduce(array([True, False] * 200, dtype='complex64'))) is complex64
 
+        for dtype in ['bool', 'int']:
+            assert np.equal.reduce([1, 2], dtype=dtype) == True
+            assert np.equal.reduce([1, 2, 0], dtype=dtype) == False
+
     def test_reduceND(self):
         from numpypy import add, arange
         a = arange(12).reshape(3, 4)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -196,16 +196,15 @@
                 axis += shapelen
         assert axis >= 0
         dtype = descriptor.decode_w_dtype(space, dtype)
-        if dtype is None:
-            if self.comparison_func:
-                dtype = descriptor.get_dtype_cache(space).w_booldtype
-            else:
-                dtype = find_unaryop_result_dtype(
-                    space, obj.get_dtype(),
-                    promote_to_float=self.promote_to_float,
-                    promote_to_largest=self.promote_to_largest,
-                    promote_bools=self.promote_bools,
-                )
+        if self.comparison_func:
+            dtype = descriptor.get_dtype_cache(space).w_booldtype
+        elif dtype is None:
+            dtype = find_unaryop_result_dtype(
+                space, obj.get_dtype(),
+                promote_to_float=self.promote_to_float,
+                promote_to_largest=self.promote_to_largest,
+                promote_bools=self.promote_bools,
+            )
         if self.identity is None:
             for i in range(shapelen):
                 if space.is_none(w_axis) or i == axis:


More information about the pypy-commit mailing list