[Numpy-svn] r4749 - in branches/maskedarray/numpy/ma: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Wed Jan 23 20:43:08 EST 2008
Author: pierregm
Date: 2008-01-23 19:43:05 -0600 (Wed, 23 Jan 2008)
New Revision: 4749
Modified:
branches/maskedarray/numpy/ma/core.py
branches/maskedarray/numpy/ma/tests/test_core.py
Log:
ma.core : add the compress method/function
Modified: branches/maskedarray/numpy/ma/core.py
===================================================================
--- branches/maskedarray/numpy/ma/core.py 2008-01-23 22:13:10 UTC (rev 4748)
+++ branches/maskedarray/numpy/ma/core.py 2008-01-24 01:43:05 UTC (rev 4749)
@@ -31,7 +31,8 @@
'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps',
'empty', 'empty_like', 'equal', 'exp',
'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
- 'getmask', 'getmaskarray', 'greater', 'greater_equal', 'hypot',
+ 'getdata','getmask', 'getmaskarray', 'greater', 'greater_equal',
+ 'hypot',
'ids', 'inner', 'innerproduct',
'isMA', 'isMaskedArray', 'is_mask', 'is_masked', 'isarray',
'left_shift', 'less', 'less_equal', 'load', 'loads', 'log', 'log10',
@@ -3099,11 +3100,35 @@
m = make_mask(mask_or(m, getmask(indices)), copy=0, shrink=True)
return masked_array(d, mask=m)
-def compress(a, condition):
+def compress(a, condition, axis=None, out=None):
"""Return a where condition is True.
+ If condition is a MaskedArray, missing values are considered as False.
+
+ Returns
+ -------
+ A MaskedArray object.
+
+ Notes
+ -----
+ Please note the difference with compressed() !
+ The output of compress has a mask, the output of compressed does not.
"""
- return a[condition]
+ # Get the basic components
+ (_data, _mask) = (getdata(a), getmask(a))
+ # Get the type of output
+ if isinstance(a, MaskedArray):
+ _view = type(a)
+ else:
+ _view = MaskedArray
+ # Make sure the condition has no missing values
+ condition = filled(condition, False)
+ #
+ _new = ndarray.compress(_data, condition, axis=axis, out=out).view(_view)
+ _new._update_from(a)
+ if _mask is not nomask:
+ _new._mask = _mask.compress(condition, axis=axis)
+ return _new
def round_(a, decimals=0, out=None):
"""Return a copy of a, rounded to 'decimals' places.
Modified: branches/maskedarray/numpy/ma/tests/test_core.py
===================================================================
--- branches/maskedarray/numpy/ma/tests/test_core.py 2008-01-23 22:13:10 UTC (rev 4748)
+++ branches/maskedarray/numpy/ma/tests/test_core.py 2008-01-24 01:43:05 UTC (rev 4749)
@@ -1389,12 +1389,29 @@
assert_equal(mxx, [1,2,30,4,5,60])
def test_compress(self):
- a = array([1,2,3],mask=[True,False,False])
- b = compress(a,a<3)
- assert_equal(b,[1,2])
- assert_equal(b.mask,[True,False])
+ "test compress"
+ a = masked_array([10, 20, 30, 40], fill_value=9999)
+ condition = (a > 15) & (a < 35)
+ assert_equal(a.compress(condition),[20,30])
+ #
+ a[1] = masked
+ b = a.compress(condition)
+ assert_equal(b._data,[20,30])
+ assert_equal(b._mask,[1,0])
+ assert_equal(b.fill_value,9999)
+ #
+ a = masked_array([[10,20,30],[40,50,60]], mask=[[0,0,1],[1,0,0]])
+ b = a.compress(a.ravel() >= 22)
+ assert_equal(b._data, [50, 60])
+ assert_equal(b._mask, [0,0])
+ #
+ x = numpy.array([3,1,2])
+ b = a.compress(x >= 2, axis=1)
+ assert_equal(b._data, [[10,30],[40,60]])
+ assert_equal(b._mask, [[0,1],[1,0]])
+
#..............................................................................
###############################################################################
More information about the Numpy-svn
mailing list