[Numpy-svn] r6130 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Dec 1 21:40:25 EST 2008


Author: pierregm
Date: 2008-12-01 20:40:22 -0600 (Mon, 01 Dec 2008)
New Revision: 6130

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/extras.py
   trunk/numpy/ma/tests/test_extras.py
Log:
* Fixed MaskedArray for nested dtype w/ input mask
* Fixed masked_all for nested dtype
* Fixed masked_all_like for nested dtype

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-12-01 17:56:58 UTC (rev 6129)
+++ trunk/numpy/ma/core.py	2008-12-02 02:40:22 UTC (rev 6130)
@@ -1488,8 +1488,16 @@
                     _data._sharedmask = not copy
                 else:
                     if names_:
-                        for n in names_:
-                            _data._mask[n] |= mask[n]
+                        def _recursive_or(a, b):
+                            "do a|=b on each field of a, recursively"
+                            for name in a.dtype.names:
+                                (af, bf) = (a[name], b[name])
+                                if af.dtype.names:
+                                    _recursive_or(af, bf)
+                                else:
+                                    af |= bf
+                            return
+                        _recursive_or(_data._mask, mask)
                     else:
                         _data._mask = np.logical_or(mask, _data._mask)
                     _data._sharedmask = False

Modified: trunk/numpy/ma/extras.py
===================================================================
--- trunk/numpy/ma/extras.py	2008-12-01 17:56:58 UTC (rev 6129)
+++ trunk/numpy/ma/extras.py	2008-12-02 02:40:22 UTC (rev 6130)
@@ -32,8 +32,8 @@
 
 import core as ma
 from core import MaskedArray, MAError, add, array, asarray, concatenate, count,\
-    filled, getmask, getmaskarray, masked, masked_array, mask_or, nomask, ones,\
-    sort, zeros
+    filled, getmask, getmaskarray, make_mask_descr, masked, masked_array,\
+    mask_or, nomask, ones, sort, zeros
 #from core import *
 
 import numpy as np
@@ -77,7 +77,7 @@
 
     """
     a = masked_array(np.empty(shape, dtype),
-                     mask=np.ones(shape, bool))
+                     mask=np.ones(shape, make_mask_descr(dtype)))
     return a
 
 def masked_all_like(arr):
@@ -85,8 +85,8 @@
     the array `a`, where all the data are masked.
 
     """
-    a = masked_array(np.empty_like(arr),
-                     mask=np.ones(arr.shape, bool))
+    a = np.empty_like(arr).view(MaskedArray)
+    a._mask = np.ones(a.shape, dtype=make_mask_descr(a.dtype))
     return a
 
 

Modified: trunk/numpy/ma/tests/test_extras.py
===================================================================
--- trunk/numpy/ma/tests/test_extras.py	2008-12-01 17:56:58 UTC (rev 6129)
+++ trunk/numpy/ma/tests/test_extras.py	2008-12-02 02:40:22 UTC (rev 6130)
@@ -17,6 +17,62 @@
 from numpy.ma.core import *
 from numpy.ma.extras import *
 
+
+class TestGeneric(TestCase):
+    #
+    def test_masked_all(self):
+        "Tests masked_all"
+        # Standard dtype 
+        test = masked_all((2,), dtype=float)
+        control = array([1, 1], mask=[1, 1], dtype=float)
+        assert_equal(test, control)
+        # Flexible dtype
+        dt = np.dtype({'names': ['a', 'b'], 'formats': ['f', 'f']})
+        test = masked_all((2,), dtype=dt)
+        control = array([(0, 0), (0, 0)], mask=[(1, 1), (1, 1)], dtype=dt)
+        assert_equal(test, control)
+        test = masked_all((2,2), dtype=dt)
+        control = array([[(0, 0), (0, 0)], [(0, 0), (0, 0)]],
+                        mask=[[(1, 1), (1, 1)], [(1, 1), (1, 1)]],
+                        dtype=dt)
+        assert_equal(test, control)
+        # Nested dtype
+        dt = np.dtype([('a','f'), ('b', [('ba', 'f'), ('bb', 'f')])])
+        test = masked_all((2,), dtype=dt)
+        control = array([(1, (1, 1)), (1, (1, 1))],
+                         mask=[(1, (1, 1)), (1, (1, 1))], dtype=dt)
+        assert_equal(test, control)
+        test = masked_all((2,), dtype=dt)
+        control = array([(1, (1, 1)), (1, (1, 1))],
+                         mask=[(1, (1, 1)), (1, (1, 1))], dtype=dt)
+        assert_equal(test, control)
+        test = masked_all((1,1), dtype=dt)
+        control = array([[(1, (1, 1))]], mask=[[(1, (1, 1))]], dtype=dt)
+        assert_equal(test, control)
+
+
+    def test_masked_all_like(self):
+        "Tests masked_all"
+        # Standard dtype 
+        base = array([1, 2], dtype=float)
+        test = masked_all_like(base)
+        control = array([1, 1], mask=[1, 1], dtype=float)
+        assert_equal(test, control)
+        # Flexible dtype
+        dt = np.dtype({'names': ['a', 'b'], 'formats': ['f', 'f']})
+        base = array([(0, 0), (0, 0)], mask=[(1, 1), (1, 1)], dtype=dt)
+        test = masked_all_like(base)
+        control = array([(10, 10), (10, 10)], mask=[(1, 1), (1, 1)], dtype=dt)
+        assert_equal(test, control)
+        # Nested dtype
+        dt = np.dtype([('a','f'), ('b', [('ba', 'f'), ('bb', 'f')])])
+        control = array([(1, (1, 1)), (1, (1, 1))],
+                        mask=[(1, (1, 1)), (1, (1, 1))], dtype=dt)
+        test = masked_all_like(control)
+        assert_equal(test, control)
+        #
+
+
 class TestAverage(TestCase):
     "Several tests of average. Why so many ? Good point..."
     def test_testAverage1(self):




More information about the Numpy-svn mailing list