[Python-checkins] GH-98363: Fix exception handling in batched() (GH-98523)

rhettinger webhook-mailer at python.org
Fri Oct 21 13:32:22 EDT 2022


https://github.com/python/cpython/commit/a5ff80c8bc96210bace3ffb683b01fbd7f4ab76d
commit: a5ff80c8bc96210bace3ffb683b01fbd7f4ab76d
branch: main
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2022-10-21T12:31:52-05:00
summary:

GH-98363:  Fix exception handling in batched() (GH-98523)

files:
M Lib/test/test_itertools.py
M Modules/itertoolsmodule.c

diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index c0e35711a2b3..a0a740fba8e8 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -2012,6 +2012,20 @@ def __iter__(self):
     def __next__(self):
         3 // 0
 
+class E2:
+    'Test propagation of exceptions after two iterations'
+    def __init__(self, seqn):
+        self.seqn = seqn
+        self.i = 0
+    def __iter__(self):
+        return self
+    def __next__(self):
+        if self.i == 2:
+            raise ZeroDivisionError
+        v = self.seqn[self.i]
+        self.i += 1
+        return v
+
 class S:
     'Test immediate stop'
     def __init__(self, seqn):
@@ -2050,6 +2064,7 @@ def test_batched(self):
         self.assertRaises(TypeError, batched, X(s), 2)
         self.assertRaises(TypeError, batched, N(s), 2)
         self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
+        self.assertRaises(ZeroDivisionError, list, batched(E2(s), 4))
 
     def test_chain(self):
         for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 868e8a8b384f..627e698fc6b9 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -154,23 +154,36 @@ batched_next(batchedobject *bo)
     if (result == NULL) {
         return NULL;
     }
+    iternextfunc iternext = *Py_TYPE(it)->tp_iternext;
+    PyObject **items = PySequence_Fast_ITEMS(result);
     for (i=0 ; i < n ; i++) {
-        item = PyIter_Next(it);
+        item = iternext(it);
         if (item == NULL) {
-            break;
+            goto null_item;
+        }
+        items[i] = item;
+    }
+    return result;
+
+ null_item:
+    if (PyErr_Occurred()) {
+        if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
+            PyErr_Clear();
+        } else {
+            /* input raised an exception other than StopIteration */
+            Py_CLEAR(bo->it);
+            Py_DECREF(result);
+            return NULL;
         }
-        PyList_SET_ITEM(result, i, item);
     }
     if (i == 0) {
         Py_CLEAR(bo->it);
         Py_DECREF(result);
         return NULL;
     }
-    if (i < n) {
-        PyObject *short_list = PyList_GetSlice(result, 0, i);
-        Py_SETREF(result, short_list);
-    }
-    return result;
+    PyObject *short_list = PyList_GetSlice(result, 0, i);
+    Py_DECREF(result);
+    return short_list;
 }
 
 static PyTypeObject batched_type = {



More information about the Python-checkins mailing list