[Numpy-svn] r6315 - in trunk/numpy/ma: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Fri Jan 9 15:18:17 EST 2009
Author: pierregm
Date: 2009-01-09 14:18:12 -0600 (Fri, 09 Jan 2009)
New Revision: 6315
Modified:
trunk/numpy/ma/core.py
trunk/numpy/ma/tests/test_core.py
Log:
* Added flatten_structured_arrays
* Fixed _get_recordarray for nested structures
Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py 2009-01-09 19:37:16 UTC (rev 6314)
+++ trunk/numpy/ma/core.py 2009-01-09 20:18:12 UTC (rev 6315)
@@ -1483,6 +1483,56 @@
return d
+def flatten_structured_array(a):
+ """
+ Flatten a strutured array.
+
+ The datatype of the output is the largest datatype of the (nested) fields.
+
+ Returns
+ -------
+ output : var
+ Flatten MaskedArray if the input is a MaskedArray,
+ standard ndarray otherwise.
+
+ Examples
+ --------
+ >>> ndtype = [('a', int), ('b', float)]
+ >>> a = np.array([(1, 1), (2, 2)], dtype=ndtype)
+ >>> flatten_structured_array(a)
+ array([[1., 1.],
+ [2., 2.]])
+
+ """
+ #
+ def flatten_sequence(iterable):
+ """Flattens a compound of nested iterables."""
+ for elm in iter(iterable):
+ if hasattr(elm,'__iter__'):
+ for f in flatten_sequence(elm):
+ yield f
+ else:
+ yield elm
+ #
+ a = np.asanyarray(a)
+ inishape = a.shape
+ a = a.ravel()
+ if isinstance(a, MaskedArray):
+ out = np.array([tuple(flatten_sequence(d.item())) for d in a._data])
+ out = out.view(MaskedArray)
+ out._mask = np.array([tuple(flatten_sequence(d.item()))
+ for d in getmaskarray(a)])
+ else:
+ out = np.array([tuple(flatten_sequence(d.item())) for d in a])
+ if len(inishape) > 1:
+ newshape = list(out.shape)
+ newshape[0] = inishape
+ out.shape = tuple(flatten_sequence(newshape))
+ return out
+
+
+
+
class MaskedArray(ndarray):
"""
Arrays with possibly masked values. Masked values of True
@@ -2021,34 +2071,28 @@
# return self._mask.reshape(self.shape)
return self._mask
mask = property(fget=_get_mask, fset=__setmask__, doc="Mask")
- #
- def _getrecordmask(self):
- """Return the mask of the records.
+
+
+ def _get_recordmask(self):
+ """
+ Return the mask of the records.
A record is masked when all the fields are masked.
"""
_mask = ndarray.__getattribute__(self, '_mask').view(ndarray)
if _mask.dtype.names is None:
return _mask
- if _mask.size > 1:
- axis = 1
- else:
- axis = None
- #
- try:
- return _mask.view((bool_, len(self.dtype))).all(axis)
- except ValueError:
- # In case we have nested fields...
- return np.all([[f[n].all() for n in _mask.dtype.names]
- for f in _mask], axis=axis)
+ return np.all(flatten_structured_array(_mask), axis=-1)
- def _setrecordmask(self):
+
+ def _set_recordmask(self):
"""Return the mask of the records.
A record is masked when all the fields are masked.
"""
raise NotImplementedError("Coming soon: setting the mask per records!")
- recordmask = property(fget=_getrecordmask)
+ recordmask = property(fget=_get_recordmask)
+
#............................................
def harden_mask(self):
"""Force the mask to hard.
Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py 2009-01-09 19:37:16 UTC (rev 6314)
+++ trunk/numpy/ma/tests/test_core.py 2009-01-09 20:18:12 UTC (rev 6315)
@@ -482,8 +482,12 @@
test = a.filled(0)
control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype)
assert_equal(test, control)
-
+ #
+ test = a['B'].filled(0)
+ control = np.array([(0, 1), (2, 0)], dtype=a['B'].dtype)
+ assert_equal(test, control)
+
def test_optinfo_propagation(self):
"Checks that _optinfo dictionary isn't back-propagated"
x = array([1,2,3,], dtype=float)
@@ -503,6 +507,45 @@
control = "[(--, (2, --)) (4, (--, 6.0))]"
assert_equal(str(test), control)
+
+ def test_flatten_structured_array(self):
+ "Test flatten_structured_array on arrays"
+ # On ndarray
+ ndtype = [('a', int), ('b', float)]
+ a = np.array([(1, 1), (2, 2)], dtype=ndtype)
+ test = flatten_structured_array(a)
+ control = np.array([[1., 1.], [2., 2.]], dtype=np.float)
+ assert_equal(test, control)
+ assert_equal(test.dtype, control.dtype)
+ # On masked_array
+ a = ma.array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
+ test = flatten_structured_array(a)
+ control = ma.array([[1., 1.], [2., 2.]],
+ mask=[[0, 1], [1, 0]], dtype=np.float)
+ assert_equal(test, control)
+ assert_equal(test.dtype, control.dtype)
+ assert_equal(test.mask, control.mask)
+ # On masked array with nested structure
+ ndtype = [('a', int), ('b', [('ba', int), ('bb', float)])]
+ a = ma.array([(1, (1, 1.1)), (2, (2, 2.2))],
+ mask=[(0, (1, 0)), (1, (0, 1))], dtype=ndtype)
+ test = flatten_structured_array(a)
+ control = ma.array([[1., 1., 1.1], [2., 2., 2.2]],
+ mask=[[0, 1, 0], [1, 0, 1]], dtype=np.float)
+ assert_equal(test, control)
+ assert_equal(test.dtype, control.dtype)
+ assert_equal(test.mask, control.mask)
+ # Keeping the initial shape
+ ndtype = [('a', int), ('b', float)]
+ a = np.array([[(1, 1),], [(2, 2),]], dtype=ndtype)
+ test = flatten_structured_array(a)
+ control = np.array([[[1., 1.],], [[2., 2.],]], dtype=np.float)
+ assert_equal(test, control)
+ assert_equal(test.dtype, control.dtype)
+
+
+
+
#------------------------------------------------------------------------------
class TestMaskedArrayArithmetic(TestCase):
More information about the Numpy-svn
mailing list