[Numpy-discussion] use index array of len n to select columns of n x m array
Martin Spacek
numpy at mspacek.mm.st
Fri Aug 6 06:01:10 EDT 2010
Keith Goodman wrote:
> Here's one way:
>
>>> a.flat[i + a.shape[1] * np.arange(a.shape[0])]
> array([0, 3, 5, 6, 9])
I'm afraid I made my example a little too simple. In retrospect, what I really
want is to be able to use a 2D index array "i", like this:
>>> a = np.array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]])
>>> i = np.array([[2, 1],
[3, 1],
[1, 1],
[0, 0],
[3, 1]])
>>> foo(a, i)
array([[ 2, 1],
[ 7, 5],
[ 9, 9],
[12, 12],
[19, 17]])
I think the flat iterator indexing suggestion is about the only thing that'll
work. Here's the function I've pretty much settled on:
def rowtake(a, i):
"""For each row in a, return values according to column indices in the
corresponding row in i. Returned shape == i.shape"""
assert a.ndim == 2
assert i.ndim <= 2
if i.ndim == 1:
return a.flat[i + a.shape[1] * np.arange(a.shape[0])]
else: # i.ndim == 2
return a.flat[i + a.shape[1] * np.vstack(np.arange(a.shape[0]))]
This is about half as fast as my Cython function, but the Cython function is
limited to fixed dtypes and ndim:
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def rowtake_cy(np.ndarray[np.int32_t, ndim=2] a,
np.ndarray[np.int32_t, ndim=2] i):
"""For each row in a, return values according to column indices in the
corresponding row in i. Returned shape == i.shape"""
cdef Py_ssize_t nrows, ncols, rowi, coli
cdef np.ndarray[np.int32_t, ndim=2] out
nrows = i.shape[0]
ncols = i.shape[1] # num cols to take from a for each row
assert a.shape[0] == nrows
assert i.max() < a.shape[1]
out = np.empty((nrows, ncols), dtype=np.int32)
for rowi in range(nrows):
for coli in range(ncols):
out[rowi, coli] = a[rowi, i[rowi, coli]]
return out
Cheers,
Martin
More information about the NumPy-Discussion
mailing list