[Numpy-svn] r6331 - in trunk/numpy/lib: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Thu Jan 22 00:40:28 EST 2009
Author: pierregm
Date: 2009-01-21 23:40:25 -0600 (Wed, 21 Jan 2009)
New Revision: 6331
Added:
trunk/numpy/lib/recfunctions.py
trunk/numpy/lib/tests/test_recfunctions.py
Log:
* added recfunctions, a collection of utilities to manipulate structured arrays.
Added: trunk/numpy/lib/recfunctions.py
===================================================================
--- trunk/numpy/lib/recfunctions.py 2009-01-22 05:37:36 UTC (rev 6330)
+++ trunk/numpy/lib/recfunctions.py 2009-01-22 05:40:25 UTC (rev 6331)
@@ -0,0 +1,931 @@
+"""
+Collection of utilities to manipulate structured arrays.
+
+Most of these functions were initially implemented by John Hunter for matplotlib.
+They have been rewritten and extended for convenience.
+
+
+"""
+
+
+import itertools
+from itertools import chain as iterchain, repeat as iterrepeat, izip as iterizip
+import numpy as np
+from numpy import ndarray, recarray
+import numpy.ma as ma
+from numpy.ma import MaskedArray
+from numpy.ma.mrecords import MaskedRecords
+
+from numpy.lib._iotools import _is_string_like
+
+_check_fill_value = np.ma.core._check_fill_value
+
+__all__ = ['append_fields',
+ 'drop_fields',
+ 'find_duplicates',
+ 'get_fieldstructure',
+ 'join_by',
+ 'merge_arrays',
+ 'rec_append_fields', 'rec_drop_fields', 'rec_join',
+ 'recursive_fill_fields', 'rename_fields',
+ 'stack_arrays',
+ ]
+
+
+def recursive_fill_fields(input, output):
+ """
+ Fills fields from output with fields from input,
+ with support for nested structures.
+
+ Parameters
+ ----------
+ input : ndarray
+ Input array.
+ output : ndarray
+ Output array.
+
+ Notes
+ -----
+ * `output` should be at least the same size as `input`
+
+ Examples
+ --------
+ >>> a = np.array([(1, 10.), (2, 20.)], dtype=[('A', int), ('B', float)])
+ >>> b = np.zeros((3,), dtype=a.dtype)
+ >>> recursive_fill_fields(a, b)
+ np.array([(1, 10.), (2, 20.), (0, 0.)], dtype=[('A', int), ('B', float)])
+
+ """
+ newdtype = output.dtype
+ for field in newdtype.names:
+ try:
+ current = input[field]
+ except ValueError:
+ continue
+ if current.dtype.names:
+ recursive_fill_fields(current, output[field])
+ else:
+ output[field][:len(current)] = current
+ return output
+
+
+
+def get_names(adtype):
+ """
+ Returns the field names of the input datatype as a tuple.
+
+ Parameters
+ ----------
+ adtype : dtype
+ Input datatype
+
+ Examples
+ --------
+ >>> get_names(np.empty((1,), dtype=int)) is None
+ True
+ >>> get_names(np.empty((1,), dtype=[('A',int), ('B', float)]))
+ ('A', 'B')
+ >>> adtype = np.dtype([('a', int), ('b', [('ba', int), ('bb', int)])])
+ >>> get_names(adtype)
+ ('a', ('b', ('ba', 'bb')))
+ """
+ listnames = []
+ names = adtype.names
+ for name in names:
+ current = adtype[name]
+ if current.names:
+ listnames.append((name, tuple(get_names(current))))
+ else:
+ listnames.append(name)
+ return tuple(listnames) or None
+
+
+def get_names_flat(adtype):
+ """
+ Returns the field names of the input datatype as a tuple. Nested structure
+ are flattend beforehand.
+
+ Parameters
+ ----------
+ adtype : dtype
+ Input datatype
+
+ Examples
+ --------
+ >>> get_names_flat(np.empty((1,), dtype=int)) is None
+ True
+ >>> get_names_flat(np.empty((1,), dtype=[('A',int), ('B', float)]))
+ ('A', 'B')
+ >>> adtype = np.dtype([('a', int), ('b', [('ba', int), ('bb', int)])])
+ >>> get_names_flat(adtype)
+ ('a', 'b', 'ba', 'bb')
+ """
+ listnames = []
+ names = adtype.names
+ for name in names:
+ listnames.append(name)
+ current = adtype[name]
+ if current.names:
+ listnames.extend(get_names_flat(current))
+ return tuple(listnames) or None
+
+
+def flatten_descr(ndtype):
+ """
+ Flatten a structured data-type description.
+
+ Examples
+ --------
+ >>> ndtype = np.dtype([('a', '<i4'), ('b', [('ba', '<f8'), ('bb', '<i4')])])
+ >>> flatten_descr(ndtype)
+ (('a', dtype('int32')), ('ba', dtype('float64')), ('bb', dtype('int32')))
+
+ """
+ names = ndtype.names
+ if names is None:
+ return ndtype.descr
+ else:
+ descr = []
+ for field in names:
+ (typ, _) = ndtype.fields[field]
+ if typ.names:
+ descr.extend(flatten_descr(typ))
+ else:
+ descr.append((field, typ))
+ return tuple(descr)
+
+
+def zip_descr(seqarrays, flatten=False):
+ """
+ Combine the dtype description of a series of arrays.
+
+ Parameters
+ ----------
+ seqarrays : sequence of arrays
+ Sequence of arrays
+ flatten : {boolean}, optional
+ Whether to collapse nested descriptions.
+ """
+ newdtype = []
+ if flatten:
+ for a in seqarrays:
+ newdtype.extend(flatten_descr(a.dtype))
+ else:
+ for a in seqarrays:
+ current = a.dtype
+ names = current.names or ()
+ if len(names) > 1:
+ newdtype.append(('', current.descr))
+ else:
+ newdtype.extend(current.descr)
+ return np.dtype(newdtype).descr
+
+
+def get_fieldstructure(adtype, lastname=None, parents=None,):
+ """
+ Returns a dictionary with fields as keys and a list of parent fields as values.
+
+ This function is used to simplify access to fields nested in other fields.
+
+ Parameters
+ ----------
+ adtype : np.dtype
+ Input datatype
+ lastname : optional
+ Last processed field name (used internally during recursion).
+ parents : dictionary
+ Dictionary of parent fields (used interbally during recursion).
+
+ Examples
+ --------
+ >>> ndtype = np.dtype([('A', int),
+ ... ('B', [('BA', int),
+ ... ('BB', [('BBA', int), ('BBB', int)])])])
+ >>> get_fieldstructure(ndtype)
+ {'A': [], 'B': [], 'BA': ['B'], 'BB': ['B'],
+ 'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
+
+ """
+ if parents is None:
+ parents = {}
+ names = adtype.names
+ for name in names:
+ current = adtype[name]
+ if current.names:
+ if lastname:
+ parents[name] = [lastname,]
+ else:
+ parents[name] = []
+ parents.update(get_fieldstructure(current, name, parents))
+ else:
+ lastparent = [_ for _ in (parents.get(lastname, []) or [])]
+ if lastparent:
+# if (lastparent[-1] != lastname):
+ lastparent.append(lastname)
+ elif lastname:
+ lastparent = [lastname,]
+ parents[name] = lastparent or []
+ return parents or None
+
+
+def _izip_fields_flat(iterable):
+ """
+ Returns an iterator of concatenated fields from a sequence of arrays,
+ collapsing any nested structure.
+ """
+ for element in iterable:
+ if isinstance(element, np.void):
+ for f in _izip_fields_flat(tuple(element)):
+ yield f
+ else:
+ yield element
+
+
+def _izip_fields(iterable):
+ """
+ Returns an iterator of concatenated fields from a sequence of arrays.
+ """
+ for element in iterable:
+ if hasattr(element, '__iter__') and not isinstance(element, basestring):
+ for f in _izip_fields(element):
+ yield f
+ elif isinstance(element, np.void) and len(tuple(element)) == 1:
+ for f in _izip_fields(element):
+ yield f
+ else:
+ yield element
+
+
+def izip_records(seqarrays, fill_value=None, flatten=True):
+ """
+ Returns an iterator of concatenated items from a sequence of arrays.
+
+ Parameters
+ ----------
+ seqarray : sequence of arrays
+ Sequence of arrays.
+ fill_value : {None, integer}
+ Value used to pad shorter iterables.
+ flatten : {True, False},
+ Whether to
+ """
+ # OK, that's a complete ripoff from Python2.6 itertools.izip_longest
+ def sentinel(counter = ([fill_value]*(len(seqarrays)-1)).pop):
+ "Yields the fill_value or raises IndexError"
+ yield counter()
+ #
+ fillers = iterrepeat(fill_value)
+ iters = [iterchain(it, sentinel(), fillers) for it in seqarrays]
+ # Should we flatten the items, or just use a nested approach
+ if flatten:
+ zipfunc = _izip_fields_flat
+ else:
+ zipfunc = _izip_fields
+ #
+ try:
+ for tup in iterizip(*iters):
+ yield tuple(zipfunc(tup))
+ except IndexError:
+ pass
+
+
+def _fix_output(output, usemask=True, asrecarray=False):
+ """
+ Private function: return a recarray, a ndarray, a MaskedArray
+ or a MaskedRecords depending on the input parameters
+ """
+ if not isinstance(output, MaskedArray):
+ usemask = False
+ if usemask:
+ if asrecarray:
+ output = output.view(MaskedRecords)
+ else:
+ output = ma.filled(output)
+ if asrecarray:
+ output = output.view(recarray)
+ return output
+
+
+def _fix_defaults(output, defaults=None):
+ """
+ Update the fill_value and masked data of `output`
+ from the default given in a dictionary defaults.
+ """
+ names = output.dtype.names
+ (data, mask, fill_value) = (output.data, output.mask, output.fill_value)
+ for (k, v) in (defaults or {}).iteritems():
+ if k in names:
+ fill_value[k] = v
+ data[k][mask[k]] = v
+ return output
+
+
+def merge_arrays(seqarrays,
+ fill_value=-1, flatten=False, usemask=True, asrecarray=False):
+ """
+ Merge arrays field by field.
+
+ Parameters
+ ----------
+ seqarrays : sequence of ndarrays
+ Sequence of arrays
+ fill_value : {float}, optional
+ Filling value used to pad missing data on the shorter arrays.
+ flatten : {False, True}, optional
+ Whether to collapse nested fields.
+ usemask : {False, True}, optional
+ Whether to return a masked array or not.
+ asrecarray : {False, True}, optional
+ Whether to return a recarray (MaskedRecords) or not.
+
+ Examples
+ --------
+ >>> merge_arrays((np.array([1, 2]), np.array([10., 20., 30.])))
+ masked_array(data = [(1, 10.0) (2, 20.0) (--, 30.0)],
+ mask = [(False, False) (False, False) (True, False)],
+ fill_value=(999999, 1e+20)
+ dtype=[('f0', '<i4'), ('f1', '<f8')])
+ >>> merge_arrays((np.array([1, 2]), np.array([10., 20., 30.])),
+ ... usemask=False)
+ array(data = [(1, 10.0) (2, 20.0) (-1, 30.0)],
+ dtype=[('f0', '<i4'), ('f1', '<f8')])
+ >>> merge_arrays((np.array([1, 2]).view([('a', int)]),
+ np.array([10., 20., 30.])),
+ usemask=False, asrecarray=True)
+ rec.array(data = [(1, 10.0) (2, 20.0) (-1, 30.0)],
+ dtype=[('a', int), ('f1', '<f8')])
+ """
+ if (len(seqarrays) == 1):
+ seqarrays = seqarrays[0]
+ if isinstance(seqarrays, ndarray):
+ seqdtype = seqarrays.dtype
+ if (not flatten) or \
+ (zip_descr((seqarrays,), flatten=True) == seqdtype.descr):
+ seqarrays = seqarrays.ravel()
+ if not seqdtype.names:
+ seqarrays = seqarrays.view([('', seqdtype)])
+ if usemask:
+ if asrecarray:
+ return seqarrays.view(MaskedRecords)
+ return seqarrays.view(MaskedArray)
+ elif asrecarray:
+ return seqarrays.view(recarray)
+ return seqarrays
+ else:
+ seqarrays = (seqarrays,)
+ # Get the dtype
+ newdtype = zip_descr(seqarrays, flatten=flatten)
+ # Get the data and the fill_value from each array
+ seqdata = [ma.getdata(a.ravel()) for a in seqarrays]
+ seqmask = [ma.getmaskarray(a).ravel() for a in seqarrays]
+ fill_value = [_check_fill_value(fill_value, a.dtype) for a in seqdata]
+ # Make an iterator from each array, padding w/ fill_values
+ maxlength = max(len(a) for a in seqarrays)
+ for (i, (a, m, fval)) in enumerate(zip(seqdata, seqmask, fill_value)):
+ # Flatten the fill_values if there's only one field
+ if isinstance(fval, (ndarray, np.void)):
+ fmsk = ma.ones((1,), m.dtype)[0]
+ if len(fval.dtype) == 1:
+ fval = fval.item()[0]
+ fmsk = True
+ else:
+ # fval and fmsk should be np.void objects
+ fval = np.array([fval,], dtype=a.dtype)[0]
+# fmsk = np.array([fmsk,], dtype=m.dtype)[0]
+ else:
+ fmsk = True
+ nbmissing = (maxlength-len(a))
+ seqdata[i] = iterchain(a, [fval]*nbmissing)
+ seqmask[i] = iterchain(m, [fmsk]*nbmissing)
+ #
+ data = izip_records(seqdata, flatten=flatten)
+ data = tuple(data)
+ if usemask:
+ mask = izip_records(seqmask, fill_value=True, flatten=flatten)
+ mask = tuple(mask)
+ output = ma.array(np.fromiter(data, dtype=newdtype))
+ output._mask[:] = list(mask)
+ if asrecarray:
+ output = output.view(MaskedRecords)
+ else:
+ output = np.fromiter(data, dtype=newdtype)
+ if asrecarray:
+ output = output.view(recarray)
+ return output
+
+
+
+def drop_fields(base, drop_names, usemask=True, asrecarray=False):
+ """
+ Return a new array with fields in `drop_names` dropped.
+
+ Nested fields are supported.
+
+ Parameters
+ ----------
+ base : array
+ Input array
+ drop_names : string or sequence
+ String or sequence of strings corresponding to the names of the fields
+ to drop.
+ usemask : {False, True}, optional
+ Whether to return a masked array or not.
+ asrecarray : string or sequence
+ Whether to return a recarray or a mrecarray (`asrecarray=True`) or
+ a plain ndarray or masked array with flexible dtype (`asrecarray=False`)
+
+ Examples
+ --------
+ >>> a = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+ dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+ >>> drop_fields(a, 'a')
+ array([((2.0, 3),), ((5.0, 6),)],
+ dtype=[('b', [('ba', '<f8'), ('bb', '<i4')])])
+ >>> drop_fields(a, 'ba')
+ array([(1, (3,)), (4, (6,))],
+ dtype=[('a', '<i4'), ('b', [('bb', '<i4')])])
+ >>> drop_fields(a, ['ba', 'bb'])
+ array([(1,), (4,)],
+ dtype=[('a', '<i4')])
+ """
+ if _is_string_like(drop_names):
+ drop_names = [drop_names,]
+ else:
+ drop_names = set(drop_names)
+ #
+ def _drop_descr(ndtype, drop_names):
+ names = ndtype.names
+ newdtype = []
+ for name in names:
+ current = ndtype[name]
+ if name in drop_names:
+ continue
+ if current.names:
+ descr = _drop_descr(current, drop_names)
+ if descr:
+ newdtype.append((name, descr))
+ else:
+ newdtype.append((name, current))
+ return newdtype
+ #
+ newdtype = _drop_descr(base.dtype, drop_names)
+ if not newdtype:
+ return None
+ #
+ output = np.empty(base.shape, dtype=newdtype)
+ output = recursive_fill_fields(base, output)
+ return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+
+
+def rec_drop_fields(base, drop_names):
+ """
+ Returns a new numpy.recarray with fields in `drop_names` dropped.
+ """
+ return drop_fields(base, drop_names, usemask=False, asrecarray=True)
+
+
+
+def rename_fields(base, namemapper):
+ """
+ Rename the fields from a flexible-datatype ndarray or recarray.
+
+ Nested fields are supported.
+
+ Parameters
+ ----------
+ base : ndarray
+ Input array whose fields must be modified.
+ namemapper : dictionary
+ Dictionary mapping old field names to their new version.
+
+ Examples
+ --------
+ >>> a = np.array([(1, (2, [3.0, 30.])), (4, (5, [6.0, 60.]))],
+ dtype=[('a', int),
+ ('b', [('ba', float), ('bb', (float, 2))])])
+ >>> rename_fields(a, {'a':'A', 'bb':'BB'})
+ array([(1, (2.0, 3)), (4, (5.0, 6))],
+ dtype=[('A', '<i4'), ('b', [('ba', '<f8'), ('BB', '<i4')])])
+
+ """
+ def _recursive_rename_fields(ndtype, namemapper):
+ newdtype = []
+ for name in ndtype.names:
+ newname = namemapper.get(name, name)
+ current = ndtype[name]
+ if current.names:
+ newdtype.append((newname,
+ _recursive_rename_fields(current, namemapper)))
+ else:
+ newdtype.append((newname, current))
+ return newdtype
+ newdtype = _recursive_rename_fields(base.dtype, namemapper)
+ return base.view(newdtype)
+
+
+def append_fields(base, names, data=None, dtypes=None,
+ fill_value=-1, usemask=True, asrecarray=False):
+ """
+ Add new fields to an existing array.
+
+ The names of the fields are given with the `names` arguments,
+ the corresponding values with the `data` arguments.
+ If a single field is appended, `names`, `data` and `dtypes` do not have
+ to be lists but just values.
+
+ Parameters
+ ----------
+ base : array
+ Input array to extend.
+ names : string, sequence
+ String or sequence of strings corresponding to the names
+ of the new fields.
+ data : array or sequence of arrays
+ Array or sequence of arrays storing the fields to add to the base.
+ dtypes : sequence of datatypes
+ Datatype or sequence of datatypes.
+ If None, the datatypes are estimated from the `data`.
+ fill_value : {float}, optional
+ Filling value used to pad missing data on the shorter arrays.
+ usemask : {False, True}, optional
+ Whether to return a masked array or not.
+ asrecarray : {False, True}, optional
+ Whether to return a recarray (MaskedRecords) or not.
+
+ """
+ # Check the names
+ if isinstance(names, (tuple, list)):
+ if len(names) != len(data):
+ err_msg = "The number of arrays does not match the number of names"
+ raise ValueError(err_msg)
+ elif isinstance(names, basestring):
+ names = [names,]
+ data = [data,]
+ #
+ if dtypes is None:
+ data = [np.array(a, copy=False, subok=True) for a in data]
+ data = [a.view([(name, a.dtype)]) for (name, a) in zip(names, data)]
+ elif not hasattr(dtypes, '__iter__'):
+ dtypes = [dtypes,]
+ if len(data) != len(dtypes):
+ if len(dtypes) == 1:
+ dtypes = dtypes * len(data)
+ else:
+ msg = "The dtypes argument must be None, "\
+ "a single dtype or a list."
+ raise ValueError(msg)
+ data = [np.array(a, copy=False, subok=True, dtype=d).view([(n, d)])
+ for (a, n, d) in zip(data, names, dtypes)]
+ #
+ base = merge_arrays(base, usemask=usemask, fill_value=fill_value)
+ if len(data) > 1:
+ data = merge_arrays(data, flatten=True, usemask=usemask,
+ fill_value=fill_value)
+ else:
+ data = data.pop()
+ #
+ output = ma.masked_all(max(len(base), len(data)),
+ dtype=base.dtype.descr + data.dtype.descr)
+ output = recursive_fill_fields(base, output)
+ output = recursive_fill_fields(data, output)
+ #
+ return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+
+
+
+def rec_append_fields(base, names, data, dtypes=None):
+ """
+ Add new fields to an existing array.
+
+ The names of the fields are given with the `names` arguments,
+ the corresponding values with the `data` arguments.
+ If a single field is appended, `names`, `data` and `dtypes` do not have
+ to be lists but just values.
+
+ Parameters
+ ----------
+ base : array
+ Input array to extend.
+ names : string, sequence
+ String or sequence of strings corresponding to the names
+ of the new fields.
+ data : array or sequence of arrays
+ Array or sequence of arrays storing the fields to add to the base.
+ dtypes : sequence of datatypes, optional
+ Datatype or sequence of datatypes.
+ If None, the datatypes are estimated from the `data`.
+
+ See Also
+ --------
+ append_fields
+
+ Returns
+ -------
+ appended_array : np.recarray
+ """
+ return append_fields(base, names, data=data, dtypes=dtypes,
+ asrecarray=True, usemask=False)
+
+
+
+def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
+ """
+ Superposes arrays fields by fields
+
+ Parameters
+ ----------
+ seqarrays : array or sequence
+ Sequence of input arrays.
+ defaults : dictionary, optional
+ Dictionary mapping field names to the corresponding default values.
+ usemask : {True, False}, optional
+ Whether to return a MaskedArray (or MaskedRecords is `asrecarray==True`)
+ or a ndarray.
+ asrecarray : {False, True}, optional
+ Whether to return a recarray (or MaskedRecords if `usemask==True`) or
+ just a flexible-type ndarray.
+
+ Examples
+ --------
+ >>> x = np.array([1, 2,])
+ >>> stack_arrays(x) is x
+ True
+ >>> z = np.array([('A', 1), ('B', 2)], dtype=[('A', '|S3'), ('B', float)])
+ >>> zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+ dtype=[('A', '|S3'), ('B', float), ('C', float)])
+ >>> test = stack_arrays((z,zz))
+ >>> masked_array(data = [('A', 1.0, --) ('B', 2.0, --) ('a', 10.0, 100.0)
+ ... ('b', 20.0, 200.0) ('c', 30.0, 300.0)],
+ ... mask = [(False, False, True) (False, False, True) (False, False, False)
+ ... (False, False, False) (False, False, False)],
+ ... fill_value=('N/A', 1e+20, 1e+20)
+ ... dtype=[('A', '|S3'), ('B', '<f8'), ('C', '<f8')])
+
+ """
+ if isinstance(arrays, ndarray):
+ return arrays
+ elif len(arrays) == 1:
+ return arrays[0]
+ seqarrays = [np.asanyarray(a).ravel() for a in arrays]
+ nrecords = [len(a) for a in seqarrays]
+ ndtype = [a.dtype for a in seqarrays]
+ fldnames = [d.names for d in ndtype]
+ #
+ dtype_l = ndtype[0]
+ newdescr = dtype_l.descr
+ names = list(dtype_l.names or ()) or ['']
+ for dtype_n in ndtype[1:]:
+ for descr in dtype_n.descr:
+ name = descr[0] or ''
+ if name not in names:
+ newdescr.append(descr)
+ names.append(name)
+ elif descr[1] != dict(newdescr)[name]:
+ raise TypeError("Incompatible type '%s' <> '%s'" %\
+ (dict(newdescr)[name], descr[1]))
+ # Only one field: use concatenate
+ if len(newdescr) == 1:
+ output = ma.concatenate(seqarrays)
+ else:
+ #
+ output = ma.masked_all((np.sum(nrecords),), newdescr)
+ offset = np.cumsum(np.r_[0, nrecords])
+ seen = []
+ for (a, n, i, j) in zip(seqarrays, fldnames, offset[:-1], offset[1:]):
+ names = a.dtype.names
+ if names is None:
+ output['f%i' % len(seen)][i:j] = a
+ else:
+ for name in n:
+ output[name][i:j] = a[name]
+ if name not in seen:
+ seen.append(name)
+ #
+ return _fix_output(_fix_defaults(output, defaults),
+ usemask=usemask, asrecarray=asrecarray)
+
+
+
+def find_duplicates(a, key=None, ignoremask=True, return_index=False):
+ """
+ Find the duplicates in a structured array along a given key
+
+ Parameters
+ ----------
+ a : array-like
+ Input array
+ key : {string, None}, optional
+ Name of the fields along which to check the duplicates.
+ If None, the search is performed by records
+ ignoremask : {True, False}, optional
+ Whether masked data should be discarded or considered as duplicates.
+ return_index : {False, True}, optional
+ Whether to return the indices of the duplicated values.
+
+ Examples
+ --------
+ >>> ndtype = [('a', int)]
+ >>> a = ma.array([1, 1, 1, 2, 2, 3, 3],
+ ... mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype)
+ >>> find_duplicates(a, ignoremask=True, return_index=True)
+ """
+ a = np.asanyarray(a).ravel()
+ # Get a dictionary of fields
+ fields = get_fieldstructure(a.dtype)
+ # Get the sorting data (by selecting the corresponding field)
+ base = a
+ if key:
+ for f in fields[key]:
+ base = base[f]
+ base = base[key]
+ # Get the sorting indices and the sorted data
+ sortidx = base.argsort()
+ sortedbase = base[sortidx]
+ sorteddata = sortedbase.filled()
+ # Compare the sorting data
+ flag = (sorteddata[:-1] == sorteddata[1:])
+ # If masked data must be ignored, set the flag to false where needed
+ if ignoremask:
+ sortedmask = sortedbase.recordmask
+ flag[sortedmask[1:]] = False
+ flag = np.concatenate(([False], flag))
+ # We need to take the point on the left as well (else we're missing it)
+ flag[:-1] = flag[:-1] + flag[1:]
+ duplicates = a[sortidx][flag]
+ if return_index:
+ return (duplicates, sortidx[flag])
+ else:
+ return duplicates
+
+
+
+def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
+ defaults=None, usemask=True, asrecarray=False):
+ """
+ Join arrays `r1` and `r2` on key `key`.
+
+ The key should be either a string or a sequence of string corresponding
+ to the fields used to join the array.
+ An exception is raised if the `key` field cannot be found in the two input
+ arrays.
+ Neither `r1` nor `r2` should have any duplicates along `key`: the presence
+ of duplicates will make the output quite unreliable. Note that duplicates
+ are not looked for by the algorithm.
+
+ Parameters
+ ----------
+ key : {string, sequence}
+ A string or a sequence of strings corresponding to the fields used
+ for comparison.
+ r1, r2 : arrays
+ Structured arrays.
+ jointype : {'inner', 'outer', 'leftouter'}, optional
+ If 'inner', returns the elements common to both r1 and r2.
+ If 'outer', returns the common elements as well as the elements of r1
+ not in r2 and the elements of not in r2.
+ If 'leftouter', returns the common elements and the elements of r1 not
+ in r2.
+ r1postfix : string, optional
+ String appended to the names of the fields of r1 that are present in r2
+ but absent of the key.
+ r2postfix : string, optional
+ String appended to the names of the fields of r2 that are present in r1
+ but absent of the key.
+ defaults : {dictionary}, optional
+ Dictionary mapping field names to the corresponding default values.
+ usemask : {True, False}, optional
+ Whether to return a MaskedArray (or MaskedRecords is `asrecarray==True`)
+ or a ndarray.
+ asrecarray : {False, True}, optional
+ Whether to return a recarray (or MaskedRecords if `usemask==True`) or
+ just a flexible-type ndarray.
+
+ Notes
+ -----
+ * The output is sorted along the key.
+ * A temporary array is formed by dropping the fields not in the key for the
+ two arrays and concatenating the result. This array is then sorted, and
+ the common entries selected. The output is constructed by filling the fields
+ with the selected entries. Matching is not preserved if there are some
+ duplicates...
+
+ """
+ # Check jointype
+ if jointype not in ('inner', 'outer', 'leftouter'):
+ raise ValueError("The 'jointype' argument should be in 'inner', "\
+ "'outer' or 'leftouter' (got '%s' instead)" % jointype)
+ # If we have a single key, put it in a tuple
+ if isinstance(key, basestring):
+ key = (key, )
+
+ # Check the keys
+ for name in key:
+ if name not in r1.dtype.names:
+ raise ValueError('r1 does not have key field %s'%name)
+ if name not in r2.dtype.names:
+ raise ValueError('r2 does not have key field %s'%name)
+
+ # Make sure we work with ravelled arrays
+ r1 = r1.ravel()
+ r2 = r2.ravel()
+ (nb1, nb2) = (len(r1), len(r2))
+ (r1names, r2names) = (r1.dtype.names, r2.dtype.names)
+
+ # Make temporary arrays of just the keys
+ r1k = drop_fields(r1, [n for n in r1names if n not in key])
+ r2k = drop_fields(r2, [n for n in r2names if n not in key])
+
+ # Concatenate the two arrays for comparison
+ aux = ma.concatenate((r1k, r2k))
+ idx_sort = aux.argsort(order=key)
+ aux = aux[idx_sort]
+ #
+ # Get the common keys
+ flag_in = ma.concatenate(([False], aux[1:] == aux[:-1]))
+ flag_in[:-1] = flag_in[1:] + flag_in[:-1]
+ idx_in = idx_sort[flag_in]
+ idx_1 = idx_in[(idx_in < nb1)]
+ idx_2 = idx_in[(idx_in >= nb1)] - nb1
+ (r1cmn, r2cmn) = (len(idx_1), len(idx_2))
+ if jointype == 'inner':
+ (r1spc, r2spc) = (0, 0)
+ elif jointype == 'outer':
+ idx_out = idx_sort[~flag_in]
+ idx_1 = np.concatenate((idx_1, idx_out[(idx_out < nb1)]))
+ idx_2 = np.concatenate((idx_2, idx_out[(idx_out >= nb1)] - nb1))
+ (r1spc, r2spc) = (len(idx_1) - r1cmn, len(idx_2) - r2cmn)
+ elif jointype == 'leftouter':
+ idx_out = idx_sort[~flag_in]
+ idx_1 = np.concatenate((idx_1, idx_out[(idx_out < nb1)]))
+ (r1spc, r2spc) = (len(idx_1) - r1cmn, 0)
+ # Select the entries from each input
+ (s1, s2) = (r1[idx_1], r2[idx_2])
+ #
+ # Build the new description of the output array .......
+ # Start with the key fields
+ ndtype = [list(_) for _ in r1k.dtype.descr]
+ # Add the other fields
+ ndtype.extend(list(_) for _ in r1.dtype.descr if _[0] not in key)
+ # Find the new list of names (it may be different from r1names)
+ names = list(_[0] for _ in ndtype)
+ for desc in r2.dtype.descr:
+ desc = list(desc)
+ name = desc[0]
+ # Have we seen the current name already ?
+ if name in names:
+ nameidx = names.index(name)
+ current = ndtype[nameidx]
+ # The current field is part of the key: take the largest dtype
+ if name in key:
+ current[-1] = max(desc[1], current[-1])
+ # The current field is not part of the key: add the suffixes
+ else:
+ current[0] += r1postfix
+ desc[0] += r2postfix
+ ndtype.insert(nameidx+1, desc)
+ #... we haven't: just add the description to the current list
+ else:
+ names.extend(desc[0])
+ ndtype.append(desc)
+ # Revert the elements to tuples
+ ndtype = [tuple(_) for _ in ndtype]
+ # Find the largest nb of common fields : r1cmn and r2cmn should be equal, but...
+ cmn = max(r1cmn, r2cmn)
+ # Construct an empty array
+ output = ma.masked_all((cmn + r1spc + r2spc,), dtype=ndtype)
+ names = output.dtype.names
+ for f in r1names:
+ selected = s1[f]
+ if f not in names:
+ f += r1postfix
+ current = output[f]
+ current[:r1cmn] = selected[:r1cmn]
+ if jointype in ('outer', 'leftouter'):
+ current[cmn:cmn+r1spc] = selected[r1cmn:]
+ for f in r2names:
+ selected = s2[f]
+ if f not in names:
+ f += r2postfix
+ current = output[f]
+ current[:r2cmn] = selected[:r2cmn]
+ if (jointype == 'outer') and r2spc:
+ current[-r2spc:] = selected[r2cmn:]
+ # Sort and finalize the output
+ output.sort(order=key)
+ kwargs = dict(usemask=usemask, asrecarray=asrecarray)
+ return _fix_output(_fix_defaults(output, defaults), **kwargs)
+
+
+def rec_join(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
+ defaults=None):
+ """
+ Join arrays `r1` and `r2` on keys.
+ Alternative to join_by, that always returns a np.recarray.
+
+ See Also
+ --------
+ join_by : equivalent function
+ """
+ kwargs = dict(jointype=jointype, r1postfix=r1postfix, r2postfix=r2postfix,
+ defaults=defaults, usemask=False, asrecarray=True)
+ return join_by(key, r1, r2, **kwargs)
Property changes on: trunk/numpy/lib/recfunctions.py
___________________________________________________________________
Name: svn:mime-type
+ text/plain
Added: trunk/numpy/lib/tests/test_recfunctions.py
===================================================================
--- trunk/numpy/lib/tests/test_recfunctions.py 2009-01-22 05:37:36 UTC (rev 6330)
+++ trunk/numpy/lib/tests/test_recfunctions.py 2009-01-22 05:40:25 UTC (rev 6331)
@@ -0,0 +1,570 @@
+
+import numpy as np
+import numpy.ma as ma
+from numpy.ma.testutils import *
+
+from numpy.ma.mrecords import MaskedRecords
+
+from numpy.lib.recfunctions import *
+get_names = np.lib.recfunctions.get_names
+get_names_flat = np.lib.recfunctions.get_names_flat
+zip_descr = np.lib.recfunctions.zip_descr
+
+class TestRecFunctions(TestCase):
+ """
+ Misc tests
+ """
+ #
+ def setUp(self):
+ x = np.array([1, 2,])
+ y = np.array([10, 20, 30])
+ z = np.array([('A', 1.), ('B', 2.)],
+ dtype=[('A', '|S3'), ('B', float)])
+ w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+ dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+ self.data = (w, x, y, z)
+
+
+ def test_zip_descr(self):
+ "Test zip_descr"
+ (w, x, y, z) = self.data
+ # Std array
+ test = zip_descr((x, x), flatten=True)
+ assert_equal(test,
+ np.dtype([('', '<i4'), ('', '<i4')]))
+ test = zip_descr((x, x), flatten=False)
+ assert_equal(test,
+ np.dtype([('', '<i4'), ('', '<i4')]))
+ # Std & flexible-dtype
+ test = zip_descr((x, z), flatten=True)
+ assert_equal(test,
+ np.dtype([('', '<i4'), ('A', '|S3'), ('B', float)]))
+ test = zip_descr((x, z), flatten=False)
+ assert_equal(test,
+ np.dtype([('', '<i4'),
+ ('', [('A', '|S3'), ('B', float)])]))
+ # Standard & nested dtype
+ test = zip_descr((x, w), flatten=True)
+ assert_equal(test,
+ np.dtype([('', '<i4'),
+ ('a', int),
+ ('ba', float), ('bb', int)]))
+ test = zip_descr((x, w), flatten=False)
+ assert_equal(test,
+ np.dtype([('', '<i4'),
+ ('', [('a', int),
+ ('b', [('ba', float), ('bb', int)])])]))
+
+
+ def test_drop_fields(self):
+ "Test drop_fields"
+ a = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+ dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+ # A basic field
+ test = drop_fields(a, 'a')
+ control = np.array([((2, 3.0),), ((5, 6.0),)],
+ dtype=[('b', [('ba', float), ('bb', int)])])
+ assert_equal(test, control)
+ # Another basic field (but nesting two fields)
+ test = drop_fields(a, 'b')
+ control = np.array([(1,), (4,)], dtype=[('a', int)])
+ assert_equal(test, control)
+ # A nested sub-field
+ test = drop_fields(a, ['ba',])
+ control = np.array([(1, (3.0,)), (4, (6.0,))],
+ dtype=[('a', int), ('b', [('bb', int)])])
+ assert_equal(test, control)
+ # All the nested sub-field from a field: zap that field
+ test = drop_fields(a, ['ba', 'bb'])
+ control = np.array([(1,), (4,)], dtype=[('a', int)])
+ assert_equal(test, control)
+ #
+ test = drop_fields(a, ['a', 'b'])
+ assert(test is None)
+
+
+ def test_rename_fields(self):
+ "Tests rename fields"
+ a = np.array([(1, (2, [3.0, 30.])), (4, (5, [6.0, 60.]))],
+ dtype=[('a', int),
+ ('b', [('ba', float), ('bb', (float, 2))])])
+ test = rename_fields(a, {'a':'A', 'bb':'BB'})
+ newdtype = [('A', int), ('b', [('ba', float), ('BB', (float, 2))])]
+ control = a.view(newdtype)
+ assert_equal(test.dtype, newdtype)
+ assert_equal(test, control)
+
+
+ def test_get_names(self):
+ "Tests get_names"
+ ndtype = np.dtype([('A', '|S3'), ('B', float)])
+ test = get_names(ndtype)
+ assert_equal(test, ('A', 'B'))
+ #
+ ndtype = np.dtype([('a', int), ('b', [('ba', float), ('bb', int)])])
+ test = get_names(ndtype)
+ assert_equal(test, ('a', ('b', ('ba', 'bb'))))
+
+
+ def test_get_names_flat(self):
+ "Test get_names_flat"
+ ndtype = np.dtype([('A', '|S3'), ('B', float)])
+ test = get_names_flat(ndtype)
+ assert_equal(test, ('A', 'B'))
+ #
+ ndtype = np.dtype([('a', int), ('b', [('ba', float), ('bb', int)])])
+ test = get_names_flat(ndtype)
+ assert_equal(test, ('a', 'b', 'ba', 'bb'))
+
+
+ def test_get_fieldstructure(self):
+ "Test get_fieldstructure"
+ # No nested fields
+ ndtype = np.dtype([('A', '|S3'), ('B', float)])
+ test = get_fieldstructure(ndtype)
+ assert_equal(test, {'A':[], 'B':[]})
+ # One 1-nested field
+ ndtype = np.dtype([('A', int), ('B', [('BA', float), ('BB', '|S1')])])
+ test = get_fieldstructure(ndtype)
+ assert_equal(test, {'A': [], 'B': [], 'BA':['B',], 'BB':['B']})
+ # One 2-nested fields
+ ndtype = np.dtype([('A', int),
+ ('B', [('BA', int),
+ ('BB', [('BBA', int), ('BBB', int)])])])
+ test = get_fieldstructure(ndtype)
+ control = {'A': [], 'B': [], 'BA': ['B'], 'BB': ['B'],
+ 'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
+ assert_equal(test, control)
+
+
+ def test_find_duplicates(self):
+ "Test find_duplicates"
+ a = ma.array([(2, (2., 'B')), (1, (2., 'B')), (2, (2., 'B')),
+ (1, (1., 'B')), (2, (2., 'B')), (2, (2., 'C'))],
+ mask=[(0, (0, 0)), (0, (0, 0)), (0, (0, 0)),
+ (0, (0, 0)), (1, (0, 0)), (0, (1, 0))],
+ dtype=[('A', int), ('B', [('BA', float), ('BB', '|S1')])])
+ #
+ test = find_duplicates(a, ignoremask=False, return_index=True)
+ control = [0, 2]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+ #
+ test = find_duplicates(a, key='A', return_index=True)
+ control = [1, 3, 0, 2, 5]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+ #
+ test = find_duplicates(a, key='B', return_index=True)
+ control = [0, 1, 2, 4]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+ #
+ test = find_duplicates(a, key='BA', return_index=True)
+ control = [0, 1, 2, 4]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+ #
+ test = find_duplicates(a, key='BB', return_index=True)
+ control = [0, 1, 2, 3, 4]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+
+
+ def test_find_duplicates_ignoremask(self):
+ "Test the ignoremask option of find_duplicates"
+ ndtype = [('a', int)]
+ a = ma.array([1, 1, 1, 2, 2, 3, 3],
+ mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype)
+ test = find_duplicates(a, ignoremask=True, return_index=True)
+ control = [0, 1, 3, 4]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+ #
+ test = find_duplicates(a, ignoremask=False, return_index=True)
+ control = [0, 1, 3, 4, 6, 2]
+ assert_equal(test[-1], control)
+ assert_equal(test[0], a[control])
+
+
+class TestRecursiveFillFields(TestCase):
+ """
+ Test recursive_fill_fields.
+ """
+ def test_simple_flexible(self):
+ "Test recursive_fill_fields on flexible-array"
+ a = np.array([(1, 10.), (2, 20.)], dtype=[('A', int), ('B', float)])
+ b = np.zeros((3,), dtype=a.dtype)
+ test = recursive_fill_fields(a, b)
+ control = np.array([(1, 10.), (2, 20.), (0, 0.)],
+ dtype=[('A', int), ('B', float)])
+ assert_equal(test, control)
+ #
+ def test_masked_flexible(self):
+ "Test recursive_fill_fields on masked flexible-array"
+ a = ma.array([(1, 10.), (2, 20.)], mask=[(0, 1), (1, 0)],
+ dtype=[('A', int), ('B', float)])
+ b = ma.zeros((3,), dtype=a.dtype)
+ test = recursive_fill_fields(a, b)
+ control = ma.array([(1, 10.), (2, 20.), (0, 0.)],
+ mask=[(0, 1), (1, 0), (0, 0)],
+ dtype=[('A', int), ('B', float)])
+ assert_equal(test, control)
+ #
+
+
+
+class TestMergeArrays(TestCase):
+ """
+ Test merge_arrays
+ """
+ def setUp(self):
+ x = np.array([1, 2,])
+ y = np.array([10, 20, 30])
+ z = np.array([('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
+ w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+ dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+ self.data = (w, x, y, z)
+ #
+ def test_solo(self):
+ "Test merge_arrays on a single array."
+ (_, x, _, z) = self.data
+ #
+ test = merge_arrays(x)
+ control = np.array([(1,), (2,)], dtype=[('f0', int)])
+ assert_equal(test, control)
+ test = merge_arrays((x,))
+ assert_equal(test, control)
+ #
+ test = merge_arrays(z, flatten=False)
+ assert_equal(test, z)
+ test = merge_arrays(z, flatten=True)
+ assert_equal(test, z)
+ #
+ def test_solo_w_flatten(self):
+ "Test merge_arrays on a single array w & w/o flattening"
+ w = self.data[0]
+ test = merge_arrays(w, flatten=False)
+ assert_equal(test, w)
+ #
+ test = merge_arrays(w, flatten=True)
+ control = np.array([(1, 2, 3.0), (4, 5, 6.0)],
+ dtype=[('a', int), ('ba', float), ('bb', int)])
+ assert_equal(test, control)
+ #
+ def test_standard(self):
+ "Test standard & standard"
+ # Test merge arrays
+ (_, x, y, _) = self.data
+ test = merge_arrays((x, y), usemask=False)
+ control = np.array([(1, 10), (2, 20), (-1, 30)],
+ dtype=[('f0', int), ('f1', int)])
+ assert_equal(test, control)
+ #
+ test = merge_arrays((x, y), usemask=True)
+ control = ma.array([(1, 10), (2, 20), (-1, 30)],
+ mask=[(0, 0), (0, 0), (1, 0)],
+ dtype=[('f0', int), ('f1', int)])
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+ #
+ def test_flatten(self):
+ "Test standard & flexible"
+ (_, x, _, z) = self.data
+ test = merge_arrays((x, z), flatten=True)
+ control = np.array([(1, 'A', 1.), (2, 'B', 2.)],
+ dtype=[('f0', int), ('A', '|S3'), ('B', float)])
+ assert_equal(test, control)
+ #
+ test = merge_arrays((x, z), flatten=False)
+ control = np.array([(1, ('A', 1.)), (2, ('B', 2.))],
+ dtype=[('f0', int),
+ ('f1', [('A', '|S3'), ('B', float)])])
+ assert_equal(test, control)
+ #
+ def test_flatten_wflexible(self):
+ "Test flatten standard & nested"
+ (w, x, _, _) = self.data
+ test = merge_arrays((x, w), flatten=True)
+ control = np.array([(1, 1, 2, 3.0), (2, 4, 5, 6.0)],
+ dtype=[('f0', int),
+ ('a', int), ('ba', float), ('bb', int)])
+ assert_equal(test, control)
+ #
+ test = merge_arrays((x, w), flatten=False)
+ controldtype = dtype=[('f0', int),
+ ('f1', [('a', int),
+ ('b', [('ba', float), ('bb', int)])])]
+ control = np.array([(1., (1, (2, 3.0))), (2, (4, (5, 6.0)))],
+ dtype=controldtype)
+ #
+ def test_wmasked_arrays(self):
+ "Test merge_arrays masked arrays"
+ (_, x, _, _) = self.data
+ mx = ma.array([1, 2, 3], mask=[1, 0, 0])
+ test = merge_arrays((x, mx), usemask=True)
+ control = ma.array([(1, 1), (2, 2), (-1, 3)],
+ mask=[(0, 1), (0, 0), (1, 0)],
+ dtype=[('f0', int), ('f1', int)])
+ assert_equal(test, control)
+ test = merge_arrays((x, mx), usemask=True, asrecarray=True)
+ assert_equal(test, control)
+ assert(isinstance(test, MaskedRecords))
+ #
+ def test_w_singlefield(self):
+ "Test single field"
+ test = merge_arrays((np.array([1, 2]).view([('a', int)]),
+ np.array([10., 20., 30.])),)
+ control = ma.array([(1, 10.), (2, 20.), (-1, 30.)],
+ mask=[(0, 0), (0, 0), (1, 0)],
+ dtype=[('a', int), ('f1', float)])
+ assert_equal(test, control)
+ #
+ def test_w_shorter_flex(self):
+ "Test merge_arrays w/ a shorter flexndarray."
+ z = self.data[-1]
+ test = merge_arrays((z, np.array([10, 20, 30]).view([('C', int)])))
+ control = np.array([('A', 1., 10), ('B', 2., 20), ('-1', -1, 20)],
+ dtype=[('A', '|S3'), ('B', float), ('C', int)])
+
+
+
+class TestAppendFields(TestCase):
+ """
+ Test append_fields
+ """
+ def setUp(self):
+ x = np.array([1, 2,])
+ y = np.array([10, 20, 30])
+ z = np.array([('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
+ w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+ dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+ self.data = (w, x, y, z)
+ #
+ def test_append_single(self):
+ "Test simple case"
+ (_, x, _, _) = self.data
+ test = append_fields(x, 'A', data=[10, 20, 30])
+ control = ma.array([(1, 10), (2, 20), (-1, 30)],
+ mask=[(0, 0), (0, 0), (1, 0)],
+ dtype=[('f0', int), ('A', int)],)
+ assert_equal(test, control)
+ #
+ def test_append_double(self):
+ "Test simple case"
+ (_, x, _, _) = self.data
+ test = append_fields(x, ('A', 'B'), data=[[10, 20, 30], [100, 200]])
+ control = ma.array([(1, 10, 100), (2, 20, 200), (-1, 30, -1)],
+ mask=[(0, 0, 0), (0, 0, 0), (1, 0, 1)],
+ dtype=[('f0', int), ('A', int), ('B', int)],)
+ assert_equal(test, control)
+ #
+ def test_append_on_flex(self):
+ "Test append_fields on flexible type arrays"
+ z = self.data[-1]
+ test = append_fields(z, 'C', data=[10, 20, 30])
+ control = ma.array([('A', 1., 10), ('B', 2., 20), (-1, -1., 30)],
+ mask=[(0, 0, 0), (0, 0, 0), (1, 1, 0)],
+ dtype=[('A', '|S3'), ('B', float), ('C', int)],)
+ assert_equal(test, control)
+ #
+ def test_append_on_nested(self):
+ "Test append_fields on nested fields"
+ w = self.data[0]
+ test = append_fields(w, 'C', data=[10, 20, 30])
+ control = ma.array([(1, (2, 3.0), 10),
+ (4, (5, 6.0), 20),
+ (-1, (-1, -1.), 30)],
+ mask=[(0, (0, 0), 0), (0, (0, 0), 0), (1, (1, 1), 0)],
+ dtype=[('a', int),
+ ('b', [('ba', float), ('bb', int)]),
+ ('C', int)],)
+ assert_equal(test, control)
+
+
+
+class TestStackArrays(TestCase):
+ """
+ Test stack_arrays
+ """
+ def setUp(self):
+ x = np.array([1, 2,])
+ y = np.array([10, 20, 30])
+ z = np.array([('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
+ w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+ dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+ self.data = (w, x, y, z)
+ #
+ def test_solo(self):
+ "Test stack_arrays on single arrays"
+ (_, x, _, _) = self.data
+ test = stack_arrays((x,))
+ assert_equal(test, x)
+ self.failUnless(test is x)
+ #
+ test = stack_arrays(x)
+ assert_equal(test, x)
+ self.failUnless(test is x)
+ #
+ def test_unnamed_fields(self):
+ "Tests combinations of arrays w/o named fields"
+ (_, x, y, _) = self.data
+ #
+ test = stack_arrays((x, x), usemask=False)
+ control = np.array([1, 2, 1, 2])
+ assert_equal(test, control)
+ #
+ test = stack_arrays((x, y), usemask=False)
+ control = np.array([1, 2, 10, 20, 30])
+ assert_equal(test, control)
+ #
+ test = stack_arrays((y, x), usemask=False)
+ control = np.array([10, 20, 30, 1, 2])
+ assert_equal(test, control)
+ #
+ def test_unnamed_and_named_fields(self):
+ "Test combination of arrays w/ & w/o named fields"
+ (_, x, _, z) = self.data
+ #
+ test = stack_arrays((x, z))
+ control = ma.array([(1, -1, -1), (2, -1, -1),
+ (-1, 'A', 1), (-1, 'B', 2)],
+ mask=[(0, 1, 1), (0, 1, 1),
+ (1, 0, 0), (1, 0, 0)],
+ dtype=[('f0', int), ('A', '|S3'), ('B', float)])
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+ #
+ test = stack_arrays((z, x))
+ control = ma.array([('A', 1, -1), ('B', 2, -1),
+ (-1, -1, 1), (-1, -1, 2),],
+ mask=[(0, 0, 1), (0, 0, 1),
+ (1, 1, 0), (1, 1, 0)],
+ dtype=[('A', '|S3'), ('B', float), ('f2', int)])
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+ #
+ test = stack_arrays((z, z, x))
+ control = ma.array([('A', 1, -1), ('B', 2, -1),
+ ('A', 1, -1), ('B', 2, -1),
+ (-1, -1, 1), (-1, -1, 2),],
+ mask=[(0, 0, 1), (0, 0, 1),
+ (0, 0, 1), (0, 0, 1),
+ (1, 1, 0), (1, 1, 0)],
+ dtype=[('A', '|S3'), ('B', float), ('f2', int)])
+ assert_equal(test, control)
+ #
+ def test_matching_named_fields(self):
+ "Test combination of arrays w/ matching field names"
+ (_, x, _, z) = self.data
+ zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+ dtype=[('A', '|S3'), ('B', float), ('C', float)])
+ test = stack_arrays((z, zz))
+ control = ma.array([('A', 1, -1), ('B', 2, -1),
+ ('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+ dtype=[('A', '|S3'), ('B', float), ('C', float)],
+ mask=[(0, 0, 1), (0, 0, 1),
+ (0, 0, 0), (0, 0, 0), (0, 0, 0)])
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+ #
+ test = stack_arrays((z, zz, x))
+ ndtype = [('A', '|S3'), ('B', float), ('C', float), ('f3', int)]
+ control = ma.array([('A', 1, -1, -1), ('B', 2, -1, -1),
+ ('a', 10., 100., -1), ('b', 20., 200., -1),
+ ('c', 30., 300., -1),
+ (-1, -1, -1, 1), (-1, -1, -1, 2)],
+ dtype=ndtype,
+ mask=[(0, 0, 1, 1), (0, 0, 1, 1),
+ (0, 0, 0, 1), (0, 0, 0, 1), (0, 0, 0, 1),
+ (1, 1, 1, 0), (1, 1, 1, 0)])
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+
+
+ #
+ def test_defaults(self):
+ "Test defaults: no exception raised if keys of defaults are not fields."
+ (_, _, _, z) = self.data
+ zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+ dtype=[('A', '|S3'), ('B', float), ('C', float)])
+ defaults = {'A':'???', 'B':-999., 'C':-9999., 'D':-99999.}
+ test = stack_arrays((z, zz), defaults=defaults)
+ control = ma.array([('A', 1, -9999.), ('B', 2, -9999.),
+ ('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+ dtype=[('A', '|S3'), ('B', float), ('C', float)],
+ mask=[(0, 0, 1), (0, 0, 1),
+ (0, 0, 0), (0, 0, 0), (0, 0, 0)])
+ assert_equal(test, control)
+ assert_equal(test.data, control.data)
+ assert_equal(test.mask, control.mask)
+
+
+
+class TestJoinBy(TestCase):
+ #
+ def test_base(self):
+ "Basic test of join_by"
+ a = np.array(zip(np.arange(10), np.arange(50, 60), np.arange(100, 110)),
+ dtype=[('a', int), ('b', int), ('c', int)])
+ b = np.array(zip(np.arange(5, 15), np.arange(65, 75), np.arange(100, 110)),
+ dtype=[('a', int), ('b', int), ('d', int)])
+ #
+ test = join_by('a', a, b, jointype='inner')
+ control = np.array([(5, 55, 65, 105, 100), (6, 56, 66, 106, 101),
+ (7, 57, 67, 107, 102), (8, 58, 68, 108, 103),
+ (9, 59, 69, 109, 104)],
+ dtype=[('a', int), ('b1', int), ('b2', int),
+ ('c', int), ('d', int)])
+ assert_equal(test, control)
+ #
+ test = join_by(('a', 'b'), a, b)
+ control = np.array([(5, 55, 105, 100), (6, 56, 106, 101),
+ (7, 57, 107, 102), (8, 58, 108, 103),
+ (9, 59, 109, 104)],
+ dtype=[('a', int), ('b', int),
+ ('c', int), ('d', int)])
+ #
+ test = join_by(('a', 'b'), a, b, 'outer')
+ control = ma.array([( 0, 50, 100, -1), ( 1, 51, 101, -1),
+ ( 2, 52, 102, -1), ( 3, 53, 103, -1),
+ ( 4, 54, 104, -1), ( 5, 55, 105, -1),
+ ( 5, 65, -1, 100), ( 6, 56, 106, -1),
+ ( 6, 66, -1, 101), ( 7, 57, 107, -1),
+ ( 7, 67, -1, 102), ( 8, 58, 108, -1),
+ ( 8, 68, -1, 103), ( 9, 59, 109, -1),
+ ( 9, 69, -1, 104), (10, 70, -1, 105),
+ (11, 71, -1, 106), (12, 72, -1, 107),
+ (13, 73, -1, 108), (14, 74, -1, 109)],
+ mask=[( 0, 0, 0, 1), ( 0, 0, 0, 1),
+ ( 0, 0, 0, 1), ( 0, 0, 0, 1),
+ ( 0, 0, 0, 1), ( 0, 0, 0, 1),
+ ( 0, 0, 1, 0), ( 0, 0, 0, 1),
+ ( 0, 0, 1, 0), ( 0, 0, 0, 1),
+ ( 0, 0, 1, 0), ( 0, 0, 0, 1),
+ ( 0, 0, 1, 0), ( 0, 0, 0, 1),
+ ( 0, 0, 1, 0), ( 0, 0, 1, 0),
+ ( 0, 0, 1, 0), ( 0, 0, 1, 0),
+ ( 0, 0, 1, 0), ( 0, 0, 1, 0)],
+ dtype=[('a', int), ('b', int),
+ ('c', int), ('d', int)])
+ assert_equal(test, control)
+ #
+ test = join_by(('a', 'b'), a, b, 'leftouter')
+ control = ma.array([(0, 50, 100, -1), (1, 51, 101, -1),
+ (2, 52, 102, -1), (3, 53, 103, -1),
+ (4, 54, 104, -1), (5, 55, 105, -1),
+ (6, 56, 106, -1), (7, 57, 107, -1),
+ (8, 58, 108, -1), (9, 59, 109, -1)],
+ mask=[(0, 0, 0, 1), (0, 0, 0, 1),
+ (0, 0, 0, 1), (0, 0, 0, 1),
+ (0, 0, 0, 1), (0, 0, 0, 1),
+ (0, 0, 0, 1), (0, 0, 0, 1),
+ (0, 0, 0, 1), (0, 0, 0, 1)],
+ dtype=[('a', int), ('b', int), ('c', int), ('d', int)])
+
+
+
+
+if __name__ == '__main__':
+ run_module_suite()
Property changes on: trunk/numpy/lib/tests/test_recfunctions.py
___________________________________________________________________
Name: svn:mime-type
+ text/plain
More information about the Numpy-svn
mailing list