[Python-checkins] gh-56276: Add tests to test_compare (#3199)

terryjreedy webhook-mailer at python.org
Sat May 20 12:07:54 EDT 2023


https://github.com/python/cpython/commit/68ee8b3f15b744339c156bfeb583a414061ab22d
commit: 68ee8b3f15b744339c156bfeb583a414061ab22d
branch: main
author: Cheryl Sabella <cheryl.sabella at gmail.com>
committer: terryjreedy <tjreedy at udel.edu>
date: 2023-05-20T12:07:40-04:00
summary:

gh-56276: Add tests to test_compare (#3199)

Co-authored-by: Terry Jan Reedy <tjreedy at udel.edu>
Co-authored-by: Oleg Iarygin <oleg at arhadthedev.net>

files:
M Lib/test/test_compare.py

diff --git a/Lib/test/test_compare.py b/Lib/test/test_compare.py
index 2b3faed7965b..8166b0eea306 100644
--- a/Lib/test/test_compare.py
+++ b/Lib/test/test_compare.py
@@ -1,21 +1,27 @@
+"""Test equality and order comparisons."""
 import unittest
 from test.support import ALWAYS_EQ
+from fractions import Fraction
+from decimal import Decimal
 
-class Empty:
-    def __repr__(self):
-        return '<Empty>'
 
-class Cmp:
-    def __init__(self,arg):
-        self.arg = arg
+class ComparisonSimpleTest(unittest.TestCase):
+    """Test equality and order comparisons for some simple cases."""
 
-    def __repr__(self):
-        return '<Cmp %s>' % self.arg
+    class Empty:
+        def __repr__(self):
+            return '<Empty>'
 
-    def __eq__(self, other):
-        return self.arg == other
+    class Cmp:
+        def __init__(self, arg):
+            self.arg = arg
+
+        def __repr__(self):
+            return '<Cmp %s>' % self.arg
+
+        def __eq__(self, other):
+            return self.arg == other
 
-class ComparisonTest(unittest.TestCase):
     set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)]
     set2 = [[1], (3,), None, Empty()]
     candidates = set1 + set2
@@ -32,16 +38,15 @@ def test_id_comparisons(self):
         # Ensure default comparison compares id() of args
         L = []
         for i in range(10):
-            L.insert(len(L)//2, Empty())
+            L.insert(len(L)//2, self.Empty())
         for a in L:
             for b in L:
-                self.assertEqual(a == b, id(a) == id(b),
-                                 'a=%r, b=%r' % (a, b))
+                self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b))
 
     def test_ne_defaults_to_not_eq(self):
-        a = Cmp(1)
-        b = Cmp(1)
-        c = Cmp(2)
+        a = self.Cmp(1)
+        b = self.Cmp(1)
+        c = self.Cmp(2)
         self.assertIs(a == b, True)
         self.assertIs(a != b, False)
         self.assertIs(a != c, True)
@@ -114,5 +119,392 @@ def test_issue_1393(self):
         self.assertEqual(ALWAYS_EQ, y)
 
 
+class ComparisonFullTest(unittest.TestCase):
+    """Test equality and ordering comparisons for built-in types and
+    user-defined classes that implement relevant combinations of rich
+    comparison methods.
+    """
+
+    class CompBase:
+        """Base class for classes with rich comparison methods.
+
+        The "x" attribute should be set to an underlying value to compare.
+
+        Derived classes have a "meth" tuple attribute listing names of
+        comparison methods implemented. See assert_total_order().
+        """
+
+    # Class without any rich comparison methods.
+    class CompNone(CompBase):
+        meth = ()
+
+    # Classes with all combinations of value-based equality comparison methods.
+    class CompEq(CompBase):
+        meth = ("eq",)
+        def __eq__(self, other):
+            return self.x == other.x
+
+    class CompNe(CompBase):
+        meth = ("ne",)
+        def __ne__(self, other):
+            return self.x != other.x
+
+    class CompEqNe(CompBase):
+        meth = ("eq", "ne")
+        def __eq__(self, other):
+            return self.x == other.x
+        def __ne__(self, other):
+            return self.x != other.x
+
+    # Classes with all combinations of value-based less/greater-than order
+    # comparison methods.
+    class CompLt(CompBase):
+        meth = ("lt",)
+        def __lt__(self, other):
+            return self.x < other.x
+
+    class CompGt(CompBase):
+        meth = ("gt",)
+        def __gt__(self, other):
+            return self.x > other.x
+
+    class CompLtGt(CompBase):
+        meth = ("lt", "gt")
+        def __lt__(self, other):
+            return self.x < other.x
+        def __gt__(self, other):
+            return self.x > other.x
+
+    # Classes with all combinations of value-based less/greater-or-equal-than
+    # order comparison methods
+    class CompLe(CompBase):
+        meth = ("le",)
+        def __le__(self, other):
+            return self.x <= other.x
+
+    class CompGe(CompBase):
+        meth = ("ge",)
+        def __ge__(self, other):
+            return self.x >= other.x
+
+    class CompLeGe(CompBase):
+        meth = ("le", "ge")
+        def __le__(self, other):
+            return self.x <= other.x
+        def __ge__(self, other):
+            return self.x >= other.x
+
+    # It should be sufficient to combine the comparison methods only within
+    # each group.
+    all_comp_classes = (
+            CompNone,
+            CompEq, CompNe, CompEqNe,  # equal group
+            CompLt, CompGt, CompLtGt,  # less/greater-than group
+            CompLe, CompGe, CompLeGe)  # less/greater-or-equal group
+
+    def create_sorted_instances(self, class_, values):
+        """Create objects of type `class_` and return them in a list.
+
+        `values` is a list of values that determines the value of data
+        attribute `x` of each object.
+
+        Objects in the returned list are sorted by their identity.  They
+        assigned values in `values` list order.  By assign decreasing
+        values to objects with increasing identities, testcases can assert
+        that order comparison is performed by value and not by identity.
+        """
+
+        instances = [class_() for __ in range(len(values))]
+        instances.sort(key=id)
+        # Assign the provided values to the instances.
+        for inst, value in zip(instances, values):
+            inst.x = value
+        return instances
+
+    def assert_equality_only(self, a, b, equal):
+        """Assert equality result and that ordering is not implemented.
+
+        a, b: Instances to be tested (of same or different type).
+        equal: Boolean indicating the expected equality comparison results.
+        """
+        self.assertEqual(a == b, equal)
+        self.assertEqual(b == a, equal)
+        self.assertEqual(a != b, not equal)
+        self.assertEqual(b != a, not equal)
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            a < b
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            a <= b
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            a > b
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            a >= b
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            b < a
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            b <= a
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            b > a
+        with self.assertRaisesRegex(TypeError, "not supported"):
+            b >= a
+
+    def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None):
+        """Test total ordering comparison of two instances.
+
+        a, b: Instances to be tested (of same or different type).
+
+        comp: -1, 0, or 1 indicates that the expected order comparison
+           result for operations that are supported by the classes is
+           a <, ==, or > b.
+
+        a_meth, b_meth: Either None, indicating that all rich comparison
+           methods are available, aa for builtins, or the tuple (subset)
+           of "eq", "ne", "lt", "le", "gt", and "ge" that are available
+           for the corresponding instance (of a user-defined class).
+        """
+        self.assert_eq_subtest(a, b, comp, a_meth, b_meth)
+        self.assert_ne_subtest(a, b, comp, a_meth, b_meth)
+        self.assert_lt_subtest(a, b, comp, a_meth, b_meth)
+        self.assert_le_subtest(a, b, comp, a_meth, b_meth)
+        self.assert_gt_subtest(a, b, comp, a_meth, b_meth)
+        self.assert_ge_subtest(a, b, comp, a_meth, b_meth)
+
+    # The body of each subtest has form:
+    #
+    #     if value-based comparison methods:
+    #         expect what the testcase defined for a op b and b rop a;
+    #     else:  no value-based comparison
+    #         expect default behavior of object for a op b and b rop a.
+
+    def assert_eq_subtest(self, a, b, comp, a_meth, b_meth):
+        if a_meth is None or "eq" in a_meth or "eq" in b_meth:
+            self.assertEqual(a == b, comp == 0)
+            self.assertEqual(b == a, comp == 0)
+        else:
+            self.assertEqual(a == b, a is b)
+            self.assertEqual(b == a, a is b)
+
+    def assert_ne_subtest(self, a, b, comp, a_meth, b_meth):
+        if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth):
+            self.assertEqual(a != b, comp != 0)
+            self.assertEqual(b != a, comp != 0)
+        else:
+            self.assertEqual(a != b, a is not b)
+            self.assertEqual(b != a, a is not b)
+
+    def assert_lt_subtest(self, a, b, comp, a_meth, b_meth):
+        if a_meth is None or "lt" in a_meth or "gt" in b_meth:
+            self.assertEqual(a < b, comp < 0)
+            self.assertEqual(b > a, comp < 0)
+        else:
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                a < b
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                b > a
+
+    def assert_le_subtest(self, a, b, comp, a_meth, b_meth):
+        if a_meth is None or "le" in a_meth or "ge" in b_meth:
+            self.assertEqual(a <= b, comp <= 0)
+            self.assertEqual(b >= a, comp <= 0)
+        else:
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                a <= b
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                b >= a
+
+    def assert_gt_subtest(self, a, b, comp, a_meth, b_meth):
+        if a_meth is None or "gt" in a_meth or "lt" in b_meth:
+            self.assertEqual(a > b, comp > 0)
+            self.assertEqual(b < a, comp > 0)
+        else:
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                a > b
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                b < a
+
+    def assert_ge_subtest(self, a, b, comp, a_meth, b_meth):
+        if a_meth is None or "ge" in a_meth or "le" in b_meth:
+            self.assertEqual(a >= b, comp >= 0)
+            self.assertEqual(b <= a, comp >= 0)
+        else:
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                a >= b
+            with self.assertRaisesRegex(TypeError, "not supported"):
+                b <= a
+
+    def test_objects(self):
+        """Compare instances of type 'object'."""
+        a = object()
+        b = object()
+        self.assert_equality_only(a, a, True)
+        self.assert_equality_only(a, b, False)
+
+    def test_comp_classes_same(self):
+        """Compare same-class instances with comparison methods."""
+
+        for cls in self.all_comp_classes:
+            with self.subTest(cls):
+                instances = self.create_sorted_instances(cls, (1, 2, 1))
+
+                # Same object.
+                self.assert_total_order(instances[0], instances[0], 0,
+                                        cls.meth, cls.meth)
+
+                # Different objects, same value.
+                self.assert_total_order(instances[0], instances[2], 0,
+                                        cls.meth, cls.meth)
+
+                # Different objects, value ascending for ascending identities.
+                self.assert_total_order(instances[0], instances[1], -1,
+                                        cls.meth, cls.meth)
+
+                # different objects, value descending for ascending identities.
+                # This is the interesting case to assert that order comparison
+                # is performed based on the value and not based on the identity.
+                self.assert_total_order(instances[1], instances[2], +1,
+                                        cls.meth, cls.meth)
+
+    def test_comp_classes_different(self):
+        """Compare different-class instances with comparison methods."""
+
+        for cls_a in self.all_comp_classes:
+            for cls_b in self.all_comp_classes:
+                with self.subTest(a=cls_a, b=cls_b):
+                    a1 = cls_a()
+                    a1.x = 1
+                    b1 = cls_b()
+                    b1.x = 1
+                    b2 = cls_b()
+                    b2.x = 2
+
+                    self.assert_total_order(
+                        a1, b1, 0, cls_a.meth, cls_b.meth)
+                    self.assert_total_order(
+                        a1, b2, -1, cls_a.meth, cls_b.meth)
+
+    def test_str_subclass(self):
+        """Compare instances of str and a subclass."""
+        class StrSubclass(str):
+            pass
+
+        s1 = str("a")
+        s2 = str("b")
+        c1 = StrSubclass("a")
+        c2 = StrSubclass("b")
+        c3 = StrSubclass("b")
+
+        self.assert_total_order(s1, s1,   0)
+        self.assert_total_order(s1, s2, -1)
+        self.assert_total_order(c1, c1,   0)
+        self.assert_total_order(c1, c2, -1)
+        self.assert_total_order(c2, c3,   0)
+
+        self.assert_total_order(s1, c2, -1)
+        self.assert_total_order(s2, c3,   0)
+        self.assert_total_order(c1, s2, -1)
+        self.assert_total_order(c2, s2,   0)
+
+    def test_numbers(self):
+        """Compare number types."""
+
+        # Same types.
+        i1 = 1001
+        i2 = 1002
+        self.assert_total_order(i1, i1, 0)
+        self.assert_total_order(i1, i2, -1)
+
+        f1 = 1001.0
+        f2 = 1001.1
+        self.assert_total_order(f1, f1, 0)
+        self.assert_total_order(f1, f2, -1)
+
+        q1 = Fraction(2002, 2)
+        q2 = Fraction(2003, 2)
+        self.assert_total_order(q1, q1, 0)
+        self.assert_total_order(q1, q2, -1)
+
+        d1 = Decimal('1001.0')
+        d2 = Decimal('1001.1')
+        self.assert_total_order(d1, d1, 0)
+        self.assert_total_order(d1, d2, -1)
+
+        c1 = 1001+0j
+        c2 = 1001+1j
+        self.assert_equality_only(c1, c1, True)
+        self.assert_equality_only(c1, c2, False)
+
+
+        # Mixing types.
+        for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)):
+            self.assert_total_order(n1, n2, 0)
+        for n1 in (i1, f1, q1, d1):
+            self.assert_equality_only(n1, c1, True)
+
+    def test_sequences(self):
+        """Compare list, tuple, and range."""
+        l1 = [1, 2]
+        l2 = [2, 3]
+        self.assert_total_order(l1, l1, 0)
+        self.assert_total_order(l1, l2, -1)
+
+        t1 = (1, 2)
+        t2 = (2, 3)
+        self.assert_total_order(t1, t1, 0)
+        self.assert_total_order(t1, t2, -1)
+
+        r1 = range(1, 2)
+        r2 = range(2, 2)
+        self.assert_equality_only(r1, r1, True)
+        self.assert_equality_only(r1, r2, False)
+
+        self.assert_equality_only(t1, l1, False)
+        self.assert_equality_only(l1, r1, False)
+        self.assert_equality_only(r1, t1, False)
+
+    def test_bytes(self):
+        """Compare bytes and bytearray."""
+        bs1 = b'a1'
+        bs2 = b'b2'
+        self.assert_total_order(bs1, bs1, 0)
+        self.assert_total_order(bs1, bs2, -1)
+
+        ba1 = bytearray(b'a1')
+        ba2 = bytearray(b'b2')
+        self.assert_total_order(ba1, ba1,  0)
+        self.assert_total_order(ba1, ba2, -1)
+
+        self.assert_total_order(bs1, ba1, 0)
+        self.assert_total_order(bs1, ba2, -1)
+        self.assert_total_order(ba1, bs1, 0)
+        self.assert_total_order(ba1, bs2, -1)
+
+    def test_sets(self):
+        """Compare set and frozenset."""
+        s1 = {1, 2}
+        s2 = {1, 2, 3}
+        self.assert_total_order(s1, s1, 0)
+        self.assert_total_order(s1, s2, -1)
+
+        f1 = frozenset(s1)
+        f2 = frozenset(s2)
+        self.assert_total_order(f1, f1,  0)
+        self.assert_total_order(f1, f2, -1)
+
+        self.assert_total_order(s1, f1, 0)
+        self.assert_total_order(s1, f2, -1)
+        self.assert_total_order(f1, s1, 0)
+        self.assert_total_order(f1, s2, -1)
+
+    def test_mappings(self):
+        """ Compare dict.
+        """
+        d1 = {1: "a", 2: "b"}
+        d2 = {2: "b", 3: "c"}
+        d3 = {3: "c", 2: "b"}
+        self.assert_equality_only(d1, d1, True)
+        self.assert_equality_only(d1, d2, False)
+        self.assert_equality_only(d2, d3, True)
+
+
 if __name__ == '__main__':
     unittest.main()



More information about the Python-checkins mailing list