[Scipy-svn] r2798 - in trunk/Lib/sandbox/maskedarray: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Wed Feb 28 23:05:37 EST 2007


Author: pierregm
Date: 2007-02-28 22:05:32 -0600 (Wed, 28 Feb 2007)
New Revision: 2798

Modified:
   trunk/Lib/sandbox/maskedarray/core.py
   trunk/Lib/sandbox/maskedarray/tests/test_core.py
   trunk/Lib/sandbox/maskedarray/tests/test_extras.py
   trunk/Lib/sandbox/maskedarray/tests/test_mrecords.py
   trunk/Lib/sandbox/maskedarray/tests/test_subclassing.py
Log:
core: fixed pickling with _data a subclass of ndarray

Modified: trunk/Lib/sandbox/maskedarray/core.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/core.py	2007-03-01 02:49:48 UTC (rev 2797)
+++ trunk/Lib/sandbox/maskedarray/core.py	2007-03-01 04:05:32 UTC (rev 2798)
@@ -1884,6 +1884,52 @@
         """Returns the `_data` part of the MaskedArray.
 You should really use `data` instead..."""
         return self._data
+    #--------------------------------------------
+    # Pickling
+    def __getstate__(self):
+        "Returns the internal state of the masked array, for pickling purposes."
+        state = (1,
+                 self.shape,
+                 self.dtype,
+                 self.flags.fnc,
+                 self._data.tostring(),
+                 getmaskarray(self).tostring(),
+                 self._fill_value,
+                 )
+        return state    
+    #
+    def __setstate__(self, state):
+        """Restores the internal state of the masked array, for pickling purposes.
+    `state` is typically the output of the ``__getstate__`` output, and is a 5-tuple:
+    
+        - class name
+        - a tuple giving the shape of the data
+        - a typecode for the data
+        - a binary string for the data
+        - a binary string for the mask.
+            """
+        (ver, shp, typ, isf, raw, msk, flv) = state
+        ndarray.__setstate__(self, (shp, typ, isf, raw))
+        self._mask.__setstate__((shp, dtype(bool), isf, msk))
+        self.fill_value = flv
+    #
+    def __reduce__(self):
+        """Returns a 3-tuple for pickling a MaskedArray."""
+        return (_mareconstruct,
+                (self.__class__, self._baseclass, (0,), 'b', ),
+                self.__getstate__())
+    
+    
+def _mareconstruct(subtype, baseclass, baseshape, basetype,):
+    """Internal function that builds a new MaskedArray from the information stored
+in a pickle."""
+    _data = ndarray.__new__(baseclass, baseshape, basetype)
+    _mask = ndarray.__new__(ndarray, baseshape, 'b1')
+    return subtype.__new__(subtype, _data, mask=_mask, dtype=basetype, small_mask=False)
+#MaskedArray.__dump__ = dump
+#MaskedArray.__dumps__ = dumps
+    
+    
 
 #####--------------------------------------------------------------------------
 #---- --- Shortcuts ---
@@ -2531,44 +2577,6 @@
 #####--------------------------------------------------------------------------
 #---- --- Pickling ---
 #####--------------------------------------------------------------------------
-#FIXME: We're kinda stuck with forcing the mask to have the same shape as the data
-def _mareconstruct(subtype, baseshape, basetype,):
-    """Internal function that builds a new MaskedArray from the information stored
-in a pickle."""
-    _data = ndarray.__new__(ndarray, baseshape, basetype)
-    _mask = ndarray.__new__(ndarray, baseshape, basetype)
-    return MaskedArray.__new__(subtype, _data, mask=_mask, dtype=basetype, small_mask=False)
-
-def _getstate(a):
-    "Returns the internal state of the masked array, for pickling purposes."
-    state = (1,
-             a.shape,
-             a.dtype,
-             a.flags.fnc,
-             a.tostring(),
-             getmaskarray(a).tostring())
-    return state
-
-def _setstate(a, state):
-    """Restores the internal state of the masked array, for pickling purposes.
-`state` is typically the output of the ``__getstate__`` output, and is a 5-tuple:
-
-    - class name
-    - a tuple giving the shape of the data
-    - a typecode for the data
-    - a binary string for the data
-    - a binary string for the mask.
-        """
-    (ver, shp, typ, isf, raw, msk) = state
-    super(MaskedArray, a).__setstate__((shp, typ, isf, raw))
-    (a._mask).__setstate__((shp, dtype('|b1'), isf, msk))
-
-def _reduce(a):
-    """Returns a 3-tuple for pickling a MaskedArray."""
-    return (_mareconstruct,
-            (a.__class__, (0,), 'b', ),
-            a.__getstate__())
-
 def dump(a,F):
     """Pickles the MaskedArray `a` to the file `F`.
 `F` can either be the handle of an exiting file, or a string representing a file name.
@@ -2592,15 +2600,24 @@
     "Loads a pickle from the current string."""
     return cPickle.loads(strg)
 
-MaskedArray.__getstate__ = _getstate
-MaskedArray.__setstate__ = _setstate
-MaskedArray.__reduce__ = _reduce
-MaskedArray.__dump__ = dump
-MaskedArray.__dumps__ = dumps
 
 ################################################################################
 
-#if __name__ == '__main__':
-#    import numpy as N
-#    from maskedarray.testutils import assert_equal, assert_array_equal
-#    pi = N.pi
+if __name__ == '__main__':
+    import numpy as N
+    from maskedarray.testutils import assert_equal, assert_array_equal
+    #
+    a = arange(10)
+    a[::3] = masked
+    a.fill_value = 999
+    a_pickled = cPickle.loads(a.dumps())
+    assert_equal(a_pickled._mask, a._mask)
+    assert_equal(a_pickled._data, a._data)
+    assert_equal(a_pickled.fill_value, 999)
+    #
+    a = array(numpy.matrix(range(10)), mask=[1,0,1,0,0]*2)
+    a_pickled = cPickle.loads(a.dumps())
+    assert_equal(a_pickled._mask, a._mask)
+    assert_equal(a_pickled, a)
+    assert(isinstance(a_pickled._data,numpy.matrix))
+    

Modified: trunk/Lib/sandbox/maskedarray/tests/test_core.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/tests/test_core.py	2007-03-01 02:49:48 UTC (rev 2797)
+++ trunk/Lib/sandbox/maskedarray/tests/test_core.py	2007-03-01 04:05:32 UTC (rev 2798)
@@ -18,11 +18,9 @@
 from numpy.testing.utils import build_err_msg
 
 import maskedarray.testutils
-reload(maskedarray.testutils)
 from maskedarray.testutils import *
 
 import maskedarray.core as coremodule
-reload(coremodule)
 from maskedarray.core import *
 
 pi = N.pi
@@ -600,16 +598,6 @@
         assert t[0] == 'abc'
         assert t[1] == 2
         assert t[2] == 3
-    #........................
-    def check_pickling(self):
-        "Test of pickling"
-        import pickle
-        x = arange(12)
-        x[4:10:2] = masked
-        x = x.reshape(4,3)
-        s = pickle.dumps(x)
-        y = pickle.loads(s)
-        assert_equal(x,y)
     #.......................
     def check_maskedelement(self):
         "Test of masked element"
@@ -700,6 +688,23 @@
         assert_equal(X._mask, x.mask)
         assert_equal(getmask(x), [0,0,1,0,0])
         
+    def check_pickling(self):
+        "Tests pickling"
+        import cPickle
+        a = arange(10)
+        a[::3] = masked
+        a.fill_value = 999
+        a_pickled = cPickle.loads(a.dumps())
+        assert_equal(a_pickled._mask, a._mask)
+        assert_equal(a_pickled._data, a._data)
+        assert_equal(a_pickled.fill_value, 999)
+        #
+        a = array(N.matrix(range(10)), mask=[1,0,1,0,0]*2)
+        a_pickled = cPickle.loads(a.dumps())
+        assert_equal(a_pickled._mask, a._mask)
+        assert_equal(a_pickled, a)
+        assert(isinstance(a_pickled._data,N.matrix))
+        
 #...............................................................................
         
 class test_ufuncs(NumpyTestCase):

Modified: trunk/Lib/sandbox/maskedarray/tests/test_extras.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/tests/test_extras.py	2007-03-01 02:49:48 UTC (rev 2797)
+++ trunk/Lib/sandbox/maskedarray/tests/test_extras.py	2007-03-01 04:05:32 UTC (rev 2798)
@@ -16,14 +16,11 @@
 from numpy.testing.utils import build_err_msg
 
 import maskedarray.testutils
-reload(maskedarray.testutils)
 from maskedarray.testutils import *
 
 import maskedarray.core
-reload(maskedarray.core)
 from maskedarray.core import *
 import maskedarray.extras
-reload(maskedarray.extras)
 from maskedarray.extras import *
 
 class test_average(NumpyTestCase):        

Modified: trunk/Lib/sandbox/maskedarray/tests/test_mrecords.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/tests/test_mrecords.py	2007-03-01 02:49:48 UTC (rev 2797)
+++ trunk/Lib/sandbox/maskedarray/tests/test_mrecords.py	2007-03-01 04:05:32 UTC (rev 2798)
@@ -18,16 +18,12 @@
 from numpy.testing.utils import build_err_msg
 
 import maskedarray.testutils
-reload(maskedarray.testutils)
 from maskedarray.testutils import *
 
 import maskedarray.core as MA
-##reload(MA)
 #import maskedarray.mrecords
-##reload(maskedarray.mrecords)
 #from maskedarray.mrecords import mrecarray, fromarrays, fromtextfile, fromrecords
 import maskedarray.mrecords
-reload(maskedarray.mrecords)
 from maskedarray.mrecords import MaskedRecords, \
     fromarrays, fromtextfile, fromrecords, addfield
 

Modified: trunk/Lib/sandbox/maskedarray/tests/test_subclassing.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/tests/test_subclassing.py	2007-03-01 02:49:48 UTC (rev 2797)
+++ trunk/Lib/sandbox/maskedarray/tests/test_subclassing.py	2007-03-01 04:05:32 UTC (rev 2798)
@@ -16,11 +16,9 @@
 from numpy.testing import NumpyTest, NumpyTestCase
 
 import maskedarray.testutils
-#reload(maskedarray.testutils)
 from maskedarray.testutils import *
 
 import maskedarray.core as coremodule
-#reload(coremodule)
 from maskedarray.core import *
 
 
@@ -125,5 +123,15 @@
 ################################################################################
 if __name__ == '__main__':
     NumpyTest().run()
+    if 1:
+        x = N.arange(5)
+        m = [0,0,1,0,0]
+        xinfo = [(i,j) for (i,j) in zip(x,m)]
+        xsub = MSubArray(x, mask=m, info={'xsub':xinfo})
+        #
+        xsub_low = less(xsub,3)
+        assert isinstance(xsub, MSubArray)
+        assert_equal(xsub_low.info, xinfo)
+                     
 
 




More information about the Scipy-svn mailing list