[pypy-commit] pypy more_strategies: Provide fast paths in find for integer and float strategy lists.

ltratt noreply at buildbot.pypy.org
Thu Nov 7 23:14:26 CET 2013


Author: Laurence Tratt <laurie at tratt.net>
Branch: more_strategies
Changeset: r67876:599ed4285a6d
Date: 2013-11-07 21:56 +0000
http://bitbucket.org/pypy/pypy/changeset/599ed4285a6d/

Log:	Provide fast paths in find for integer and float strategy lists.

	This patch affects "x in l" and "l.index(x)" where l is a list. It
	leaves the expected common path (searching for an integer in an
	integer list; for a float in a flaot list) unchanged. However,
	comparisons of other types are significantly sped up. In some cases,
	we can use the type of an object to immediately prove that it can't
	be in the list (e.g. a user object which doesn't override __eq__
	can't possibly be in an integer or float list) and return
	immediately; in others (e.g. when searching for a float in an
	integer list), we can convert the input type into a primitive that
	allows significantly faster comparisons.

	As rough examples, searching for a float in an integer list is
	approximately 3x faster; for a long in an integer list approximately
	10x faster; searching for a string in an integer list returns
	immediately, no matter the size of the list.

diff --git a/pypy/objspace/std/listobject.py b/pypy/objspace/std/listobject.py
--- a/pypy/objspace/std/listobject.py
+++ b/pypy/objspace/std/listobject.py
@@ -19,6 +19,7 @@
 from pypy.objspace.std import slicetype
 from pypy.objspace.std.floatobject import W_FloatObject
 from pypy.objspace.std.intobject import W_IntObject
+from pypy.objspace.std.longobject import W_LongObject
 from pypy.objspace.std.iterobject import (W_FastListIterObject,
     W_ReverseSeqIterObject)
 from pypy.objspace.std.sliceobject import W_SliceObject, normalize_simple_slice
@@ -1537,6 +1538,47 @@
     def getitems_int(self, w_list):
         return self.unerase(w_list.lstorage)
 
+    _orig_find = find
+    def find(self, w_list, w_obj, start, stop):
+        # Find an element in this integer list. For integers, floats, and longs,
+        # we can use primitive comparisons (possibly after a conversion to an
+        # int). For other user types (strings and user objects which don't play
+        # funny tricks with __eq__ etc.) we can prove immediately that an object
+        # could not be in the list and return.
+        #
+        # Note: although it might seem we want to do the clever tricks first,
+        # we expect that the common case is searching for an integer in an
+        # integer list. The clauses of this if are thus ordered in likely order
+        # of frequency of use.
+
+        w_objt = type(w_obj)
+        if w_objt is W_IntObject:
+            return self._safe_find(w_list, self.unwrap(w_obj), start, stop)
+        elif w_objt is W_FloatObject or w_objt is W_LongObject:
+            if w_objt is W_FloatObject:
+                # Asking for an int from a W_FloatObject can return either a
+                # W_IntObject or W_LongObject, so we then need to disambiguate
+                # between the two.
+                w_obj = self.space.int(w_obj)
+                w_objt = type(w_obj)
+
+            if w_objt is W_IntObject:
+                intv = self.unwrap(w_obj)
+            else:
+                assert w_objt is W_LongObject
+                try:
+                    intv = w_obj.toint()
+                except OverflowError:
+                    # Longs which overflow can't possibly be found in an integer
+                    # list.
+                    raise ValueError
+            return self._safe_find(w_list, intv, start, stop)
+        elif w_objt is W_StringObject or w_objt is W_UnicodeObject:
+            raise ValueError
+        elif self.space.type(w_obj).compares_by_identity():
+            raise ValueError
+        return self._orig_find(w_list, w_obj, start, stop)
+
 
     _base_extend_from_list = _extend_from_list
 
@@ -1581,6 +1623,19 @@
     def list_is_correct_type(self, w_list):
         return w_list.strategy is self.space.fromcache(FloatListStrategy)
 
+    _orig_find = find
+    def find(self, w_list, w_obj, start, stop):
+        w_objt = type(w_obj)
+        if w_objt is W_FloatObject:
+            return self._safe_find(w_list, self.unwrap(w_obj), start, stop)
+        elif w_objt is W_IntObject or w_objt is W_LongObject:
+            return self._safe_find(w_list, w_obj.float_w(self.space), start, stop)
+        elif w_objt is W_StringObject or w_objt is W_UnicodeObject:
+            raise ValueError
+        elif self.space.type(w_obj).compares_by_identity():
+            raise ValueError
+        return self._orig_find(w_list, w_obj, start, stop)
+
     def sort(self, w_list, reverse):
         l = self.unerase(w_list.lstorage)
         sorter = FloatSort(l, len(l))
diff --git a/pypy/objspace/std/test/test_listobject.py b/pypy/objspace/std/test/test_listobject.py
--- a/pypy/objspace/std/test/test_listobject.py
+++ b/pypy/objspace/std/test/test_listobject.py
@@ -457,6 +457,39 @@
         assert l.__contains__(2)
         assert not l.__contains__("2")
         assert l.__contains__(1.0)
+        assert l.__contains__(1.1)
+        assert l.__contains__(1.9)
+        assert l.__contains__(1L)
+        assert not l.__contains__(object())
+        assert not l.__contains__(object())
+        class t(object):
+            def __eq__(self, o):
+                if o == 2:
+                    return True
+                return False
+        assert l.__contains__(t())
+        assert not [1,3].__contains__(t())
+        assert "1" not in l
+       
+        l = [1.0,2.0,3.0]
+        assert l.__contains__(2.0)
+        assert l.__contains__(2)
+        assert not l.__contains__(4)
+        assert not l.__contains__("2")
+        assert l.__contains__(1.0)
+        assert not l.__contains__(1.1)
+        assert l.__contains__(1L)
+        assert not l.__contains__(4.0)
+        assert not l.__contains__(object())
+        assert l.__contains__(t())
+        assert not [1.0,3.0].__contains__(t())
+        assert "1" not in l
+
+        import sys
+        l = [sys.maxint]
+        assert l.__contains__(sys.maxint)
+        assert not l.__contains__(sys.maxint + 1)
+        assert not l.__contains__(sys.maxint * 1.0)
 
         l = ["1","2","3"]
         assert l.__contains__("2")


More information about the pypy-commit mailing list