[Python-Dev] "groupby" iterator

Hye-Shik Chang perky at i18n.org
Sat Nov 29 17:32:20 EST 2003


On Sat, Nov 29, 2003 at 01:12:34AM -0500, Raymond Hettinger wrote:
> [Guido]
> > I would make one change: after looking at another use case, I'd like
> > to change the outer iterator to produce (key, grouper) tuples.  This
> > way, you can write things like
> > 
> >   totals = {}
> >   for key, group in groupby(sequence):
> >       totals[key] = sum(group)

Heh. I love that!

> 
> Here is an implementation that translates readily into C.  It uses
> Guido's syntax and meets my requirement that bad things don't happen
> when someone runs the outer iterator independently of the inner
> iterator.
> 

I updated my implementation according to your guideline. Please
see attachments. Docstrings are still insufficient due to my
english shortage. :)

Thanks!


Regards,
  Hye-Shik
-------------- next part --------------
Index: Modules/itertoolsmodule.c
===================================================================
RCS file: /cvsroot/python/python/dist/src/Modules/itertoolsmodule.c,v
retrieving revision 1.26
diff -u -u -r1.26 itertoolsmodule.c
--- Modules/itertoolsmodule.c	12 Nov 2003 14:32:26 -0000	1.26
+++ Modules/itertoolsmodule.c	29 Nov 2003 22:25:18 -0000
@@ -2081,6 +2081,332 @@
 };
 
 
+/* groupby object ***********************************************************/
+
+typedef struct {
+	PyObject_HEAD
+	PyObject *it;
+	PyObject *keyfunc;
+	PyObject *tgtkey;
+	PyObject *currkey;
+	PyObject *currvalue;
+} groupbyobject;
+
+static PyTypeObject groupby_type;
+static PyObject *_grouper_create(groupbyobject *, PyObject *);
+
+static PyObject *
+groupby_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+	groupbyobject *gbo;
+	PyObject *it, *keyfunc;
+
+	if (!PyArg_UnpackTuple(args, "groupby", 2, 2, &keyfunc, &it))
+		return NULL;
+
+	if (keyfunc != Py_None && !PyCallable_Check(keyfunc)) {
+		PyErr_SetString(PyExc_ValueError,
+		   "Key argument must be a callable object or None.");
+		return NULL;
+	}
+
+	gbo = (groupbyobject *)type->tp_alloc(type, 0);
+	if (gbo == NULL)
+		return NULL;
+	gbo->tgtkey = NULL;
+	gbo->currkey = NULL;
+	gbo->currvalue = NULL;
+	gbo->keyfunc = keyfunc;
+	Py_INCREF(keyfunc);
+	gbo->it = PyObject_GetIter(it);
+	if (gbo->it == NULL) {
+		Py_DECREF(gbo);
+		return NULL;
+	}
+	return (PyObject *)gbo;
+}
+
+static void
+groupby_dealloc(groupbyobject *gbo)
+{
+	PyObject_GC_UnTrack(gbo);
+	Py_XDECREF(gbo->it);
+	Py_XDECREF(gbo->keyfunc);
+	Py_XDECREF(gbo->tgtkey);
+	Py_XDECREF(gbo->currkey);
+	Py_XDECREF(gbo->currvalue);
+	gbo->ob_type->tp_free(gbo);
+}
+
+static int
+groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg)
+{
+	int err;
+
+	if (gbo->it) {
+		err = visit(gbo->it, arg);
+		if (err)
+			return err;
+	}
+
+	if (gbo->keyfunc) {
+		err = visit(gbo->keyfunc, arg);
+		if (err)
+			return err;
+	}
+
+	if (gbo->tgtkey) {
+		err = visit(gbo->tgtkey, arg);
+		if (err)
+			return err;
+	}
+
+	if (gbo->currkey) {
+		err = visit(gbo->currkey, arg);
+		if (err)
+			return err;
+	}
+
+	if (gbo->currvalue) {
+		err = visit(gbo->currvalue, arg);
+		if (err)
+			return err;
+	}
+
+	return 0;
+}
+
+static PyObject *
+groupby_next(groupbyobject *gbo)
+{
+	PyObject *newvalue, *newkey, *r, *grouper;
+	int rcmp;
+
+	/* skip to next iteration group */
+	for (;;) {
+		if (gbo->currkey == NULL)
+			rcmp = 0;
+		else if (gbo->tgtkey == NULL)
+			break;
+		else if (PyObject_Cmp(gbo->tgtkey, gbo->currkey, &rcmp) == -1)
+			return NULL;
+
+		if (rcmp != 0)
+			break;
+
+		newvalue = PyIter_Next(gbo->it);
+		if (newvalue == NULL)
+			return NULL;
+
+		if (gbo->keyfunc == Py_None) {
+			newkey = newvalue;
+			Py_INCREF(newvalue);
+		} else {
+			newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc,
+							      newvalue, NULL);
+			if (newkey == NULL) {
+				Py_DECREF(newvalue);
+				return NULL;
+			}
+		}
+
+		Py_XDECREF(gbo->currkey);
+		gbo->currkey = newkey;
+		Py_XDECREF(gbo->currvalue);
+		gbo->currvalue = newvalue;
+	}
+
+	Py_XDECREF(gbo->tgtkey);
+	gbo->tgtkey = gbo->currkey;
+	Py_INCREF(gbo->currkey);
+
+	grouper = _grouper_create(gbo, gbo->tgtkey);
+	if (grouper == NULL)
+		return NULL;
+
+	r = PyTuple_New(2);
+	if (r == NULL)
+		return NULL;
+	PyTuple_SET_ITEM(r, 0, gbo->tgtkey);
+	Py_INCREF(gbo->tgtkey);
+	PyTuple_SET_ITEM(r, 1, grouper);
+
+	return r;
+}
+
+PyDoc_STRVAR(groupby_doc,
+"groupby(keyfunc, iterable) -> create an iterator which returns\n\
+(key, sub-iterator) grouped by each value of key(value).\n");
+
+static PyTypeObject groupby_type = {
+	PyObject_HEAD_INIT(NULL)
+	0,				/* ob_size */
+	"itertools.groupby",		/* tp_name */
+	sizeof(groupbyobject),		/* tp_basicsize */
+	0,				/* tp_itemsize */
+	/* methods */
+	(destructor)groupby_dealloc,	/* tp_dealloc */
+	0,				/* tp_print */
+	0,				/* tp_getattr */
+	0,				/* tp_setattr */
+	0,				/* tp_compare */
+	0,				/* tp_repr */
+	0,				/* tp_as_number */
+	0,				/* tp_as_sequence */
+	0,				/* tp_as_mapping */
+	0,				/* tp_hash */
+	0,				/* tp_call */
+	0,				/* tp_str */
+	PyObject_GenericGetAttr,	/* tp_getattro */
+	0,				/* tp_setattro */
+	0,				/* tp_as_buffer */
+	Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+		Py_TPFLAGS_BASETYPE,	/* tp_flags */
+	groupby_doc,			/* tp_doc */
+	(traverseproc)groupby_traverse,	/* tp_traverse */
+	0,				/* tp_clear */
+	0,				/* tp_richcompare */
+	0,				/* tp_weaklistoffset */
+	PyObject_SelfIter,		/* tp_iter */
+	(iternextfunc)groupby_next,	/* tp_iternext */
+	0,				/* tp_methods */
+	0,				/* tp_members */
+	0,				/* tp_getset */
+	0,				/* tp_base */
+	0,				/* tp_dict */
+	0,				/* tp_descr_get */
+	0,				/* tp_descr_set */
+	0,				/* tp_dictoffset */
+	0,				/* tp_init */
+	0,				/* tp_alloc */
+	groupby_new,			/* tp_new */
+	PyObject_GC_Del,		/* tp_free */
+};
+
+
+/* _grouper object (internal) ************************************************/
+
+typedef struct {
+	PyObject_HEAD
+	PyObject *parent;
+	PyObject *tgtkey;
+} _grouperobject;
+
+static PyTypeObject _grouper_type;
+
+static PyObject *
+_grouper_create(groupbyobject *parent, PyObject *tgtkey)
+{
+	_grouperobject *igo;
+
+	igo = PyObject_New(_grouperobject, &_grouper_type);
+	if (igo == NULL)
+		return PyErr_NoMemory();
+	igo->parent = (PyObject *)parent;
+	Py_INCREF(parent);
+	igo->tgtkey = tgtkey;
+	Py_INCREF(tgtkey);
+
+	return (PyObject *)igo;
+}
+
+static void
+_grouper_dealloc(_grouperobject *igo)
+{
+	Py_DECREF(igo->parent);
+	Py_DECREF(igo->tgtkey);
+	PyObject_Del(igo);
+}
+
+static PyObject *
+_grouper_next(_grouperobject *igo)
+{
+	groupbyobject *gbo = (groupbyobject *)igo->parent;
+	PyObject *newvalue, *newkey, *r;
+	int rcmp;
+
+	if (gbo->currvalue == NULL) {
+		newvalue = PyIter_Next(gbo->it);
+		if (newvalue == NULL)
+			return NULL;
+
+		if (gbo->keyfunc == Py_None) {
+			newkey = newvalue;
+			Py_INCREF(newvalue);
+		} else {
+			newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc,
+							      newvalue, NULL);
+			if (newkey == NULL) {
+				Py_DECREF(newvalue);
+				return NULL;
+			}
+		}
+
+		assert(gbo->currkey == NULL);
+		gbo->currkey = newkey;
+		gbo->currvalue = newvalue;
+	}
+
+	assert(gbo->currkey != NULL);
+	if (PyObject_Cmp(igo->tgtkey, gbo->currkey, &rcmp) == -1)
+		return NULL;
+
+	if (rcmp != 0)
+		return NULL;
+
+	r = gbo->currvalue;
+	gbo->currvalue = NULL;
+	Py_DECREF(gbo->currkey);
+	gbo->currkey = NULL;
+
+	return r;
+}
+
+static PyTypeObject _grouper_type = {
+	PyObject_HEAD_INIT(NULL)
+	0,				/* ob_size */
+	"itertools._grouper",		/* tp_name */
+	sizeof(_grouperobject),		/* tp_basicsize */
+	0,				/* tp_itemsize */
+	/* methods */
+	(destructor)_grouper_dealloc,	/* tp_dealloc */
+	0,				/* tp_print */
+	0,				/* tp_getattr */
+	0,				/* tp_setattr */
+	0,				/* tp_compare */
+	0,				/* tp_repr */
+	0,				/* tp_as_number */
+	0,				/* tp_as_sequence */
+	0,				/* tp_as_mapping */
+	0,				/* tp_hash */
+	0,				/* tp_call */
+	0,				/* tp_str */
+	PyObject_GenericGetAttr,	/* tp_getattro */
+	0,				/* tp_setattro */
+	0,				/* tp_as_buffer */
+	Py_TPFLAGS_DEFAULT,		/* tp_flags */
+	0,				/* tp_doc */
+	0, 				/* tp_traverse */
+	0,				/* tp_clear */
+	0,				/* tp_richcompare */
+	0,				/* tp_weaklistoffset */
+	PyObject_SelfIter,		/* tp_iter */
+	(iternextfunc)_grouper_next,	/* tp_iternext */
+	0,				/* tp_methods */
+	0,				/* tp_members */
+	0,				/* tp_getset */
+	0,				/* tp_base */
+	0,				/* tp_dict */
+	0,				/* tp_descr_get */
+	0,				/* tp_descr_set */
+	0,				/* tp_dictoffset */
+	0,				/* tp_init */
+	0,				/* tp_alloc */
+	0,				/* tp_new */
+	_PyObject_Del,			/* tp_free */
+};
+
+
 /* module level code ********************************************************/
 
 PyDoc_STRVAR(module_doc,
@@ -2103,6 +2429,7 @@
 chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\
 takewhile(pred, seq) --> seq[0], seq[1], until pred fails\n\
 dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\
+groupby(keyfunc, iterable) --> sub-iteraters grouped by value of keyfunc(v)\n\
 ");
 
 
@@ -2130,6 +2457,7 @@
 		&count_type,
 		&izip_type,
 		&repeat_type,
+		&groupby_type,
 		NULL
 	};
 
@@ -2148,5 +2476,6 @@
 		return;
 	if (PyType_Ready(&tee_type) < 0)
 		return;
-
+	if (PyType_Ready(&_grouper_type) < 0)
+		return;
 }
-------------- next part --------------
import unittest
from itertools import groupby

class TestBasicOps(unittest.TestCase):

    def test_groupby(self):
        # Check zero length input
        self.assertEqual([], list(groupby(lambda r:r[0], [])))

        # Check normal input
        s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
             (2,15,22), (3,16,23), (3,17,23)]
        dup = []
        for k, g in groupby(lambda r:r[0], s):
            for elem in g:
                self.assertEqual(k, elem[0])
                dup.append(elem)
        self.assertEqual(s, dup)

        # Check nested case
        dup = []
        for k, g in groupby(lambda r:r[0], s):
            for ik, ig in groupby(lambda r:r[2], g):
                for elem in ig:
                    self.assertEqual(k, elem[0])
                    self.assertEqual(ik, elem[2])
                    dup.append(elem)
        self.assertEqual(s, dup)

        # Check case where inner iterator is not used
        keys = [k for k, g in groupby(lambda r:r[0], s)]
        expectedkeys = set([r[0] for r in s])
        self.assertEqual(set(keys), expectedkeys)
        self.assertEqual(len(keys), len(expectedkeys))

        # Check case where key is None
        word = 'abracadabra'
        keys = [k for k, g in groupby(None, list.sorted(word))]
        expectedkeys = set(word)
        self.assertEqual(set(keys), expectedkeys)
        self.assertEqual(len(keys), len(expectedkeys))

        # Exercise pipes and filters style
        s = 'abracadabra'
        ilen = lambda it: len(list(it))
        # sort s | uniq
        r = [k for k, g in groupby(None, list.sorted(s))]
        self.assertEqual(r, ['a', 'b', 'c', 'd', 'r'])
        # sort s | uniq -d
        r = [k for k, g in groupby(None, list.sorted(s)) if ilen(g)>1]
        self.assertEqual(r, ['a', 'b', 'r'])
        # sort s | uniq -c
        r = [(ilen(g), k) for k, g in groupby(None, list.sorted(s))]
        self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')])
        # sort s | uniq -c | sort -rn | head -3
        r = list.sorted([(ilen(g), k) for k, g in groupby(None, list.sorted(s))], reverse=True)[:3]
        self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')])

        # Uniteratable argument
        self.assertRaises(TypeError, groupby, None, None)

        # iter.next failure
        class ExpectedError(Exception):
            pass
        def delayed_raise(n=0):
            for i in range(n):
                yield 'yo'
            raise ExpectedError
        def gulp(key, iterable, func=list):
            return [func(g) for k, g in groupby(key, iterable)]

        # iter.next failure on outer object
        self.assertRaises(ExpectedError, gulp, None, delayed_raise(0))
        # iter.next failure on inner object
        self.assertRaises(ExpectedError, gulp, None, delayed_raise(1))

        # __cmp__ failure
        class DummyCmp:
            def __cmp__(self, dst):
                raise ExpectedError
        s = [DummyCmp(), DummyCmp(), None]

        # __cmp__ failure on outer object
        self.assertRaises(ExpectedError, gulp, None, s, id)
        # __cmp__ failure on inner object
        self.assertRaises(ExpectedError, gulp, None, s)

        # keyfunc failure
        def keyfunc(obj):
            if keyfunc.skip > 0:
                keyfunc.skip -= 1
                return obj
            else:
                raise ExpectedError

        # keyfunc failure on outer object
        keyfunc.skip = 0
        self.assertRaises(ExpectedError, gulp, keyfunc, [None])
        keyfunc.skip = 1
        self.assertRaises(ExpectedError, gulp, keyfunc, [None, None])


suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestBasicOps))
unittest.TextTestRunner(verbosity=2).run(suite)


More information about the Python-Dev mailing list