[Numpy-svn] r6194 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Tue Dec 23 18:43:48 EST 2008


Author: pierregm
Date: 2008-12-23 17:43:43 -0600 (Tue, 23 Dec 2008)
New Revision: 6194

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
   trunk/numpy/ma/tests/test_mrecords.py
   trunk/numpy/ma/testutils.py
Log:
testutils:
* assert_equal : use assert_equal_array on records
* assert_array_compare : prevent the common mask to be back-propagated to the initial input arrays.
* assert_equal_array : use operator.__eq__ instead of ma.equal
* assert_equal_less: use operator.__less__ instead of ma.less

core:
* Fixed _check_fill_value for nested flexible types
* Add a ndtype option to _make_mask_descr
* Fixed mask_or for nested flexible types
* Fixed the printing of masked arrays w/ flexible types.


Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/core.py	2008-12-23 23:43:43 UTC (rev 6194)
@@ -217,6 +217,28 @@
         raise TypeError(errmsg)
 
 
+def _recursive_set_default_fill_value(dtypedescr):
+    deflist = []
+    for currentdescr in dtypedescr:
+        currenttype = currentdescr[1]
+        if isinstance(currenttype, list):
+            deflist.append(tuple(_recursive_set_default_fill_value(currenttype)))
+        else:
+            deflist.append(default_fill_value(np.dtype(currenttype)))
+    return tuple(deflist)
+
+def _recursive_set_fill_value(fillvalue, dtypedescr):
+    fillvalue = np.resize(fillvalue, len(dtypedescr))
+    output_value = []
+    for (fval, descr) in zip(fillvalue, dtypedescr):
+        cdtype = descr[1]
+        if isinstance(cdtype, list):
+            output_value.append(tuple(_recursive_set_fill_value(fval, cdtype)))
+        else:
+            output_value.append(np.array(fval, dtype=cdtype).item())
+    return tuple(output_value)
+
+
 def _check_fill_value(fill_value, ndtype):
     """
     Private function validating the given `fill_value` for the given dtype.
@@ -233,10 +255,9 @@
     fields = ndtype.fields
     if fill_value is None:
         if fields:
-            fdtype = [(_[0], _[1]) for _ in ndtype.descr]
-            fill_value = np.array(tuple([default_fill_value(fields[n][0])
-                                         for n in ndtype.names]),
-                                  dtype=fdtype)
+            descr = ndtype.descr
+            fill_value = np.array(_recursive_set_default_fill_value(descr),
+                                  dtype=ndtype)
         else:
             fill_value = default_fill_value(ndtype)
     elif fields:
@@ -248,10 +269,9 @@
                 err_msg = "Unable to transform %s to dtype %s"
                 raise ValueError(err_msg % (fill_value, fdtype))
         else:
-            fval = np.resize(fill_value, len(ndtype.descr))
-            fill_value = [np.asarray(f).astype(desc[1]).item()
-                          for (f, desc) in zip(fval, ndtype.descr)]
-            fill_value = np.array(tuple(fill_value), copy=False, dtype=fdtype)
+            descr = ndtype.descr
+            fill_value = np.array(_recursive_set_fill_value(fill_value, descr),
+                                  dtype=ndtype)
     else:
         if isinstance(fill_value, basestring) and (ndtype.char not in 'SV'):
             fill_value = default_fill_value(ndtype)
@@ -831,35 +851,35 @@
 #####--------------------------------------------------------------------------
 #---- --- Mask creation functions ---
 #####--------------------------------------------------------------------------
+def _recursive_make_descr(datatype, newtype=bool_):
+    "Private function allowing recursion in make_descr."
+    # Do we have some name fields ?
+    if datatype.names:
+        descr = []
+        for name in datatype.names:
+            field = datatype.fields[name]
+            if len(field) == 3:
+                # Prepend the title to the name
+                name = (field[-1], name)
+            descr.append((name, _recursive_make_descr(field[0], newtype)))
+        return descr
+    # Is this some kind of composite a la (np.float,2)
+    elif datatype.subdtype:
+        mdescr = list(datatype.subdtype)
+        mdescr[0] = newtype
+        return tuple(mdescr)
+    else:
+        return newtype
 
 def make_mask_descr(ndtype):
     """Constructs a dtype description list from a given dtype.
     Each field is set to a bool.
 
     """
-    def _make_descr(datatype):
-        "Private function allowing recursion."
-        # Do we have some name fields ?
-        if datatype.names:
-            descr = []
-            for name in datatype.names:
-                field = datatype.fields[name]
-                if len(field) == 3:
-                    # Prepend the title to the name
-                    name = (field[-1], name)
-                descr.append((name, _make_descr(field[0])))
-            return descr
-        # Is this some kind of composite a la (np.float,2)
-        elif datatype.subdtype:
-            mdescr = list(datatype.subdtype)
-            mdescr[0] = np.dtype(bool)
-            return tuple(mdescr)
-        else:
-            return np.bool
     # Make sure we do have a dtype
     if not isinstance(ndtype, np.dtype):
         ndtype = np.dtype(ndtype)
-    return np.dtype(_make_descr(ndtype))
+    return np.dtype(_recursive_make_descr(ndtype, np.bool))
 
 def get_mask(a):
     """Return the mask of a, if any, or nomask.
@@ -988,7 +1008,17 @@
     ValueError
         If m1 and m2 have different flexible dtypes.
 
-     """
+    """
+    def _recursive_mask_or(m1, m2, newmask):
+        names = m1.dtype.names
+        for name in names:
+            current1 = m1[name]
+            if current1.dtype.names:
+                _recursive_mask_or(current1, m2[name], newmask[name])
+            else:
+                 umath.logical_or(current1, m2[name], newmask[name])
+        return
+    #
     if (m1 is nomask) or (m1 is False):
         dtype = getattr(m2, 'dtype', MaskType)
         return make_mask(m2, copy=copy, shrink=shrink, dtype=dtype)
@@ -1002,8 +1032,7 @@
         raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
     if dtype1.names:
         newmask = np.empty_like(m1)
-        for n in dtype1.names:
-            newmask[n] = umath.logical_or(m1[n], m2[n])
+        _recursive_mask_or(m1, m2, newmask)
         return newmask
     return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
 
@@ -1291,6 +1320,22 @@
 #if you single index into a masked location you get this object.
 masked_print_option = _MaskedPrintOption('--')
 
+
+def _recursive_printoption(result, mask, printopt):
+    """
+    Puts printoptions in result where mask is True.
+    Private function allowing for recursion
+    """
+    names = result.dtype.names
+    for name in names:
+        (curdata, curmask) = (result[name], mask[name])
+        if curdata.dtype.names:
+            _recursive_printoption(curdata, curmask, printopt)
+        else:
+            np.putmask(curdata, curmask, printopt)
+    return
+
+
 #####--------------------------------------------------------------------------
 #---- --- MaskedArray class ---
 #####--------------------------------------------------------------------------
@@ -2184,13 +2229,9 @@
                     res = self._data.astype("|O8")
                     res[m] = f
                 else:
-                    rdtype = [list(_) for _ in self.dtype.descr]
-                    for r in rdtype:
-                        r[1] = '|O8'
-                    rdtype = [tuple(_) for _ in rdtype]
+                    rdtype = _recursive_make_descr(self.dtype, "|O8")
                     res = self._data.astype(rdtype)
-                    for field in names:
-                        np.putmask(res[field], m[field], f)
+                    _recursive_printoption(res, m, f)
         else:
             res = self.filled(self.fill_value)
         return str(res)

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/tests/test_core.py	2008-12-23 23:43:43 UTC (rev 6194)
@@ -483,6 +483,16 @@
         y._optinfo['info'] = '!!!'
         assert_equal(x._optinfo['info'], '???')
 
+
+    def test_fancy_printoptions(self):
+        "Test printing a masked array w/ fancy dtype."
+        fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
+        test = array([(1, (2, 3.0)), (4, (5, 6.0))],
+                     mask=[(1, (0, 1)), (0, (1, 0))],
+                     dtype=fancydtype)
+        control = "[(--, (2, --)) (4, (--, 6.0))]"
+        assert_equal(str(test), control)
+
 #------------------------------------------------------------------------------
 
 class TestMaskedArrayArithmetic(TestCase):
@@ -1049,19 +1059,19 @@
         # The shape shouldn't matter
         ndtype = [('f0', float, (2, 2))]
         control = np.array((default_fill_value(0.),),
-                           dtype=[('f0',float)])
+                           dtype=[('f0',float)]).astype(ndtype)
         assert_equal(_check_fill_value(None, ndtype), control)
-        control = np.array((0,), dtype=[('f0',float)])
+        control = np.array((0,), dtype=[('f0',float)]).astype(ndtype)
         assert_equal(_check_fill_value(0, ndtype), control)
         #
         ndtype = np.dtype("int, (2,3)float, float")
         control = np.array((default_fill_value(0),
                             default_fill_value(0.),
                             default_fill_value(0.),),
-                           dtype="int, float, float")
+                           dtype="int, float, float").astype(ndtype)
         test = _check_fill_value(None, ndtype)
         assert_equal(test, control)
-        control = np.array((0,0,0), dtype="int, float, float")
+        control = np.array((0,0,0), dtype="int, float, float").astype(ndtype)
         assert_equal(_check_fill_value(0, ndtype), control)
 
 #------------------------------------------------------------------------------
@@ -1912,8 +1922,8 @@
                      dtype=ndtype)
         data[[0,1,2,-1]] = masked
         record = data.torecords()
-        assert_equal(record['_data'], data._data)
-        assert_equal(record['_mask'], data._mask)
+        assert_equal_records(record['_data'], data._data)
+        assert_equal_records(record['_mask'], data._mask)
 
 #------------------------------------------------------------------------------
 
@@ -2531,6 +2541,12 @@
             test = mask_or(mask, other)
         except ValueError:
             pass
+        # Using nested arrays
+        dtype = [('a', np.bool), ('b', [('ba', np.bool), ('bb', np.bool)])]
+        amask = np.array([(0, (1, 0)), (0, (1, 0))], dtype=dtype)
+        bmask = np.array([(1, (0, 1)), (0, (0, 0))], dtype=dtype)
+        cntrl = np.array([(1, (1, 1)), (0, (1, 0))], dtype=dtype)
+        assert_equal(mask_or(amask, bmask), cntrl)
 
 
     def test_flatten_mask(self):

Modified: trunk/numpy/ma/tests/test_mrecords.py
===================================================================
--- trunk/numpy/ma/tests/test_mrecords.py	2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/tests/test_mrecords.py	2008-12-23 23:43:43 UTC (rev 6194)
@@ -334,8 +334,8 @@
         mult[0] = masked
         mult[1] = (1, 1, 1)
         mult.filled(0)
-        assert_equal(mult.filled(0),
-                     np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
+        assert_equal_records(mult.filled(0),
+                             np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
 
 
 class TestView(TestCase):

Modified: trunk/numpy/ma/testutils.py
===================================================================
--- trunk/numpy/ma/testutils.py	2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/testutils.py	2008-12-23 23:43:43 UTC (rev 6194)
@@ -110,14 +110,14 @@
         return _assert_equal_on_sequences(actual.tolist(),
                                           desired.tolist(),
                                           err_msg='')
-    elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
-        if (actual_dtype != desired_dtype) and actual_dtype:
-            msg = build_err_msg([actual_dtype, desired_dtype],
-                                err_msg, header='', names=('actual', 'desired'))
-            raise ValueError(msg)
-        return _assert_equal_on_sequences(actual.tolist(),
-                                          desired.tolist(),
-                                          err_msg='')
+#    elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
+#        if (actual_dtype != desired_dtype) and actual_dtype:
+#            msg = build_err_msg([actual_dtype, desired_dtype],
+#                                err_msg, header='', names=('actual', 'desired'))
+#            raise ValueError(msg)
+#        return _assert_equal_on_sequences(actual.tolist(),
+#                                          desired.tolist(),
+#                                          err_msg='')
     return assert_array_equal(actual, desired, err_msg)
 
 
@@ -171,15 +171,14 @@
 #    yf = filled(y)
     # Allocate a common mask and refill
     m = mask_or(getmask(x), getmask(y))
-    x = masked_array(x, copy=False, mask=m, subok=False)
-    y = masked_array(y, copy=False, mask=m, subok=False)
+    x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
+    y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
     if ((x is masked) and not (y is masked)) or \
         ((y is masked) and not (x is masked)):
         msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
                             header=header, names=('x', 'y'))
         raise ValueError(msg)
     # OK, now run the basic tests on filled versions
-    comparison = getattr(np, comparison.__name__, lambda x,y: True)
     return utils.assert_array_compare(comparison,
                                       x.filled(fill_value),
                                       y.filled(fill_value),
@@ -189,7 +188,8 @@
 
 def assert_array_equal(x, y, err_msg='', verbose=True):
     """Checks the elementwise equality of two masked arrays."""
-    assert_array_compare(equal, x, y, err_msg=err_msg, verbose=verbose,
+    assert_array_compare(operator.__eq__, x, y,
+                         err_msg=err_msg, verbose=verbose,
                          header='Arrays are not equal')
 
 
@@ -223,7 +223,8 @@
 
 def assert_array_less(x, y, err_msg='', verbose=True):
     "Checks that x is smaller than y elementwise."
-    assert_array_compare(less, x, y, err_msg=err_msg, verbose=verbose,
+    assert_array_compare(operator.__lt__, x, y,
+                         err_msg=err_msg, verbose=verbose,
                          header='Arrays are not less-ordered')
 
 




More information about the Numpy-svn mailing list