Nasty gotcha/bug in heapq.nlargest/nsmallest
Peter Otten
__peter__ at web.de
Thu May 15 03:06:19 EDT 2008
George Sakkis wrote:
> I spent several hours debugging some bogus data results that turned
> out to be caused by the fact that heapq.nlargest doesn't respect rich
> comparisons:
>
> import heapq
> import random
>
> class X(object):
> def __init__(self, x): self.x=x
> def __repr__(self): return 'X(%s)' % self.x
> if True:
> # this succeeds
> def __cmp__(self, other): return cmp(self.x , other.x)
> else:
> # this fails
> def __lt__(self, other): return self.x < other.x
>
> s = [X(i) for i in range(10)]
> random.shuffle(s)
>
> s1 = heapq.nlargest(5, s)
> s2 = sorted(s, reverse=True)[:5]
> assert s1 == s2, (s,s1,s2)
>
> s1 = heapq.nsmallest(5, s)
> s2 = sorted(s)[:5]
> assert s1 == s2, (s,s1,s2)
>
>
> According to the docs, nlargest is equivalent to: "sorted(iterable,
> key=key, reverse=True)[:n]" and similarly for nsmallest. So that must
> be at least a documentation bug, if not an implementation one.
Implementing a subset of the rich comparisons is always dangerous. According
to my ad hoc test you need <, <=, and == for nlargest()/nsmallest() to
work:
import heapq
import random
used_rich = set()
class X(object):
def __init__(self, x): self.x=x
def __repr__(self): return 'X(%s)' % self.x
def __lt__(self, other):
used_rich.add("lt")
return self.x < other.x
def __eq__(self, other):
used_rich.add("eq")
return self.x == other.x
def __gt__(self, other):
used_rich.add("gt")
return self.x > other.x
def __ge__(self, other):
used_rich.add("ge")
return self.x >= other.x
def __ne__(self, other):
used_rich.add("ne")
return self.x != other.x
def __le__(self, other):
used_rich.add("le")
return self.x <= other.x
s = [X(8), X(0), X(3), X(4), X(5), X(2), X(1), X(6), X(7), X(9)]
smallest = sorted(s)[:5]
largest = sorted(s, reverse=True)[:5]
print "used by sorted:", used_rich
used_rich = set()
for i in range(10000):
s1 = heapq.nlargest(5, s)
assert s1 == largest, (s, s1, largest)
s1 = heapq.nsmallest(5, s)
assert s1 == smallest, (s, s1, smallest)
random.shuffle(s)
print "used by nlargest/nsmallest:", used_rich
Output:
used by sorted: set(['lt'])
used by nlargest/nsmallest: set(['lt', 'le', 'eq'])
Peter
More information about the Python-list
mailing list