[Numpy-svn] r6337 - in trunk/numpy/lib: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Feb 2 00:20:23 EST 2009
Author: pierregm
Date: 2009-02-01 23:20:17 -0600 (Sun, 01 Feb 2009)
New Revision: 6337
Modified:
trunk/numpy/lib/recfunctions.py
trunk/numpy/lib/tests/test_recfunctions.py
Log:
* Added a 'autoconvert' option to stack_arrays.
* Fixed 'stack_arrays' to work with fields with titles.
Modified: trunk/numpy/lib/recfunctions.py
===================================================================
--- trunk/numpy/lib/recfunctions.py 2009-01-30 00:26:44 UTC (rev 6336)
+++ trunk/numpy/lib/recfunctions.py 2009-02-02 05:20:17 UTC (rev 6337)
@@ -628,7 +628,8 @@
-def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
+def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
+ autoconvert=False):
"""
Superposes arrays fields by fields
@@ -644,6 +645,8 @@
asrecarray : {False, True}, optional
Whether to return a recarray (or MaskedRecords if `usemask==True`) or
just a flexible-type ndarray.
+ autoconvert : {False, True}, optional
+ Whether automatically cast the type of the field to the maximum.
Examples
--------
@@ -673,16 +676,24 @@
#
dtype_l = ndtype[0]
newdescr = dtype_l.descr
- names = list(dtype_l.names or ()) or ['']
+ names = [_[0] for _ in newdescr]
for dtype_n in ndtype[1:]:
for descr in dtype_n.descr:
name = descr[0] or ''
if name not in names:
newdescr.append(descr)
names.append(name)
- elif descr[1] != dict(newdescr)[name]:
- raise TypeError("Incompatible type '%s' <> '%s'" %\
- (dict(newdescr)[name], descr[1]))
+ else:
+ nameidx = names.index(name)
+ current_descr = newdescr[nameidx]
+ if autoconvert:
+ if np.dtype(descr[1]) > np.dtype(current_descr[-1]):
+ current_descr = list(current_descr)
+ current_descr[-1] = descr[1]
+ newdescr[nameidx] = tuple(current_descr)
+ elif descr[1] != current_descr[-1]:
+ raise TypeError("Incompatible type '%s' <> '%s'" %\
+ (dict(newdescr)[name], descr[1]))
# Only one field: use concatenate
if len(newdescr) == 1:
output = ma.concatenate(seqarrays)
Modified: trunk/numpy/lib/tests/test_recfunctions.py
===================================================================
--- trunk/numpy/lib/tests/test_recfunctions.py 2009-01-30 00:26:44 UTC (rev 6336)
+++ trunk/numpy/lib/tests/test_recfunctions.py 2009-02-02 05:20:17 UTC (rev 6337)
@@ -485,7 +485,6 @@
assert_equal(test.mask, control.mask)
- #
def test_defaults(self):
"Test defaults: no exception raised if keys of defaults are not fields."
(_, _, _, z) = self.data
@@ -503,7 +502,38 @@
assert_equal(test.mask, control.mask)
+ def test_autoconversion(self):
+ "Tests autoconversion"
+ adtype = [('A', int), ('B', bool), ('C', float)]
+ a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
+ bdtype = [('A', int), ('B', float), ('C', float)]
+ b = ma.array([(4, 5, 6)], dtype=bdtype)
+ control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
+ dtype=bdtype)
+ test = stack_arrays((a, b), autoconvert=True)
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+ try:
+ test = stack_arrays((a, b), autoconvert=False)
+ except TypeError:
+ pass
+ else:
+ raise AssertionError
+
+ def test_checktitles(self):
+ "Test using titles in the field names"
+ adtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
+ a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
+ bdtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
+ b = ma.array([(4, 5, 6)], dtype=bdtype)
+ test = stack_arrays((a, b))
+ control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
+ dtype=bdtype)
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+
+
class TestJoinBy(TestCase):
#
def test_base(self):
More information about the Numpy-svn
mailing list