[Numpy-svn] r5835 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Sep 19 15:43:09 EDT 2008


Author: pierregm
Date: 2008-09-19 14:43:05 -0500 (Fri, 19 Sep 2008)
New Revision: 5835

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/mrecords.py
   trunk/numpy/ma/tests/test_core.py
   trunk/numpy/ma/tests/test_mrecords.py
Log:
core:
* add dtype to the repr of masked arrays w/ flexible type
* prevent __getitem__ to return masked on flexible-type masked array
* make sure __str__ returns something sensible for flexible dtype w/ masked fields
* simplify the count method

mrecords:
* fixed a pb with fromrecords when the number of fields cannot be determined from the first element.

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-09-19 03:33:40 UTC (rev 5834)
+++ trunk/numpy/ma/core.py	2008-09-19 19:43:05 UTC (rev 5835)
@@ -1479,15 +1479,22 @@
 #        if getmask(indx) is not nomask:
 #            msg = "Masked arrays must be filled before they can be used as indices!"
 #            raise IndexError, msg
-        dout = ndarray.__getitem__(self.view(ndarray), indx)
+        dout = ndarray.__getitem__(ndarray.view(self,ndarray), indx)
         # We could directly use ndarray.__getitem__ on self...
         # But then we would have to modify __array_finalize__ to prevent the
         # mask of being reshaped if it hasn't been set up properly yet...
         # So it's easier to stick to the current version
         _mask = self._mask
         if not getattr(dout,'ndim', False):
+            # A record ................
+            if isinstance(dout, np.void):
+                mask = _mask[indx]
+                if mask.view((bool,len(mask.dtype))).any():
+                    dout = masked_array(dout, mask=mask)
+                else:
+                    return dout
             # Just a scalar............
-            if _mask is not nomask and _mask[indx]:
+            elif _mask is not nomask and _mask[indx]:
                 return masked
         else:
             # Force dout to MA ........
@@ -1896,7 +1903,15 @@
                 res = self._data
             else:
                 if m.shape == ():
-                    if m:
+                    if m.dtype.names:
+                        m = m.view((bool, len(m.dtype)))
+                        if m.any():
+                            r = np.array(self._data.tolist(), dtype=object)
+                            np.putmask(r, m, f)
+                            return str(tuple(r))
+                        else:
+                            return str(self._data)
+                    elif m:
                         return str(f)
                     else:
                         return str(self._data)
@@ -1933,21 +1948,31 @@
       mask = %(mask)s,
       fill_value=%(fill)s)
 """
+        with_mask_flx = """\
+masked_%(name)s(data =
+ %(data)s,
+      mask =
+ %(mask)s,
+      fill_value=%(fill)s,
+      dtype=%(dtype)s)
+"""
+        with_mask1_flx = """\
+masked_%(name)s(data = %(data)s,
+      mask = %(mask)s,
+      fill_value=%(fill)s
+      dtype=%(dtype)s)
+"""
         n = len(self.shape)
         name = repr(self._data).split('(')[0]
-        if n <= 1:
-            return with_mask1 % {
-                'name': name,
-                'data': str(self),
-                'mask': str(self._mask),
-                'fill': str(self.fill_value),
-                }
-        return with_mask % {
-            'name': name,
-            'data': str(self),
-            'mask': str(self._mask),
-            'fill': str(self.fill_value),
-            }
+        parameters =  dict(name=name, data=str(self), mask=str(self._mask),
+                           fill=str(self.fill_value), dtype=str(self.dtype))
+        if self.dtype.names:
+            if n<= 1:
+                return with_mask1_flx % parameters
+            return  with_mask_flx % parameters
+        elif n <= 1:
+            return with_mask1 % parameters
+        return with_mask % parameters
     #............................................
     def __add__(self, other):
         "Add other to self, and return a new masked array."
@@ -3509,6 +3534,8 @@
     return data
 
 def count(a, axis = None):
+    if isinstance(a, MaskedArray):
+        return a.count(axis)
     return masked_array(a, copy=False).count(axis)
 count.__doc__ = MaskedArray.count.__doc__
 

Modified: trunk/numpy/ma/mrecords.py
===================================================================
--- trunk/numpy/ma/mrecords.py	2008-09-19 03:33:40 UTC (rev 5834)
+++ trunk/numpy/ma/mrecords.py	2008-09-19 19:43:05 UTC (rev 5835)
@@ -151,10 +151,9 @@
         return self
     #......................................................
     def __array_finalize__(self,obj):
-        MaskedArray._update_from(self,obj)
         # Make sure we have a _fieldmask by default ..
-        _fieldmask = getattr(obj, '_fieldmask', None)
-        if _fieldmask is None:
+        _mask = getattr(obj, '_mask', None)
+        if _mask is None:
             objmask = getattr(obj, '_mask', nomask)
             _dtype = ndarray.__getattribute__(self,'dtype')
             if objmask is nomask:
@@ -163,15 +162,15 @@
                 mdescr = ma.make_mask_descr(_dtype)
                 _mask = narray([tuple([m]*len(mdescr)) for m in objmask],
                                dtype=mdescr).view(recarray)
-        else:
-            _mask = _fieldmask
         # Update some of the attributes
-        _locdict = self.__dict__
-        if _locdict['_baseclass'] == ndarray:
-            _locdict['_baseclass'] = recarray
-        _locdict.update(_mask=_mask, _fieldmask=_mask)
+        _dict = self.__dict__
+        _dict.update(_mask=_mask, _fieldmask=_mask)
+        self._update_from(obj)
+        if _dict['_baseclass'] == ndarray:
+            _dict['_baseclass'] = recarray
         return
 
+
     def _getdata(self):
         "Returns the data as a recarray."
         return ndarray.view(self,recarray)
@@ -248,7 +247,6 @@
             # Get the list of names ......
             fielddict = ndarray.__getattribute__(self,'dtype').fields or {}
             # Check the attribute
-#####            _localdict = self.__dict__
             if attr not in fielddict:
                 return ret
             if newattr:         # We just added this one
@@ -282,8 +280,8 @@
         """Returns all the fields sharing the same fieldname base.
 The fieldname base is either `_data` or `_mask`."""
         _localdict = self.__dict__
-        _mask = _localdict['_fieldmask']
-        _data = self._data
+        _mask = ndarray.__getattribute__(self,'_mask')
+        _data = ndarray.view(self, _localdict['_baseclass'])
         # We want a field ........
         if isinstance(indx, basestring):
             #!!!: Make sure _sharedmask is True to propagate back to _fieldmask
@@ -471,7 +469,7 @@
                            dtype=dtype, shape=shape, formats=formats,
                            names=names, titles=titles, aligned=aligned,
                            byteorder=byteorder).view(mrecarray)
-    _array._fieldmask.flat = zip(*masklist)
+    _array._mask.flat = zip(*masklist)
     if fill_value is not None:
         _array.fill_value = fill_value
     return _array
@@ -505,13 +503,17 @@
     mask : {nomask, sequence}, optional.
         External mask to apply on the data.
 
-*Notes*:
+    Notes
+    -----
     Lists of tuples should be preferred over lists of lists for faster processing.
     """
     # Grab the initial _fieldmask, if needed:
     _fieldmask = getattr(reclist, '_fieldmask', None)
     # Get the list of records.....
-    nfields = len(reclist[0])
+    try:
+        nfields = len(reclist[0])
+    except TypeError:
+        nfields = len(reclist[0].dtype)
     if isinstance(reclist, ndarray):
         # Make sure we don't have some hidden mask
         if isinstance(reclist,MaskedArray):
@@ -654,7 +656,7 @@
 set to 'fi', where `i` is the number of existing fields.
     """
     _data = mrecord._data
-    _mask = mrecord._fieldmask
+    _mask = mrecord._mask
     if newfieldname is None or newfieldname in reserved_fields:
         newfieldname = 'f%i' % len(_data.dtype)
     newfield = ma.array(newfield)

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-09-19 03:33:40 UTC (rev 5834)
+++ trunk/numpy/ma/tests/test_core.py	2008-09-19 19:43:05 UTC (rev 5835)
@@ -2372,8 +2372,24 @@
         test = a.view((float,2), np.matrix)
         assert_equal(test, data)
         assert(isinstance(test, np.matrix))
+    #
+    def test_getitem(self):
+        ndtype = [('a',float), ('b',float)]
+        a = array(zip(np.random.rand(10),np.arange(10)), dtype=ndtype)
+        a.mask = np.array(zip([0,0,0,0,0,0,0,0,1,1],
+                              [1,0,0,0,0,0,0,0,1,0]),
+                          dtype=[('a',bool),('b',bool)])
+        # No mask
+        assert(isinstance(a[1], np.void))
+        # One element masked
+        assert(isinstance(a[0], MaskedArray))
+        assert_equal_records(a[0]._data, a._data[0])
+        assert_equal_records(a[0]._mask, a._mask[0])
+        # All element masked
+        assert(isinstance(a[-2], MaskedArray))
+        assert_equal_records(a[-2]._data, a._data[-2])
+        assert_equal_records(a[-2]._mask, a._mask[-2])
 
-
 ###############################################################################
 #------------------------------------------------------------------------------
 if __name__ == "__main__":

Modified: trunk/numpy/ma/tests/test_mrecords.py
===================================================================
--- trunk/numpy/ma/tests/test_mrecords.py	2008-09-19 03:33:40 UTC (rev 5834)
+++ trunk/numpy/ma/tests/test_mrecords.py	2008-09-19 19:43:05 UTC (rev 5835)
@@ -203,8 +203,8 @@
     #
     def test_set_elements(self):
         base = self.base.copy()
-        mbase = base.view(mrecarray)
         # Set an element to mask .....................
+        mbase = base.view(mrecarray).copy()
         mbase[-2] = masked
         assert_equal(mbase._fieldmask.tolist(),
                      np.array([(0,0,0),(1,1,1),(0,0,0),(1,1,1),(1,1,1)],




More information about the Numpy-svn mailing list