[Jython-checkins] jython: from: http://hg.python.org/cpython/file/225126c9d4b5/Lib/test/test_set.py

frank.wierzbicki jython-checkins at python.org
Tue Apr 17 00:45:09 CEST 2012


http://hg.python.org/jython/rev/7b057c1b0157
changeset:   6596:7b057c1b0157
user:        Darjus Loktevic <darjus at gmail.com>
date:        Sat Apr 14 14:51:58 2012 -0700
summary:
  from: http://hg.python.org/cpython/file/225126c9d4b5/Lib/test/test_set.py

files:
  Lib/test/test_set.py |  469 +++++++++++++++++++++++++-----
  1 files changed, 380 insertions(+), 89 deletions(-)


diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -1,13 +1,14 @@
+
 import unittest
 from test import test_support
-from test_weakref import extra_collect
-from weakref import proxy
+import gc
+import weakref
 import operator
 import copy
 import pickle
-import os
 from random import randrange, shuffle
 import sys
+import collections
 
 class PassThru(Exception):
     pass
@@ -47,6 +48,7 @@
 
     def test_new_or_init(self):
         self.assertRaises(TypeError, self.thetype, [], 2)
+        self.assertRaises(TypeError, set().__init__, a=1)
 
     def test_uniquification(self):
         actual = sorted(self.s)
@@ -63,7 +65,7 @@
             self.assertEqual(c in self.s, c in self.d)
         self.assertRaises(TypeError, self.s.__contains__, [[]])
         s = self.thetype([frozenset(self.letters)])
-        self.assert_(self.thetype(self.letters) in s)
+        self.assertIn(self.thetype(self.letters), s)
 
     def test_union(self):
         u = self.s.union(self.otherword)
@@ -78,6 +80,11 @@
             self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
             self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
             self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
+            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
+
+        # Issue #6573
+        x = self.thetype()
+        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
 
     def test_or(self):
         i = self.s.union(self.otherword)
@@ -102,6 +109,27 @@
             self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
             self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
             self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
+            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
+        s = self.thetype('abcba')
+        z = s.intersection()
+        if self.thetype == frozenset():
+            self.assertEqual(id(s), id(z))
+        else:
+            self.assertNotEqual(id(s), id(z))
+
+    def test_isdisjoint(self):
+        def f(s1, s2):
+            'Pure python equivalent of isdisjoint()'
+            return not set(s1).intersection(s2)
+        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
+            s1 = self.thetype(larg)
+            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
+                for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
+                    s2 = C(rarg)
+                    actual = s1.isdisjoint(s2)
+                    expected = f(s1, s2)
+                    self.assertEqual(actual, expected)
+                    self.assertTrue(actual is True or actual is False)
 
     def test_and(self):
         i = self.s.intersection(self.otherword)
@@ -127,6 +155,8 @@
             self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
             self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
             self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
+            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
+            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
 
     def test_sub(self):
         i = self.s.difference(self.otherword)
@@ -182,22 +212,22 @@
 
     def test_sub_and_super(self):
         p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
-        self.assert_(p < q)
-        self.assert_(p <= q)
-        self.assert_(q <= q)
-        self.assert_(q > p)
-        self.assert_(q >= p)
-        self.failIf(q < r)
-        self.failIf(q <= r)
-        self.failIf(q > r)
-        self.failIf(q >= r)
-        self.assert_(set('a').issubset('abc'))
-        self.assert_(set('abc').issuperset('a'))
-        self.failIf(set('a').issubset('cbs'))
-        self.failIf(set('cbs').issuperset('a'))
+        self.assertTrue(p < q)
+        self.assertTrue(p <= q)
+        self.assertTrue(q <= q)
+        self.assertTrue(q > p)
+        self.assertTrue(q >= p)
+        self.assertFalse(q < r)
+        self.assertFalse(q <= r)
+        self.assertFalse(q > r)
+        self.assertFalse(q >= r)
+        self.assertTrue(set('a').issubset('abc'))
+        self.assertTrue(set('abc').issuperset('a'))
+        self.assertFalse(set('a').issubset('cbs'))
+        self.assertFalse(set('cbs').issuperset('a'))
 
     def test_pickling(self):
-        for i in (0, 1, 2):
+        for i in range(pickle.HIGHEST_PROTOCOL + 1):
             p = pickle.dumps(self.s, i)
             dup = pickle.loads(p)
             self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
@@ -242,7 +272,7 @@
         s=H()
         f=set()
         f.add(s)
-        self.assert_(s in f)
+        self.assertIn(s, f)
         f.remove(s)
         f.add(s)
         f.discard(s)
@@ -269,18 +299,17 @@
         w = ReprWrapper()
         s = self.thetype([w])
         w.value = s
+        fo = open(test_support.TESTFN, "wb")
         try:
-            fo = open(test_support.TESTFN, "wb")
             print >> fo, s,
             fo.close()
             fo = open(test_support.TESTFN, "rb")
             self.assertEqual(fo.read(), repr(s))
         finally:
             fo.close()
-            os.remove(test_support.TESTFN)
+            test_support.unlink(test_support.TESTFN)
 
-    # XXX: Tests CPython internals (caches key hashes)
-    def _test_do_not_rehash_dict_keys(self):
+    def test_do_not_rehash_dict_keys(self):
         n = 10
         d = dict.fromkeys(map(HashCountingInt, xrange(n)))
         self.assertEqual(sum(elem.hash_count for elem in d), n)
@@ -299,6 +328,18 @@
         self.assertEqual(sum(elem.hash_count for elem in d), n)
         self.assertEqual(d3, dict.fromkeys(d, 123))
 
+    def test_container_iterator(self):
+        # Bug #3680: tp_traverse was not implemented for set iterator object
+        class C(object):
+            pass
+        obj = C()
+        ref = weakref.ref(obj)
+        container = set([obj, 1])
+        obj.x = iter(container)
+        del obj, container
+        gc.collect()
+        self.assertTrue(ref() is None, "Cycle was not collected")
+
 class TestSet(TestJointOps):
     thetype = set
 
@@ -331,7 +372,7 @@
 
     def test_add(self):
         self.s.add('Q')
-        self.assert_('Q' in self.s)
+        self.assertIn('Q', self.s)
         dup = self.s.copy()
         self.s.add('Q')
         self.assertEqual(self.s, dup)
@@ -339,13 +380,13 @@
 
     def test_remove(self):
         self.s.remove('a')
-        self.assert_('a' not in self.s)
+        self.assertNotIn('a', self.s)
         self.assertRaises(KeyError, self.s.remove, 'Q')
         self.assertRaises(TypeError, self.s.remove, [])
         s = self.thetype([frozenset(self.word)])
-        self.assert_(self.thetype(self.word) in s)
+        self.assertIn(self.thetype(self.word), s)
         s.remove(self.thetype(self.word))
-        self.assert_(self.thetype(self.word) not in s)
+        self.assertNotIn(self.thetype(self.word), s)
         self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
 
     def test_remove_keyerror_unpacking(self):
@@ -359,28 +400,39 @@
             else:
                 self.fail()
 
+    def test_remove_keyerror_set(self):
+        key = self.thetype([3, 4])
+        try:
+            self.s.remove(key)
+        except KeyError as e:
+            self.assertTrue(e.args[0] is key,
+                         "KeyError should be {0}, not {1}".format(key,
+                                                                  e.args[0]))
+        else:
+            self.fail()
+
     def test_discard(self):
         self.s.discard('a')
-        self.assert_('a' not in self.s)
+        self.assertNotIn('a', self.s)
         self.s.discard('Q')
         self.assertRaises(TypeError, self.s.discard, [])
         s = self.thetype([frozenset(self.word)])
-        self.assert_(self.thetype(self.word) in s)
+        self.assertIn(self.thetype(self.word), s)
         s.discard(self.thetype(self.word))
-        self.assert_(self.thetype(self.word) not in s)
+        self.assertNotIn(self.thetype(self.word), s)
         s.discard(self.thetype(self.word))
 
     def test_pop(self):
         for i in xrange(len(self.s)):
             elem = self.s.pop()
-            self.assert_(elem not in self.s)
+            self.assertNotIn(elem, self.s)
         self.assertRaises(KeyError, self.s.pop)
 
     def test_update(self):
         retval = self.s.update(self.otherword)
         self.assertEqual(retval, None)
         for c in (self.word + self.otherword):
-            self.assert_(c in self.s)
+            self.assertIn(c, self.s)
         self.assertRaises(PassThru, self.s.update, check_pass_thru())
         self.assertRaises(TypeError, self.s.update, [[]])
         for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
@@ -388,20 +440,26 @@
                 s = self.thetype('abcba')
                 self.assertEqual(s.update(C(p)), None)
                 self.assertEqual(s, set(q))
+        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
+            q = 'ahi'
+            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
+                s = self.thetype('abcba')
+                self.assertEqual(s.update(C(p), C(q)), None)
+                self.assertEqual(s, set(s) | set(p) | set(q))
 
     def test_ior(self):
         self.s |= set(self.otherword)
         for c in (self.word + self.otherword):
-            self.assert_(c in self.s)
+            self.assertIn(c, self.s)
 
     def test_intersection_update(self):
         retval = self.s.intersection_update(self.otherword)
         self.assertEqual(retval, None)
         for c in (self.word + self.otherword):
             if c in self.otherword and c in self.word:
-                self.assert_(c in self.s)
+                self.assertIn(c, self.s)
             else:
-                self.assert_(c not in self.s)
+                self.assertNotIn(c, self.s)
         self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
         self.assertRaises(TypeError, self.s.intersection_update, [[]])
         for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
@@ -409,23 +467,28 @@
                 s = self.thetype('abcba')
                 self.assertEqual(s.intersection_update(C(p)), None)
                 self.assertEqual(s, set(q))
+                ss = 'abcba'
+                s = self.thetype(ss)
+                t = 'cbc'
+                self.assertEqual(s.intersection_update(C(p), C(t)), None)
+                self.assertEqual(s, set('abcba')&set(p)&set(t))
 
     def test_iand(self):
         self.s &= set(self.otherword)
         for c in (self.word + self.otherword):
             if c in self.otherword and c in self.word:
-                self.assert_(c in self.s)
+                self.assertIn(c, self.s)
             else:
-                self.assert_(c not in self.s)
+                self.assertNotIn(c, self.s)
 
     def test_difference_update(self):
         retval = self.s.difference_update(self.otherword)
         self.assertEqual(retval, None)
         for c in (self.word + self.otherword):
             if c in self.word and c not in self.otherword:
-                self.assert_(c in self.s)
+                self.assertIn(c, self.s)
             else:
-                self.assert_(c not in self.s)
+                self.assertNotIn(c, self.s)
         self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
         self.assertRaises(TypeError, self.s.difference_update, [[]])
         self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
@@ -435,22 +498,34 @@
                 self.assertEqual(s.difference_update(C(p)), None)
                 self.assertEqual(s, set(q))
 
+                s = self.thetype('abcdefghih')
+                s.difference_update()
+                self.assertEqual(s, self.thetype('abcdefghih'))
+
+                s = self.thetype('abcdefghih')
+                s.difference_update(C('aba'))
+                self.assertEqual(s, self.thetype('cdefghih'))
+
+                s = self.thetype('abcdefghih')
+                s.difference_update(C('cdc'), C('aba'))
+                self.assertEqual(s, self.thetype('efghih'))
+
     def test_isub(self):
         self.s -= set(self.otherword)
         for c in (self.word + self.otherword):
             if c in self.word and c not in self.otherword:
-                self.assert_(c in self.s)
+                self.assertIn(c, self.s)
             else:
-                self.assert_(c not in self.s)
+                self.assertNotIn(c, self.s)
 
     def test_symmetric_difference_update(self):
         retval = self.s.symmetric_difference_update(self.otherword)
         self.assertEqual(retval, None)
         for c in (self.word + self.otherword):
             if (c in self.word) ^ (c in self.otherword):
-                self.assert_(c in self.s)
+                self.assertIn(c, self.s)
             else:
-                self.assert_(c not in self.s)
+                self.assertNotIn(c, self.s)
         self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
         self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
         for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
@@ -463,9 +538,9 @@
         self.s ^= set(self.otherword)
         for c in (self.word + self.otherword):
             if (c in self.word) ^ (c in self.otherword):
-                self.assert_(c in self.s)
+                self.assertIn(c, self.s)
             else:
-                self.assert_(c not in self.s)
+                self.assertNotIn(c, self.s)
 
     def test_inplace_on_self(self):
         t = self.s.copy()
@@ -481,16 +556,15 @@
 
     def test_weakref(self):
         s = self.thetype('gallahad')
-        p = proxy(s)
+        p = weakref.proxy(s)
         self.assertEqual(str(p), str(s))
         s = None
-        extra_collect()
         self.assertRaises(ReferenceError, str, p)
 
     # C API test only available in a debug build
     if hasattr(set, "test_c_api"):
         def test_c_api(self):
-            self.assertEqual(set('abc').test_c_api(), True)
+            self.assertEqual(set().test_c_api(), True)
 
 class SetSubclass(set):
     pass
@@ -561,8 +635,7 @@
         f = self.thetype('abcdcda')
         self.assertEqual(hash(f), hash(f))
 
-    # XXX: tied to CPython's hash implementation
-    def _test_hash_effectiveness(self):
+    def test_hash_effectiveness(self):
         n = 13
         hashvalues = set()
         addhashvalue = hashvalues.add
@@ -614,16 +687,27 @@
         if self.repr is not None:
             self.assertEqual(repr(self.set), self.repr)
 
+    def check_repr_against_values(self):
+        text = repr(self.set)
+        self.assertTrue(text.startswith('{'))
+        self.assertTrue(text.endswith('}'))
+
+        result = text[1:-1].split(', ')
+        result.sort()
+        sorted_repr_values = [repr(value) for value in self.values]
+        sorted_repr_values.sort()
+        self.assertEqual(result, sorted_repr_values)
+
     def test_print(self):
+        fo = open(test_support.TESTFN, "wb")
         try:
-            fo = open(test_support.TESTFN, "wb")
             print >> fo, self.set,
             fo.close()
             fo = open(test_support.TESTFN, "rb")
             self.assertEqual(fo.read(), repr(self.set))
         finally:
             fo.close()
-            os.remove(test_support.TESTFN)
+            test_support.unlink(test_support.TESTFN)
 
     def test_length(self):
         self.assertEqual(len(self.set), self.length)
@@ -661,11 +745,23 @@
         result = empty_set & self.set
         self.assertEqual(result, empty_set)
 
+    def test_self_isdisjoint(self):
+        result = self.set.isdisjoint(self.set)
+        self.assertEqual(result, not self.set)
+
+    def test_empty_isdisjoint(self):
+        result = self.set.isdisjoint(empty_set)
+        self.assertEqual(result, True)
+
+    def test_isdisjoint_empty(self):
+        result = empty_set.isdisjoint(self.set)
+        self.assertEqual(result, True)
+
     def test_self_symmetric_difference(self):
         result = self.set ^ self.set
         self.assertEqual(result, empty_set)
 
-    def checkempty_symmetric_difference(self):
+    def test_empty_symmetric_difference(self):
         result = self.set ^ empty_set
         self.assertEqual(result, self.set)
 
@@ -683,13 +779,11 @@
 
     def test_iteration(self):
         for v in self.set:
-            self.assert_(v in self.values)
-        # XXX: jython does not use length_hint
-        if not test_support.is_jython:
-            setiter = iter(self.set)
-            # note: __length_hint__ is an internal undocumented API,
-            # don't rely on it in your own programs
-            self.assertEqual(setiter.__length_hint__(), len(self.set))
+            self.assertIn(v, self.values)
+        setiter = iter(self.set)
+        # note: __length_hint__ is an internal undocumented API,
+        # don't rely on it in your own programs
+        self.assertEqual(setiter.__length_hint__(), len(self.set))
 
     def test_pickling(self):
         p = pickle.dumps(self.set)
@@ -720,10 +814,10 @@
         self.repr   = "set([3])"
 
     def test_in(self):
-        self.failUnless(3 in self.set)
+        self.assertIn(3, self.set)
 
     def test_not_in(self):
-        self.failUnless(2 not in self.set)
+        self.assertNotIn(2, self.set)
 
 #------------------------------------------------------------------------------
 
@@ -737,10 +831,10 @@
         self.repr   = "set([(0, 'zero')])"
 
     def test_in(self):
-        self.failUnless((0, "zero") in self.set)
+        self.assertIn((0, "zero"), self.set)
 
     def test_not_in(self):
-        self.failUnless(9 not in self.set)
+        self.assertNotIn(9, self.set)
 
 #------------------------------------------------------------------------------
 
@@ -753,6 +847,46 @@
         self.length = 3
         self.repr   = None
 
+#------------------------------------------------------------------------------
+
+class TestBasicOpsString(TestBasicOps):
+    def setUp(self):
+        self.case   = "string set"
+        self.values = ["a", "b", "c"]
+        self.set    = set(self.values)
+        self.dup    = set(self.values)
+        self.length = 3
+
+    def test_repr(self):
+        self.check_repr_against_values()
+
+#------------------------------------------------------------------------------
+
+class TestBasicOpsUnicode(TestBasicOps):
+    def setUp(self):
+        self.case   = "unicode set"
+        self.values = [u"a", u"b", u"c"]
+        self.set    = set(self.values)
+        self.dup    = set(self.values)
+        self.length = 3
+
+    def test_repr(self):
+        self.check_repr_against_values()
+
+#------------------------------------------------------------------------------
+
+class TestBasicOpsMixedStringUnicode(TestBasicOps):
+    def setUp(self):
+        self.case   = "string and bytes set"
+        self.values = ["a", "b", u"a", u"b"]
+        self.set    = set(self.values)
+        self.dup    = set(self.values)
+        self.length = 4
+
+    def test_repr(self):
+        with test_support.check_warnings():
+            self.check_repr_against_values()
+
 #==============================================================================
 
 def baditer():
@@ -841,6 +975,22 @@
         result = self.set & set([8])
         self.assertEqual(result, empty_set)
 
+    def test_isdisjoint_subset(self):
+        result = self.set.isdisjoint(set((2, 4)))
+        self.assertEqual(result, False)
+
+    def test_isdisjoint_superset(self):
+        result = self.set.isdisjoint(set([2, 4, 6, 8]))
+        self.assertEqual(result, False)
+
+    def test_isdisjoint_overlap(self):
+        result = self.set.isdisjoint(set([3, 4, 5]))
+        self.assertEqual(result, False)
+
+    def test_isdisjoint_non_overlap(self):
+        result = self.set.isdisjoint(set([8]))
+        self.assertEqual(result, True)
+
     def test_sym_difference_subset(self):
         result = self.set ^ set((2, 4))
         self.assertEqual(result, set([6]))
@@ -1016,7 +1166,7 @@
             popped[self.set.pop()] = None
         self.assertEqual(len(popped), len(self.values))
         for v in self.values:
-            self.failUnless(v in popped)
+            self.assertIn(v, popped)
 
     def test_update_empty_tuple(self):
         self.set.update(())
@@ -1248,6 +1398,10 @@
         self.other = operator.add
         self.otherIsIterable = False
 
+    def test_ge_gt_le_lt(self):
+        with test_support.check_py3k_warnings():
+            super(TestOnlySetsOperator, self).test_ge_gt_le_lt()
+
 #------------------------------------------------------------------------------
 
 class TestOnlySetsTuple(TestOnlySetsInBinaryOps):
@@ -1280,21 +1434,17 @@
 class TestCopying(unittest.TestCase):
 
     def test_copy(self):
-        dup = self.set.copy()
-        dup_list = list(dup); dup_list.sort()
-        set_list = list(self.set); set_list.sort()
-        self.assertEqual(len(dup_list), len(set_list))
-        for i in range(len(dup_list)):
-            self.failUnless(dup_list[i] is set_list[i])
+        dup = list(self.set.copy())
+        self.assertEqual(len(dup), len(self.set))
+        for el in self.set:
+            self.assertIn(el, dup)
+            pos = dup.index(el)
+            self.assertIs(el, dup.pop(pos))
+        self.assertFalse(dup)
 
     def test_deep_copy(self):
         dup = copy.deepcopy(self.set)
-        ##print type(dup), repr(dup)
-        dup_list = list(dup); dup_list.sort()
-        set_list = list(self.set); set_list.sort()
-        self.assertEqual(len(dup_list), len(set_list))
-        for i in range(len(dup_list)):
-            self.assertEqual(dup_list[i], set_list[i])
+        self.assertSetEqual(dup, self.set)
 
 #------------------------------------------------------------------------------
 
@@ -1335,13 +1485,13 @@
 
     def test_binopsVsSubsets(self):
         a, b = self.a, self.b
-        self.assert_(a - b < a)
-        self.assert_(b - a < b)
-        self.assert_(a & b < a)
-        self.assert_(a & b < b)
-        self.assert_(a | b > a)
-        self.assert_(a | b > b)
-        self.assert_(a ^ b < a | b)
+        self.assertTrue(a - b < a)
+        self.assertTrue(b - a < b)
+        self.assertTrue(a & b < a)
+        self.assertTrue(a & b < b)
+        self.assertTrue(a | b > a)
+        self.assertTrue(a | b > b)
+        self.assertTrue(a ^ b < a | b)
 
     def test_commutativity(self):
         a, b = self.a, self.b
@@ -1454,7 +1604,7 @@
         for cons in (set, frozenset):
             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
                 for g in (G, I, Ig, S, L, R):
-                    self.assertEqual(sorted(cons(g(s))), sorted(g(s)))
+                    self.assertSetEqual(cons(g(s)), set(g(s)))
                 self.assertRaises(TypeError, cons , X(s))
                 self.assertRaises(TypeError, cons , N(s))
                 self.assertRaises(ZeroDivisionError, cons , E(s))
@@ -1462,11 +1612,14 @@
     def test_inline_methods(self):
         s = set('november')
         for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
-            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference):
+            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
                 for g in (G, I, Ig, L, R):
                     expected = meth(data)
                     actual = meth(G(data))
-                    self.assertEqual(sorted(actual), sorted(expected))
+                    if isinstance(expected, bool):
+                        self.assertEqual(actual, expected)
+                    else:
+                        self.assertSetEqual(actual, expected)
                 self.assertRaises(TypeError, meth, X(s))
                 self.assertRaises(TypeError, meth, N(s))
                 self.assertRaises(ZeroDivisionError, meth, E(s))
@@ -1480,16 +1633,152 @@
                     t = s.copy()
                     getattr(s, methname)(list(g(data)))
                     getattr(t, methname)(g(data))
-                    self.assertEqual(sorted(s), sorted(t))
+                    self.assertSetEqual(s, t)
 
                 self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
                 self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
                 self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
 
+class bad_eq:
+    def __eq__(self, other):
+        if be_bad:
+            set2.clear()
+            raise ZeroDivisionError
+        return self is other
+    def __hash__(self):
+        return 0
+
+class bad_dict_clear:
+    def __eq__(self, other):
+        if be_bad:
+            dict2.clear()
+        return self is other
+    def __hash__(self):
+        return 0
+
+class TestWeirdBugs(unittest.TestCase):
+    def test_8420_set_merge(self):
+        # This used to segfault
+        global be_bad, set2, dict2
+        be_bad = False
+        set1 = {bad_eq()}
+        set2 = {bad_eq() for i in range(75)}
+        be_bad = True
+        self.assertRaises(ZeroDivisionError, set1.update, set2)
+
+        be_bad = False
+        set1 = {bad_dict_clear()}
+        dict2 = {bad_dict_clear(): None}
+        be_bad = True
+        set1.symmetric_difference_update(dict2)
+
+# Application tests (based on David Eppstein's graph recipes ====================================
+
+def powerset(U):
+    """Generates all subsets of a set or sequence U."""
+    U = iter(U)
+    try:
+        x = frozenset([U.next()])
+        for S in powerset(U):
+            yield S
+            yield S | x
+    except StopIteration:
+        yield frozenset()
+
+def cube(n):
+    """Graph of n-dimensional hypercube."""
+    singletons = [frozenset([x]) for x in range(n)]
+    return dict([(x, frozenset([x^s for s in singletons]))
+                 for x in powerset(range(n))])
+
+def linegraph(G):
+    """Graph, the vertices of which are edges of G,
+    with two vertices being adjacent iff the corresponding
+    edges share a vertex."""
+    L = {}
+    for x in G:
+        for y in G[x]:
+            nx = [frozenset([x,z]) for z in G[x] if z != y]
+            ny = [frozenset([y,z]) for z in G[y] if z != x]
+            L[frozenset([x,y])] = frozenset(nx+ny)
+    return L
+
+def faces(G):
+    'Return a set of faces in G.  Where a face is a set of vertices on that face'
+    # currently limited to triangles,squares, and pentagons
+    f = set()
+    for v1, edges in G.items():
+        for v2 in edges:
+            for v3 in G[v2]:
+                if v1 == v3:
+                    continue
+                if v1 in G[v3]:
+                    f.add(frozenset([v1, v2, v3]))
+                else:
+                    for v4 in G[v3]:
+                        if v4 == v2:
+                            continue
+                        if v1 in G[v4]:
+                            f.add(frozenset([v1, v2, v3, v4]))
+                        else:
+                            for v5 in G[v4]:
+                                if v5 == v3 or v5 == v2:
+                                    continue
+                                if v1 in G[v5]:
+                                    f.add(frozenset([v1, v2, v3, v4, v5]))
+    return f
+
+
+class TestGraphs(unittest.TestCase):
+
+    def test_cube(self):
+
+        g = cube(3)                             # vert --> {v1, v2, v3}
+        vertices1 = set(g)
+        self.assertEqual(len(vertices1), 8)     # eight vertices
+        for edge in g.values():
+            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
+        vertices2 = set(v for edges in g.values() for v in edges)
+        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
+
+        cubefaces = faces(g)
+        self.assertEqual(len(cubefaces), 6)     # six faces
+        for face in cubefaces:
+            self.assertEqual(len(face), 4)      # each face is a square
+
+    def test_cuboctahedron(self):
+
+        # http://en.wikipedia.org/wiki/Cuboctahedron
+        # 8 triangular faces and 6 square faces
+        # 12 indentical vertices each connecting a triangle and square
+
+        g = cube(3)
+        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
+        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
+
+        vertices = set(cuboctahedron)
+        for edges in cuboctahedron.values():
+            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
+        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
+        self.assertEqual(vertices, othervertices)   # edge vertices in original set
+
+        cubofaces = faces(cuboctahedron)
+        facesizes = collections.defaultdict(int)
+        for face in cubofaces:
+            facesizes[len(face)] += 1
+        self.assertEqual(facesizes[3], 8)       # eight triangular faces
+        self.assertEqual(facesizes[4], 6)       # six square faces
+
+        for vertex in cuboctahedron:
+            edge = vertex                       # Cuboctahedron vertices are edges in Cube
+            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
+            for cubevert in edge:
+                self.assertIn(cubevert, g)
+
+
 #==============================================================================
 
 def test_main(verbose=None):
-    from test import test_sets
     test_classes = (
         TestSet,
         TestSetSubclass,
@@ -1523,6 +1812,8 @@
         TestCopyingNested,
         TestIdentities,
         TestVariousIteratorArgs,
+        TestGraphs,
+        TestWeirdBugs,
         )
 
     test_support.run_unittest(*test_classes)

-- 
Repository URL: http://hg.python.org/jython


More information about the Jython-checkins mailing list