[Python-checkins] r61118 - in python/trunk: Lib/test/test_itertools.py Modules/itertoolsmodule.c

raymond.hettinger python-checkins at python.org
Thu Feb 28 23:30:43 CET 2008


Author: raymond.hettinger
Date: Thu Feb 28 23:30:42 2008
New Revision: 61118

Modified:
   python/trunk/Lib/test/test_itertools.py
   python/trunk/Modules/itertoolsmodule.c
Log:
Have itertools.chain() consume its inputs lazily instead of building a tuple of iterators at the outset.

Modified: python/trunk/Lib/test/test_itertools.py
==============================================================================
--- python/trunk/Lib/test/test_itertools.py	(original)
+++ python/trunk/Lib/test/test_itertools.py	Thu Feb 28 23:30:42 2008
@@ -50,7 +50,7 @@
         self.assertEqual(list(chain('abc')), list('abc'))
         self.assertEqual(list(chain('')), [])
         self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
-        self.assertRaises(TypeError, chain, 2, 3)
+        self.assertRaises(TypeError, list,chain(2, 3))
 
     def test_combinations(self):
         self.assertRaises(TypeError, combinations, 'abc')   # missing r argument
@@ -670,7 +670,7 @@
             for g in (G, I, Ig, S, L, R):
                 self.assertEqual(list(chain(g(s))), list(g(s)))
                 self.assertEqual(list(chain(g(s), g(s))), list(g(s))+list(g(s)))
-            self.assertRaises(TypeError, chain, X(s))
+            self.assertRaises(TypeError, list, chain(X(s)))
             self.assertRaises(TypeError, list, chain(N(s)))
             self.assertRaises(ZeroDivisionError, list, chain(E(s)))
 

Modified: python/trunk/Modules/itertoolsmodule.c
==============================================================================
--- python/trunk/Modules/itertoolsmodule.c	(original)
+++ python/trunk/Modules/itertoolsmodule.c	Thu Feb 28 23:30:42 2008
@@ -1601,92 +1601,92 @@
 
 typedef struct {
 	PyObject_HEAD
-	Py_ssize_t tuplesize;
-	Py_ssize_t iternum;		/* which iterator is active */
-	PyObject *ittuple;		/* tuple of iterators */
+	PyObject *source;		/* Iterator over input iterables */
+	PyObject *active;		/* Currently running input iterator */
 } chainobject;
 
 static PyTypeObject chain_type;
 
-static PyObject *
-chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+static PyObject * 
+chain_new_internal(PyTypeObject *type, PyObject *source)
 {
 	chainobject *lz;
-	Py_ssize_t tuplesize = PySequence_Length(args);
-	Py_ssize_t i;
-	PyObject *ittuple;
-
-	if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds))
-		return NULL;
-
-	/* obtain iterators */
-	assert(PyTuple_Check(args));
-	ittuple = PyTuple_New(tuplesize);
-	if (ittuple == NULL)
-		return NULL;
-	for (i=0; i < tuplesize; ++i) {
-		PyObject *item = PyTuple_GET_ITEM(args, i);
-		PyObject *it = PyObject_GetIter(item);
-		if (it == NULL) {
-			if (PyErr_ExceptionMatches(PyExc_TypeError))
-				PyErr_Format(PyExc_TypeError,
-				    "chain argument #%zd must support iteration",
-				    i+1);
-			Py_DECREF(ittuple);
-			return NULL;
-		}
-		PyTuple_SET_ITEM(ittuple, i, it);
-	}
 
-	/* create chainobject structure */
 	lz = (chainobject *)type->tp_alloc(type, 0);
 	if (lz == NULL) {
-		Py_DECREF(ittuple);
+		Py_DECREF(source);
 		return NULL;
 	}
+	
+	lz->source = source;
+	lz->active = NULL;
+	return (PyObject *)lz;
+}
 
-	lz->ittuple = ittuple;
-	lz->iternum = 0;
-	lz->tuplesize = tuplesize;
+static PyObject *
+chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+	PyObject *source;
 
-	return (PyObject *)lz;
+	if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds))
+		return NULL;
+	
+	source = PyObject_GetIter(args);
+	if (source == NULL)
+		return NULL;
+
+	return chain_new_internal(type, source);
 }
 
 static void
 chain_dealloc(chainobject *lz)
 {
 	PyObject_GC_UnTrack(lz);
-	Py_XDECREF(lz->ittuple);
+	Py_XDECREF(lz->active);
+	Py_XDECREF(lz->source);
 	Py_TYPE(lz)->tp_free(lz);
 }
 
 static int
 chain_traverse(chainobject *lz, visitproc visit, void *arg)
 {
-	Py_VISIT(lz->ittuple);
+	Py_VISIT(lz->source);
+	Py_VISIT(lz->active);
 	return 0;
 }
 
 static PyObject *
 chain_next(chainobject *lz)
 {
-	PyObject *it;
 	PyObject *item;
 
-	while (lz->iternum < lz->tuplesize) {
-		it = PyTuple_GET_ITEM(lz->ittuple, lz->iternum);
-		item = PyIter_Next(it);
-		if (item != NULL)
-			return item;
-		if (PyErr_Occurred()) {
-			if (PyErr_ExceptionMatches(PyExc_StopIteration))
-				PyErr_Clear();
-			else
-				return NULL;
+	if (lz->source == NULL)
+		return NULL;				/* already stopped */
+
+	if (lz->active == NULL) {
+		PyObject *iterable = PyIter_Next(lz->source);
+		if (iterable == NULL) {
+			Py_CLEAR(lz->source);
+			return NULL;			/* no more input sources */
+		}
+		lz->active = PyObject_GetIter(iterable);
+		if (lz->active == NULL) {
+			Py_DECREF(iterable);
+			Py_CLEAR(lz->source);
+			return NULL;			/* input not iterable */
 		}
-		lz->iternum++;
 	}
-	return NULL;
+	item = PyIter_Next(lz->active);
+	if (item != NULL)
+		return item;
+	if (PyErr_Occurred()) {
+		if (PyErr_ExceptionMatches(PyExc_StopIteration))
+			PyErr_Clear();
+		else
+			return NULL; 			/* input raised an exception */
+	}
+	Py_CLEAR(lz->active);
+	return chain_next(lz);			/* recurse and use next active */
 }
 
 PyDoc_STRVAR(chain_doc,


More information about the Python-checkins mailing list