[Numpy-svn] r6087 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Nov 21 15:49:35 EST 2008


Author: pierregm
Date: 2008-11-21 14:49:33 -0600 (Fri, 21 Nov 2008)
New Revision: 6087

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
Rewrote allclose to allow comparison with a scalar

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-11-21 17:12:47 UTC (rev 6086)
+++ trunk/numpy/ma/core.py	2008-11-21 20:49:33 UTC (rev 6087)
@@ -1105,10 +1105,10 @@
         Whether to collapse a mask full of False to nomask
 
     """
-    abs = umath.absolute
+    mabs = umath.absolute
     xnew = filled(x, value)
     if issubclass(xnew.dtype.type, np.floating):
-        condition = umath.less_equal(abs(xnew-value), atol+rtol*abs(value))
+        condition = umath.less_equal(mabs(xnew-value), atol + rtol*mabs(value))
         mask = getattr(x, '_mask', nomask)
     else:
         condition = umath.equal(xnew, value)
@@ -1446,7 +1446,7 @@
         self.__dict__.update(_optinfo)
         return
     #........................
-    def __array_finalize__(self,obj):
+    def __array_finalize__(self, obj):
         """Finalizes the masked array.
         """
         # Get main attributes .........
@@ -1574,7 +1574,7 @@
             if self._mask is nomask:
                 output._mask = nomask
             else:
-                output._mask = self._mask.astype([(n,bool) for n in names])
+                output._mask = self._mask.astype([(n, bool) for n in names])
         # Don't check _fill_value if it's None, that'll speed things up
         if self._fill_value is not None:
             output._fill_value = _check_fill_value(self._fill_value, newtype)
@@ -1685,7 +1685,7 @@
             ndarray.__setitem__(_mask, indx, mval)
         elif hasattr(indx, 'dtype') and (indx.dtype==MaskType):
             indx = indx * umath.logical_not(_mask)
-            ndarray.__setitem__(_data,indx,dval)
+            ndarray.__setitem__(_data, indx, dval)
         else:
             if nbfields:
                 err_msg = "Flexible 'hard' masks are not yet supported..."
@@ -1716,7 +1716,7 @@
     those locations.
 
         """
-        self.__setitem__(slice(i,j), value)
+        self.__setitem__(slice(i, j), value)
     #............................................
     def __setmask__(self, mask, copy=False):
         """Set the mask.
@@ -2220,12 +2220,14 @@
         return int(self.item())
     #............................................
     def get_imag(self):
+        "Returns the imaginary part."
         result = self._data.imag.view(type(self))
         result.__setmask__(self._mask)
         return result
     imag = property(fget=get_imag, doc="Imaginary part.")
 
     def get_real(self):
+        "Returns the real part."
         result = self._data.real.view(type(self))
         result.__setmask__(self._mask)
         return result
@@ -2234,14 +2236,14 @@
 
     #............................................
     def count(self, axis=None):
-        """Count the non-masked elements of the array along the given
-        axis.
+        """
+        Count the non-masked elements of the array along the given axis.
 
         Parameters
         ----------
         axis : int, optional
-            Axis along which to count the non-masked elements. If
-            not given, all the non masked elements are counted.
+            Axis along which to count the non-masked elements. If axis is None,
+            all the non masked elements are counted.
 
         Returns
         -------
@@ -3447,9 +3449,11 @@
                 (self.__class__, self._baseclass, (0,), 'b', ),
                 self.__getstate__())
     #
-    def __deepcopy__(self, memo={}):
+    def __deepcopy__(self, memo=None):
         from copy import deepcopy
         copied = MaskedArray.__new__(type(self), self, copy=True)
+        if memo is None:
+            memo = {}
         memo[id(self)] = copied
         for (k,v) in self.__dict__.iteritems():
             copied.__dict__[k] = deepcopy(v, memo)
@@ -3687,16 +3691,16 @@
     fa = getdata(a)
     fb = getdata(b)
     # Get the type of the result (so that we preserve subclasses)
-    if isinstance(a,MaskedArray):
+    if isinstance(a, MaskedArray):
         basetype = type(a)
     else:
         basetype = MaskedArray
     # Get the result and view it as a (subclass of) MaskedArray
-    result = umath.power(fa,fb).view(basetype)
+    result = umath.power(fa, fb).view(basetype)
     # Find where we're in trouble w/ NaNs and Infs
     invalid = np.logical_not(np.isfinite(result.view(ndarray)))
     # Retrieve some extra attributes if needed
-    if isinstance(result,MaskedArray):
+    if isinstance(result, MaskedArray):
         result._update_from(a)
     # Add the initial mask
     if m is not nomask:
@@ -3770,7 +3774,7 @@
         filler = fill_value
 #    return
     indx = np.indices(a.shape).tolist()
-    indx[axis] = filled(a,filler).argsort(axis=axis,kind=kind,order=order)
+    indx[axis] = filled(a, filler).argsort(axis=axis, kind=kind, order=order)
     return a[indx]
 sort.__doc__ = MaskedArray.sort.__doc__
 
@@ -3820,7 +3824,7 @@
 count.__doc__ = MaskedArray.count.__doc__
 
 
-def expand_dims(x,axis):
+def expand_dims(x, axis):
     """
     Expand the shape of the array by including a new axis before
     the given one.
@@ -4160,24 +4164,76 @@
     else:
         return False
 
-def allclose (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8):
-    """ Return True if all elements of a and b are equal subject to
+def allclose (a, b, masked_equal=True, rtol=1.e-5, atol=1.e-8, fill_value=None):
+    """
+        Returns True if two arrays are element-wise equal within a tolerance.
+
+    The tolerance values are positive, typically very small numbers.  The
+    relative difference (`rtol` * `b`) and the absolute difference (`atol`)
+    are added together to compare against the absolute difference between `a`
+    and `b`.
+
+    Parameters
+    ----------
+    a, b : array_like
+        Input arrays to compare.
+    fill_value : boolean, optional
+        Whether masked values in a or b are considered equal (True) or not
+        (False).
+        
+    rtol : Relative tolerance
+        The relative difference is equal to `rtol` * `b`.
+    atol : Absolute tolerance
+        The absolute difference is equal to `atol`.
+
+    Returns
+    -------
+    y : bool
+        Returns True if the two arrays are equal within the given
+        tolerance; False otherwise. If either array contains NaN, then
+        False is returned.
+
+    See Also
+    --------
+    all, any, alltrue, sometrue
+
+    Notes
+    -----
+    If the following equation is element-wise True, then allclose returns
+    True.
+
+     absolute(`a` - `b`) <= (`atol` + `rtol` * absolute(`b`))
+    
+    Return True if all elements of a and b are equal subject to
     given tolerances.
 
-    If fill_value is True, masked values are considered equal.
-    If fill_value is False, masked values considered unequal.
-    The relative error rtol should be positive and << 1.0
-    The absolute error atol comes into play for those elements of b
-    that are very small or zero; it says how small `a` must be also.
-
     """
-    m = mask_or(getmask(a), getmask(b))
-    d1 = getdata(a)
-    d2 = getdata(b)
-    x = filled(array(d1, copy=0, mask=m), fill_value).astype(float)
-    y = filled(array(d2, copy=0, mask=m), 1).astype(float)
-    d = umath.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
-    return np.alltrue(np.ravel(d))
+    if fill_value is not None:
+        warnings.warn("The use of fill_value is deprecated."\
+                      " Please use masked_equal instead.")
+        masked_equal = fill_value
+    #
+    x = masked_array(a, copy=False)
+    y = masked_array(b, copy=False)
+    m = mask_or(getmask(x), getmask(y))
+    xinf = np.isinf(masked_array(x, copy=False, mask=m)).filled(False)
+    # If we have some infs, they should fall at the same place.
+    if not np.all(xinf == filled(np.isinf(y), False)):
+        return False
+    # No infs at all
+    if not np.any(xinf):
+        d = filled(umath.less_equal(umath.absolute(x-y),
+                                    atol + rtol * umath.absolute(y)),
+                   masked_equal)
+        return np.all(d)
+    if not np.all(filled(x[xinf] == y[xinf], masked_equal)):
+        return False
+    x = x[~xinf]
+    y = y[~xinf]
+    d = filled(umath.less_equal(umath.absolute(x-y),
+                                atol + rtol * umath.absolute(y)),
+               masked_equal)
+    return np.all(d)
 
 #..............................................................................
 def asarray(a, dtype=None):
@@ -4225,7 +4281,7 @@
 #####--------------------------------------------------------------------------
 #---- --- Pickling ---
 #####--------------------------------------------------------------------------
-def dump(a,F):
+def dump(a, F):
     """
     Pickle the MaskedArray `a` to the file `F`.  `F` can either be
     the handle of an exiting file, or a string representing a file

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-11-21 17:12:47 UTC (rev 6086)
+++ trunk/numpy/ma/tests/test_core.py	2008-11-21 20:49:33 UTC (rev 6087)
@@ -1298,6 +1298,27 @@
         assert_equal(m.transpose(), m._data.transpose())
 
 
+    def test_allclose(self):
+        "Tests allclose on arrays"
+        a = np.random.rand(10)
+        b = a + np.random.rand(10) * 1e-8
+        self.failUnless(allclose(a,b))
+        # Test allclose w/ infs
+        a[0] = np.inf
+        self.failUnless(not allclose(a,b))
+        b[0] = np.inf
+        self.failUnless(allclose(a,b))
+        # Test all close w/ masked
+        a = masked_array(a)
+        a[-1] = masked
+        self.failUnless(allclose(a,b, masked_equal=True))
+        self.failUnless(not allclose(a, b, masked_equal=False))
+        # Test comparison w/ scalar
+        a *= 1e-8
+        a[0] = 0
+        self.failUnless(allclose(a, 0, masked_equal=True))
+        
+
     def test_allany(self):
         """Checks the any/all methods/functions."""
         x = np.array([[ 0.13,  0.26,  0.90],
@@ -1467,7 +1488,7 @@
 
     def test_empty(self):
         "Tests empty/like"
-        datatype = [('a',int_),('b',float),('c','|S8')]
+        datatype = [('a',int),('b',float),('c','|S8')]
         a = masked_array([(1,1.1,'1.1'),(2,2.2,'2.2'),(3,3.3,'3.3')],
                          dtype=datatype)
         assert_equal(len(a.fill_value.item()), len(datatype))




More information about the Numpy-svn mailing list