[Numpy-svn] r5630 - in trunk/numpy/lib: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Tue Aug 12 02:56:13 EDT 2008
Author: stefan
Date: 2008-08-12 01:56:11 -0500 (Tue, 12 Aug 2008)
New Revision: 5630
Modified:
trunk/numpy/lib/function_base.py
trunk/numpy/lib/tests/test_function_base.py
Log:
More consistent nan-operations.
Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py 2008-08-10 22:21:49 UTC (rev 5629)
+++ trunk/numpy/lib/function_base.py 2008-08-12 06:56:11 UTC (rev 5630)
@@ -1356,6 +1356,38 @@
"""
return _insert(arr, mask, vals)
+def _nanop(op, fill, a, axis=None):
+ """
+ General operation on arrays with not-a-number values.
+
+ Parameters
+ ----------
+ op : callable
+ Operation to perform.
+ fill : float
+ NaN values are set to fill before doing the operation.
+ a : array-like
+ Input array.
+ axis : {int, None}, optional
+ Axis along which the operation is computed.
+ By default the input is flattened.
+
+ Returns
+ -------
+ y : {ndarray, scalar}
+ Processed data.
+
+ """
+ y = array(a,subok=True)
+ mask = isnan(a)
+ if mask.all():
+ return np.nan
+
+ if not issubclass(y.dtype.type, np.integer):
+ y[mask] = fill
+
+ return op(y, axis=axis)
+
def nansum(a, axis=None):
"""
Sum the array along the given axis, treating NaNs as zero.
@@ -1381,10 +1413,7 @@
array([ 2., 1.])
"""
- y = array(a,subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = 0
- return y.sum(axis)
+ return _nanop(np.sum, 0, a, axis)
def nanmin(a, axis=None):
"""
@@ -1413,10 +1442,7 @@
array([ 1., 3.])
"""
- y = array(a,subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = _nx.inf
- return y.min(axis)
+ return _nanop(np.min, np.inf, a, axis)
def nanargmin(a, axis=None):
"""
@@ -1426,10 +1452,7 @@
Refer to `numpy.nanargmax` for detailed documentation.
"""
- y = array(a, subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = _nx.inf
- return y.argmin(axis)
+ return _nanop(np.argmin, np.inf, a, axis)
def nanmax(a, axis=None):
"""
@@ -1458,10 +1481,7 @@
array([ 2., 3.])
"""
- y = array(a, subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = -_nx.inf
- return y.max(axis)
+ return _nanop(np.max, -np.inf, a, axis)
def nanargmax(a, axis=None):
"""
@@ -1497,10 +1517,7 @@
array([1, 0])
"""
- y = array(a,subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = -_nx.inf
- return y.argmax(axis)
+ return _nanop(np.argmax, -np.inf, a, axis)
def disp(mesg, device=None, linefeed=True):
"""Display a message to the given device (default is sys.stdout)
Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py 2008-08-10 22:21:49 UTC (rev 5629)
+++ trunk/numpy/lib/tests/test_function_base.py 2008-08-12 06:56:11 UTC (rev 5630)
@@ -685,6 +685,7 @@
array([[ 0.01319214, 0.11704017, 0.1630199 ],
[ 0.37910852, 0.87964135, 0.34306596],
[ 0.72687499, 0.23913896, 0.33850425]]))
+ assert nanmin([nan, nan]) is nan
def test_nanargmin(self):
assert_almost_equal(nanargmin(self.A), 1)
More information about the Numpy-svn
mailing list