[Numpy-svn] r6337 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Feb 2 00:20:23 EST 2009


Author: pierregm
Date: 2009-02-01 23:20:17 -0600 (Sun, 01 Feb 2009)
New Revision: 6337

Modified:
   trunk/numpy/lib/recfunctions.py
   trunk/numpy/lib/tests/test_recfunctions.py
Log:
* Added a 'autoconvert' option to stack_arrays.
* Fixed 'stack_arrays' to work with fields with titles.

Modified: trunk/numpy/lib/recfunctions.py
===================================================================
--- trunk/numpy/lib/recfunctions.py	2009-01-30 00:26:44 UTC (rev 6336)
+++ trunk/numpy/lib/recfunctions.py	2009-02-02 05:20:17 UTC (rev 6337)
@@ -628,7 +628,8 @@
 
 
 
-def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
+def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
+                 autoconvert=False):
     """
     Superposes arrays fields by fields
 
@@ -644,6 +645,8 @@
     asrecarray : {False, True}, optional
         Whether to return a recarray (or MaskedRecords if `usemask==True`) or
         just a flexible-type ndarray.
+    autoconvert : {False, True}, optional
+        Whether automatically cast the type of the field to the maximum.
 
     Examples
     --------
@@ -673,16 +676,24 @@
     #
     dtype_l = ndtype[0]
     newdescr = dtype_l.descr
-    names = list(dtype_l.names or ()) or ['']
+    names = [_[0] for _ in newdescr]
     for dtype_n in ndtype[1:]:
         for descr in dtype_n.descr:
             name = descr[0] or ''
             if name not in names:
                 newdescr.append(descr)
                 names.append(name)
-            elif descr[1] != dict(newdescr)[name]:
-                raise TypeError("Incompatible type '%s' <> '%s'" %\
-                                (dict(newdescr)[name], descr[1]))
+            else:
+                nameidx = names.index(name)
+                current_descr = newdescr[nameidx]
+                if autoconvert:
+                    if np.dtype(descr[1]) > np.dtype(current_descr[-1]):
+                        current_descr = list(current_descr)
+                        current_descr[-1] = descr[1]
+                        newdescr[nameidx] = tuple(current_descr)
+                elif descr[1] != current_descr[-1]:
+                    raise TypeError("Incompatible type '%s' <> '%s'" %\
+                                    (dict(newdescr)[name], descr[1]))
     # Only one field: use concatenate
     if len(newdescr) == 1:
         output = ma.concatenate(seqarrays)

Modified: trunk/numpy/lib/tests/test_recfunctions.py
===================================================================
--- trunk/numpy/lib/tests/test_recfunctions.py	2009-01-30 00:26:44 UTC (rev 6336)
+++ trunk/numpy/lib/tests/test_recfunctions.py	2009-02-02 05:20:17 UTC (rev 6337)
@@ -485,7 +485,6 @@
         assert_equal(test.mask, control.mask)
 
 
-    #
     def test_defaults(self):
         "Test defaults: no exception raised if keys of defaults are not fields."
         (_, _, _, z) = self.data
@@ -503,7 +502,38 @@
         assert_equal(test.mask, control.mask)
 
 
+    def test_autoconversion(self):
+        "Tests autoconversion"
+        adtype = [('A', int), ('B', bool), ('C', float)]
+        a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
+        bdtype = [('A', int), ('B', float), ('C', float)]
+        b = ma.array([(4, 5, 6)], dtype=bdtype)
+        control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
+                           dtype=bdtype)
+        test = stack_arrays((a, b), autoconvert=True)
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+        try:
+            test = stack_arrays((a, b), autoconvert=False)
+        except TypeError:
+            pass
+        else:
+            raise AssertionError
 
+
+    def test_checktitles(self):
+        "Test using titles in the field names"
+        adtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
+        a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
+        bdtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
+        b = ma.array([(4, 5, 6)], dtype=bdtype)
+        test = stack_arrays((a, b))
+        control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
+                           dtype=bdtype)
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+
+
 class TestJoinBy(TestCase):
     #
     def test_base(self):




More information about the Numpy-svn mailing list