[Numpy-discussion] efficient use of numpy.where() and .any()

Mark.Miller mpmusu at cc.usu.edu
Mon Apr 23 10:37:57 EDT 2007


Greetings:

In some of my code, I need to use large matrix of random numbers that 
meet specific criteria (i.e., some random numbers need to be removed and 
replaces with new ones).

I have been working with .any() and .where() to facilitate this process. 
  In the code below, .any() is used in a while loop to test for the 
presence of matrix elements that do not meet criteria.  After that, 
.where() is used to obtain the tuple of indices of elements that do not 
meet criteria.  I then use iteration over the tuple of indices and 
replace each 'bad' random number one at a time.

Here's an example:

 >>> a=numpy.random.normal(0,1,(3,3))
 >>> a
array([[ 0.79653228, -2.28751484,  1.15989261],
        [-0.7549573 ,  2.35133816,  0.22551564],
        [ 0.85666713,  1.17484243,  1.18306248]])
 >>> while (a<0).any() or (a>1).any():
	ind=numpy.where(a<0)
	for aa in xrange(len(ind[0])):
		a[ind[0][aa],ind[1][aa]]=numpy.random.normal(0,1)
	ind=numpy.where(a>1)
	for aa in xrange(len(ind[0])):
		a[ind[0][aa],ind[1][aa]]=numpy.random.normal(0,1)

		
 >>> a
array([[ 0.79653228,  0.99298488,  0.24903299],
        [ 0.10884186,  0.10139654,  0.22551564],
        [ 0.85666713,  0.76554597,  0.38383126]])
 >>>


My main question:  is there a more efficient approach that I could be 
using?  I would ideally like to be able to remove the two for loops.



An ancillary question:  I don't quite see why the test in the while loop 
above works fine, however, the following formation does not.

 >>> while (0<a<1).any():
	ind=numpy.where(a<0)
	for aa in xrange(len(ind[0])):
		a[ind[0][aa],ind[1][aa]]=numpy.random.normal(0,1)
	ind=numpy.where(a>1)
	for aa in xrange(len(ind[0])):
		a[ind[0][aa],ind[1][aa]]=numpy.random.normal(0,1)
	print type(ind)

	
Traceback (most recent call last):
   File "<pyshell#71>", line 1, in <module>
     while (0<a<1).any():
ValueError: The truth value of an array with more than one element is 
ambiguous. Use a.any() or a.all()
 >>>

Can someone clarify this?

Thanks,

-Mark



More information about the NumPy-Discussion mailing list