[Numpy-discussion] use index array of len n to select columns of n x m array
Keith Goodman
kwgoodman at gmail.com
Fri Aug 6 09:57:40 EDT 2010
On Fri, Aug 6, 2010 at 3:01 AM, Martin Spacek <numpy at mspacek.mm.st> wrote:
> Keith Goodman wrote:
> > Here's one way:
> >
> >>> a.flat[i + a.shape[1] * np.arange(a.shape[0])]
> > array([0, 3, 5, 6, 9])
>
>
> I'm afraid I made my example a little too simple. In retrospect, what I really
> want is to be able to use a 2D index array "i", like this:
>
> >>> a = np.array([[ 0, 1, 2, 3],
> [ 4, 5, 6, 7],
> [ 8, 9, 10, 11],
> [12, 13, 14, 15],
> [16, 17, 18, 19]])
> >>> i = np.array([[2, 1],
> [3, 1],
> [1, 1],
> [0, 0],
> [3, 1]])
> >>> foo(a, i)
> array([[ 2, 1],
> [ 7, 5],
> [ 9, 9],
> [12, 12],
> [19, 17]])
>
> I think the flat iterator indexing suggestion is about the only thing that'll
> work. Here's the function I've pretty much settled on:
>
> def rowtake(a, i):
> """For each row in a, return values according to column indices in the
> corresponding row in i. Returned shape == i.shape"""
> assert a.ndim == 2
> assert i.ndim <= 2
> if i.ndim == 1:
> return a.flat[i + a.shape[1] * np.arange(a.shape[0])]
> else: # i.ndim == 2
> return a.flat[i + a.shape[1] * np.vstack(np.arange(a.shape[0]))]
>
> This is about half as fast as my Cython function, but the Cython function is
> limited to fixed dtypes and ndim:
You can speed it up by getting rid of two copies:
idx = np.arange(a.shape[0])
idx *= a.shape[1]
idx += i
More information about the NumPy-Discussion
mailing list