[Numpy-svn] r8509 - in trunk/numpy/core: src/multiarray tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Sun Jul 18 16:55:56 EDT 2010
Author: ptvirtan
Date: 2010-07-18 15:55:56 -0500 (Sun, 18 Jul 2010)
New Revision: 8509
Modified:
trunk/numpy/core/src/multiarray/arraytypes.c.src
trunk/numpy/core/tests/test_multiarray.py
Log:
BUG: core: fix argmax and argmin NaN handling to conform with max/min (#1429)
This makes `argmax` and `argmix` treat NaN as a maximal element.
Effectively, this causes propagation of NaNs, which is consistent with
the current behavior of amax & amin.
Modified: trunk/numpy/core/src/multiarray/arraytypes.c.src
===================================================================
--- trunk/numpy/core/src/multiarray/arraytypes.c.src 2010-07-18 20:55:33 UTC (rev 8508)
+++ trunk/numpy/core/src/multiarray/arraytypes.c.src 2010-07-18 20:55:56 UTC (rev 8509)
@@ -2735,6 +2735,8 @@
* #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong,
* longlong, ulonglong, float, double, longdouble,
* float, double, longdouble, datetime, timedelta#
+ * #isfloat = 0*11, 1*6, 0*2#
+ * #iscomplex = 0*14, 1*3, 0*2#
* #incr = ip++*14, ip+=2*3, ip++*2#
*/
static int
@@ -2742,14 +2744,54 @@
{
intp i;
@type@ mp = *ip;
+#if @iscomplex@
+ @type@ mp_im = ip[1];
+#endif
*max_ind = 0;
+
+#if @isfloat@
+ if (npy_isnan(mp)) {
+ /* nan encountered; it's maximal */
+ return 0;
+ }
+#endif
+#if @iscomplex@
+ if (npy_isnan(mp_im)) {
+ /* nan encountered; it's maximal */
+ return 0;
+ }
+#endif
+
for (i = 1; i < n; i++) {
@incr@;
- if (*ip > mp) {
+ /*
+ * Propagate nans, similarly as max() and min()
+ */
+#if @iscomplex@
+ /* Lexical order for complex numbers */
+ if ((ip[0] > mp) || ((ip[0] == mp) && (ip[1] > mp_im))
+ || npy_isnan(ip[0]) || npy_isnan(ip[1])) {
+ mp = ip[0];
+ mp_im = ip[1];
+ *max_ind = i;
+ if (npy_isnan(mp) || npy_isnan(mp_im)) {
+ /* nan encountered, it's maximal */
+ break;
+ }
+ }
+#else
+ if (!(*ip <= mp)) { /* negated, for correct nan handling */
mp = *ip;
*max_ind = i;
+#if @isfloat@
+ if (npy_isnan(mp)) {
+ /* nan encountered, it's maximal */
+ break;
+ }
+#endif
}
+#endif
}
return 0;
}
Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py 2010-07-18 20:55:33 UTC (rev 8508)
+++ trunk/numpy/core/tests/test_multiarray.py 2010-07-18 20:55:56 UTC (rev 8509)
@@ -671,6 +671,27 @@
class TestArgmax(TestCase):
+
+ nan_arr = [
+ ([0, 1, 2, 3, np.nan], 4),
+ ([0, 1, 2, np.nan, 3], 3),
+ ([np.nan, 0, 1, 2, 3], 0),
+ ([np.nan, 0, np.nan, 2, 3], 0),
+ ([0, 1, 2, 3, complex(0,np.nan)], 4),
+ ([0, 1, 2, 3, complex(np.nan,0)], 4),
+ ([0, 1, 2, complex(np.nan,0), 3], 3),
+ ([0, 1, 2, complex(0,np.nan), 3], 3),
+ ([complex(0,np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
+
+ ([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
+ ([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
+ ([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
+ ]
+
def test_all(self):
a = np.random.normal(0,1,(4,5,6,7,8))
for i in xrange(a.ndim):
@@ -680,6 +701,12 @@
axes.remove(i)
assert all(amax == aargmax.choose(*a.transpose(i,*axes)))
+ def test_combinations(self):
+ for arr, pos in self.nan_arr:
+ assert_equal(np.argmax(arr), pos, err_msg="%r"%arr)
+ assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr)
+
+
class TestMinMax(TestCase):
def test_scalar(self):
assert_raises(ValueError, np.amax, 1, 1)
More information about the Numpy-svn
mailing list