[Numpy-svn] r8509 - in trunk/numpy/core: src/multiarray tests

numpy-svn at scipy.org numpy-svn at scipy.org
Sun Jul 18 16:55:56 EDT 2010


Author: ptvirtan
Date: 2010-07-18 15:55:56 -0500 (Sun, 18 Jul 2010)
New Revision: 8509

Modified:
   trunk/numpy/core/src/multiarray/arraytypes.c.src
   trunk/numpy/core/tests/test_multiarray.py
Log:
BUG: core: fix argmax and argmin NaN handling to conform with max/min (#1429)

This makes `argmax` and `argmix` treat NaN as a maximal element.
Effectively, this causes propagation of NaNs, which is consistent with
the current behavior of amax & amin.

Modified: trunk/numpy/core/src/multiarray/arraytypes.c.src
===================================================================
--- trunk/numpy/core/src/multiarray/arraytypes.c.src	2010-07-18 20:55:33 UTC (rev 8508)
+++ trunk/numpy/core/src/multiarray/arraytypes.c.src	2010-07-18 20:55:56 UTC (rev 8509)
@@ -2735,6 +2735,8 @@
  * #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong,
  *         longlong, ulonglong, float, double, longdouble,
  *         float, double, longdouble, datetime, timedelta#
+ * #isfloat = 0*11, 1*6, 0*2#
+ * #iscomplex = 0*14, 1*3, 0*2#
  * #incr = ip++*14, ip+=2*3, ip++*2#
  */
 static int
@@ -2742,14 +2744,54 @@
 {
     intp i;
     @type@ mp = *ip;
+#if @iscomplex@
+    @type@ mp_im = ip[1];
+#endif
 
     *max_ind = 0;
+
+#if @isfloat@
+    if (npy_isnan(mp)) {
+        /* nan encountered; it's maximal */
+        return 0;
+    }
+#endif
+#if @iscomplex@
+    if (npy_isnan(mp_im)) {
+        /* nan encountered; it's maximal */
+        return 0;
+    }
+#endif
+
     for (i = 1; i < n; i++) {
         @incr@;
-        if (*ip > mp) {
+        /*
+         * Propagate nans, similarly as max() and min()
+         */
+#if @iscomplex@
+        /* Lexical order for complex numbers */
+        if ((ip[0] > mp) || ((ip[0] == mp) && (ip[1] > mp_im))
+                || npy_isnan(ip[0]) || npy_isnan(ip[1])) {
+            mp = ip[0];
+            mp_im = ip[1];
+            *max_ind = i;
+            if (npy_isnan(mp) || npy_isnan(mp_im)) {
+                /* nan encountered, it's maximal */
+                break;
+            }
+        }
+#else
+        if (!(*ip <= mp)) {  /* negated, for correct nan handling */
             mp = *ip;
             *max_ind = i;
+#if @isfloat@
+            if (npy_isnan(mp)) {
+                /* nan encountered, it's maximal */
+                break;
+            }
+#endif
         }
+#endif
     }
     return 0;
 }

Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py	2010-07-18 20:55:33 UTC (rev 8508)
+++ trunk/numpy/core/tests/test_multiarray.py	2010-07-18 20:55:56 UTC (rev 8509)
@@ -671,6 +671,27 @@
 
 
 class TestArgmax(TestCase):
+
+    nan_arr = [
+        ([0, 1, 2, 3, np.nan], 4),
+        ([0, 1, 2, np.nan, 3], 3),
+        ([np.nan, 0, 1, 2, 3], 0),
+        ([np.nan, 0, np.nan, 2, 3], 0),
+        ([0, 1, 2, 3, complex(0,np.nan)], 4),
+        ([0, 1, 2, 3, complex(np.nan,0)], 4),
+        ([0, 1, 2, complex(np.nan,0), 3], 3),
+        ([0, 1, 2, complex(0,np.nan), 3], 3),
+        ([complex(0,np.nan), 0, 1, 2, 3], 0),
+        ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
+        ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
+        ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
+        ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
+
+        ([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
+        ([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
+        ([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
+    ]
+
     def test_all(self):
         a = np.random.normal(0,1,(4,5,6,7,8))
         for i in xrange(a.ndim):
@@ -680,6 +701,12 @@
             axes.remove(i)
             assert all(amax == aargmax.choose(*a.transpose(i,*axes)))
 
+    def test_combinations(self):
+        for arr, pos in self.nan_arr:
+            assert_equal(np.argmax(arr), pos, err_msg="%r"%arr)
+            assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr)
+
+
 class TestMinMax(TestCase):
     def test_scalar(self):
         assert_raises(ValueError, np.amax, 1, 1)




More information about the Numpy-svn mailing list