[Python-checkins] cpython (3.3): Issue #18594: Make the C code more closely match the pure python code.

raymond.hettinger python-checkins at python.org
Sat Oct 5 01:53:36 CEST 2013


http://hg.python.org/cpython/rev/e4cec1116e5c
changeset:   85960:e4cec1116e5c
branch:      3.3
parent:      85955:bfebfadfc4aa
user:        Raymond Hettinger <python at rcn.com>
date:        Fri Oct 04 16:51:02 2013 -0700
summary:
  Issue #18594:  Make the C code more closely match the pure python code.

files:
  Lib/test/test_collections.py |  24 +++++++++++++++++++
  Modules/_collectionsmodule.c |  29 ++++++++++++-----------
  2 files changed, 39 insertions(+), 14 deletions(-)


diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -818,6 +818,24 @@
 ### Counter
 ################################################################################
 
+class CounterSubclassWithSetItem(Counter):
+    # Test a counter subclass that overrides __setitem__
+    def __init__(self, *args, **kwds):
+        self.called = False
+        Counter.__init__(self, *args, **kwds)
+    def __setitem__(self, key, value):
+        self.called = True
+        Counter.__setitem__(self, key, value)
+
+class CounterSubclassWithGet(Counter):
+    # Test a counter subclass that overrides get()
+    def __init__(self, *args, **kwds):
+        self.called = False
+        Counter.__init__(self, *args, **kwds)
+    def get(self, key, default):
+        self.called = True
+        return Counter.get(self, key, default)
+
 class TestCounter(unittest.TestCase):
 
     def test_basics(self):
@@ -1022,6 +1040,12 @@
         self.assertEqual(m,
              OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
 
+        # test fidelity to the pure python version
+        c = CounterSubclassWithSetItem('abracadabra')
+        self.assertTrue(c.called)
+        c = CounterSubclassWithGet('abracadabra')
+        self.assertTrue(c.called)
+
 
 ################################################################################
 ### OrderedDict
diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c
--- a/Modules/_collectionsmodule.c
+++ b/Modules/_collectionsmodule.c
@@ -1689,17 +1689,17 @@
 static PyObject *
 _count_elements(PyObject *self, PyObject *args)
 {
-    _Py_IDENTIFIER(__getitem__);
+    _Py_IDENTIFIER(get);
     _Py_IDENTIFIER(__setitem__);
     PyObject *it, *iterable, *mapping, *oldval;
     PyObject *newval = NULL;
     PyObject *key = NULL;
     PyObject *zero = NULL;
     PyObject *one = NULL;
-    PyObject *mapping_get = NULL;
-    PyObject *mapping_getitem;
+    PyObject *bound_get = NULL;
+    PyObject *mapping_get;
+    PyObject *dict_get;
     PyObject *mapping_setitem;
-    PyObject *dict_getitem;
     PyObject *dict_setitem;
 
     if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
@@ -1713,15 +1713,16 @@
     if (one == NULL)
         goto done;
 
-    mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__);
-    dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__);
+    /* Only take the fast path when get() and __setitem__()
+     * have not been overridden.
+     */
+    mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get);
+    dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get);
     mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__);
     dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__);
 
-    if (mapping_getitem != NULL &&
-        mapping_getitem == dict_getitem &&
-        mapping_setitem != NULL &&
-        mapping_setitem == dict_setitem) {
+    if (mapping_get != NULL && mapping_get == dict_get &&
+        mapping_setitem != NULL && mapping_setitem == dict_setitem) {
         while (1) {
             key = PyIter_Next(it);
             if (key == NULL)
@@ -1741,8 +1742,8 @@
             Py_DECREF(key);
         }
     } else {
-        mapping_get = PyObject_GetAttrString(mapping, "get");
-        if (mapping_get == NULL)
+        bound_get = PyObject_GetAttrString(mapping, "get");
+        if (bound_get == NULL)
             goto done;
 
         zero = PyLong_FromLong(0);
@@ -1753,7 +1754,7 @@
             key = PyIter_Next(it);
             if (key == NULL)
                 break;
-            oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL);
+            oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL);
             if (oldval == NULL)
                 break;
             newval = PyNumber_Add(oldval, one);
@@ -1771,7 +1772,7 @@
     Py_DECREF(it);
     Py_XDECREF(key);
     Py_XDECREF(newval);
-    Py_XDECREF(mapping_get);
+    Py_XDECREF(bound_get);
     Py_XDECREF(zero);
     Py_XDECREF(one);
     if (PyErr_Occurred())

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list