[Numpy-discussion] Optimisation of a Numpy set intersection function

Tim Churches tchur at optushome.com.au
Sun Apr 8 17:32:09 EDT 2001


I have been trying to optimise a simple set intersection function 
for a Numpy-based application on which I am working. The function needs
to
find the common elements (i.e the intersection) of two or more 
Numpy rank-1 arrays of integers passed to it as arguments - call them
array A, array B etc. My first attempt used searchsorted() to see where
in array A each element of array B would fit and then check if the
element in array A at that position equalled the element in array B - if
it did,
then that element was part of the intersection:

def intersect(A,B):
	c = searchsorted(A,B)
	d = compress(less(c,len(A)),B)
	return compress(equal(take(A,compress(less(c,len(A)),c)),d),d)

This works OK but it requires the arrays to be stored in sorted order
(which is what
we did) or to be sorted by the function (not shown).

I should add that the arrays passed to the function may have up to a few
million elements each.

Ole Nielsen at ANU (Australian National University), who gave a paper on
data
mining using Python at IPC9, suggested an improvement, shown below.
Provided the arrays are pre-sorted, Ole's function is about 40% faster
than my attempt, and about the same speed if the arrays aren't
pre-sorted.
However, Ole's function can be generalised to find the intersection of
more than
two arrays at a time. Most of the time is spent sorting the arrays, and
this
increases as N.log(N) as the total size of the concatenated arrays (N)
increases
(I checked this empirically).

def intersect(Arraylist):
	C = sort(concatenate(Arraylist))
        D = subtract(C[0:-n], C[n:])   #or
        # D = convolve(C,[n,-n])   
        return compress(equal(D,0),C)     #or
        # return take(C,nonzero(equal(D,0)))   

Paul Dubois suggested the following elegant alternative:

def intersect(x,y):
	return compress(sum(equal(subtract.outer(x,y),0),1),x)

Unfortunately, perhaps due to the overhead of creating the rank-2 array
to
hold the results of the subtract.outer() method call, it turns out to be 
slower than Ole's function, as well as using huge tracts of memory.

My questions for the list, are:

a) can anyone suggest any further optimisation of Ole's function, or
some
alternative?

b) how much do you think would be gained by re-implementing this
function in C using the Numpy C API? If such an effort would be
worthwhile,
are there any things we should consider while tackling this task?

Regards,

Tim Churches
Sydney, Australia




More information about the NumPy-Discussion mailing list