[Numpy-discussion] use index array of len n to select columns of n x m array

Bruce Southey bsouthey at gmail.com
Thu Aug 5 16:47:46 EDT 2010


  On 08/05/2010 03:07 PM, Martin Spacek wrote:
> josef.pkt wrote:
>>>> a = np.array([[0, 1],
>                     [2, 3],
>                     [4, 5],
>                     [6, 7],
>                     [8, 9]])
>>>> i = np.array([0, 1, 1, 0, 1])
>>>> a[range(a.shape[0]), i]
> array([0, 3, 5, 6, 9])
>>>> a[np.arange(a.shape[0]), i]
> array([0, 3, 5, 6, 9])
>
>
> Thanks for all the tips. I guess I was hoping for something that could avoid
> having to generate np.arange(a.shape[0]), but
>
>   >>>  a[np.arange(a.shape[0]), i]
>
> sure is easy to understand. Is there maybe a more CPU and/or memory efficient
> way? I kind of like John Salvatier's idea:
>
>   >>>  np.choose(i, (a[:,0], a[:,1])
>
> but that would need to be generalized to "a" of arbitrary columns. This could be
> done using split or vsplit:
>
>   >>>  np.choose(i, np.vsplit(a.T, a.shape[1]))[0]
> array([0, 3, 5, 6, 9])
>
> That avoids having to generate an np.arange(), but looks kind of wordy. Is there
> a more compact way? Maybe this is better:
>
>   >>>  b, = i.choose(np.vsplit(a.T, a.shape[1]))
>   >>>  b
> array([0, 3, 5, 6, 9])
>
> Ah, but I've just discovered a strange limitation of choose():
>
>   >>>  a = np.arange(9*32)
>   >>>  a.shape = 9, 32
>   >>>  i = np.random.randint(0, a.shape[1], size=a.shape[0])
>   >>>  i
> array([ 1, 21, 23,  2, 30, 23, 20, 30, 17])
>   >>>  b, = i.choose(np.vsplit(a.T, a.shape[1]))
> Traceback (most recent call last):
>     File "<input>", line 1, in<module>
> ValueError: Need between 2 and (32) array objects (inclusive).
>
> Compare with:
>
>   >>>  a = np.arange(9*31)
>   >>>  a.shape = 9, 31
>   >>>  i = np.random.randint(0, a.shape[1], size=a.shape[0])
>   >>>  i
> array([14, 22, 18,  6,  1, 12,  8,  8, 30])
>   >>>  b, = i.choose(np.vsplit(a.T, a.shape[1]))
>   >>>  b
> array([ 14,  53,  80,  99, 125, 167, 194, 225, 278])
>
> So, the ValueError should really read "Need between 2 and 31 array object
> (inclusive)", should it not? Also, I can't seem to find this limitation in the
> docs for choose(). I guess I'll stick to using the np.arange(a.shape[0]) method.
>
> Martin
> _______________________________________________
> NumPy-Discussion mailing list
> NumPy-Discussion at scipy.org
> http://mail.scipy.org/mailman/listinfo/numpy-discussion

I think you might want numpy's where function:
 >>> np.where(i,a[:,1],a[:,0])
array([0, 3, 5, 6, 9])

Bruce






More information about the NumPy-Discussion mailing list