[pypy-commit] pypy propogate-nans: avoid using ne on nans by adding argmax, argmin to types

mattip noreply at buildbot.pypy.org
Sun Nov 8 11:18:20 EST 2015


Author: mattip <matti.picus at gmail.com>
Branch: propogate-nans
Changeset: r80589:50e5f751fee5
Date: 2015-11-08 18:18 +0200
http://bitbucket.org/pypy/pypy/changeset/50e5f751fee5/

Log:	avoid using ne on nans by adding argmax, argmin to types

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -534,10 +534,10 @@
             while not inner_iter.done(inner_state):
                 arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
                 w_val = inner_iter.getitem(inner_state)
-                new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
-                if dtype.itemtype.ne(new_best, cur_best):
+                old_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+                if not old_best:
                     result = idx
-                    cur_best = new_best
+                    cur_best = w_val
                 inner_state = inner_iter.next(inner_state)
                 idx += 1
             result = get_dtype_cache(space).w_longdtype.box(result)
@@ -557,17 +557,17 @@
         while not iter.done(state):
             arg_flat_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
             w_val = iter.getitem(state)
-            new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
-            if dtype.itemtype.ne(new_best, cur_best):
+            old_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+            if not old_best:
                 result = idx
-                cur_best = new_best
+                cur_best = w_val
             state = iter.next(state)
             idx += 1
         return result
 
     return argmin_argmax, argmin_argmax_flat
-argmin, argmin_flat = _new_argmin_argmax('min')
-argmax, argmax_flat = _new_argmin_argmax('max')
+argmin, argmin_flat = _new_argmin_argmax('argmin')
+argmax, argmax_flat = _new_argmin_argmax('argmax')
 
 dot_driver = jit.JitDriver(name = 'numpy_dot',
                            greens = ['dtype'],
diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1858,6 +1858,7 @@
         e = array([0, -1, -float('inf'), float('nan'), 6], dtype='float16')
         assert map(isnan, e) == [False, False, False, True, False]
         assert map(isinf, e) == [False, False, True, False, False]
+        assert e.argmax() == 3
         # numpy preserves value for uint16 -> cast_as_float16 -> 
         #     convert_to_float64 -> convert_to_float16 -> uint16
         #  even for float16 various float16 nans
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -345,6 +345,14 @@
     def min(self, v1, v2):
         return min(v1, v2)
 
+    @raw_binary_op
+    def argmax(self, v1, v2):
+        return v1 >= v2
+
+    @raw_binary_op
+    def argmin(self, v1, v2):
+        return v1 <= v2
+
     @raw_unary_op
     def rint(self, v):
         float64 = Float64(self.space)
@@ -820,6 +828,14 @@
     def min(self, v1, v2):
         return v1 if v1 <= v2 or rfloat.isnan(v1) else v2
 
+    @raw_binary_op
+    def argmax(self, v1, v2):
+        return v1 >= v2 or rfloat.isnan(v1)
+
+    @raw_binary_op
+    def argmin(self, v1, v2):
+        return v1 <= v2 or rfloat.isnan(v1)
+
     @simple_binary_op
     def fmax(self, v1, v2):
         return v1 if v1 >= v2 or rfloat.isnan(v2) else v2
@@ -1407,6 +1423,16 @@
             return v1
         return v2
 
+    def argmin(self, v1, v2):
+        if self.le(v1, v2) or self.isnan(v1):
+            return True
+        return False
+
+    def argmax(self, v1, v2):
+        if self.ge(v1, v2) or self.isnan(v1):
+            return True
+        return False
+
     @complex_binary_op
     def floordiv(self, v1, v2):
         (r1, i1), (r2, i2) = v1, v2
@@ -1927,6 +1953,18 @@
             return v1
         return v2
 
+    @raw_binary_op
+    def argmax(self, v1, v2):
+        if self.space.is_true(self.space.ge(v1, v2)):
+            return True
+        return False
+
+    @raw_binary_op
+    def argmin(self, v1, v2):
+        if self.space.is_true(self.space.le(v1, v2)):
+            return True
+        return False
+
     @raw_unary_op
     def bool(self,v):
         return self._obool(v)


More information about the pypy-commit mailing list