[Numpy-svn] r5696 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 25 12:36:31 EDT 2008


Author: pierregm
Date: 2008-08-25 11:36:27 -0500 (Mon, 25 Aug 2008)
New Revision: 5696

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/mrecords.py
   trunk/numpy/ma/tests/test_core.py
Log:
core    : make sure that masked_equal works with a list as input
mrecords: make sure that the keys of self._optinfo are recognized as valid attributes.

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-08-25 13:35:27 UTC (rev 5695)
+++ trunk/numpy/ma/core.py	2008-08-25 16:36:27 UTC (rev 5696)
@@ -916,9 +916,8 @@
 
 def masked_not_equal(x, value, copy=True):
     "Shortcut to masked_where, with condition = (x != value)."
-    return masked_where((x != value), x, copy=copy)
+    return masked_where(not_equal(x, value), x, copy=copy)
 
-#
 def masked_equal(x, value, copy=True):
     """Shortcut to masked_where, with condition = (x == value).  For
     floating point, consider `masked_values(x, value)` instead.
@@ -929,7 +928,7 @@
     # c = umath.equal(d, value)
     # m = mask_or(c, getmask(x))
     # return array(d, mask=m, copy=copy)
-    return masked_where((x == value), x, copy=copy)
+    return masked_where(equal(x, value), x, copy=copy)
 
 def masked_inside(x, v1, v2, copy=True):
     """Shortcut to masked_where, where condition is True for x inside

Modified: trunk/numpy/ma/mrecords.py
===================================================================
--- trunk/numpy/ma/mrecords.py	2008-08-25 13:35:27 UTC (rev 5695)
+++ trunk/numpy/ma/mrecords.py	2008-08-25 16:36:27 UTC (rev 5696)
@@ -158,10 +158,11 @@
         _fieldmask = getattr(obj, '_fieldmask', None)
         if _fieldmask is None:
             objmask = getattr(obj, '_mask', nomask)
+            _dtype = ndarray.__getattribute__(self,'dtype')
             if objmask is nomask:
-                _mask = ma.make_mask_none(self.shape, dtype=self.dtype)
+                _mask = ma.make_mask_none(self.shape, dtype=_dtype)
             else:
-                mdescr = ma.make_mask_descr(self.dtype)
+                mdescr = ma.make_mask_descr(_dtype)
                 _mask = narray([tuple([m]*len(mdescr)) for m in objmask],
                                dtype=mdescr).view(recarray)
         else:
@@ -232,7 +233,7 @@
             self.__setmask__(val)
             return
         # Create a shortcut (so that we don't have to call getattr all the time)
-        _localdict = self.__dict__
+        _localdict = object.__getattribute__(self, '__dict__')
         # Check whether we're creating a new field
         newattr = attr not in _localdict
         try:
@@ -241,7 +242,8 @@
         except:
             # Not a generic attribute: exit if it's not a valid field
             fielddict = ndarray.__getattribute__(self,'dtype').fields or {}
-            if attr not in fielddict:
+            optinfo = ndarray.__getattribute__(self,'_optinfo') or {}
+            if not (attr in fielddict or attr in optinfo):
                 exctype, value = sys.exc_info()[:2]
                 raise exctype, value
         else:

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-08-25 13:35:27 UTC (rev 5695)
+++ trunk/numpy/ma/tests/test_core.py	2008-08-25 16:36:27 UTC (rev 5696)
@@ -1942,13 +1942,21 @@
         xm.set_fill_value(1.e+20)
         self.info = (xm, ym)
 
-    #
     def test_masked_where_bool(self):
         x = [1,2]
         y = masked_where(False,x)
         assert_equal(y,[1,2])
         assert_equal(y[1],2)
 
+    def test_masked_equal_wlist(self):
+        x = [1, 2, 3]
+        mx = masked_equal(x, 3)
+        assert_equal(mx, x)
+        assert_equal(mx._mask, [0,0,1])
+        mx = masked_not_equal(x, 3)
+        assert_equal(mx, x)
+        assert_equal(mx._mask, [1,1,0])
+
     def test_masked_where_condition(self):
         "Tests masking functions."
         x = array([1.,2.,3.,4.,5.])




More information about the Numpy-svn mailing list