[Numpy-svn] r4749 - in branches/maskedarray/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Wed Jan 23 20:43:08 EST 2008


Author: pierregm
Date: 2008-01-23 19:43:05 -0600 (Wed, 23 Jan 2008)
New Revision: 4749

Modified:
   branches/maskedarray/numpy/ma/core.py
   branches/maskedarray/numpy/ma/tests/test_core.py
Log:
ma.core : add the compress method/function

Modified: branches/maskedarray/numpy/ma/core.py
===================================================================
--- branches/maskedarray/numpy/ma/core.py	2008-01-23 22:13:10 UTC (rev 4748)
+++ branches/maskedarray/numpy/ma/core.py	2008-01-24 01:43:05 UTC (rev 4749)
@@ -31,7 +31,8 @@
            'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps',
            'empty', 'empty_like', 'equal', 'exp',
            'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
-           'getmask', 'getmaskarray', 'greater', 'greater_equal', 'hypot',
+           'getdata','getmask', 'getmaskarray', 'greater', 'greater_equal', 
+           'hypot',
            'ids', 'inner', 'innerproduct',
            'isMA', 'isMaskedArray', 'is_mask', 'is_masked', 'isarray',
            'left_shift', 'less', 'less_equal', 'load', 'loads', 'log', 'log10',
@@ -3099,11 +3100,35 @@
     m = make_mask(mask_or(m, getmask(indices)), copy=0, shrink=True)
     return masked_array(d, mask=m)
 
-def compress(a, condition):
+def compress(a, condition, axis=None, out=None):
     """Return a where condition is True.
+    If condition is a MaskedArray, missing values are considered as False.
+    
+    Returns
+    -------
+    A MaskedArray object.
+    
+    Notes
+    -----
+    Please note the difference with compressed() ! 
+    The output of compress has a mask, the output of compressed does not.
 
     """
-    return a[condition]
+    # Get the basic components
+    (_data, _mask) = (getdata(a), getmask(a))
+    # Get the type of output
+    if isinstance(a, MaskedArray):
+        _view = type(a)
+    else:
+        _view = MaskedArray
+    # Make sure the condition has no missing values
+    condition = filled(condition, False)
+    #
+    _new = ndarray.compress(_data, condition, axis=axis, out=out).view(_view)
+    _new._update_from(a)
+    if _mask is not nomask:
+        _new._mask = _mask.compress(condition, axis=axis)
+    return _new
 
 def round_(a, decimals=0, out=None):
     """Return a copy of a, rounded to 'decimals' places.

Modified: branches/maskedarray/numpy/ma/tests/test_core.py
===================================================================
--- branches/maskedarray/numpy/ma/tests/test_core.py	2008-01-23 22:13:10 UTC (rev 4748)
+++ branches/maskedarray/numpy/ma/tests/test_core.py	2008-01-24 01:43:05 UTC (rev 4749)
@@ -1389,12 +1389,29 @@
         assert_equal(mxx, [1,2,30,4,5,60])
 
     def test_compress(self):
-        a = array([1,2,3],mask=[True,False,False])
-        b = compress(a,a<3)
-        assert_equal(b,[1,2])
-        assert_equal(b.mask,[True,False])
+        "test compress"
+        a = masked_array([10, 20, 30, 40], fill_value=9999)
+        condition = (a > 15) & (a < 35)
+        assert_equal(a.compress(condition),[20,30])
+        #
+        a[1] = masked
+        b = a.compress(condition)
+        assert_equal(b._data,[20,30])
+        assert_equal(b._mask,[1,0])
+        assert_equal(b.fill_value,9999)
+        #
+        a = masked_array([[10,20,30],[40,50,60]], mask=[[0,0,1],[1,0,0]])
+        b = a.compress(a.ravel() >= 22)
+        assert_equal(b._data, [50, 60])
+        assert_equal(b._mask, [0,0])
+        #
+        x = numpy.array([3,1,2])
+        b = a.compress(x >= 2, axis=1)    
+        assert_equal(b._data, [[10,30],[40,60]])
+        assert_equal(b._mask, [[0,1],[1,0]])
 
 
+
 #..............................................................................
 
 ###############################################################################




More information about the Numpy-svn mailing list