[Numpy-svn] r5696 - in trunk/numpy/ma: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Aug 25 12:36:31 EDT 2008
Author: pierregm
Date: 2008-08-25 11:36:27 -0500 (Mon, 25 Aug 2008)
New Revision: 5696
Modified:
trunk/numpy/ma/core.py
trunk/numpy/ma/mrecords.py
trunk/numpy/ma/tests/test_core.py
Log:
core : make sure that masked_equal works with a list as input
mrecords: make sure that the keys of self._optinfo are recognized as valid attributes.
Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py 2008-08-25 13:35:27 UTC (rev 5695)
+++ trunk/numpy/ma/core.py 2008-08-25 16:36:27 UTC (rev 5696)
@@ -916,9 +916,8 @@
def masked_not_equal(x, value, copy=True):
"Shortcut to masked_where, with condition = (x != value)."
- return masked_where((x != value), x, copy=copy)
+ return masked_where(not_equal(x, value), x, copy=copy)
-#
def masked_equal(x, value, copy=True):
"""Shortcut to masked_where, with condition = (x == value). For
floating point, consider `masked_values(x, value)` instead.
@@ -929,7 +928,7 @@
# c = umath.equal(d, value)
# m = mask_or(c, getmask(x))
# return array(d, mask=m, copy=copy)
- return masked_where((x == value), x, copy=copy)
+ return masked_where(equal(x, value), x, copy=copy)
def masked_inside(x, v1, v2, copy=True):
"""Shortcut to masked_where, where condition is True for x inside
Modified: trunk/numpy/ma/mrecords.py
===================================================================
--- trunk/numpy/ma/mrecords.py 2008-08-25 13:35:27 UTC (rev 5695)
+++ trunk/numpy/ma/mrecords.py 2008-08-25 16:36:27 UTC (rev 5696)
@@ -158,10 +158,11 @@
_fieldmask = getattr(obj, '_fieldmask', None)
if _fieldmask is None:
objmask = getattr(obj, '_mask', nomask)
+ _dtype = ndarray.__getattribute__(self,'dtype')
if objmask is nomask:
- _mask = ma.make_mask_none(self.shape, dtype=self.dtype)
+ _mask = ma.make_mask_none(self.shape, dtype=_dtype)
else:
- mdescr = ma.make_mask_descr(self.dtype)
+ mdescr = ma.make_mask_descr(_dtype)
_mask = narray([tuple([m]*len(mdescr)) for m in objmask],
dtype=mdescr).view(recarray)
else:
@@ -232,7 +233,7 @@
self.__setmask__(val)
return
# Create a shortcut (so that we don't have to call getattr all the time)
- _localdict = self.__dict__
+ _localdict = object.__getattribute__(self, '__dict__')
# Check whether we're creating a new field
newattr = attr not in _localdict
try:
@@ -241,7 +242,8 @@
except:
# Not a generic attribute: exit if it's not a valid field
fielddict = ndarray.__getattribute__(self,'dtype').fields or {}
- if attr not in fielddict:
+ optinfo = ndarray.__getattribute__(self,'_optinfo') or {}
+ if not (attr in fielddict or attr in optinfo):
exctype, value = sys.exc_info()[:2]
raise exctype, value
else:
Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py 2008-08-25 13:35:27 UTC (rev 5695)
+++ trunk/numpy/ma/tests/test_core.py 2008-08-25 16:36:27 UTC (rev 5696)
@@ -1942,13 +1942,21 @@
xm.set_fill_value(1.e+20)
self.info = (xm, ym)
- #
def test_masked_where_bool(self):
x = [1,2]
y = masked_where(False,x)
assert_equal(y,[1,2])
assert_equal(y[1],2)
+ def test_masked_equal_wlist(self):
+ x = [1, 2, 3]
+ mx = masked_equal(x, 3)
+ assert_equal(mx, x)
+ assert_equal(mx._mask, [0,0,1])
+ mx = masked_not_equal(x, 3)
+ assert_equal(mx, x)
+ assert_equal(mx._mask, [1,1,0])
+
def test_masked_where_condition(self):
"Tests masking functions."
x = array([1.,2.,3.,4.,5.])
More information about the Numpy-svn
mailing list