[Numpy-svn] r6331 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Thu Jan 22 00:40:28 EST 2009


Author: pierregm
Date: 2009-01-21 23:40:25 -0600 (Wed, 21 Jan 2009)
New Revision: 6331

Added:
   trunk/numpy/lib/recfunctions.py
   trunk/numpy/lib/tests/test_recfunctions.py
Log:
* added recfunctions, a collection of utilities to manipulate structured arrays.

Added: trunk/numpy/lib/recfunctions.py
===================================================================
--- trunk/numpy/lib/recfunctions.py	2009-01-22 05:37:36 UTC (rev 6330)
+++ trunk/numpy/lib/recfunctions.py	2009-01-22 05:40:25 UTC (rev 6331)
@@ -0,0 +1,931 @@
+"""
+Collection of utilities to manipulate structured arrays.
+
+Most of these functions were initially implemented by John Hunter for matplotlib.
+They have been rewritten and extended for convenience.
+
+
+"""
+
+
+import itertools
+from itertools import chain as iterchain, repeat as iterrepeat, izip as iterizip
+import numpy as np
+from numpy import ndarray, recarray
+import numpy.ma as ma
+from numpy.ma import MaskedArray
+from numpy.ma.mrecords import MaskedRecords
+
+from numpy.lib._iotools import _is_string_like
+
+_check_fill_value = np.ma.core._check_fill_value
+
+__all__ = ['append_fields',
+           'drop_fields',
+           'find_duplicates',
+           'get_fieldstructure',
+           'join_by',
+           'merge_arrays',
+           'rec_append_fields', 'rec_drop_fields', 'rec_join',
+           'recursive_fill_fields', 'rename_fields',
+           'stack_arrays',
+           ]
+
+
+def recursive_fill_fields(input, output):
+    """
+    Fills fields from output with fields from input,
+    with support for nested structures.
+
+    Parameters
+    ----------
+    input : ndarray
+        Input array.
+    output : ndarray
+        Output array.
+
+    Notes
+    -----
+    * `output` should be at least the same size as `input`
+
+    Examples
+    --------
+    >>> a = np.array([(1, 10.), (2, 20.)], dtype=[('A', int), ('B', float)])
+    >>> b = np.zeros((3,), dtype=a.dtype)
+    >>> recursive_fill_fields(a, b)
+    np.array([(1, 10.), (2, 20.), (0, 0.)], dtype=[('A', int), ('B', float)])
+
+    """
+    newdtype = output.dtype
+    for field in newdtype.names:
+        try:
+            current = input[field]
+        except ValueError:
+            continue
+        if current.dtype.names:
+            recursive_fill_fields(current, output[field])
+        else:
+            output[field][:len(current)] = current
+    return output
+
+
+
+def get_names(adtype):
+    """
+    Returns the field names of the input datatype as a tuple.
+
+    Parameters
+    ----------
+    adtype : dtype
+        Input datatype
+
+    Examples
+    --------
+    >>> get_names(np.empty((1,), dtype=int)) is None
+    True
+    >>> get_names(np.empty((1,), dtype=[('A',int), ('B', float)]))
+    ('A', 'B')
+    >>> adtype = np.dtype([('a', int), ('b', [('ba', int), ('bb', int)])])
+    >>> get_names(adtype)
+    ('a', ('b', ('ba', 'bb')))
+    """
+    listnames = []
+    names = adtype.names
+    for name in names:
+        current = adtype[name]
+        if current.names:
+            listnames.append((name, tuple(get_names(current))))
+        else:
+            listnames.append(name)
+    return tuple(listnames) or None
+
+
+def get_names_flat(adtype):
+    """
+    Returns the field names of the input datatype as a tuple. Nested structure
+    are flattend beforehand.
+
+    Parameters
+    ----------
+    adtype : dtype
+        Input datatype
+
+    Examples
+    --------
+    >>> get_names_flat(np.empty((1,), dtype=int)) is None
+    True
+    >>> get_names_flat(np.empty((1,), dtype=[('A',int), ('B', float)]))
+    ('A', 'B')
+    >>> adtype = np.dtype([('a', int), ('b', [('ba', int), ('bb', int)])])
+    >>> get_names_flat(adtype)
+    ('a', 'b', 'ba', 'bb')
+    """
+    listnames = []
+    names = adtype.names
+    for name in names:
+        listnames.append(name)
+        current = adtype[name]
+        if current.names:
+            listnames.extend(get_names_flat(current))
+    return tuple(listnames) or None
+
+
+def flatten_descr(ndtype):
+    """
+    Flatten a structured data-type description.
+
+    Examples
+    --------
+    >>> ndtype = np.dtype([('a', '<i4'), ('b', [('ba', '<f8'), ('bb', '<i4')])])
+    >>> flatten_descr(ndtype)
+    (('a', dtype('int32')), ('ba', dtype('float64')), ('bb', dtype('int32')))
+
+    """
+    names = ndtype.names
+    if names is None:
+        return ndtype.descr
+    else:
+        descr = []
+        for field in names:
+            (typ, _) = ndtype.fields[field]
+            if typ.names:
+                descr.extend(flatten_descr(typ))
+            else:
+                descr.append((field, typ))
+        return tuple(descr)
+
+
+def zip_descr(seqarrays, flatten=False):
+    """
+    Combine the dtype description of a series of arrays.
+
+    Parameters
+    ----------
+    seqarrays : sequence of arrays
+        Sequence of arrays
+    flatten : {boolean}, optional
+        Whether to collapse nested descriptions.
+    """
+    newdtype = []
+    if flatten:
+        for a in seqarrays:
+            newdtype.extend(flatten_descr(a.dtype))
+    else:
+        for a in seqarrays:
+            current = a.dtype
+            names = current.names or ()
+            if len(names) > 1:
+                newdtype.append(('', current.descr))
+            else:
+                newdtype.extend(current.descr)
+    return np.dtype(newdtype).descr
+
+
+def get_fieldstructure(adtype, lastname=None, parents=None,):
+    """
+    Returns a dictionary with fields as keys and a list of parent fields as values.
+
+    This function is used to simplify access to fields nested in other fields.
+
+    Parameters
+    ----------
+    adtype : np.dtype
+        Input datatype
+    lastname : optional
+        Last processed field name (used internally during recursion).
+    parents : dictionary
+        Dictionary of parent fields (used interbally during recursion).
+
+    Examples
+    --------
+    >>> ndtype =  np.dtype([('A', int), 
+    ...                     ('B', [('BA', int),
+    ...                            ('BB', [('BBA', int), ('BBB', int)])])])
+    >>> get_fieldstructure(ndtype)
+    {'A': [], 'B': [], 'BA': ['B'], 'BB': ['B'],
+     'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
+    
+    """
+    if parents is None:
+        parents = {}
+    names = adtype.names
+    for name in names:
+        current = adtype[name]
+        if current.names:
+            if lastname:
+                parents[name] = [lastname,]
+            else:
+                parents[name] = []
+            parents.update(get_fieldstructure(current, name, parents))
+        else:
+            lastparent = [_ for _ in (parents.get(lastname, []) or [])]
+            if lastparent:
+#                if (lastparent[-1] != lastname):
+                    lastparent.append(lastname)
+            elif lastname:
+                lastparent = [lastname,]
+            parents[name] = lastparent or []
+    return parents or None
+
+
+def _izip_fields_flat(iterable):
+    """
+    Returns an iterator of concatenated fields from a sequence of arrays,
+    collapsing any nested structure.
+    """
+    for element in iterable:
+        if isinstance(element, np.void):
+            for f in _izip_fields_flat(tuple(element)):
+                yield f
+        else:
+            yield element
+
+
+def _izip_fields(iterable):
+    """
+    Returns an iterator of concatenated fields from a sequence of arrays.
+    """
+    for element in iterable:
+        if hasattr(element, '__iter__') and not isinstance(element, basestring):
+            for f in _izip_fields(element):
+                yield f
+        elif isinstance(element, np.void) and len(tuple(element)) == 1:
+            for f in _izip_fields(element):
+                yield f
+        else:
+            yield element
+
+
+def izip_records(seqarrays, fill_value=None, flatten=True):
+    """
+    Returns an iterator of concatenated items from a sequence of arrays.
+
+    Parameters
+    ----------
+    seqarray : sequence of arrays
+        Sequence of arrays.
+    fill_value : {None, integer}
+        Value used to pad shorter iterables.
+    flatten : {True, False}, 
+        Whether to 
+    """
+    # OK, that's a complete ripoff from Python2.6 itertools.izip_longest
+    def sentinel(counter = ([fill_value]*(len(seqarrays)-1)).pop):
+        "Yields the fill_value or raises IndexError"
+        yield counter()
+    #
+    fillers = iterrepeat(fill_value)
+    iters = [iterchain(it, sentinel(), fillers) for it in seqarrays] 
+    # Should we flatten the items, or just use a nested approach
+    if flatten:
+        zipfunc = _izip_fields_flat
+    else:
+        zipfunc = _izip_fields
+    #
+    try:
+        for tup in iterizip(*iters):
+            yield tuple(zipfunc(tup))
+    except IndexError:
+        pass
+
+
+def _fix_output(output, usemask=True, asrecarray=False):
+    """
+    Private function: return a recarray, a ndarray, a MaskedArray
+    or a MaskedRecords depending on the input parameters
+    """
+    if not isinstance(output, MaskedArray):
+        usemask = False
+    if usemask:
+        if asrecarray:
+            output = output.view(MaskedRecords)
+    else:
+        output = ma.filled(output)
+        if asrecarray:
+            output = output.view(recarray)
+    return output
+
+
+def _fix_defaults(output, defaults=None):
+    """
+    Update the fill_value and masked data of `output`
+    from the default given in a dictionary defaults.
+    """
+    names = output.dtype.names
+    (data, mask, fill_value) = (output.data, output.mask, output.fill_value)
+    for (k, v) in (defaults or {}).iteritems():
+        if k in names:
+            fill_value[k] = v
+            data[k][mask[k]] = v
+    return output
+
+
+def merge_arrays(seqarrays,
+                 fill_value=-1, flatten=False, usemask=True, asrecarray=False):
+    """
+    Merge arrays field by field.
+
+    Parameters
+    ----------
+    seqarrays : sequence of ndarrays
+        Sequence of arrays
+    fill_value : {float}, optional
+        Filling value used to pad missing data on the shorter arrays.
+    flatten : {False, True}, optional
+        Whether to collapse nested fields.
+    usemask : {False, True}, optional
+        Whether to return a masked array or not.
+    asrecarray : {False, True}, optional
+        Whether to return a recarray (MaskedRecords) or not.
+
+    Examples
+    --------
+    >>> merge_arrays((np.array([1, 2]), np.array([10., 20., 30.])))
+    masked_array(data = [(1, 10.0) (2, 20.0) (--, 30.0)],
+          mask = [(False, False) (False, False) (True, False)],
+          fill_value=(999999, 1e+20)
+          dtype=[('f0', '<i4'), ('f1', '<f8')])
+    >>> merge_arrays((np.array([1, 2]), np.array([10., 20., 30.])),
+    ...              usemask=False)
+    array(data = [(1, 10.0) (2, 20.0) (-1, 30.0)],
+          dtype=[('f0', '<i4'), ('f1', '<f8')])
+    >>> merge_arrays((np.array([1, 2]).view([('a', int)]),
+                      np.array([10., 20., 30.])),
+                     usemask=False, asrecarray=True)
+    rec.array(data = [(1, 10.0) (2, 20.0) (-1, 30.0)],
+              dtype=[('a', int), ('f1', '<f8')])
+    """
+    if (len(seqarrays) == 1):
+        seqarrays = seqarrays[0]
+    if isinstance(seqarrays, ndarray):
+        seqdtype = seqarrays.dtype
+        if (not flatten) or \
+           (zip_descr((seqarrays,), flatten=True) == seqdtype.descr):
+            seqarrays = seqarrays.ravel()
+            if not seqdtype.names:
+                seqarrays = seqarrays.view([('', seqdtype)])
+            if usemask:
+                if asrecarray:
+                    return seqarrays.view(MaskedRecords)
+                return seqarrays.view(MaskedArray)
+            elif asrecarray:
+                return seqarrays.view(recarray)
+            return seqarrays
+        else:
+            seqarrays = (seqarrays,)
+    # Get the dtype
+    newdtype = zip_descr(seqarrays, flatten=flatten)
+    # Get the data and the fill_value from each array
+    seqdata = [ma.getdata(a.ravel()) for a in seqarrays]
+    seqmask = [ma.getmaskarray(a).ravel() for a in seqarrays]
+    fill_value = [_check_fill_value(fill_value, a.dtype) for a in seqdata]
+    # Make an iterator from each array, padding w/ fill_values
+    maxlength = max(len(a) for a in seqarrays)
+    for (i, (a, m, fval)) in enumerate(zip(seqdata, seqmask, fill_value)):
+        # Flatten the fill_values if there's only one field
+        if isinstance(fval, (ndarray, np.void)):
+            fmsk = ma.ones((1,), m.dtype)[0]
+            if len(fval.dtype) == 1:
+                fval = fval.item()[0]
+                fmsk = True
+            else:
+                # fval and fmsk should be np.void objects
+                fval = np.array([fval,], dtype=a.dtype)[0]
+#                fmsk = np.array([fmsk,], dtype=m.dtype)[0]
+        else:
+            fmsk = True
+        nbmissing = (maxlength-len(a))
+        seqdata[i] = iterchain(a, [fval]*nbmissing)
+        seqmask[i] = iterchain(m, [fmsk]*nbmissing)
+    #
+    data = izip_records(seqdata, flatten=flatten)
+    data = tuple(data)
+    if usemask:
+        mask = izip_records(seqmask, fill_value=True, flatten=flatten)
+        mask = tuple(mask)
+        output = ma.array(np.fromiter(data, dtype=newdtype))
+        output._mask[:] = list(mask)
+        if asrecarray:
+            output = output.view(MaskedRecords)
+    else:
+        output = np.fromiter(data, dtype=newdtype)
+        if asrecarray:
+            output = output.view(recarray)
+    return output
+
+
+
+def drop_fields(base, drop_names, usemask=True, asrecarray=False):
+    """
+    Return a new array with fields in `drop_names` dropped.
+
+    Nested fields are supported.
+
+    Parameters
+    ----------
+    base : array
+        Input array
+    drop_names : string or sequence
+        String or sequence of strings corresponding to the names of the fields
+        to drop.
+    usemask : {False, True}, optional
+        Whether to return a masked array or not.
+    asrecarray : string or sequence
+        Whether to return a recarray or a mrecarray (`asrecarray=True`) or
+        a plain ndarray or masked array with flexible dtype (`asrecarray=False`)
+
+    Examples
+    --------
+    >>> a = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+                     dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+    >>> drop_fields(a, 'a')
+    array([((2.0, 3),), ((5.0, 6),)], 
+          dtype=[('b', [('ba', '<f8'), ('bb', '<i4')])])
+    >>> drop_fields(a, 'ba')
+    array([(1, (3,)), (4, (6,))], 
+          dtype=[('a', '<i4'), ('b', [('bb', '<i4')])])
+    >>> drop_fields(a, ['ba', 'bb'])
+    array([(1,), (4,)], 
+          dtype=[('a', '<i4')])
+    """
+    if _is_string_like(drop_names):
+        drop_names = [drop_names,]
+    else:
+        drop_names = set(drop_names)
+    #
+    def _drop_descr(ndtype, drop_names):
+        names = ndtype.names
+        newdtype = []
+        for name in names:
+            current = ndtype[name]
+            if name in drop_names:
+                continue
+            if current.names:
+                descr = _drop_descr(current, drop_names)
+                if descr:
+                    newdtype.append((name, descr))
+            else:
+                newdtype.append((name, current))
+        return newdtype
+    #
+    newdtype = _drop_descr(base.dtype, drop_names)
+    if not newdtype:
+        return None
+    #
+    output = np.empty(base.shape, dtype=newdtype)
+    output = recursive_fill_fields(base, output)
+    return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+
+
+def rec_drop_fields(base, drop_names):
+    """
+    Returns a new numpy.recarray with fields in `drop_names` dropped.
+    """
+    return drop_fields(base, drop_names, usemask=False, asrecarray=True)
+
+
+
+def rename_fields(base, namemapper):
+    """
+    Rename the fields from a flexible-datatype ndarray or recarray.
+
+    Nested fields are supported.
+
+    Parameters
+    ----------
+    base : ndarray
+        Input array whose fields must be modified.
+    namemapper : dictionary
+        Dictionary mapping old field names to their new version.
+
+    Examples
+    --------
+    >>> a = np.array([(1, (2, [3.0, 30.])), (4, (5, [6.0, 60.]))],
+                      dtype=[('a', int),
+                             ('b', [('ba', float), ('bb', (float, 2))])])
+    >>> rename_fields(a, {'a':'A', 'bb':'BB'})
+    array([(1, (2.0, 3)), (4, (5.0, 6))], 
+          dtype=[('A', '<i4'), ('b', [('ba', '<f8'), ('BB', '<i4')])])
+
+    """
+    def _recursive_rename_fields(ndtype, namemapper):
+        newdtype = []
+        for name in ndtype.names:
+            newname = namemapper.get(name, name)
+            current = ndtype[name]
+            if current.names:
+                newdtype.append((newname,
+                                 _recursive_rename_fields(current, namemapper)))
+            else:
+                newdtype.append((newname, current))
+        return newdtype
+    newdtype = _recursive_rename_fields(base.dtype, namemapper)
+    return base.view(newdtype)
+
+
+def append_fields(base, names, data=None, dtypes=None, 
+                  fill_value=-1, usemask=True, asrecarray=False):
+    """
+    Add new fields to an existing array.
+
+    The names of the fields are given with the `names` arguments,
+    the corresponding values with the `data` arguments.
+    If a single field is appended, `names`, `data` and `dtypes` do not have
+    to be lists but just values.
+
+    Parameters
+    ----------
+    base : array
+        Input array to extend.
+    names : string, sequence
+        String or sequence of strings corresponding to the names
+        of the new fields.
+    data : array or sequence of arrays
+        Array or sequence of arrays storing the fields to add to the base.
+    dtypes : sequence of datatypes
+        Datatype or sequence of datatypes.
+        If None, the datatypes are estimated from the `data`.
+    fill_value : {float}, optional
+        Filling value used to pad missing data on the shorter arrays.
+    usemask : {False, True}, optional
+        Whether to return a masked array or not.
+    asrecarray : {False, True}, optional
+        Whether to return a recarray (MaskedRecords) or not.
+
+    """
+    # Check the names
+    if isinstance(names, (tuple, list)):
+        if len(names) != len(data):
+            err_msg = "The number of arrays does not match the number of names"
+            raise ValueError(err_msg)
+    elif isinstance(names, basestring):
+        names = [names,]
+        data = [data,]
+    #
+    if dtypes is None:
+        data = [np.array(a, copy=False, subok=True) for a in data]
+        data = [a.view([(name, a.dtype)]) for (name, a) in zip(names, data)]
+    elif not hasattr(dtypes, '__iter__'):
+        dtypes = [dtypes,]
+        if len(data) != len(dtypes):
+            if len(dtypes) == 1:
+                dtypes = dtypes * len(data)
+            else:
+                msg = "The dtypes argument must be None, "\
+                      "a single dtype or a list."
+                raise ValueError(msg)
+        data = [np.array(a, copy=False, subok=True, dtype=d).view([(n, d)])
+                for (a, n, d) in zip(data, names, dtypes)]
+    #
+    base = merge_arrays(base, usemask=usemask, fill_value=fill_value)
+    if len(data) > 1:
+        data = merge_arrays(data, flatten=True, usemask=usemask,
+                            fill_value=fill_value)
+    else:
+        data = data.pop()
+    #
+    output = ma.masked_all(max(len(base), len(data)),
+                           dtype=base.dtype.descr + data.dtype.descr)
+    output = recursive_fill_fields(base, output)
+    output = recursive_fill_fields(data, output)
+    #
+    return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
+
+
+
+def rec_append_fields(base, names, data, dtypes=None):
+    """
+    Add new fields to an existing array.
+
+    The names of the fields are given with the `names` arguments,
+    the corresponding values with the `data` arguments.
+    If a single field is appended, `names`, `data` and `dtypes` do not have
+    to be lists but just values.
+    
+    Parameters
+    ----------
+    base : array
+        Input array to extend.
+    names : string, sequence
+        String or sequence of strings corresponding to the names
+        of the new fields.
+    data : array or sequence of arrays
+        Array or sequence of arrays storing the fields to add to the base.
+    dtypes : sequence of datatypes, optional
+        Datatype or sequence of datatypes.
+        If None, the datatypes are estimated from the `data`.
+    
+    See Also
+    --------
+    append_fields
+
+    Returns
+    -------
+    appended_array : np.recarray
+    """
+    return append_fields(base, names, data=data, dtypes=dtypes,
+                         asrecarray=True, usemask=False)
+
+
+
+def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
+    """
+    Superposes arrays fields by fields
+
+    Parameters
+    ----------
+    seqarrays : array or sequence
+        Sequence of input arrays.
+    defaults : dictionary, optional
+        Dictionary mapping field names to the corresponding default values.
+    usemask : {True, False}, optional
+        Whether to return a MaskedArray (or MaskedRecords is `asrecarray==True`)
+        or a ndarray.
+    asrecarray : {False, True}, optional
+        Whether to return a recarray (or MaskedRecords if `usemask==True`) or
+        just a flexible-type ndarray.
+
+    Examples
+    --------
+    >>> x = np.array([1, 2,])
+    >>> stack_arrays(x) is x
+    True
+    >>> z = np.array([('A', 1), ('B', 2)], dtype=[('A', '|S3'), ('B', float)])
+    >>> zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+                      dtype=[('A', '|S3'), ('B', float), ('C', float)])
+    >>> test = stack_arrays((z,zz))
+    >>> masked_array(data = [('A', 1.0, --) ('B', 2.0, --) ('a', 10.0, 100.0)
+    ... ('b', 20.0, 200.0) ('c', 30.0, 300.0)],
+    ...       mask = [(False, False, True) (False, False, True) (False, False, False)
+    ... (False, False, False) (False, False, False)],
+    ...       fill_value=('N/A', 1e+20, 1e+20)
+    ...       dtype=[('A', '|S3'), ('B', '<f8'), ('C', '<f8')])
+
+    """
+    if isinstance(arrays, ndarray):
+        return arrays
+    elif len(arrays) == 1:
+        return arrays[0]
+    seqarrays = [np.asanyarray(a).ravel() for a in arrays]
+    nrecords = [len(a) for a in seqarrays]
+    ndtype = [a.dtype for a in seqarrays]
+    fldnames = [d.names for d in ndtype]
+    #
+    dtype_l = ndtype[0]
+    newdescr = dtype_l.descr
+    names = list(dtype_l.names or ()) or ['']
+    for dtype_n in ndtype[1:]:
+        for descr in dtype_n.descr:
+            name = descr[0] or ''
+            if name not in names:
+                newdescr.append(descr)
+                names.append(name)
+            elif descr[1] != dict(newdescr)[name]:
+                raise TypeError("Incompatible type '%s' <> '%s'" %\
+                                (dict(newdescr)[name], descr[1]))
+    # Only one field: use concatenate
+    if len(newdescr) == 1:
+        output = ma.concatenate(seqarrays)
+    else:
+        #
+        output = ma.masked_all((np.sum(nrecords),), newdescr)
+        offset = np.cumsum(np.r_[0, nrecords])
+        seen = []
+        for (a, n, i, j) in zip(seqarrays, fldnames, offset[:-1], offset[1:]):
+            names = a.dtype.names
+            if names is None:
+                output['f%i' % len(seen)][i:j] = a
+            else:
+                for name in n:
+                    output[name][i:j] = a[name]
+                    if name not in seen:
+                        seen.append(name)
+    #
+    return _fix_output(_fix_defaults(output, defaults),
+                       usemask=usemask, asrecarray=asrecarray)
+
+
+
+def find_duplicates(a, key=None, ignoremask=True, return_index=False):
+    """
+    Find the duplicates in a structured array along a given key
+
+    Parameters
+    ----------
+    a : array-like
+        Input array
+    key : {string, None}, optional
+        Name of the fields along which to check the duplicates.
+        If None, the search is performed by records
+    ignoremask : {True, False}, optional
+        Whether masked data should be discarded or considered as duplicates.
+    return_index : {False, True}, optional
+        Whether to return the indices of the duplicated values.
+
+    Examples
+    --------
+    >>> ndtype = [('a', int)]
+    >>> a = ma.array([1, 1, 1, 2, 2, 3, 3], 
+    ...         mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype)
+    >>> find_duplicates(a, ignoremask=True, return_index=True)
+    """
+    a = np.asanyarray(a).ravel()
+    # Get a dictionary of fields
+    fields = get_fieldstructure(a.dtype)
+    # Get the sorting data (by selecting the corresponding field)
+    base = a
+    if key:
+        for f in fields[key]:
+            base = base[f]
+        base = base[key]
+    # Get the sorting indices and the sorted data
+    sortidx = base.argsort()
+    sortedbase = base[sortidx]
+    sorteddata = sortedbase.filled()
+    # Compare the sorting data
+    flag = (sorteddata[:-1] == sorteddata[1:])
+    # If masked data must be ignored, set the flag to false where needed
+    if ignoremask:
+        sortedmask = sortedbase.recordmask
+        flag[sortedmask[1:]] = False
+    flag = np.concatenate(([False], flag))
+    # We need to take the point on the left as well (else we're missing it)
+    flag[:-1] = flag[:-1] + flag[1:]
+    duplicates = a[sortidx][flag]
+    if return_index:
+        return (duplicates, sortidx[flag])
+    else:
+        return duplicates
+
+
+
+def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
+                defaults=None, usemask=True, asrecarray=False):
+    """
+    Join arrays `r1` and `r2` on key `key`.
+
+    The key should be either a string or a sequence of string corresponding
+    to the fields used to join the array.
+    An exception is raised if the `key` field cannot be found in the two input
+    arrays.
+    Neither `r1` nor `r2` should have any duplicates along `key`: the presence
+    of duplicates will make the output quite unreliable. Note that duplicates
+    are not looked for by the algorithm.
+
+    Parameters
+    ----------
+    key : {string, sequence}
+        A string or a sequence of strings corresponding to the fields used
+        for comparison.
+    r1, r2 : arrays
+        Structured arrays.
+    jointype : {'inner', 'outer', 'leftouter'}, optional
+        If 'inner', returns the elements common to both r1 and r2.
+        If 'outer', returns the common elements as well as the elements of r1
+        not in r2 and the elements of not in r2.
+        If 'leftouter', returns the common elements and the elements of r1 not
+        in r2.
+    r1postfix : string, optional
+        String appended to the names of the fields of r1 that are present in r2
+        but absent of the key.
+    r2postfix : string, optional
+        String appended to the names of the fields of r2 that are present in r1
+        but absent of the key.
+    defaults : {dictionary}, optional
+        Dictionary mapping field names to the corresponding default values.
+    usemask : {True, False}, optional
+        Whether to return a MaskedArray (or MaskedRecords is `asrecarray==True`)
+        or a ndarray.
+    asrecarray : {False, True}, optional
+        Whether to return a recarray (or MaskedRecords if `usemask==True`) or
+        just a flexible-type ndarray.
+
+    Notes
+    -----
+    * The output is sorted along the key.
+    * A temporary array is formed by dropping the fields not in the key for the
+      two arrays and concatenating the result. This array is then sorted, and
+      the common entries selected. The output is constructed by filling the fields
+      with the selected entries. Matching is not preserved if there are some
+      duplicates...
+
+    """
+    # Check jointype
+    if jointype not in ('inner', 'outer', 'leftouter'):
+        raise ValueError("The 'jointype' argument should be in 'inner', "\
+                         "'outer' or 'leftouter' (got '%s' instead)" % jointype)
+    # If we have a single key, put it in a tuple
+    if isinstance(key, basestring):
+        key = (key, )
+
+    # Check the keys
+    for name in key:
+        if name not in r1.dtype.names:
+            raise ValueError('r1 does not have key field %s'%name)
+        if name not in r2.dtype.names:
+            raise ValueError('r2 does not have key field %s'%name)
+
+    # Make sure we work with ravelled arrays
+    r1 = r1.ravel()
+    r2 = r2.ravel()
+    (nb1, nb2) = (len(r1), len(r2))
+    (r1names, r2names) = (r1.dtype.names, r2.dtype.names)
+
+    # Make temporary arrays of just the keys
+    r1k = drop_fields(r1, [n for n in r1names if n not in key])
+    r2k = drop_fields(r2, [n for n in r2names if n not in key])
+
+    # Concatenate the two arrays for comparison
+    aux = ma.concatenate((r1k, r2k))
+    idx_sort = aux.argsort(order=key)
+    aux = aux[idx_sort]
+    #
+    # Get the common keys
+    flag_in = ma.concatenate(([False], aux[1:] == aux[:-1]))
+    flag_in[:-1] = flag_in[1:] + flag_in[:-1]
+    idx_in = idx_sort[flag_in]
+    idx_1 = idx_in[(idx_in < nb1)]
+    idx_2 = idx_in[(idx_in >= nb1)] - nb1
+    (r1cmn, r2cmn) = (len(idx_1), len(idx_2))
+    if jointype == 'inner':
+        (r1spc, r2spc) = (0, 0)
+    elif jointype == 'outer':
+        idx_out = idx_sort[~flag_in]
+        idx_1 = np.concatenate((idx_1, idx_out[(idx_out < nb1)]))
+        idx_2 = np.concatenate((idx_2, idx_out[(idx_out >= nb1)] - nb1))
+        (r1spc, r2spc) = (len(idx_1) - r1cmn, len(idx_2) - r2cmn)
+    elif jointype == 'leftouter':
+        idx_out = idx_sort[~flag_in]
+        idx_1 = np.concatenate((idx_1, idx_out[(idx_out < nb1)]))
+        (r1spc, r2spc) = (len(idx_1) - r1cmn, 0)
+    # Select the entries from each input
+    (s1, s2) = (r1[idx_1], r2[idx_2])
+    #
+    # Build the new description of the output array .......
+    # Start with the key fields
+    ndtype = [list(_) for _ in r1k.dtype.descr]
+    # Add the other fields
+    ndtype.extend(list(_) for _ in r1.dtype.descr if _[0] not in key)
+    # Find the new list of names (it may be different from r1names)
+    names = list(_[0] for _ in ndtype)
+    for desc in r2.dtype.descr:
+        desc = list(desc)
+        name = desc[0]
+        # Have we seen the current name already ?
+        if name in names:
+            nameidx = names.index(name)
+            current = ndtype[nameidx]
+            # The current field is part of the key: take the largest dtype
+            if name in key:
+                current[-1] = max(desc[1], current[-1])
+            # The current field is not part of the key: add the suffixes
+            else:
+                current[0] += r1postfix
+                desc[0] += r2postfix
+                ndtype.insert(nameidx+1, desc)
+        #... we haven't: just add the description to the current list
+        else:
+            names.extend(desc[0])
+            ndtype.append(desc)
+    # Revert the elements to tuples
+    ndtype = [tuple(_) for _ in ndtype]
+    # Find the largest nb of common fields : r1cmn and r2cmn should be equal, but...
+    cmn = max(r1cmn, r2cmn)
+    # Construct an empty array
+    output = ma.masked_all((cmn + r1spc + r2spc,), dtype=ndtype)
+    names = output.dtype.names
+    for f in r1names:
+        selected = s1[f]
+        if f not in names:
+            f += r1postfix
+        current = output[f]
+        current[:r1cmn] = selected[:r1cmn]
+        if jointype in ('outer', 'leftouter'):
+            current[cmn:cmn+r1spc] = selected[r1cmn:]
+    for f in r2names:
+        selected = s2[f]
+        if f not in names:
+            f += r2postfix
+        current = output[f]
+        current[:r2cmn] = selected[:r2cmn]
+        if (jointype == 'outer') and r2spc:
+            current[-r2spc:] = selected[r2cmn:]
+    # Sort and finalize the output
+    output.sort(order=key)
+    kwargs = dict(usemask=usemask, asrecarray=asrecarray)
+    return _fix_output(_fix_defaults(output, defaults), **kwargs)
+
+
+def rec_join(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
+             defaults=None):
+    """
+    Join arrays `r1` and `r2` on keys.
+    Alternative to join_by, that always returns a np.recarray.
+
+    See Also
+    --------
+    join_by : equivalent function
+    """
+    kwargs = dict(jointype=jointype, r1postfix=r1postfix, r2postfix=r2postfix,
+                  defaults=defaults, usemask=False, asrecarray=True)
+    return join_by(key, r1, r2, **kwargs)


Property changes on: trunk/numpy/lib/recfunctions.py
___________________________________________________________________
Name: svn:mime-type
   + text/plain

Added: trunk/numpy/lib/tests/test_recfunctions.py
===================================================================
--- trunk/numpy/lib/tests/test_recfunctions.py	2009-01-22 05:37:36 UTC (rev 6330)
+++ trunk/numpy/lib/tests/test_recfunctions.py	2009-01-22 05:40:25 UTC (rev 6331)
@@ -0,0 +1,570 @@
+
+import numpy as np
+import numpy.ma as ma
+from numpy.ma.testutils import *
+
+from numpy.ma.mrecords import MaskedRecords
+
+from numpy.lib.recfunctions import *
+get_names = np.lib.recfunctions.get_names
+get_names_flat = np.lib.recfunctions.get_names_flat
+zip_descr = np.lib.recfunctions.zip_descr
+
+class TestRecFunctions(TestCase):
+    """
+    Misc tests
+    """
+    #
+    def setUp(self):
+        x = np.array([1, 2,])
+        y = np.array([10, 20, 30])
+        z = np.array([('A', 1.), ('B', 2.)],
+                     dtype=[('A', '|S3'), ('B', float)])
+        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+            dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+        self.data = (w, x, y, z)
+
+
+    def test_zip_descr(self):
+        "Test zip_descr"
+        (w, x, y, z) = self.data
+        # Std array
+        test = zip_descr((x, x), flatten=True)
+        assert_equal(test,
+                     np.dtype([('', '<i4'), ('', '<i4')]))
+        test = zip_descr((x, x), flatten=False)
+        assert_equal(test,
+                     np.dtype([('', '<i4'), ('', '<i4')]))
+        # Std & flexible-dtype 
+        test = zip_descr((x, z), flatten=True)
+        assert_equal(test,
+                     np.dtype([('', '<i4'), ('A', '|S3'), ('B', float)]))
+        test = zip_descr((x, z), flatten=False)
+        assert_equal(test, 
+                     np.dtype([('', '<i4'),
+                               ('', [('A', '|S3'), ('B', float)])]))
+        # Standard & nested dtype 
+        test = zip_descr((x, w), flatten=True)
+        assert_equal(test,
+                     np.dtype([('', '<i4'),
+                               ('a', int),
+                               ('ba', float), ('bb', int)]))
+        test = zip_descr((x, w), flatten=False)
+        assert_equal(test,
+                     np.dtype([('', '<i4'),
+                               ('', [('a', int),
+                                     ('b', [('ba', float), ('bb', int)])])]))
+
+
+    def test_drop_fields(self):
+        "Test drop_fields"
+        a = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+                     dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+        # A basic field
+        test = drop_fields(a, 'a')
+        control = np.array([((2, 3.0),), ((5, 6.0),)],
+                     dtype=[('b', [('ba', float), ('bb', int)])])
+        assert_equal(test, control)
+        # Another basic field (but nesting two fields)
+        test = drop_fields(a, 'b')
+        control = np.array([(1,), (4,)], dtype=[('a', int)])
+        assert_equal(test, control)
+        # A nested sub-field
+        test = drop_fields(a, ['ba',])
+        control = np.array([(1, (3.0,)), (4, (6.0,))],
+                     dtype=[('a', int), ('b', [('bb', int)])])
+        assert_equal(test, control)
+        # All the nested sub-field from a field: zap that field
+        test = drop_fields(a, ['ba', 'bb'])
+        control = np.array([(1,), (4,)], dtype=[('a', int)])
+        assert_equal(test, control)
+        #
+        test = drop_fields(a, ['a', 'b'])
+        assert(test is None)
+
+
+    def test_rename_fields(self):
+        "Tests rename fields"
+        a = np.array([(1, (2, [3.0, 30.])), (4, (5, [6.0, 60.]))],
+                     dtype=[('a', int),
+                            ('b', [('ba', float), ('bb', (float, 2))])])
+        test = rename_fields(a, {'a':'A', 'bb':'BB'})
+        newdtype = [('A', int), ('b', [('ba', float), ('BB', (float, 2))])]
+        control = a.view(newdtype)
+        assert_equal(test.dtype, newdtype)
+        assert_equal(test, control)
+
+
+    def test_get_names(self):
+        "Tests get_names"
+        ndtype = np.dtype([('A', '|S3'), ('B', float)])
+        test = get_names(ndtype)
+        assert_equal(test, ('A', 'B'))
+        #
+        ndtype = np.dtype([('a', int), ('b', [('ba', float), ('bb', int)])])
+        test = get_names(ndtype)
+        assert_equal(test, ('a', ('b', ('ba', 'bb'))))
+
+
+    def test_get_names_flat(self):
+        "Test get_names_flat"
+        ndtype = np.dtype([('A', '|S3'), ('B', float)])
+        test = get_names_flat(ndtype)
+        assert_equal(test, ('A', 'B'))
+        #
+        ndtype = np.dtype([('a', int), ('b', [('ba', float), ('bb', int)])])
+        test = get_names_flat(ndtype)
+        assert_equal(test, ('a', 'b', 'ba', 'bb'))
+
+
+    def test_get_fieldstructure(self):
+        "Test get_fieldstructure"
+        # No nested fields
+        ndtype = np.dtype([('A', '|S3'), ('B', float)])
+        test = get_fieldstructure(ndtype)
+        assert_equal(test, {'A':[], 'B':[]})
+        # One 1-nested field
+        ndtype = np.dtype([('A', int), ('B', [('BA', float), ('BB', '|S1')])])
+        test = get_fieldstructure(ndtype)
+        assert_equal(test, {'A': [], 'B': [], 'BA':['B',], 'BB':['B']})
+        # One 2-nested fields
+        ndtype = np.dtype([('A', int),
+                           ('B', [('BA', int),
+                                  ('BB', [('BBA', int), ('BBB', int)])])])
+        test = get_fieldstructure(ndtype)
+        control = {'A': [], 'B': [], 'BA': ['B'], 'BB': ['B'], 
+                   'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
+        assert_equal(test, control)
+
+
+    def test_find_duplicates(self):
+        "Test find_duplicates"
+        a = ma.array([(2, (2., 'B')), (1, (2., 'B')), (2, (2., 'B')),
+                      (1, (1., 'B')), (2, (2., 'B')), (2, (2., 'C'))],
+                      mask=[(0, (0, 0)), (0, (0, 0)), (0, (0, 0)),
+                            (0, (0, 0)), (1, (0, 0)), (0, (1, 0))],
+                      dtype=[('A', int), ('B', [('BA', float), ('BB', '|S1')])])
+        #
+        test = find_duplicates(a, ignoremask=False, return_index=True)
+        control = [0, 2]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+        #
+        test = find_duplicates(a, key='A', return_index=True)
+        control = [1, 3, 0, 2, 5]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+        #
+        test = find_duplicates(a, key='B', return_index=True)
+        control = [0, 1, 2, 4]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+        #
+        test = find_duplicates(a, key='BA', return_index=True)
+        control = [0, 1, 2, 4]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+        #
+        test = find_duplicates(a, key='BB', return_index=True)
+        control = [0, 1, 2, 3, 4]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+
+
+    def test_find_duplicates_ignoremask(self):
+        "Test the ignoremask option of find_duplicates"
+        ndtype = [('a', int)]
+        a = ma.array([1, 1, 1, 2, 2, 3, 3], 
+                mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype)
+        test = find_duplicates(a, ignoremask=True, return_index=True)
+        control = [0, 1, 3, 4]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+        #
+        test = find_duplicates(a, ignoremask=False, return_index=True)
+        control = [0, 1, 3, 4, 6, 2]
+        assert_equal(test[-1], control)
+        assert_equal(test[0], a[control])
+
+
+class TestRecursiveFillFields(TestCase):
+    """
+    Test recursive_fill_fields.
+    """
+    def test_simple_flexible(self):
+        "Test recursive_fill_fields on flexible-array"
+        a = np.array([(1, 10.), (2, 20.)], dtype=[('A', int), ('B', float)])
+        b = np.zeros((3,), dtype=a.dtype)
+        test = recursive_fill_fields(a, b)
+        control = np.array([(1, 10.), (2, 20.), (0, 0.)],
+                           dtype=[('A', int), ('B', float)])
+        assert_equal(test, control)
+    #
+    def test_masked_flexible(self):
+        "Test recursive_fill_fields on masked flexible-array"
+        a = ma.array([(1, 10.), (2, 20.)], mask=[(0, 1), (1, 0)],
+                     dtype=[('A', int), ('B', float)])
+        b = ma.zeros((3,), dtype=a.dtype)
+        test = recursive_fill_fields(a, b)
+        control = ma.array([(1, 10.), (2, 20.), (0, 0.)],
+                           mask=[(0, 1), (1, 0), (0, 0)],
+                           dtype=[('A', int), ('B', float)])
+        assert_equal(test, control)
+    #
+
+
+
+class TestMergeArrays(TestCase):
+    """
+    Test merge_arrays
+    """
+    def setUp(self):
+        x = np.array([1, 2,])
+        y = np.array([10, 20, 30])
+        z = np.array([('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
+        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+            dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+        self.data = (w, x, y, z)
+    #
+    def test_solo(self):
+        "Test merge_arrays on a single array."
+        (_, x, _, z) = self.data
+        #
+        test = merge_arrays(x)
+        control = np.array([(1,), (2,)], dtype=[('f0', int)])
+        assert_equal(test, control)
+        test = merge_arrays((x,))
+        assert_equal(test, control)
+        #
+        test = merge_arrays(z, flatten=False)
+        assert_equal(test, z)
+        test = merge_arrays(z, flatten=True)
+        assert_equal(test, z)
+    #
+    def test_solo_w_flatten(self):
+        "Test merge_arrays on a single array w & w/o flattening"
+        w = self.data[0]
+        test = merge_arrays(w, flatten=False)
+        assert_equal(test, w)
+        #
+        test = merge_arrays(w, flatten=True)
+        control = np.array([(1, 2, 3.0), (4, 5, 6.0)],
+            dtype=[('a', int), ('ba', float), ('bb', int)])
+        assert_equal(test, control)
+    #
+    def test_standard(self):
+        "Test standard & standard"
+        # Test merge arrays
+        (_, x, y, _) = self.data
+        test = merge_arrays((x, y), usemask=False)
+        control = np.array([(1, 10), (2, 20), (-1, 30)],
+                           dtype=[('f0', int), ('f1', int)])
+        assert_equal(test, control)
+        # 
+        test = merge_arrays((x, y), usemask=True)
+        control = ma.array([(1, 10), (2, 20), (-1, 30)],
+                           mask=[(0, 0), (0, 0), (1, 0)],
+                           dtype=[('f0', int), ('f1', int)])
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+    #
+    def test_flatten(self):
+        "Test standard & flexible"
+        (_, x, _, z) = self.data
+        test = merge_arrays((x, z), flatten=True)
+        control = np.array([(1, 'A', 1.), (2, 'B', 2.)],
+                           dtype=[('f0', int), ('A', '|S3'), ('B', float)])
+        assert_equal(test, control)
+        #
+        test = merge_arrays((x, z), flatten=False)
+        control = np.array([(1, ('A', 1.)), (2, ('B', 2.))],
+                           dtype=[('f0', int),
+                                  ('f1', [('A', '|S3'), ('B', float)])])
+        assert_equal(test, control)
+    #
+    def test_flatten_wflexible(self):
+        "Test flatten standard & nested"
+        (w, x, _, _) = self.data
+        test = merge_arrays((x, w), flatten=True)
+        control = np.array([(1, 1, 2, 3.0), (2, 4, 5, 6.0)],
+                           dtype=[('f0', int),
+                                  ('a', int), ('ba', float), ('bb', int)])
+        assert_equal(test, control)
+        #
+        test = merge_arrays((x, w), flatten=False)
+        controldtype = dtype=[('f0', int),
+                              ('f1', [('a', int),
+                                      ('b', [('ba', float), ('bb', int)])])]
+        control = np.array([(1., (1, (2, 3.0))), (2, (4, (5, 6.0)))],
+                           dtype=controldtype)
+    #
+    def test_wmasked_arrays(self):
+        "Test merge_arrays masked arrays"
+        (_, x, _, _) = self.data
+        mx = ma.array([1, 2, 3], mask=[1, 0, 0])
+        test = merge_arrays((x, mx), usemask=True)
+        control = ma.array([(1, 1), (2, 2), (-1, 3)],
+                           mask=[(0, 1), (0, 0), (1, 0)],
+                           dtype=[('f0', int), ('f1', int)])
+        assert_equal(test, control)
+        test = merge_arrays((x, mx), usemask=True, asrecarray=True)
+        assert_equal(test, control)
+        assert(isinstance(test, MaskedRecords))
+    #
+    def test_w_singlefield(self):
+        "Test single field"
+        test = merge_arrays((np.array([1, 2]).view([('a', int)]),
+                             np.array([10., 20., 30.])),)
+        control = ma.array([(1, 10.), (2, 20.), (-1, 30.)],
+                           mask=[(0, 0), (0, 0), (1, 0)],
+                           dtype=[('a', int), ('f1', float)])
+        assert_equal(test, control)
+    #
+    def test_w_shorter_flex(self):
+        "Test merge_arrays w/ a shorter flexndarray."
+        z = self.data[-1]
+        test = merge_arrays((z, np.array([10, 20, 30]).view([('C', int)])))
+        control = np.array([('A', 1., 10), ('B', 2., 20), ('-1', -1, 20)],
+                           dtype=[('A', '|S3'), ('B', float), ('C', int)])
+
+
+
+class TestAppendFields(TestCase):
+    """
+    Test append_fields
+    """
+    def setUp(self):
+        x = np.array([1, 2,])
+        y = np.array([10, 20, 30])
+        z = np.array([('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
+        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+            dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+        self.data = (w, x, y, z)
+    #
+    def test_append_single(self):
+        "Test simple case"
+        (_, x, _, _) = self.data
+        test = append_fields(x, 'A', data=[10, 20, 30])
+        control = ma.array([(1, 10), (2, 20), (-1, 30)],
+                           mask=[(0, 0), (0, 0), (1, 0)],
+                           dtype=[('f0', int), ('A', int)],)
+        assert_equal(test, control)
+    #
+    def test_append_double(self):
+        "Test simple case"
+        (_, x, _, _) = self.data
+        test = append_fields(x, ('A', 'B'), data=[[10, 20, 30], [100, 200]])
+        control = ma.array([(1, 10, 100), (2, 20, 200), (-1, 30, -1)],
+                           mask=[(0, 0, 0), (0, 0, 0), (1, 0, 1)],
+                           dtype=[('f0', int), ('A', int), ('B', int)],)
+        assert_equal(test, control)
+    #
+    def test_append_on_flex(self):
+        "Test append_fields on flexible type arrays"
+        z = self.data[-1]
+        test = append_fields(z, 'C', data=[10, 20, 30])
+        control = ma.array([('A', 1., 10), ('B', 2., 20), (-1, -1., 30)],
+                           mask=[(0, 0, 0), (0, 0, 0), (1, 1, 0)],
+                           dtype=[('A', '|S3'), ('B', float), ('C', int)],)
+        assert_equal(test, control)
+    #
+    def test_append_on_nested(self):
+        "Test append_fields on nested fields"
+        w = self.data[0]
+        test = append_fields(w, 'C', data=[10, 20, 30])
+        control = ma.array([(1, (2, 3.0), 10),
+                            (4, (5, 6.0), 20),
+                            (-1, (-1, -1.), 30)],
+                      mask=[(0, (0, 0), 0), (0, (0, 0), 0), (1, (1, 1), 0)],
+                     dtype=[('a', int),
+                            ('b', [('ba', float), ('bb', int)]),
+                            ('C', int)],)
+        assert_equal(test, control)
+
+
+
+class TestStackArrays(TestCase):
+    """
+    Test stack_arrays
+    """
+    def setUp(self):
+        x = np.array([1, 2,])
+        y = np.array([10, 20, 30])
+        z = np.array([('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
+        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
+            dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
+        self.data = (w, x, y, z)
+    #
+    def test_solo(self):
+        "Test stack_arrays on single arrays"
+        (_, x, _, _) = self.data
+        test = stack_arrays((x,))
+        assert_equal(test, x)
+        self.failUnless(test is x)
+        #
+        test = stack_arrays(x)
+        assert_equal(test, x)
+        self.failUnless(test is x)
+    #
+    def test_unnamed_fields(self):
+        "Tests combinations of arrays w/o named fields"
+        (_, x, y, _) = self.data
+        #
+        test = stack_arrays((x, x), usemask=False)
+        control = np.array([1, 2, 1, 2])
+        assert_equal(test, control)
+        #
+        test = stack_arrays((x, y), usemask=False)
+        control = np.array([1, 2, 10, 20, 30])
+        assert_equal(test, control)
+        #
+        test = stack_arrays((y, x), usemask=False)
+        control = np.array([10, 20, 30, 1, 2])
+        assert_equal(test, control)
+    #
+    def test_unnamed_and_named_fields(self):
+        "Test combination of arrays w/ & w/o named fields"
+        (_, x, _, z) = self.data
+        #
+        test = stack_arrays((x, z))
+        control = ma.array([(1, -1, -1), (2, -1, -1),
+                            (-1, 'A', 1), (-1, 'B', 2)],
+                      mask=[(0, 1, 1), (0, 1, 1),
+                            (1, 0, 0), (1, 0, 0)],
+                      dtype=[('f0', int), ('A', '|S3'), ('B', float)])
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+        #
+        test = stack_arrays((z, x))
+        control = ma.array([('A', 1, -1), ('B', 2, -1),
+                            (-1, -1, 1), (-1, -1, 2),],
+                      mask=[(0, 0, 1), (0, 0, 1),
+                            (1, 1, 0), (1, 1, 0)],
+                      dtype=[('A', '|S3'), ('B', float), ('f2', int)])
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+        #
+        test = stack_arrays((z, z, x))
+        control = ma.array([('A', 1, -1), ('B', 2, -1),
+                            ('A', 1, -1), ('B', 2, -1),
+                            (-1, -1, 1), (-1, -1, 2),],
+                      mask=[(0, 0, 1), (0, 0, 1),
+                            (0, 0, 1), (0, 0, 1),
+                            (1, 1, 0), (1, 1, 0)],
+                      dtype=[('A', '|S3'), ('B', float), ('f2', int)])
+        assert_equal(test, control)
+    #
+    def test_matching_named_fields(self):
+        "Test combination of arrays w/ matching field names"
+        (_, x, _, z) = self.data
+        zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+                      dtype=[('A', '|S3'), ('B', float), ('C', float)])
+        test = stack_arrays((z, zz))
+        control = ma.array([('A', 1, -1), ('B', 2, -1),
+                            ('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+                            dtype=[('A', '|S3'), ('B', float), ('C', float)],
+                            mask=[(0, 0, 1), (0, 0, 1),
+                                  (0, 0, 0), (0, 0, 0), (0, 0, 0)])
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+        #
+        test = stack_arrays((z, zz, x))
+        ndtype = [('A', '|S3'), ('B', float), ('C', float), ('f3', int)]
+        control = ma.array([('A', 1, -1, -1), ('B', 2, -1, -1),
+                            ('a', 10., 100., -1), ('b', 20., 200., -1),
+                            ('c', 30., 300., -1),
+                            (-1, -1, -1, 1), (-1, -1, -1, 2)],
+                            dtype=ndtype,
+                            mask=[(0, 0, 1, 1), (0, 0, 1, 1),
+                                  (0, 0, 0, 1), (0, 0, 0, 1), (0, 0, 0, 1),
+                                  (1, 1, 1, 0), (1, 1, 1, 0)])
+        assert_equal(test, control)
+        assert_equal(test.mask, control.mask)
+
+
+    #
+    def test_defaults(self):
+        "Test defaults: no exception raised if keys of defaults are not fields."
+        (_, _, _, z) = self.data
+        zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+                      dtype=[('A', '|S3'), ('B', float), ('C', float)])
+        defaults = {'A':'???', 'B':-999., 'C':-9999., 'D':-99999.}
+        test = stack_arrays((z, zz), defaults=defaults)
+        control = ma.array([('A', 1, -9999.), ('B', 2, -9999.),
+                            ('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
+                            dtype=[('A', '|S3'), ('B', float), ('C', float)],
+                            mask=[(0, 0, 1), (0, 0, 1),
+                                  (0, 0, 0), (0, 0, 0), (0, 0, 0)])
+        assert_equal(test, control)
+        assert_equal(test.data, control.data)
+        assert_equal(test.mask, control.mask)
+
+
+
+class TestJoinBy(TestCase):
+    #
+    def test_base(self):
+        "Basic test of join_by"
+        a = np.array(zip(np.arange(10), np.arange(50, 60), np.arange(100, 110)),
+                     dtype=[('a', int), ('b', int), ('c', int)])
+        b = np.array(zip(np.arange(5, 15), np.arange(65, 75), np.arange(100, 110)),
+                     dtype=[('a', int), ('b', int), ('d', int)])
+        #
+        test = join_by('a', a, b, jointype='inner')
+        control = np.array([(5, 55, 65, 105, 100), (6, 56, 66, 106, 101),
+                            (7, 57, 67, 107, 102), (8, 58, 68, 108, 103),
+                            (9, 59, 69, 109, 104)],
+                           dtype=[('a', int), ('b1', int), ('b2', int),
+                                  ('c', int), ('d', int)])
+        assert_equal(test, control)
+        #
+        test = join_by(('a', 'b'), a, b)
+        control = np.array([(5, 55, 105, 100), (6, 56, 106, 101),
+                            (7, 57, 107, 102), (8, 58, 108, 103),
+                            (9, 59, 109, 104)],
+                           dtype=[('a', int), ('b', int),
+                                  ('c', int), ('d', int)])
+        #
+        test = join_by(('a', 'b'), a, b, 'outer')
+        control = ma.array([( 0, 50, 100, -1),  ( 1, 51, 101,  -1),
+                            ( 2, 52, 102, -1),  ( 3, 53, 103,  -1),
+                            ( 4, 54, 104, -1),  ( 5, 55, 105,  -1),
+                            ( 5, 65,  -1, 100), ( 6, 56, 106,  -1),
+                            ( 6, 66,  -1, 101), ( 7, 57, 107,  -1),
+                            ( 7, 67,  -1, 102), ( 8, 58, 108,  -1),
+                            ( 8, 68,  -1, 103), ( 9, 59, 109,  -1),
+                            ( 9, 69,  -1, 104), (10, 70,  -1, 105),
+                            (11, 71,  -1, 106), (12, 72,  -1, 107),
+                            (13, 73,  -1, 108), (14, 74,  -1, 109)],
+                      mask=[( 0,  0,   0,   1), ( 0,  0,   0,   1),
+                            ( 0,  0,   0,   1), ( 0,  0,   0,   1),
+                            ( 0,  0,   0,   1), ( 0,  0,   0,   1),
+                            ( 0,  0,   1,   0), ( 0,  0,   0,   1),
+                            ( 0,  0,   1,   0), ( 0,  0,   0,   1),
+                            ( 0,  0,   1,   0), ( 0,  0,   0,   1),
+                            ( 0,  0,   1,   0), ( 0,  0,   0,   1),
+                            ( 0,  0,   1,   0), ( 0,  0,   1,   0),
+                            ( 0,  0,   1,   0), ( 0,  0,   1,   0),
+                            ( 0,  0,   1,   0), ( 0,  0,   1,   0)],
+                      dtype=[('a', int), ('b', int),
+                             ('c', int), ('d', int)])
+        assert_equal(test, control)
+        #
+        test = join_by(('a', 'b'), a, b, 'leftouter')
+        control = ma.array([(0, 50, 100, -1), (1, 51, 101, -1),
+                            (2, 52, 102, -1), (3, 53, 103, -1),
+                            (4, 54, 104, -1), (5, 55, 105, -1),
+                            (6, 56, 106, -1), (7, 57, 107, -1),
+                            (8, 58, 108, -1), (9, 59, 109, -1)],
+                      mask=[(0,  0,   0,  1), (0,  0,   0,  1),
+                            (0,  0,   0,  1), (0,  0,   0,  1),
+                            (0,  0,   0,  1), (0,  0,   0,  1),
+                            (0,  0,   0,  1), (0,  0,   0,  1),
+                            (0,  0,   0,  1), (0,  0,   0,  1)],
+                      dtype=[('a', int), ('b', int), ('c', int), ('d', int)])
+
+
+
+
+if __name__ == '__main__':
+    run_module_suite()


Property changes on: trunk/numpy/lib/tests/test_recfunctions.py
___________________________________________________________________
Name: svn:mime-type
   + text/plain




More information about the Numpy-svn mailing list