[Python-checkins] CVS: python/nondist/sandbox/sets set.py,NONE,1.1 test_set.py,NONE,1.1

Greg Wilson gvwilson@users.sourceforge.net
Tue, 22 May 2001 09:55:14 -0700


Update of /cvsroot/python/python/nondist/sandbox/sets
In directory usw-pr-cvs1:/tmp/cvs-serv9016

Added Files:
	set.py test_set.py 
Log Message:
First beta version of set module for Python.

--- NEW FILE: set.py ---
"""A class to represent sets in Python.
This class implements sets as dictionaries whose values are ignored.
The usual operations (union, intersection, deletion, etc.) are
provided as both methods and operators.  The only unusual feature of
this class is that once a set's hash code has been calculated (for
example, once it has been used as a dictionary key, or as an element
in another set), that set 'freezes', and becomes immutable.  See
PEP-0218 for a full discussion.
"""

__version__ = "$Revision: 1.1 $"
__author__  = "$Author: gvwilson $"
__date__    = "$Date: 2001/05/22 16:55:12 $"

from copy import deepcopy

class Set:

    # Displayed when operation forbidden because set has been frozen
    _Frozen_Msg = "Set is frozen: %s not permitted"

    #----------------------------------------
    def __init__(self, seq=None, sortRepr=0):

        """Construct a set, optionally initializing it with elements
        drawn from a sequence.  If 'sortRepr' is true, the set's
        elements are displayed in sorted order.  This slows down
        conversion, but simplifies comparison during testing.  The
        'hashcode' element is given a non-None value the first time
        the set's hashcode is calculated; the set is frozen
        thereafter."""

        self.elements = {}
        self.sortRepr = sortRepr
        if seq is not None:
            for x in seq:
                self.elements[x] = None
        self.hashcode = None

    #----------------------------------------
    def __str__(self):
        """Convert set to string."""
        content = self.elements.keys()
        if self.sortRepr:
            content.sort()
        return 'Set(' + `content` + ')'

    #----------------------------------------
    # '__repr__' returns the same thing as '__str__'
    __repr__ = __str__

    #----------------------------------------
    def __len__(self):
        """Return number of elements in set."""
        return len(self.elements)

    #----------------------------------------
    def __contains__(self, item):
        """Test presence of value in set."""
        return item in self.elements

    #----------------------------------------
    def __iter__(self):
        """Return iterator for enumerating set elements.  This is a
        keys iterator for the underlying dictionary."""
        return self.elements.iterkeys()

    #----------------------------------------
    def __cmp__(self, other):
        """Compare one set with another.  Sets may only be compared
        with sets; ordering is determined by the keys in the
        underlying dictionary."""
        if not isinstance(other, Set):
            raise ValueError, "Sets can only be compared to sets"
        return cmp(self.elements, other.elements)

    #----------------------------------------
    def __hash__(self):

        """Calculate hash code for set by xor'ing hash codes of set
        elements.  This algorithm ensures that the hash code does not
        depend on the order in which elements are added to the
        code."""

        # If set already has hashcode, the set has been frozen, so
        # code is still valid.
        if self.hashcode is not None:
            return self.hashcode

        # Combine hash codes of set elements to produce set's hash code.
        self.hashcode = 0
        for elt in self.elements:
            self.hashcode ^= hash(elt)
        return self.hashcode

    #----------------------------------------
    def isFrozen(self):

        """Report whether set is frozen or not.  A frozen set is one
        whose hash code has already been calculated.  Frozen sets may
        not be mutated, but unfrozen sets can be."""

        return self.hashcode is not None

    #----------------------------------------
    def __copy__(self):
        """Return a shallow copy of the set."""
        result = Set()
        result.elements = self.elements.copy()
        return result

    #----------------------------------------
    # Define 'copy' method as readable alias for '__copy__'.
    copy = __copy__

    #----------------------------------------
    def __deepcopy__(self, memo):
        result          = Set()
        memo[id(self)]  = result
        result.elements = deepcopy(self.elements, memo)
        return result

    #----------------------------------------
    def clear(self):
        """Remove all elements of unfrozen set."""
        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "clearing"
        self.elements.clear()

    #----------------------------------------
    def unionUpdate(self, other):
        """Update set with union of its own elements and the elements
        in another set."""

        self._binaryOpSanityCheck(other, "updating union")
        self.elements.update(other.elements)
        return self

    #----------------------------------------
    def union(self, other):
        """Create new set whose elements are the union of this set's
        and another's."""

        self._binaryOpSanityCheck(other)
        result = self.__copy__()
        result.unionUpdate(other)
        return result

    #----------------------------------------
    def intersectUpdate(self, other):
        """Update set with intersection of its own elements and the
        elements in another set."""

        self._binaryOpSanityCheck(other, "updating intersection")
        new_elements = {}
        for elt in self.elements:
            if elt in other.elements:
                new_elements[elt] = None
        self.elements = new_elements
        return self

    #----------------------------------------
    def intersect(self, other):
        """Create new set whose elements are the intersection of this
        set's and another's."""

        self._binaryOpSanityCheck(other)
        if len(self) <= len(other):
            little, big = self, other
        else:
            little, big = other, self
        result = Set()
        for elt in little.elements:
            if elt in big.elements:
                result.elements[elt] = None
        return result

    #----------------------------------------
    def symDifferenceUpdate(self, other):
        """Update set with symmetric difference of its own elements
        and the elements in another set.  A value 'x' is in the result
        if it was originally present in one or the other set, but not
        in both."""

        self._binaryOpSanityCheck(other, "updating symmetric difference")
        self.elements = self._rawSymDifference(self.elements, other.elements)
        return self

    #----------------------------------------
    def symDifference(self, other):
        """Create new set with symmetric difference of this set's own
        elements and the elements in another set.  A value 'x' is in
        the result if it was originally present in one or the other
        set, but not in both."""

        self._binaryOpSanityCheck(other)
        result = Set()
        result.elements = self._rawSymDifference(self.elements, other.elements)
        return result

    #----------------------------------------
    def differenceUpdate(self, other):
        """Remove all elements of another set from this set."""

        self._binaryOpSanityCheck(other, "updating difference")
        new_elements = {}
        for elt in self.elements:
            if elt not in other.elements:
                new_elements[elt] = None
        self.elements = new_elements
        return self

    #----------------------------------------
    def difference(self, other):
        """Create new set containing elements of this set that are not
        present in another set."""

        self._binaryOpSanityCheck(other)
        result = Set()
        for elt in self.elements:
            if elt not in other.elements:
                result.elements[elt] = None
        return result

    #----------------------------------------
    def add(self, item):
        """Add an item to a set.  This has no effect if the item is
        already present."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "adding an element"
        self.elements[item] = None

    #----------------------------------------
    def update(self, iterable):
        """Add all values from an iteratable (such as a tuple, list,
        or file) to this set."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "adding an element"
        for item in iterable:
            self.elements[item] = None

    #----------------------------------------
    def remove(self, item):
        """Remove an element from a set if it is present, or raise a
        LookupError if it is not."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "removing an element"
        try:
            del self.elements[item]
        except KeyError:
            raise LookupError, `item`

    #----------------------------------------
    def discard(self, item):
        """Remove an element from a set if it is present, or do
        nothing if it is not."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "removing an element"
        try:
            del self.elements[item]
        except KeyError:
            pass

    #----------------------------------------
    def popitem(self):
        """Remove and return a randomly-chosen set element."""

        try:
            (key, value) = self.elements.popitem()
            return key
        except KeyError:
            raise LookupError, "set is empty"

    #----------------------------------------
    def isSubsetOf(self, other):
        """Reports whether other set contains this set."""
        if not isinstance(other, Set):
            raise ValueError, "Subset tests only permitted between sets"
        for element in self.elements:
            if element not in other.elements:
                return 0
        return 1

    #----------------------------------------
    def containsAllOf(self, other):
        """Report whether other subset is subset of this set."""
        if not isinstance(other, Set):
            raise ValueError, "Subset tests only permitted between sets"
        for element in other.elements:
            if element not in self.elements:
                return 0
        return 1

    #----------------------------------------
    # Arithmetic forms of operations
    __or__      = union
    __ror__     = union
    __ior__     = unionUpdate
    __and__     = intersect
    __rand__    = intersect
    __iand__    = intersectUpdate
    __xor__     = symDifference
    __rxor__    = symDifference
    __ixor__    = symDifferenceUpdate
    __sub__     = difference
    __rsub__    = difference
    __isub__    = differenceUpdate

    #----------------------------------------
    # Check that the other argument to a binary operation is also a
    # set, and that this set is still mutable (if appropriate),
    # raising a ValueError if either condition is not met.
    def _binaryOpSanityCheck(self, other, updating_op=''):
        if updating_op and (self.hashcode is not None):
            raise ValueError, Set._Frozen_Msg % updating_op
        if not isinstance(other, Set):
            raise ValueError, "Binary operation only permitted between sets"

    #----------------------------------------
    # Calculate the symmetric difference between the keys in two
    # dictionaries with don't-care values.
    def _rawSymDifference(self, left, right):
        result = {}
        for elt in left:
            if elt not in right:
                result[elt] = None
        for elt in right:
            if elt not in left:
                result[elt] = None
        return result

#----------------------------------------------------------------------
# Rudimentary self-tests
#----------------------------------------------------------------------

if __name__ == "__main__":

    # Empty set
    red = Set()
    assert `red` == "Set([])", "Empty set: %s" % `red`

    # Unit set
    green = Set((0,))
    assert `green` == "Set([0])", "Unit set: %s" % `green`

    # 3-element set
    blue = Set([0, 1, 2])
    assert `blue` == "Set([2, 1, 0])", "3-element set: %s" % `blue`

    # 2-element set with other values
    black = Set([0, 5])
    assert `black` == "Set([5, 0])", "2-element set: %s" % `black`

    # All elements from all sets
    white = Set([0, 1, 2, 5])
    assert `white` == "Set([5, 2, 1, 0])", "4-element set: %s" % `white`

    # Add element to empty set
    red.add(9)
    assert `red` == "Set([9])", "Add to empty set: %s" % `red`

    # Remove element from unit set
    red.remove(9)
    assert `red` == "Set([])", "Remove from unit set: %s" % `red`

    # Remove element from empty set
    try:
        red.remove(0)
        assert 0, "Remove element from empty set: %s" % `red`
    except KeyError:
        pass

    # Length
    assert len(red) == 0,   "Length of empty set"
    assert len(green) == 1, "Length of unit set"
    assert len(blue) == 3,  "Length of 3-element set"

    # Compare
    assert green == Set([0]), "Equality failed"
    assert green != Set([1]), "Inequality failed"

    # Union
    assert blue  | red   == blue,  "Union non-empty with empty"
    assert red   | blue  == blue,  "Union empty with non-empty"
    assert green | blue  == blue,  "Union non-empty with non-empty"
    assert blue  | black == white, "Enclosing union"

    # Intersection
    assert blue  & red   == red,   "Intersect non-empty with empty"
    assert red   & blue  == red,   "Intersect empty with non-empty"
    assert green & blue  == green, "Intersect non-empty with non-empty"
    assert blue  & black == green, "Enclosing intersection"

    # Symmetric difference
    assert red ^ green == green,        "Empty symdiff non-empty"
    assert green ^ blue == Set([1, 2]), "Non-empty symdiff"
    assert white ^ white == red,        "Self symdiff"

    # Difference
    assert red - green == red,           "Empty - non-empty"
    assert blue - red == blue,           "Non-empty - empty"
    assert white - black == Set([1, 2]), "Non-empty - non-empty"

    # In-place union
    orange = Set([])
    orange |= Set([1])
    assert orange == Set([1]), "In-place union"

    # In-place intersection
    orange = Set([1, 2])
    orange &= Set([2])
    assert orange == Set([2]), "In-place intersection"

    # In-place difference
    orange = Set([1, 2, 3])
    orange -= Set([2, 4])
    assert orange == Set([1, 3]), "In-place difference"

    # In-place symmetric difference
    orange = Set([1, 2, 3])
    orange ^= Set([3, 4])
    assert orange == Set([1, 2, 4]), "In-place symmetric difference"

    print "All tests passed"

--- NEW FILE: test_set.py ---
#!/usr/bin/env python

from set import Set
import unittest, operator, copy

EmptySet = Set()

#===============================================================================

class TestBasicOps(unittest.TestCase):

    def checkRepr(self):
        if self.repr is not None:
            assert `self.set` == self.repr, "Wrong representation for " + self.case

    def checkLength(self):
        assert len(self.set) == self.length, "Wrong length for " + self.case

    def checkSelfEquality(self):
        assert self.set == self.set, "Self-equality failed for " + self.case

    def checkEquivalentEquality(self):
        assert self.set == self.dup, "Equivalent equality failed for " + self.case

    def checkCopy(self):
        assert self.set.copy() == self.dup, "Copy and comparison failed for " + self.case

    def checkSelfUnion(self):
        result = self.set | self.set
        assert result == self.dup, "Self-union failed for " + self.case

    def checkEmptyUnion(self):
        result = self.set | EmptySet
        assert result == self.dup, "Union with empty failed for " + self.case

    def checkUnionEmpty(self):
        result = EmptySet | self.set
        assert result == self.dup, "Union with empty failed for " + self.case

    def checkSelfIntersection(self):
        result = self.set & self.set
        assert result == self.dup, "Self-intersection failed for " + self.case

    def checkEmptyIntersection(self):
        result = self.set & EmptySet
        assert result == EmptySet, "Intersection with empty failed for " + self.case

    def checkIntersectionEmpty(self):
        result = EmptySet & self.set
        assert result == EmptySet, "Intersection with empty failed for " + self.case

    def checkSelfSymmetricDifference(self):
        result = self.set ^ self.set
        assert result == EmptySet, "Self-symdiff failed for " + self.case

    def checkEmptySymmetricDifference(self):
        result = self.set ^ EmptySet
        assert result == self.set, "Symdiff with empty failed for " + self.case

    def checkSelfDifference(self):
        result = self.set - self.set
        assert result == EmptySet, "Self-difference failed for " + self.case

    def checkEmptyDifference(self):
        result = self.set - EmptySet
        assert result == self.dup, "Difference with empty failed for " + self.case

    def checkEmptyDifferenceRev(self):
        result = EmptySet - self.set
        assert result == EmptySet, "Difference from empty failed for " + self.case

    def checkIteration(self):
        for v in self.set:
            assert v in self.values, "Missing item in iteration for " + self.case

#-------------------------------------------------------------------------------

class TestBasicOpsEmpty(TestBasicOps):
    def setUp(self):
        self.case   = "empty set"
        self.values = []
        self.set    = Set(self.values, 1)
        self.dup    = Set(self.values, 1)
        self.length = 0
        self.repr   = "Set([])"

#-------------------------------------------------------------------------------

class TestBasicOpsSingleton(TestBasicOps):
    def setUp(self):
        self.case   = "unit set (number)"
        self.values = [3]
        self.set    = Set(self.values, 1)
        self.dup    = Set(self.values, 1)
        self.length = 1
        self.repr   = "Set([3])"

    def checkIn(self):
        assert 3 in self.set, "Valueship for unit set"

    def checkNotIn(self):
        assert 2 not in self.set, "Non-valueship for unit set"

#-------------------------------------------------------------------------------

class TestBasicOpsTuple(TestBasicOps):
    def setUp(self):
        self.case   = "unit set (tuple)"
        self.values = [(0, "zero")]
        self.set    = Set(self.values, 1)
        self.dup    = Set(self.values, 1)
        self.length = 1
        self.repr   = "Set([(0, 'zero')])"

    def checkIn(self):
        assert (0, "zero") in self.set, "Valueship for tuple set"

    def checkNotIn(self):
        assert 9 not in self.set, "Non-valueship for tuple set"

#-------------------------------------------------------------------------------

class TestBasicOpsTriple(TestBasicOps):
    def setUp(self):
        self.case   = "triple set"
        self.values = [0, "zero", operator.add]
        self.set    = Set(self.values, 1)
        self.dup    = Set(self.values, 1)
        self.length = 3
        self.repr   = None

#===============================================================================

class TestBinaryOps(unittest.TestCase):
    def setUp(self):
        self.set = Set((2, 4, 6))

    def checkUnionSubset(self):
        result = self.set | Set([2])
        assert result == Set((2, 4, 6)), "Subset union"

    def checkUnionSuperset(self):
        result = self.set | Set([2, 4, 6, 8])
        assert result == Set([2, 4, 6, 8]), "Superset union"

    def checkUnionOverlap(self):
        result = self.set | Set([3, 4, 5])
        assert result == Set([2, 3, 4, 5, 6]), "Overlapping union"

    def checkUnionNonOverlap(self):
        result = self.set | Set([8])
        assert result == Set([2, 4, 6, 8]), "Non-overlapping union"

    def checkIntersectionSubset(self):
        result = self.set & Set((2, 4))
        assert result == Set((2, 4)), "Subset intersection"

    def checkIntersectionSuperset(self):
        result = self.set & Set([2, 4, 6, 8])
        assert result == Set([2, 4, 6]), "Superset intersection"

    def checkIntersectionOverlap(self):
        result = self.set & Set([3, 4, 5])
        assert result == Set([4]), "Overlapping intersection"

    def checkIntersectionNonOverlap(self):
        result = self.set & Set([8])
        assert result == EmptySet, "Non-overlapping intersection"

    def checkSymDifferenceSubset(self):
        result = self.set ^ Set((2, 4))
        assert result == Set([6]), "Subset symmetric difference"

    def checkSymDifferenceSuperset(self):
        result = self.set ^ Set((2, 4, 6, 8))
        assert result == Set([8]), "Superset symmetric difference"

    def checkSymDifferenceOverlap(self):
        result = self.set ^ Set((3, 4, 5))
        assert result == Set([2, 3, 5, 6]), "Overlapping symmetric difference"

    def checkSymDifferenceNonOverlap(self):
        result = self.set ^ Set([8])
        assert result == Set([2, 4, 6, 8]), "Non-overlapping symmetric difference"

#===============================================================================

class TestUpdateOps(unittest.TestCase):
    def setUp(self):
        self.set = Set((2, 4, 6))

    def checkUnionSubset(self):
        self.set |= Set([2])
        assert self.set == Set((2, 4, 6)), "Subset union"

    def checkUnionSuperset(self):
        self.set |= Set([2, 4, 6, 8])
        assert self.set == Set([2, 4, 6, 8]), "Superset union"

    def checkUnionOverlap(self):
        self.set |= Set([3, 4, 5])
        assert self.set == Set([2, 3, 4, 5, 6]), "Overlapping union"

    def checkUnionNonOverlap(self):
        self.set |= Set([8])
        assert self.set == Set([2, 4, 6, 8]), "Non-overlapping union"

    def checkIntersectionSubset(self):
        self.set &= Set((2, 4))
        assert self.set == Set((2, 4)), "Subset intersection"

    def checkIntersectionSuperset(self):
        self.set &= Set([2, 4, 6, 8])
        assert self.set == Set([2, 4, 6]), "Superset intersection"

    def checkIntersectionOverlap(self):
        self.set &= Set([3, 4, 5])
        assert self.set == Set([4]), "Overlapping intersection"

    def checkIntersectionNonOverlap(self):
        self.set &= Set([8])
        assert self.set == EmptySet, "Non-overlapping intersection"

    def checkSymDifferenceSubset(self):
        self.set ^= Set((2, 4))
        assert self.set == Set([6]), "Subset symmetric difference"

    def checkSymDifferenceSuperset(self):
        self.set ^= Set((2, 4, 6, 8))
        assert self.set == Set([8]), "Superset symmetric difference"

    def checkSymDifferenceOverlap(self):
        self.set ^= Set((3, 4, 5))
        assert self.set == Set([2, 3, 5, 6]), "Overlapping symmetric difference"

    def checkSymDifferenceNonOverlap(self):
        self.set ^= Set([8])
        assert self.set == Set([2, 4, 6, 8]), "Non-overlapping symmetric difference"

#===============================================================================

class TestMutate(unittest.TestCase):
    def setUp(self):
        self.values = ["a", "b", "c"]
        self.set = Set(self.values)

    def checkAddPresent(self):
        self.set.add("c")
        assert self.set == Set(("a", "b", "c")), "Adding present element"

    def checkAddAbsent(self):
        self.set.add("d")
        assert self.set == Set(("a", "b", "c", "d")), "Adding missing element"

    def checkAddUntilFull(self):
        tmp = Set()
        expectedLen = 0
        for v in self.values:
            tmp.add(v)
            expectedLen += 1
            assert len(tmp) == expectedLen, "Adding values one by one to temporary"
        assert tmp == self.set, "Adding values one by one"

    def checkRemovePresent(self):
        self.set.remove("b")
        assert self.set == Set(("a", "c")), "Removing present element"

    def checkRemoveAbsent(self):
        try:
            self.set.remove("d")
            assert 0, "Removing missing element"
        except LookupError:
            pass

    def checkRemoveUntilEmpty(self):
        expectedLen = len(self.set)
        for v in self.values:
            self.set.remove(v)
            expectedLen -= 1
            assert len(self.set) == expectedLen, "Removing values one by one"

    def checkDiscardPresent(self):
        self.set.discard("c")
        assert self.set == Set(("a", "b")), "Discarding present element"

    def checkDiscardAbsent(self):
        self.set.discard("d")
        assert self.set == Set(("a", "b", "c")), "Discarding missing element"

    def checkClear(self):
        self.set.clear()
        assert len(self.set) == 0, "Clearing set"

    def checkPopitem(self):
        popped = {}
        while self.set:
            popped[self.set.popitem()] = None
        assert len(popped) == len(self.values), "Popping items"
        for v in self.values:
            assert v in popped, "Popping items"

    def checkUpdateEmptyTuple(self):
        self.set.update(())
        assert self.set == Set(self.values), "Updating with empty tuple"

    def checkUpdateUnitTupleOverlap(self):
        self.set.update(("a",))
        assert self.set == Set(self.values), "Updating with overlapping unit tuple"

    def checkUpdateUnitTupleNonOverlap(self):
        self.set.update(("a", "z"))
        assert self.set == Set(self.values + ["z"]), "Updating with non-overlapping unit tuple"

#===============================================================================

class TestFreeze(unittest.TestCase):
    def setUp(self):
        self.values = [0, 1]
        self.set = Set(self.values)
        hash(self.set)

    def checkFreezeAfterHash(self):
        assert self.set.isFrozen(), "Set not frozen after hashing"

    def checkClearAfterFreeze(self):
        try:
            self.set.clear()
            assert 0, "Empty disregards freezing"
        except ValueError:
            pass

    def checkUnionAfterFreeze(self):
        try:
            self.set |= Set([2])
            assert 0, "Union update disregards freezing"
        except ValueError:
            pass

    def checkIntersectionAfterFreeze(self):
        try:
            self.set &= Set([2])
            assert 0, "Intersection update disregards freezing"
        except ValueError:
            pass

    def checkSymDifferenceAfterFreeze(self):
        try:
            self.set ^= Set([2])
            assert 0, "Symmetric difference update disregards freezing"
        except ValueError:
            pass

    def checkDifferenceAfterFreeze(self):
        try:
            self.set -= Set([2])
            assert 0, "Difference update disregards freezing"
        except ValueError:
            pass

    def checkAddAfterFreeze(self):
        try:
            self.set.add(4)
            assert 0, "Add disregards freezing"
        except ValueError:
            pass

    def checkUpdateAfterFreeze(self):
        try:
            self.set.update([4, 5])
            assert 0, "Update disregards freezing"
        except ValueError:
            pass

#===============================================================================

class TestSubsets(unittest.TestCase):

    def checkIsSubsetOf(self):
        result = self.left.isSubsetOf(self.right)
        if "<" in self.cases:
            assert result, "subset: " + self.name
        else:
            assert not result, "non-subset: " + self.name

    def checkContainsAllOf(self):
        result = self.left.containsAllOf(self.right)
        if ">" in self.cases:
            assert result, "contains all: " + self.name
        else:
            assert not result, "not contains all: " + self.name

#-------------------------------------------------------------------------------

class TestSubsetEqualEmpty(TestSubsets):
    def setUp(self):
        self.left  = Set()
        self.right = Set()
        self.name  = "both empty"
        self.cases = "<>"

#-------------------------------------------------------------------------------

class TestSubsetEqualNonEmpty(TestSubsets):
    def setUp(self):
        self.left  = Set([1, 2])
        self.right = Set([1, 2])
        self.name  = "equal pair"
        self.cases = "<>"

#-------------------------------------------------------------------------------

class TestSubsetEmptyNonEmpty(TestSubsets):
    def setUp(self):
        self.left  = Set()
        self.right = Set([1, 2])
        self.name  = "one empty, one non-empty"
        self.cases = "<"

#-------------------------------------------------------------------------------

class TestSubsetPartial(TestSubsets):
    def setUp(self):
        self.left  = Set([1])
        self.right = Set([1, 2])
        self.name  = "one a non-empty subset of other"
        self.cases = "<"

#-------------------------------------------------------------------------------

class TestSubsetNonOverlap(TestSubsets):
    def setUp(self):
        self.left  = Set([1])
        self.right = Set([2])
        self.name  = "neither empty, neither contains"
        self.cases = ""

#===============================================================================

class TestOnlySetsInBinaryOps(unittest.TestCase):

    def checkCmp(self):
        try:
            self.other < self.set
            assert 0, "Comparison with non-set on left"
        except ValueError:
            pass
        try:
            self.set >= self.other
            assert 0, "Comparison with non-set on right"
        except ValueError:
            pass

    def checkUnionUpdate(self):
        try:
            self.set |= self.other
            assert 0, "Union update with non-set"
        except ValueError:
            pass

    def checkUnion(self):
        try:
            self.other | self.set
            assert 0, "Union with non-set on left"
        except ValueError:
            pass
        try:
            self.set | self.other
            assert 0, "Union with non-set on right"
        except ValueError:
            pass

    def checkIntersectionUpdate(self):
        try:
            self.set &= self.other
            assert 0, "Intersection update with non-set"
        except ValueError:
            pass

    def checkIntersection(self):
        try:
            self.other & self.set
            assert 0, "Intersection with non-set on left"
        except ValueError:
            pass
        try:
            self.set & self.other
            assert 0, "Intersection with non-set on right"
        except ValueError:
            pass

    def checkSymDifferenceUpdate(self):
        try:
            self.set ^= self.other
            assert 0, "Symmetric difference update with non-set"
        except ValueError:
            pass

    def checkSymDifference(self):
        try:
            self.other ^ self.set
            assert 0, "Symmetric difference with non-set on left"
        except ValueError:
            pass
        try:
            self.set ^ self.other
            assert 0, "Symmetric difference with non-set on right"
        except ValueError:
            pass

    def checkDifferenceUpdate(self):
        try:
            self.set -= self.other
            assert 0, "Symmetric difference update with non-set"
        except ValueError:
            pass

    def checkDifference(self):
        try:
            self.other - self.set
            assert 0, "Symmetric difference with non-set on left"
        except ValueError:
            pass
        try:
            self.set - self.other
            assert 0, "Symmetric difference with non-set on right"
        except ValueError:
            pass

#-------------------------------------------------------------------------------

class TestOnlySetsNumeric(TestOnlySetsInBinaryOps):
    def setUp(self):
        self.set   = Set((1, 2, 3))
        self.other = 19

#-------------------------------------------------------------------------------

class TestOnlySetsDict(TestOnlySetsInBinaryOps):
    def setUp(self):
        self.set   = Set((1, 2, 3))
        self.other = {1:2, 3:4}

#-------------------------------------------------------------------------------

class TestOnlySetsOperator(TestOnlySetsInBinaryOps):
    def setUp(self):
        self.set   = Set((1, 2, 3))
        self.other = operator.add

#===============================================================================

class TestCopying(unittest.TestCase):

    def checkCopy(self):
        dup = self.set.copy()
        dup_list = list(dup); dup_list.sort()
        set_list = list(self.set); set_list.sort()
        assert len(dup_list) == len(set_list), "Unequal lengths after copy"
        for i in range(len(dup_list)):
            assert dup_list[i] is set_list[i], "Non-identical items after copy"

    def checkDeepCopy(self):
        dup = copy.deepcopy(self.set)
        dup_list = list(dup); dup_list.sort()
        set_list = list(self.set); set_list.sort()
        assert len(dup_list) == len(set_list), "Unequal lengths after deep copy"
        for i in range(len(dup_list)):
            assert dup_list[i] == set_list[i], "Unequal items after deep copy"

#-------------------------------------------------------------------------------

class TestCopyingEmpty(TestCopying):
    def setUp(self):
        self.set = Set()

#-------------------------------------------------------------------------------

class TestCopyingSingleton(TestCopying):
    def setUp(self):
        self.set = Set(["hello"])

#-------------------------------------------------------------------------------

class TestCopyingTriple(TestCopying):
    def setUp(self):
        self.set = Set(["zero", 0, None])

#-------------------------------------------------------------------------------

class TestCopyingTuple(TestCopying):
    def setUp(self):
        self.set = Set([(1, 2)])

#-------------------------------------------------------------------------------

class TestCopyingNested(TestCopying):
    def setUp(self):
        self.set = Set([((1, 2), (3, 4))])

#===============================================================================

def makeAllTests():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestBasicOpsEmpty,       'check'))
    suite.addTest(unittest.makeSuite(TestBasicOpsSingleton,   'check'))
    suite.addTest(unittest.makeSuite(TestBasicOpsTuple,       'check'))
    suite.addTest(unittest.makeSuite(TestBasicOpsTriple,      'check'))
    suite.addTest(unittest.makeSuite(TestBinaryOps,           'check'))
    suite.addTest(unittest.makeSuite(TestUpdateOps,           'check'))
    suite.addTest(unittest.makeSuite(TestMutate,              'check'))
    suite.addTest(unittest.makeSuite(TestFreeze,              'check'))
    suite.addTest(unittest.makeSuite(TestSubsetEqualEmpty,    'check'))
    suite.addTest(unittest.makeSuite(TestSubsetEqualNonEmpty, 'check'))
    suite.addTest(unittest.makeSuite(TestSubsetEmptyNonEmpty, 'check'))
    suite.addTest(unittest.makeSuite(TestSubsetPartial,       'check'))
    suite.addTest(unittest.makeSuite(TestSubsetNonOverlap,    'check'))
    suite.addTest(unittest.makeSuite(TestOnlySetsNumeric,     'check'))
    suite.addTest(unittest.makeSuite(TestOnlySetsDict,        'check'))
    suite.addTest(unittest.makeSuite(TestOnlySetsOperator,    'check'))
    suite.addTest(unittest.makeSuite(TestCopyingEmpty,        'check'))
    suite.addTest(unittest.makeSuite(TestCopyingSingleton,    'check'))
    suite.addTest(unittest.makeSuite(TestCopyingTriple,       'check'))
    suite.addTest(unittest.makeSuite(TestCopyingTuple,        'check'))
    suite.addTest(unittest.makeSuite(TestCopyingNested,       'check'))
    return suite

#-------------------------------------------------------------------------------

if __name__ == "__main__":
    unittest.main(defaultTest="makeAllTests")