[Python-3000-checkins] r61126 - in python/branches/py3k: Lib/decimal.py Lib/test/test_decimal.py Lib/test/test_itertools.py Modules/itertoolsmodule.c

christian.heimes python-3000-checkins at python.org
Fri Feb 29 15:57:44 CET 2008


Author: christian.heimes
Date: Fri Feb 29 15:57:44 2008
New Revision: 61126

Modified:
   python/branches/py3k/   (props changed)
   python/branches/py3k/Lib/decimal.py
   python/branches/py3k/Lib/test/test_decimal.py
   python/branches/py3k/Lib/test/test_itertools.py
   python/branches/py3k/Modules/itertoolsmodule.c
Log:
Merged revisions 61038,61042-61045,61047,61050,61053,61055-61056,61061-61062,61066,61068,61070,61083,61085,61092-61097,61103-61104,61110-61112,61114-61115,61117-61125 via svnmerge from 
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r61118 | raymond.hettinger | 2008-02-28 23:30:42 +0100 (Thu, 28 Feb 2008) | 1 line
  
  Have itertools.chain() consume its inputs lazily instead of building a tuple of iterators at the outset.
........
  r61119 | raymond.hettinger | 2008-02-28 23:46:41 +0100 (Thu, 28 Feb 2008) | 1 line
  
  Add alternate constructor for itertools.chain().
........
  r61123 | mark.dickinson | 2008-02-29 03:16:37 +0100 (Fri, 29 Feb 2008) | 2 lines
  
  Add __format__ method to Decimal, to support PEP 3101
........
  r61124 | raymond.hettinger | 2008-02-29 03:21:48 +0100 (Fri, 29 Feb 2008) | 1 line
  
  Handle the repeat keyword argument for itertools.product().
........
  r61125 | mark.dickinson | 2008-02-29 04:29:17 +0100 (Fri, 29 Feb 2008) | 2 lines
  
  Fix docstring typo.
........


Modified: python/branches/py3k/Lib/decimal.py
==============================================================================
--- python/branches/py3k/Lib/decimal.py	(original)
+++ python/branches/py3k/Lib/decimal.py	Fri Feb 29 15:57:44 2008
@@ -2381,6 +2381,29 @@
             coeff = str(int(coeff)+1)
         return _dec_from_triple(self._sign, coeff, exp)
 
+    def _round(self, places, rounding):
+        """Round a nonzero, nonspecial Decimal to a fixed number of
+        significant figures, using the given rounding mode.
+
+        Infinities, NaNs and zeros are returned unaltered.
+
+        This operation is quiet: it raises no flags, and uses no
+        information from the context.
+
+        """
+        if places <= 0:
+            raise ValueError("argument should be at least 1 in _round")
+        if self._is_special or not self:
+            return Decimal(self)
+        ans = self._rescale(self.adjusted()+1-places, rounding)
+        # it can happen that the rescale alters the adjusted exponent;
+        # for example when rounding 99.97 to 3 significant figures.
+        # When this happens we end up with an extra 0 at the end of
+        # the number; a second rescale fixes this.
+        if ans.adjusted() != self.adjusted():
+            ans = ans._rescale(ans.adjusted()+1-places, rounding)
+        return ans
+
     def to_integral_exact(self, rounding=None, context=None):
         """Rounds to a nearby integer.
 
@@ -3432,6 +3455,95 @@
             return self     # My components are also immutable
         return self.__class__(str(self))
 
+    # PEP 3101 support.  See also _parse_format_specifier and _format_align
+    def __format__(self, specifier, context=None):
+        """Format a Decimal instance according to the given specifier.
+
+        The specifier should be a standard format specifier, with the
+        form described in PEP 3101.  Formatting types 'e', 'E', 'f',
+        'F', 'g', 'G', and '%' are supported.  If the formatting type
+        is omitted it defaults to 'g' or 'G', depending on the value
+        of context.capitals.
+
+        At this time the 'n' format specifier type (which is supposed
+        to use the current locale) is not supported.
+        """
+
+        # Note: PEP 3101 says that if the type is not present then
+        # there should be at least one digit after the decimal point.
+        # We take the liberty of ignoring this requirement for
+        # Decimal---it's presumably there to make sure that
+        # format(float, '') behaves similarly to str(float).
+        if context is None:
+            context = getcontext()
+
+        spec = _parse_format_specifier(specifier)
+
+        # special values don't care about the type or precision...
+        if self._is_special:
+            return _format_align(str(self), spec)
+
+        # a type of None defaults to 'g' or 'G', depending on context
+        # if type is '%', adjust exponent of self accordingly
+        if spec['type'] is None:
+            spec['type'] = ['g', 'G'][context.capitals]
+        elif spec['type'] == '%':
+            self = _dec_from_triple(self._sign, self._int, self._exp+2)
+
+        # round if necessary, taking rounding mode from the context
+        rounding = context.rounding
+        precision = spec['precision']
+        if precision is not None:
+            if spec['type'] in 'eE':
+                self = self._round(precision+1, rounding)
+            elif spec['type'] in 'gG':
+                if len(self._int) > precision:
+                    self = self._round(precision, rounding)
+            elif spec['type'] in 'fF%':
+                self = self._rescale(-precision, rounding)
+        # special case: zeros with a positive exponent can't be
+        # represented in fixed point; rescale them to 0e0.
+        elif not self and self._exp > 0 and spec['type'] in 'fF%':
+            self = self._rescale(0, rounding)
+
+        # figure out placement of the decimal point
+        leftdigits = self._exp + len(self._int)
+        if spec['type'] in 'fF%':
+            dotplace = leftdigits
+        elif spec['type'] in 'eE':
+            if not self and precision is not None:
+                dotplace = 1 - precision
+            else:
+                dotplace = 1
+        elif spec['type'] in 'gG':
+            if self._exp <= 0 and leftdigits > -6:
+                dotplace = leftdigits
+            else:
+                dotplace = 1
+
+        # figure out main part of numeric string...
+        if dotplace <= 0:
+            num = '0.' + '0'*(-dotplace) + self._int
+        elif dotplace >= len(self._int):
+            # make sure we're not padding a '0' with extra zeros on the right
+            assert dotplace==len(self._int) or self._int != '0'
+            num = self._int + '0'*(dotplace-len(self._int))
+        else:
+            num = self._int[:dotplace] + '.' + self._int[dotplace:]
+
+        # ...then the trailing exponent, or trailing '%'
+        if leftdigits != dotplace or spec['type'] in 'eE':
+            echar = {'E': 'E', 'e': 'e', 'G': 'E', 'g': 'e'}[spec['type']]
+            num = num + "{0}{1:+}".format(echar, leftdigits-dotplace)
+        elif spec['type'] == '%':
+            num = num + '%'
+
+        # add sign
+        if self._sign == 1:
+            num = '-' + num
+        return _format_align(num, spec)
+
+
 def _dec_from_triple(sign, coefficient, exponent, special=False):
     """Create a decimal instance directly, without any validation,
     normalization (e.g. removal of leading zeros) or argument
@@ -5249,8 +5361,136 @@
 
 _all_zeros = re.compile('0*$').match
 _exact_half = re.compile('50*$').match
+
+##### PEP3101 support functions ##############################################
+# The functions parse_format_specifier and format_align have little to do
+# with the Decimal class, and could potentially be reused for other pure
+# Python numeric classes that want to implement __format__
+#
+# A format specifier for Decimal looks like:
+#
+#   [[fill]align][sign][0][minimumwidth][.precision][type]
+#
+
+_parse_format_specifier_regex = re.compile(r"""\A
+(?:
+   (?P<fill>.)?
+   (?P<align>[<>=^])
+)?
+(?P<sign>[-+ ])?
+(?P<zeropad>0)?
+(?P<minimumwidth>(?!0)\d+)?
+(?:\.(?P<precision>0|(?!0)\d+))?
+(?P<type>[eEfFgG%])?
+\Z
+""", re.VERBOSE)
+
 del re
 
+def _parse_format_specifier(format_spec):
+    """Parse and validate a format specifier.
+
+    Turns a standard numeric format specifier into a dict, with the
+    following entries:
+
+      fill: fill character to pad field to minimum width
+      align: alignment type, either '<', '>', '=' or '^'
+      sign: either '+', '-' or ' '
+      minimumwidth: nonnegative integer giving minimum width
+      precision: nonnegative integer giving precision, or None
+      type: one of the characters 'eEfFgG%', or None
+      unicode: either True or False (always True for Python 3.x)
+
+    """
+    m = _parse_format_specifier_regex.match(format_spec)
+    if m is None:
+        raise ValueError("Invalid format specifier: " + format_spec)
+
+    # get the dictionary
+    format_dict = m.groupdict()
+
+    # defaults for fill and alignment
+    fill = format_dict['fill']
+    align = format_dict['align']
+    if format_dict.pop('zeropad') is not None:
+        # in the face of conflict, refuse the temptation to guess
+        if fill is not None and fill != '0':
+            raise ValueError("Fill character conflicts with '0'"
+                             " in format specifier: " + format_spec)
+        if align is not None and align != '=':
+            raise ValueError("Alignment conflicts with '0' in "
+                             "format specifier: " + format_spec)
+        fill = '0'
+        align = '='
+    format_dict['fill'] = fill or ' '
+    format_dict['align'] = align or '<'
+
+    if format_dict['sign'] is None:
+        format_dict['sign'] = '-'
+
+    # turn minimumwidth and precision entries into integers.
+    # minimumwidth defaults to 0; precision remains None if not given
+    format_dict['minimumwidth'] = int(format_dict['minimumwidth'] or '0')
+    if format_dict['precision'] is not None:
+        format_dict['precision'] = int(format_dict['precision'])
+
+    # if format type is 'g' or 'G' then a precision of 0 makes little
+    # sense; convert it to 1.  Same if format type is unspecified.
+    if format_dict['precision'] == 0:
+        if format_dict['type'] in 'gG' or format_dict['type'] is None:
+            format_dict['precision'] = 1
+
+    # record whether return type should be str or unicode
+    format_dict['unicode'] = isinstance(format_spec, unicode)
+
+    return format_dict
+
+def _format_align(body, spec_dict):
+    """Given an unpadded, non-aligned numeric string, add padding and
+    aligment to conform with the given format specifier dictionary (as
+    output from parse_format_specifier).
+
+    It's assumed that if body is negative then it starts with '-'.
+    Any leading sign ('-' or '+') is stripped from the body before
+    applying the alignment and padding rules, and replaced in the
+    appropriate position.
+
+    """
+    # figure out the sign; we only examine the first character, so if
+    # body has leading whitespace the results may be surprising.
+    if len(body) > 0 and body[0] in '-+':
+        sign = body[0]
+        body = body[1:]
+    else:
+        sign = ''
+
+    if sign != '-':
+        if spec_dict['sign'] in ' +':
+            sign = spec_dict['sign']
+        else:
+            sign = ''
+
+    # how much extra space do we have to play with?
+    minimumwidth = spec_dict['minimumwidth']
+    fill = spec_dict['fill']
+    padding = fill*(max(minimumwidth - (len(sign+body)), 0))
+
+    align = spec_dict['align']
+    if align == '<':
+        result = padding + sign + body
+    elif align == '>':
+        result = sign + body + padding
+    elif align == '=':
+        result = sign + padding + body
+    else: #align == '^'
+        half = len(padding)//2
+        result = padding[:half] + sign + body + padding[half:]
+
+    # make sure that result is unicode if necessary
+    if spec_dict['unicode']:
+        result = unicode(result)
+
+    return result
 
 ##### Useful Constants (internal use only) ################################
 

Modified: python/branches/py3k/Lib/test/test_decimal.py
==============================================================================
--- python/branches/py3k/Lib/test/test_decimal.py	(original)
+++ python/branches/py3k/Lib/test/test_decimal.py	Fri Feb 29 15:57:44 2008
@@ -610,6 +610,98 @@
             self.assertEqual(eval('Decimal(10)' + sym + 'E()'),
                              '10' + rop + 'str')
 
+class DecimalFormatTest(unittest.TestCase):
+    '''Unit tests for the format function.'''
+    def test_formatting(self):
+        # triples giving a format, a Decimal, and the expected result
+        test_values = [
+            ('e', '0E-15', '0e-15'),
+            ('e', '2.3E-15', '2.3e-15'),
+            ('e', '2.30E+2', '2.30e+2'), # preserve significant zeros
+            ('e', '2.30000E-15', '2.30000e-15'),
+            ('e', '1.23456789123456789e40', '1.23456789123456789e+40'),
+            ('e', '1.5', '1.5e+0'),
+            ('e', '0.15', '1.5e-1'),
+            ('e', '0.015', '1.5e-2'),
+            ('e', '0.0000000000015', '1.5e-12'),
+            ('e', '15.0', '1.50e+1'),
+            ('e', '-15', '-1.5e+1'),
+            ('e', '0', '0e+0'),
+            ('e', '0E1', '0e+1'),
+            ('e', '0.0', '0e-1'),
+            ('e', '0.00', '0e-2'),
+            ('.6e', '0E-15', '0.000000e-9'),
+            ('.6e', '0', '0.000000e+6'),
+            ('.6e', '9.999999', '9.999999e+0'),
+            ('.6e', '9.9999999', '1.000000e+1'),
+            ('.6e', '-1.23e5', '-1.230000e+5'),
+            ('.6e', '1.23456789e-3', '1.234568e-3'),
+            ('f', '0', '0'),
+            ('f', '0.0', '0.0'),
+            ('f', '0E-2', '0.00'),
+            ('f', '0.00E-8', '0.0000000000'),
+            ('f', '0E1', '0'), # loses exponent information
+            ('f', '3.2E1', '32'),
+            ('f', '3.2E2', '320'),
+            ('f', '3.20E2', '320'),
+            ('f', '3.200E2', '320.0'),
+            ('f', '3.2E-6', '0.0000032'),
+            ('.6f', '0E-15', '0.000000'), # all zeros treated equally
+            ('.6f', '0E1', '0.000000'),
+            ('.6f', '0', '0.000000'),
+            ('.0f', '0', '0'), # no decimal point
+            ('.0f', '0e-2', '0'),
+            ('.0f', '3.14159265', '3'),
+            ('.1f', '3.14159265', '3.1'),
+            ('.4f', '3.14159265', '3.1416'),
+            ('.6f', '3.14159265', '3.141593'),
+            ('.7f', '3.14159265', '3.1415926'), # round-half-even!
+            ('.8f', '3.14159265', '3.14159265'),
+            ('.9f', '3.14159265', '3.141592650'),
+
+            ('g', '0', '0'),
+            ('g', '0.0', '0.0'),
+            ('g', '0E1', '0e+1'),
+            ('G', '0E1', '0E+1'),
+            ('g', '0E-5', '0.00000'),
+            ('g', '0E-6', '0.000000'),
+            ('g', '0E-7', '0e-7'),
+            ('g', '-0E2', '-0e+2'),
+            ('.0g', '3.14159265', '3'),  # 0 sig fig -> 1 sig fig
+            ('.1g', '3.14159265', '3'),
+            ('.2g', '3.14159265', '3.1'),
+            ('.5g', '3.14159265', '3.1416'),
+            ('.7g', '3.14159265', '3.141593'),
+            ('.8g', '3.14159265', '3.1415926'), # round-half-even!
+            ('.9g', '3.14159265', '3.14159265'),
+            ('.10g', '3.14159265', '3.14159265'), # don't pad
+
+            ('%', '0E1', '0%'),
+            ('%', '0E0', '0%'),
+            ('%', '0E-1', '0%'),
+            ('%', '0E-2', '0%'),
+            ('%', '0E-3', '0.0%'),
+            ('%', '0E-4', '0.00%'),
+
+            ('.3%', '0', '0.000%'), # all zeros treated equally
+            ('.3%', '0E10', '0.000%'),
+            ('.3%', '0E-10', '0.000%'),
+            ('.3%', '2.34', '234.000%'),
+            ('.3%', '1.234567', '123.457%'),
+            ('.0%', '1.23', '123%'),
+
+            ('e', 'NaN', 'NaN'),
+            ('f', '-NaN123', '-NaN123'),
+            ('+g', 'NaN456', '+NaN456'),
+            ('.3e', 'Inf', 'Infinity'),
+            ('.16f', '-Inf', '-Infinity'),
+            ('.0g', '-sNaN', '-sNaN'),
+
+            ('', '1.00', '1.00'),
+            ]
+        for fmt, d, result in test_values:
+            self.assertEqual(format(Decimal(d), fmt), result)
+
 class DecimalArithmeticOperatorsTest(unittest.TestCase):
     '''Unit tests for all arithmetic operators, binary and unary.'''
 
@@ -1351,6 +1443,7 @@
             DecimalExplicitConstructionTest,
             DecimalImplicitConstructionTest,
             DecimalArithmeticOperatorsTest,
+            DecimalFormatTest,
             DecimalUseOfContextTest,
             DecimalUsabilityTest,
             DecimalPythonAPItests,

Modified: python/branches/py3k/Lib/test/test_itertools.py
==============================================================================
--- python/branches/py3k/Lib/test/test_itertools.py	(original)
+++ python/branches/py3k/Lib/test/test_itertools.py	Fri Feb 29 15:57:44 2008
@@ -54,7 +54,14 @@
         self.assertEqual(list(chain('abc')), list('abc'))
         self.assertEqual(list(chain('')), [])
         self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
-        self.assertRaises(TypeError, chain, 2, 3)
+        self.assertRaises(TypeError, list,chain(2, 3))
+
+    def test_chain_from_iterable(self):
+        self.assertEqual(list(chain.from_iterable(['abc', 'def'])), list('abcdef'))
+        self.assertEqual(list(chain.from_iterable(['abc'])), list('abc'))
+        self.assertEqual(list(chain.from_iterable([''])), [])
+        self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd'))
+        self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
 
     def test_combinations(self):
         self.assertRaises(TypeError, combinations, 'abc')   # missing r argument
@@ -298,6 +305,9 @@
             ([range(2), range(3), range(0)], []),           # last iterable with zero length
             ]:
             self.assertEqual(list(product(*args)), result)
+            for r in range(4):
+                self.assertEqual(list(product(*(args*r))),
+                                 list(product(*args, **dict(repeat=r))))
         self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
         self.assertRaises(TypeError, product, range(6), None)
         argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
@@ -684,7 +694,7 @@
             for g in (G, I, Ig, S, L, R):
                 self.assertEqual(list(chain(g(s))), list(g(s)))
                 self.assertEqual(list(chain(g(s), g(s))), list(g(s))+list(g(s)))
-            self.assertRaises(TypeError, chain, X(s))
+            self.assertRaises(TypeError, list, chain(X(s)))
             self.assertRaises(TypeError, chain, N(s))
             self.assertRaises(ZeroDivisionError, list, chain(E(s)))
 

Modified: python/branches/py3k/Modules/itertoolsmodule.c
==============================================================================
--- python/branches/py3k/Modules/itertoolsmodule.c	(original)
+++ python/branches/py3k/Modules/itertoolsmodule.c	Fri Feb 29 15:57:44 2008
@@ -1571,92 +1571,104 @@
 
 typedef struct {
 	PyObject_HEAD
-	Py_ssize_t tuplesize;
-	Py_ssize_t iternum;		/* which iterator is active */
-	PyObject *ittuple;		/* tuple of iterators */
+	PyObject *source;		/* Iterator over input iterables */
+	PyObject *active;		/* Currently running input iterator */
 } chainobject;
 
 static PyTypeObject chain_type;
 
+static PyObject * 
+chain_new_internal(PyTypeObject *type, PyObject *source)
+{
+	chainobject *lz;
+
+	lz = (chainobject *)type->tp_alloc(type, 0);
+	if (lz == NULL) {
+		Py_DECREF(source);
+		return NULL;
+	}
+	
+	lz->source = source;
+	lz->active = NULL;
+	return (PyObject *)lz;
+}
+
 static PyObject *
 chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
-	chainobject *lz;
-	Py_ssize_t tuplesize = PySequence_Length(args);
-	Py_ssize_t i;
-	PyObject *ittuple;
+	PyObject *source;
 
 	if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds))
 		return NULL;
-
-	/* obtain iterators */
-	assert(PyTuple_Check(args));
-	ittuple = PyTuple_New(tuplesize);
-	if (ittuple == NULL)
+	
+	source = PyObject_GetIter(args);
+	if (source == NULL)
 		return NULL;
-	for (i=0; i < tuplesize; ++i) {
-		PyObject *item = PyTuple_GET_ITEM(args, i);
-		PyObject *it = PyObject_GetIter(item);
-		if (it == NULL) {
-			if (PyErr_ExceptionMatches(PyExc_TypeError))
-				PyErr_Format(PyExc_TypeError,
-				    "chain argument #%zd must support iteration",
-				    i+1);
-			Py_DECREF(ittuple);
-			return NULL;
-		}
-		PyTuple_SET_ITEM(ittuple, i, it);
-	}
 
-	/* create chainobject structure */
-	lz = (chainobject *)type->tp_alloc(type, 0);
-	if (lz == NULL) {
-		Py_DECREF(ittuple);
-		return NULL;
-	}
+	return chain_new_internal(type, source);
+}
 
-	lz->ittuple = ittuple;
-	lz->iternum = 0;
-	lz->tuplesize = tuplesize;
+static PyObject *
+chain_new_from_iterable(PyTypeObject *type, PyObject *arg)
+{
+	PyObject *source;
+	
+	source = PyObject_GetIter(arg);
+	if (source == NULL)
+		return NULL;
 
-	return (PyObject *)lz;
+	return chain_new_internal(type, source);
 }
 
 static void
 chain_dealloc(chainobject *lz)
 {
 	PyObject_GC_UnTrack(lz);
-	Py_XDECREF(lz->ittuple);
+	Py_XDECREF(lz->active);
+	Py_XDECREF(lz->source);
 	Py_TYPE(lz)->tp_free(lz);
 }
 
 static int
 chain_traverse(chainobject *lz, visitproc visit, void *arg)
 {
-	Py_VISIT(lz->ittuple);
+	Py_VISIT(lz->source);
+	Py_VISIT(lz->active);
 	return 0;
 }
 
 static PyObject *
 chain_next(chainobject *lz)
 {
-	PyObject *it;
 	PyObject *item;
 
-	while (lz->iternum < lz->tuplesize) {
-		it = PyTuple_GET_ITEM(lz->ittuple, lz->iternum);
-		item = PyIter_Next(it);
-		if (item != NULL)
-			return item;
-		if (PyErr_Occurred()) {
-			if (PyErr_ExceptionMatches(PyExc_StopIteration))
-				PyErr_Clear();
-			else
-				return NULL;
+	if (lz->source == NULL)
+		return NULL;				/* already stopped */
+
+	if (lz->active == NULL) {
+		PyObject *iterable = PyIter_Next(lz->source);
+		if (iterable == NULL) {
+			Py_CLEAR(lz->source);
+			return NULL;			/* no more input sources */
+		}
+		lz->active = PyObject_GetIter(iterable);
+		if (lz->active == NULL) {
+			Py_DECREF(iterable);
+			Py_CLEAR(lz->source);
+			return NULL;			/* input not iterable */
 		}
-		lz->iternum++;
 	}
-	return NULL;
+	item = PyIter_Next(lz->active);
+	if (item != NULL)
+		return item;
+	if (PyErr_Occurred()) {
+		if (PyErr_ExceptionMatches(PyExc_StopIteration))
+			PyErr_Clear();
+		else
+			return NULL; 			/* input raised an exception */
+	}
+	Py_CLEAR(lz->active);
+	return chain_next(lz);			/* recurse and use next active */
 }
 
 PyDoc_STRVAR(chain_doc,
@@ -1666,6 +1678,18 @@
 first iterable until it is exhausted, then elements from the next\n\
 iterable, until all of the iterables are exhausted.");
 
+PyDoc_STRVAR(chain_from_iterable_doc,
+"chain.from_iterable(iterable) --> chain object\n\
+\n\
+Alternate chain() contructor taking a single iterable argument\n\
+that evaluates lazily.");
+
+static PyMethodDef chain_methods[] = {
+	{"from_iterable", (PyCFunction) chain_new_from_iterable,	METH_O | METH_CLASS,
+		chain_from_iterable_doc},
+	{NULL,		NULL}	/* sentinel */
+};
+
 static PyTypeObject chain_type = {
 	PyVarObject_HEAD_INIT(NULL, 0)
 	"itertools.chain",		/* tp_name */
@@ -1696,7 +1720,7 @@
 	0,				/* tp_weaklistoffset */
 	PyObject_SelfIter,		/* tp_iter */
 	(iternextfunc)chain_next,	/* tp_iternext */
-	0,				/* tp_methods */
+	chain_methods,			/* tp_methods */
 	0,				/* tp_members */
 	0,				/* tp_getset */
 	0,				/* tp_base */
@@ -1728,17 +1752,32 @@
 product_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
 	productobject *lz;
-	Py_ssize_t npools;
+	Py_ssize_t nargs, npools, repeat=1;
 	PyObject *pools = NULL;
 	Py_ssize_t *maxvec = NULL;
 	Py_ssize_t *indices = NULL;
 	Py_ssize_t i;
 
-	if (type == &product_type && !_PyArg_NoKeywords("product()", kwds))
-		return NULL;
+	if (kwds != NULL) {
+		char *kwlist[] = {"repeat", 0};
+		PyObject *tmpargs = PyTuple_New(0);
+		if (tmpargs == NULL)
+			return NULL;
+		if (!PyArg_ParseTupleAndKeywords(tmpargs, kwds, "|n:product", kwlist, &repeat)) {
+			Py_DECREF(tmpargs);
+			return NULL;
+		}
+		Py_DECREF(tmpargs);
+		if (repeat < 0) {
+			PyErr_SetString(PyExc_ValueError, 
+					"repeat argument cannot be negative");
+			return NULL;
+		}
+	}
 
 	assert(PyTuple_Check(args));
-	npools = PyTuple_GET_SIZE(args);
+	nargs = (repeat == 0) ? 0 : PyTuple_GET_SIZE(args);
+	npools = nargs * repeat;
 
 	maxvec = PyMem_Malloc(npools * sizeof(Py_ssize_t));
 	indices = PyMem_Malloc(npools * sizeof(Py_ssize_t));
@@ -1751,7 +1790,7 @@
 	if (pools == NULL)
 		goto error;
 
-	for (i=0; i < npools; ++i) {
+	for (i=0; i < nargs ; ++i) {
 		PyObject *item = PyTuple_GET_ITEM(args, i);
 		PyObject *pool = PySequence_Tuple(item);
 		if (pool == NULL)
@@ -1761,6 +1800,13 @@
 		maxvec[i] = PyTuple_GET_SIZE(pool);
 		indices[i] = 0;
 	}
+	for ( ; i < npools; ++i) {
+		PyObject *pool = PyTuple_GET_ITEM(pools, i - nargs);
+		Py_INCREF(pool);
+		PyTuple_SET_ITEM(pools, i, pool);
+		maxvec[i] = maxvec[i - nargs];
+		indices[i] = 0;
+	}
 
 	/* create productobject structure */
 	lz = (productobject *)type->tp_alloc(type, 0);


More information about the Python-3000-checkins mailing list