[Numpy-svn] r5630 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Tue Aug 12 02:56:13 EDT 2008


Author: stefan
Date: 2008-08-12 01:56:11 -0500 (Tue, 12 Aug 2008)
New Revision: 5630

Modified:
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/tests/test_function_base.py
Log:
More consistent nan-operations.


Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2008-08-10 22:21:49 UTC (rev 5629)
+++ trunk/numpy/lib/function_base.py	2008-08-12 06:56:11 UTC (rev 5630)
@@ -1356,6 +1356,38 @@
     """
     return _insert(arr, mask, vals)
 
+def _nanop(op, fill, a, axis=None):
+    """
+    General operation on arrays with not-a-number values.
+
+    Parameters
+    ----------
+    op : callable
+        Operation to perform.
+    fill : float
+        NaN values are set to fill before doing the operation.
+    a : array-like
+        Input array.
+    axis : {int, None}, optional
+        Axis along which the operation is computed.
+        By default the input is flattened.
+
+    Returns
+    -------
+    y : {ndarray, scalar}
+        Processed data.
+
+    """
+    y = array(a,subok=True)
+    mask = isnan(a)
+    if mask.all():
+        return np.nan
+
+    if not issubclass(y.dtype.type, np.integer):
+        y[mask] = fill
+
+    return op(y, axis=axis)
+
 def nansum(a, axis=None):
     """
     Sum the array along the given axis, treating NaNs as zero.
@@ -1381,10 +1413,7 @@
     array([ 2.,  1.])
 
     """
-    y = array(a,subok=True)
-    if not issubclass(y.dtype.type, _nx.integer):
-        y[isnan(a)] = 0
-    return y.sum(axis)
+    return _nanop(np.sum, 0, a, axis)
 
 def nanmin(a, axis=None):
     """
@@ -1413,10 +1442,7 @@
     array([ 1.,  3.])
 
     """
-    y = array(a,subok=True)
-    if not issubclass(y.dtype.type, _nx.integer):
-        y[isnan(a)] = _nx.inf
-    return y.min(axis)
+    return _nanop(np.min, np.inf, a, axis)
 
 def nanargmin(a, axis=None):
     """
@@ -1426,10 +1452,7 @@
     Refer to `numpy.nanargmax` for detailed documentation.
 
     """
-    y = array(a, subok=True)
-    if not issubclass(y.dtype.type, _nx.integer):
-        y[isnan(a)] = _nx.inf
-    return y.argmin(axis)
+    return _nanop(np.argmin, np.inf, a, axis)
 
 def nanmax(a, axis=None):
     """
@@ -1458,10 +1481,7 @@
     array([ 2.,  3.])
 
     """
-    y = array(a, subok=True)
-    if not issubclass(y.dtype.type, _nx.integer):
-        y[isnan(a)] = -_nx.inf
-    return y.max(axis)
+    return _nanop(np.max, -np.inf, a, axis)
 
 def nanargmax(a, axis=None):
     """
@@ -1497,10 +1517,7 @@
     array([1, 0])
 
     """
-    y = array(a,subok=True)
-    if not issubclass(y.dtype.type, _nx.integer):
-        y[isnan(a)] = -_nx.inf
-    return y.argmax(axis)
+    return _nanop(np.argmax, -np.inf, a, axis)
 
 def disp(mesg, device=None, linefeed=True):
     """Display a message to the given device (default is sys.stdout)

Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py	2008-08-10 22:21:49 UTC (rev 5629)
+++ trunk/numpy/lib/tests/test_function_base.py	2008-08-12 06:56:11 UTC (rev 5630)
@@ -685,6 +685,7 @@
                             array([[ 0.01319214,  0.11704017,  0.1630199 ],
                                    [ 0.37910852,  0.87964135,  0.34306596],
                                    [ 0.72687499,  0.23913896,  0.33850425]]))
+        assert nanmin([nan, nan]) is nan
 
     def test_nanargmin(self):
         assert_almost_equal(nanargmin(self.A), 1)




More information about the Numpy-svn mailing list