[Numpy-svn] r5565 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Jul 28 16:46:03 EDT 2008


Author: pierregm
Date: 2008-07-28 15:45:51 -0500 (Mon, 28 Jul 2008)
New Revision: 5565

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* core : prevent ._basedict to be incorrectly propagated

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-07-28 06:13:32 UTC (rev 5564)
+++ trunk/numpy/ma/core.py	2008-07-28 20:45:51 UTC (rev 5565)
@@ -1322,7 +1322,9 @@
             _baseclass = type(obj)
         else:
             _baseclass = ndarray
-        _basedict = getattr(obj, '_basedict', getattr(obj, '__dict__',{}))
+        # We need to copy the _basedict to avoid backward propagation
+        _basedict = {}
+        _basedict.update(getattr(obj, '_basedict', getattr(obj, '__dict__',{})))
         _dict = dict(_fill_value=getattr(obj, '_fill_value', None),
                      _hardmask=getattr(obj, '_hardmask', False),
                      _sharedmask=getattr(obj, '_sharedmask', False),

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-07-28 06:13:32 UTC (rev 5564)
+++ trunk/numpy/ma/tests/test_core.py	2008-07-28 20:45:51 UTC (rev 5565)
@@ -436,6 +436,16 @@
         assert_equal(flexi.filled(1),
                      np.array([(1, '1', 1.)], dtype=flexi.dtype))
 
+
+    def test_basedict_propagation(self):
+        "Checks that basedict isn't back-propagated"
+        x = array([1,2,3,], dtype=float)
+        x._basedict['info'] = '???'
+        y = x.copy()
+        assert_equal(y._basedict['info'],'???')
+        y._basedict['info'] = '!!!'
+        assert_equal(x._basedict['info'], '???')
+
 #------------------------------------------------------------------------------
 
 class TestMaskedArrayArithmetic(TestCase):
@@ -509,7 +519,6 @@
         z = x/y[:,None]
         assert_equal(z, [[-1.,-1.,-1.], [3.,4.,5.]])
         assert_equal(z.mask, [[1,1,1],[0,0,0]])
-        
 
     def test_mixed_arithmetic(self):
         "Tests mixed arithmetics."




More information about the Numpy-svn mailing list