[Python-checkins] cpython (2.7): Issue #10242: backport of more fixes to unittest.TestCase.assertItemsEqual

michael.foord python-checkins at python.org
Thu Mar 17 01:34:05 CET 2011


http://hg.python.org/cpython/rev/d8eaeee1c063
changeset:   68630:d8eaeee1c063
branch:      2.7
parent:      68596:b8f280d0cdbf
user:        Michael Foord <michael at python.org>
date:        Wed Mar 16 20:34:53 2011 -0400
summary:
  Issue #10242: backport of more fixes to unittest.TestCase.assertItemsEqual

files:
  Lib/unittest/case.py
  Lib/unittest/test/test_assertions.py
  Lib/unittest/test/test_case.py
  Lib/unittest/util.py

diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -10,9 +10,11 @@
 
 from . import result
 from .util import (
-    strclass, safe_repr, sorted_list_difference, unorderable_list_difference
+    strclass, safe_repr, unorderable_list_difference,
+    _count_diff_all_purpose, _count_diff_hashable
 )
 
+
 __unittest = True
 
 
@@ -863,6 +865,7 @@
             - [0, 1, 1] and [1, 0, 1] compare equal.
             - [0, 0, 1] and [0, 1] compare unequal.
         """
+        first_seq, second_seq = list(actual_seq), list(expected_seq)
         with warnings.catch_warnings():
             if sys.py3kwarning:
                 # Silence Py3k warning raised during the sorting
@@ -871,29 +874,23 @@
                              "comparing unequal types"]:
                     warnings.filterwarnings("ignore", _msg, DeprecationWarning)
             try:
-                actual = collections.Counter(iter(actual_seq))
-                expected = collections.Counter(iter(expected_seq))
+                first = collections.Counter(first_seq)
+                second = collections.Counter(second_seq)
             except TypeError:
-                # Unsortable items (example: set(), complex(), ...)
-                actual = list(actual_seq)
-                expected = list(expected_seq)
-                missing, unexpected = unorderable_list_difference(expected, actual)
+                # Handle case with unhashable elements
+                differences = _count_diff_all_purpose(first_seq, second_seq)
             else:
-                if actual == expected:
+                if first == second:
                     return
-                missing = list(expected - actual)
-                unexpected = list(actual - expected)
+                differences = _count_diff_hashable(first_seq, second_seq)
 
-        errors = []
-        if missing:
-            errors.append('Expected, but missing:\n    %s' %
-                           safe_repr(missing))
-        if unexpected:
-            errors.append('Unexpected, but present:\n    %s' %
-                           safe_repr(unexpected))
-        if errors:
-            standardMsg = '\n'.join(errors)
-            self.fail(self._formatMessage(msg, standardMsg))
+        if differences:
+            standardMsg = 'Element counts were not equal:\n'
+            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
+            diffMsg = '\n'.join(lines)
+            standardMsg = self._truncateMessage(standardMsg, diffMsg)
+            msg = self._formatMessage(msg, standardMsg)
+            self.fail(msg)
 
     def assertMultiLineEqual(self, first, second, msg=None):
         """Assert that two multi-line strings are equal."""
diff --git a/Lib/unittest/test/test_assertions.py b/Lib/unittest/test/test_assertions.py
--- a/Lib/unittest/test/test_assertions.py
+++ b/Lib/unittest/test/test_assertions.py
@@ -228,12 +228,6 @@
                              "^Missing: 'key'$",
                              "^Missing: 'key' : oops$"])
 
-    def testAssertItemsEqual(self):
-        self.assertMessages('assertItemsEqual', ([], [None]),
-                            [r"\[None\]$", "^oops$",
-                             r"\[None\]$",
-                             r"\[None\] : oops$"])
-
     def testAssertMultiLineEqual(self):
         self.assertMessages('assertMultiLineEqual', ("", "foo"),
                             [r"\+ foo$", "^oops$",
diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py
--- a/Lib/unittest/test/test_case.py
+++ b/Lib/unittest/test/test_case.py
@@ -686,20 +686,19 @@
 
         # Test that sequences of unhashable objects can be tested for sameness:
         self.assertItemsEqual([[1, 2], [3, 4], 0], [False, [3, 4], [1, 2]])
-        with test_support.check_warnings(quiet=True) as w:
-            # hashable types, but not orderable
-            self.assertRaises(self.failureException, self.assertItemsEqual,
-                              [], [divmod, 'x', 1, 5j, 2j, frozenset()])
-            # comparing dicts raises a py3k warning
-            self.assertItemsEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}])
-            # comparing heterogenous non-hashable sequences raises a py3k warning
-            self.assertItemsEqual([1, 'x', divmod, []], [divmod, [], 'x', 1])
-            self.assertRaises(self.failureException, self.assertItemsEqual,
-                              [], [divmod, [], 'x', 1, 5j, 2j, set()])
-            # fail the test if warnings are not silenced
-            if w.warnings:
-                self.fail('assertItemsEqual raised a warning: ' +
-                          str(w.warnings[0]))
+        # Test that iterator of unhashable objects can be tested for sameness:
+        self.assertItemsEqual(iter([1, 2, [], 3, 4]),
+                              iter([1, 2, [], 3, 4]))
+
+        # hashable types, but not orderable
+        self.assertRaises(self.failureException, self.assertItemsEqual,
+                          [], [divmod, 'x', 1, 5j, 2j, frozenset()])
+        # comparing dicts
+        self.assertItemsEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}])
+        # comparing heterogenous non-hashable sequences
+        self.assertItemsEqual([1, 'x', divmod, []], [divmod, [], 'x', 1])
+        self.assertRaises(self.failureException, self.assertItemsEqual,
+                          [], [divmod, [], 'x', 1, 5j, 2j, set()])
         self.assertRaises(self.failureException, self.assertItemsEqual,
                           [[1]], [[2]])
 
@@ -717,6 +716,19 @@
         b = a[::-1]
         self.assertItemsEqual(a, b)
 
+        # test utility functions supporting assertItemsEqual()
+
+        diffs = set(unittest.util._count_diff_all_purpose('aaabccd', 'abbbcce'))
+        expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')}
+        self.assertEqual(diffs, expected)
+
+        diffs = unittest.util._count_diff_all_purpose([[]], [])
+        self.assertEqual(diffs, [(1, 0, [])])
+
+        diffs = set(unittest.util._count_diff_hashable('aaabccd', 'abbbcce'))
+        expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')}
+        self.assertEqual(diffs, expected)
+
     def testAssertSetEqual(self):
         set1 = set()
         set2 = set()
diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py
--- a/Lib/unittest/util.py
+++ b/Lib/unittest/util.py
@@ -1,4 +1,6 @@
 """Various utility functions."""
+from collections import namedtuple, OrderedDict
+
 
 __unittest = True
 
@@ -92,3 +94,63 @@
 
     # anything left in actual is unexpected
     return missing, actual
+
+_Mismatch = namedtuple('Mismatch', 'actual expected value')
+
+def _count_diff_all_purpose(actual, expected):
+    'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
+    # elements need not be hashable
+    s, t = list(actual), list(expected)
+    m, n = len(s), len(t)
+    NULL = object()
+    result = []
+    for i, elem in enumerate(s):
+        if elem is NULL:
+            continue
+        cnt_s = cnt_t = 0
+        for j in range(i, m):
+            if s[j] == elem:
+                cnt_s += 1
+                s[j] = NULL
+        for j, other_elem in enumerate(t):
+            if other_elem == elem:
+                cnt_t += 1
+                t[j] = NULL
+        if cnt_s != cnt_t:
+            diff = _Mismatch(cnt_s, cnt_t, elem)
+            result.append(diff)
+
+    for i, elem in enumerate(t):
+        if elem is NULL:
+            continue
+        cnt_t = 0
+        for j in range(i, n):
+            if t[j] == elem:
+                cnt_t += 1
+                t[j] = NULL
+        diff = _Mismatch(0, cnt_t, elem)
+        result.append(diff)
+    return result
+
+def _ordered_count(iterable):
+    'Return dict of element counts, in the order they were first seen'
+    c = OrderedDict()
+    for elem in iterable:
+        c[elem] = c.get(elem, 0) + 1
+    return c
+
+def _count_diff_hashable(actual, expected):
+    'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
+    # elements must be hashable
+    s, t = _ordered_count(actual), _ordered_count(expected)
+    result = []
+    for elem, cnt_s in s.items():
+        cnt_t = t.get(elem, 0)
+        if cnt_s != cnt_t:
+            diff = _Mismatch(cnt_s, cnt_t, elem)
+            result.append(diff)
+    for elem, cnt_t in t.items():
+        if elem not in s:
+            diff = _Mismatch(0, cnt_t, elem)
+            result.append(diff)
+    return result

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list