[Scipy-svn] r2988 - in trunk/Lib/sandbox/maskedarray: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun May 13 12:32:25 EDT 2007


Author: pierregm
Date: 2007-05-13 11:32:22 -0500 (Sun, 13 May 2007)
New Revision: 2988

Modified:
   trunk/Lib/sandbox/maskedarray/core.py
   trunk/Lib/sandbox/maskedarray/tests/test_core.py
Log:
maskedarray.core : fixed the .reshape method
maskedarray.core : fixed compressed when _smallmask is False
maskedarray.core : fixed ravel when _smallmask is False

Modified: trunk/Lib/sandbox/maskedarray/core.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/core.py	2007-05-12 02:04:59 UTC (rev 2987)
+++ trunk/Lib/sandbox/maskedarray/core.py	2007-05-13 16:32:22 UTC (rev 2988)
@@ -913,6 +913,7 @@
         mask = self.obj._mask
         cls = type(self.obj)
         result = getattr(data, methodname)(*args, **params).view(cls)
+        result._smallmask = self.obj._smallmask
         if result.ndim:
             if not self._onmask:
                 result._mask = mask
@@ -1115,6 +1116,7 @@
         if hasattr(dout, 'shape') and len(dout.shape) > 0:
             # Not a scalar: make sure that dout is a MA
             dout = dout.view(type(self))
+            dout._smallmask = self._smallmask
             if m is not nomask:
                 # use _set_mask to take care of the shape
                 dout.__setmask__(m[indx])
@@ -1179,8 +1181,8 @@
 If `value` is masked, masks those locations."""
         self.__setitem__(slice(i,j), value)
     #............................................
-    def __setmask__(self, mask):
-        newmask = make_mask(mask, copy=False, small_mask=self._smallmask)
+    def __setmask__(self, mask, copy=False):
+        newmask = make_mask(mask, copy=copy, small_mask=self._smallmask)
 #        self.unshare_mask()
         if self._mask is nomask:
             self._mask = newmask
@@ -1292,6 +1294,8 @@
         d = self.ravel()
         if self._mask is nomask:
             return d
+        elif not self._smallmask and not self._mask.any():
+            return d
         else:
             return d[numeric.logical_not(d._mask)]
     #............................................
@@ -1440,10 +1444,11 @@
         """Reshapes the array to shape s.
 Returns a new masked array.
 If you want to modify the shape in place, please use `a.shape = s`"""
-        # TODO: Do we keep super, or reshape _data and take a view ?
-        result = super(MaskedArray, self).reshape(*s)
-        if self._mask is not nomask:
-            result._mask = self._mask.reshape(*s)
+        result = self._data.reshape(*s).view(type(self))
+        result.__dict__.update(self.__dict__)
+        if result._mask is not nomask:
+            result._mask = self._mask.copy()
+            result._mask.shape = result.shape
         return result
     #
     repeat = _arraymethod('repeat')
@@ -2634,4 +2639,10 @@
 if __name__ == '__main__':
     if 1:
         x = arange(10)
-        assert(x.ctypes.data == x.filled().ctypes.data)
\ No newline at end of file
+        assert(x.ctypes.data == x.filled().ctypes.data)
+    if 1:
+        a = array([1,2,3,4],mask=[0,0,0,0],small_mask=False)
+        assert(a.ravel()._mask, [0,0,0,0])
+        assert(a.compressed(), a)
+        a[0] = masked
+        assert(a.compressed()._mask, [0,0,0])
\ No newline at end of file

Modified: trunk/Lib/sandbox/maskedarray/tests/test_core.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/tests/test_core.py	2007-05-12 02:04:59 UTC (rev 2987)
+++ trunk/Lib/sandbox/maskedarray/tests/test_core.py	2007-05-13 16:32:22 UTC (rev 2988)
@@ -1178,7 +1178,35 @@
         aravel = a.ravel()
         assert_equal(a.shape,(1,5))
         assert_equal(a._mask.shape, a.shape)
+        # Checs that small_mask is preserved
+        a = array([1,2,3,4],mask=[0,0,0,0],small_mask=False)
+        assert_equal(a.ravel()._mask, [0,0,0,0])
         
+    def check_reshape(self):
+        "Tests reshape"
+        x = arange(4)
+        x[0] = masked
+        y = x.reshape(2,2)
+        assert_equal(y.shape, (2,2,))
+        assert_equal(y._mask.shape, (2,2,))
+        assert_equal(x.shape, (4,))
+        assert_equal(x._mask.shape, (4,))
+        
+    def check_compressed(self):
+        "Tests compressed"
+        a = array([1,2,3,4],mask=[0,0,0,0],small_mask=False)
+        b = a.compressed()
+        assert_equal(b, a)
+        assert_equal(b._mask, a._mask)
+        a[0] = masked
+        b = a.compressed()
+        assert_equal(b._data, [2,3,4])
+        assert_equal(b._mask, [0,0,0])
+        a._smallmask = True
+        b = a.compressed()
+        assert_equal(b._data, [2,3,4])
+        assert_equal(b._mask, nomask)
+        
 #..............................................................................
 
 ###############################################################################




More information about the Scipy-svn mailing list