[Python-checkins] bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)

Eric V. Smith webhook-mailer at python.org
Sat Jan 27 19:07:43 EST 2018


https://github.com/python/cpython/commit/ea8fc52e75363276db23c6a8d7a689f79efce4f9
commit: ea8fc52e75363276db23c6a8d7a689f79efce4f9
branch: master
author: Eric V. Smith <ericvsmith at users.noreply.github.com>
committer: GitHub <noreply at github.com>
date: 2018-01-27T19:07:40-05:00
summary:

bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)

Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.

files:
A Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst
M Lib/dataclasses.py
M Lib/test/test_dataclasses.py

diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 7d30da1aacff..fb279cd30517 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -18,6 +18,142 @@
            'is_dataclass',
            ]
 
+# Conditions for adding methods.  The boxes indicate what action the
+#  dataclass decorator takes.  For all of these tables, when I talk
+#  about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring
+#  to the arguments to the @dataclass decorator.  When checking if a
+#  dunder method already exists, I mean check for an entry in the
+#  class's __dict__.  I never check to see if an attribute is defined
+#  in a base class.
+
+# Key:
+# +=========+=========================================+
+# + Value   | Meaning                                 |
+# +=========+=========================================+
+# | <blank> | No action: no method is added.          |
+# +---------+-----------------------------------------+
+# | add     | Generated method is added.              |
+# +---------+-----------------------------------------+
+# | add*    | Generated method is added only if the   |
+# |         |  existing attribute is None and if the  |
+# |         |  user supplied a __eq__ method in the   |
+# |         |  class definition.                      |
+# +---------+-----------------------------------------+
+# | raise   | TypeError is raised.                    |
+# +---------+-----------------------------------------+
+# | None    | Attribute is set to None.               |
+# +=========+=========================================+
+
+# __init__
+#
+#   +--- init= parameter
+#   |
+#   v     |       |       |
+#         |  no   |  yes  |  <--- class has __init__ in __dict__?
+# +=======+=======+=======+
+# | False |       |       |
+# +-------+-------+-------+
+# | True  | add   |       |  <- the default
+# +=======+=======+=======+
+
+# __repr__
+#
+#    +--- repr= parameter
+#    |
+#    v    |       |       |
+#         |  no   |  yes  |  <--- class has __repr__ in __dict__?
+# +=======+=======+=======+
+# | False |       |       |
+# +-------+-------+-------+
+# | True  | add   |       |  <- the default
+# +=======+=======+=======+
+
+
+# __setattr__
+# __delattr__
+#
+#    +--- frozen= parameter
+#    |
+#    v    |       |       |
+#         |  no   |  yes  |  <--- class has __setattr__ or __delattr__ in __dict__?
+# +=======+=======+=======+
+# | False |       |       |  <- the default
+# +-------+-------+-------+
+# | True  | add   | raise |
+# +=======+=======+=======+
+# Raise because not adding these methods would break the "frozen-ness"
+#  of the class.
+
+# __eq__
+#
+#    +--- eq= parameter
+#    |
+#    v    |       |       |
+#         |  no   |  yes  |  <--- class has __eq__ in __dict__?
+# +=======+=======+=======+
+# | False |       |       |
+# +-------+-------+-------+
+# | True  | add   |       |  <- the default
+# +=======+=======+=======+
+
+# __lt__
+# __le__
+# __gt__
+# __ge__
+#
+#    +--- order= parameter
+#    |
+#    v    |       |       |
+#         |  no   |  yes  |  <--- class has any comparison method in __dict__?
+# +=======+=======+=======+
+# | False |       |       |  <- the default
+# +-------+-------+-------+
+# | True  | add   | raise |
+# +=======+=======+=======+
+# Raise because to allow this case would interfere with using
+#  functools.total_ordering.
+
+# __hash__
+
+#      +------------------- hash= parameter
+#      |       +----------- eq= parameter
+#      |       |       +--- frozen= parameter
+#      |       |       |
+#      v       v       v    |        |        |
+#                           |   no   |  yes   |  <--- class has __hash__ in __dict__?
+# +=========+=======+=======+========+========+
+# | 1 None  | False | False |        |        | No __eq__, use the base class __hash__
+# +---------+-------+-------+--------+--------+
+# | 2 None  | False | True  |        |        | No __eq__, use the base class __hash__
+# +---------+-------+-------+--------+--------+
+# | 3 None  | True  | False | None   |        | <-- the default, not hashable
+# +---------+-------+-------+--------+--------+
+# | 4 None  | True  | True  | add    | add*   | Frozen, so hashable
+# +---------+-------+-------+--------+--------+
+# | 5 False | False | False |        |        |
+# +---------+-------+-------+--------+--------+
+# | 6 False | False | True  |        |        |
+# +---------+-------+-------+--------+--------+
+# | 7 False | True  | False |        |        |
+# +---------+-------+-------+--------+--------+
+# | 8 False | True  | True  |        |        |
+# +---------+-------+-------+--------+--------+
+# | 9 True  | False | False | add    | add*   | Has no __eq__, but hashable
+# +---------+-------+-------+--------+--------+
+# |10 True  | False | True  | add    | add*   | Has no __eq__, but hashable
+# +---------+-------+-------+--------+--------+
+# |11 True  | True  | False | add    | add*   | Not frozen, but hashable
+# +---------+-------+-------+--------+--------+
+# |12 True  | True  | True  | add    | add*   | Frozen, so hashable
+# +=========+=======+=======+========+========+
+# For boxes that are blank, __hash__ is untouched and therefore
+#  inherited from the base class.  If the base is object, then
+#  id-based hashing is used.
+# Note that a class may have already __hash__=None if it specified an
+#  __eq__ method in the class body (not one that was created by
+#  @dataclass).
+
+
 # Raised when an attempt is made to modify a frozen class.
 class FrozenInstanceError(AttributeError): pass
 
@@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields):
     #  return "(self.x,self.y)".
 
     # Special case for the 0-tuple.
-    if len(fields) == 0:
+    if not fields:
         return '()'
     # Note the trailing comma, needed if this turns out to be a 1-tuple.
     return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
 
 
-def _create_fn(name, args, body, globals=None, locals=None,
+def _create_fn(name, args, body, *, globals=None, locals=None,
                return_type=MISSING):
     # Note that we mutate locals when exec() is called. Caller beware!
     if locals is None:
@@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name):
         body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
 
     # If no body lines, use 'pass'.
-    if len(body_lines) == 0:
+    if not body_lines:
         body_lines = ['pass']
 
     locals = {f'_type_{f.name}': f.type for f in fields}
@@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
                         'return NotImplemented'])
 
 
-def _set_eq_fns(cls, fields):
-    # Create and set the equality comparison methods on cls.
-    # Pre-compute self_tuple and other_tuple, then re-use them for
-    #  each function.
-    self_tuple = _tuple_str('self', fields)
-    other_tuple = _tuple_str('other', fields)
-    for name, op in [('__eq__', '=='),
-                     ('__ne__', '!='),
-                     ]:
-        _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
-
-
-def _set_order_fns(cls, fields):
-    # Create and set the ordering methods on cls.
-    # Pre-compute self_tuple and other_tuple, then re-use them for
-    #  each function.
-    self_tuple = _tuple_str('self', fields)
-    other_tuple = _tuple_str('other', fields)
-    for name, op in [('__lt__', '<'),
-                     ('__le__', '<='),
-                     ('__gt__', '>'),
-                     ('__ge__', '>='),
-                     ]:
-        _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
-
-
 def _hash_fn(fields):
     self_tuple = _tuple_str('self', fields)
     return _create_fn('__hash__',
@@ -431,20 +541,20 @@ def _find_fields(cls):
     #  a Field(), then it contains additional info beyond (and
     #  possibly including) the actual default value.  Pseudo-fields
     #  ClassVars and InitVars are included, despite the fact that
-    #  they're not real fields.  That's deal with later.
+    #  they're not real fields.  That's dealt with later.
 
     annotations = getattr(cls, '__annotations__', {})
-
     return [_get_field(cls, a_name, a_type)
             for a_name, a_type in annotations.items()]
 
 
-def _set_attribute(cls, name, value):
-    # Raise TypeError if an attribute by this name already exists.
+def _set_new_attribute(cls, name, value):
+    # Never overwrites an existing attribute.  Returns True if the
+    #  attribute already exists.
     if name in cls.__dict__:
-        raise TypeError(f'Cannot overwrite attribute {name} '
-                        f'in {cls.__name__}')
+        return True
     setattr(cls, name, value)
+    return False
 
 
 def _process_class(cls, repr, eq, order, hash, init, frozen):
@@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
     #  be inherited down.
     is_frozen = frozen or cls.__setattr__ is _frozen_setattr
 
+    # Was this class defined with an __eq__?  Used in __hash__ logic.
+    auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None
+
     # If we're generating ordering methods, we must be generating
     #  the eq methods.
     if order and not eq:
@@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
         has_post_init = hasattr(cls, _POST_INIT_NAME)
 
         # Include InitVars and regular fields (so, not ClassVars).
-        _set_attribute(cls, '__init__',
-                       _init_fn(list(filter(lambda f: f._field_type
-                                              in (_FIELD, _FIELD_INITVAR),
-                                            fields.values())),
-                                is_frozen,
-                                has_post_init,
-                                # The name to use for the "self" param
-                                #  in __init__.  Use "self" if possible.
-                                '__dataclass_self__' if 'self' in fields
-                                    else 'self',
-                                ))
+        flds = [f for f in fields.values()
+                if f._field_type in (_FIELD, _FIELD_INITVAR)]
+        _set_new_attribute(cls, '__init__',
+                           _init_fn(flds,
+                                    is_frozen,
+                                    has_post_init,
+                                    # The name to use for the "self" param
+                                    #  in __init__.  Use "self" if possible.
+                                    '__dataclass_self__' if 'self' in fields
+                                            else 'self',
+                          ))
 
     # Get the fields as a list, and include only real fields.  This is
     #  used in all of the following methods.
-    field_list = list(filter(lambda f: f._field_type is _FIELD,
-                             fields.values()))
+    field_list = [f for f in fields.values() if f._field_type is _FIELD]
 
     if repr:
-        _set_attribute(cls, '__repr__',
-                       _repr_fn(list(filter(lambda f: f.repr, field_list))))
-
-    if is_frozen:
-        _set_attribute(cls, '__setattr__', _frozen_setattr)
-        _set_attribute(cls, '__delattr__', _frozen_delattr)
-
-    generate_hash = False
-    if hash is None:
-        if eq and frozen:
-            # Generate a hash function.
-            generate_hash = True
-        elif eq and not frozen:
-            # Not hashable.
-            _set_attribute(cls, '__hash__', None)
-        elif not eq:
-            # Otherwise, use the base class definition of hash().  That is,
-            #  don't set anything on this class.
-            pass
-        else:
-            assert "can't get here"
-    else:
-        generate_hash = hash
-    if generate_hash:
-        _set_attribute(cls, '__hash__',
-                       _hash_fn(list(filter(lambda f: f.compare
-                                                      if f.hash is None
-                                                      else f.hash,
-                                            field_list))))
+        flds = [f for f in field_list if f.repr]
+        _set_new_attribute(cls, '__repr__', _repr_fn(flds))
 
     if eq:
-        # Create and __eq__ and __ne__ methods.
-        _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
+        # Create _eq__ method.  There's no need for a __ne__ method,
+        #  since python will call __eq__ and negate it.
+        flds = [f for f in field_list if f.compare]
+        self_tuple = _tuple_str('self', flds)
+        other_tuple = _tuple_str('other', flds)
+        _set_new_attribute(cls, '__eq__',
+                           _cmp_fn('__eq__', '==',
+                                   self_tuple, other_tuple))
 
     if order:
-        # Create and __lt__, __le__, __gt__, and __ge__ methods.
-        # Create and set the comparison functions.
-        _set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
+        # Create and set the ordering methods.
+        flds = [f for f in field_list if f.compare]
+        self_tuple = _tuple_str('self', flds)
+        other_tuple = _tuple_str('other', flds)
+        for name, op in [('__lt__', '<'),
+                         ('__le__', '<='),
+                         ('__gt__', '>'),
+                         ('__ge__', '>='),
+                         ]:
+            if _set_new_attribute(cls, name,
+                                  _cmp_fn(name, op, self_tuple, other_tuple)):
+                raise TypeError(f'Cannot overwrite attribute {name} '
+                                f'in {cls.__name__}. Consider using '
+                                'functools.total_ordering')
+
+    if is_frozen:
+        for name, fn in [('__setattr__', _frozen_setattr),
+                         ('__delattr__', _frozen_delattr)]:
+            if _set_new_attribute(cls, name, fn):
+                raise TypeError(f'Cannot overwrite attribute {name} '
+                                f'in {cls.__name__}')
+
+    # Decide if/how we're going to create a hash function.
+    # TODO: Move this table to module scope, so it's not recreated
+    #  all the time.
+    generate_hash = {(None,  False, False): ('',     ''),
+                     (None,  False, True):  ('',     ''),
+                     (None,  True,  False): ('none', ''),
+                     (None,  True,  True):  ('fn',   'fn-x'),
+                     (False, False, False): ('',     ''),
+                     (False, False, True):  ('',     ''),
+                     (False, True,  False): ('',     ''),
+                     (False, True,  True):  ('',     ''),
+                     (True,  False, False): ('fn',   'fn-x'),
+                     (True,  False, True):  ('fn',   'fn-x'),
+                     (True,  True,  False): ('fn',   'fn-x'),
+                     (True,  True,  True):  ('fn',   'fn-x'),
+                     }[None if hash is None else bool(hash),   # Force bool() if not None.
+                       bool(eq),
+                       bool(frozen)]['__hash__' in cls.__dict__]
+    # No need to call _set_new_attribute here, since we already know if
+    #  we're overwriting a __hash__ or not.
+    if generate_hash == '':
+        # Do nothing.
+        pass
+    elif generate_hash == 'none':
+        cls.__hash__ = None
+    elif generate_hash in ('fn', 'fn-x'):
+        if generate_hash == 'fn' or auto_hash_test:
+            flds = [f for f in field_list
+                    if (f.compare if f.hash is None else f.hash)]
+            cls.__hash__ = _hash_fn(flds)
+    else:
+        assert False, f"can't get here: {generate_hash}"
 
     if not getattr(cls, '__doc__'):
         # Create a class doc-string.
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index 69819ea45078..53281f9dd9de 100755
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -9,6 +9,7 @@
 from unittest.mock import Mock
 from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
 from collections import deque, OrderedDict, namedtuple
+from functools import total_ordering
 
 # Just any custom exception we can catch.
 class CustomError(Exception): pass
@@ -82,68 +83,12 @@ class B:
             class C(B):
                 x: int = 0
 
-    def test_overwriting_init(self):
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __init__ '
-                                    'in C'):
-            @dataclass
-            class C:
-                x: int
-                def __init__(self, x):
-                    self.x = 2 * x
-
-        @dataclass(init=False)
-        class C:
-            x: int
-            def __init__(self, x):
-                self.x = 2 * x
-        self.assertEqual(C(5).x, 10)
-
-    def test_overwriting_repr(self):
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __repr__ '
-                                    'in C'):
-            @dataclass
-            class C:
-                x: int
-                def __repr__(self):
-                    pass
-
-        @dataclass(repr=False)
-        class C:
-            x: int
-            def __repr__(self):
-                return 'x'
-        self.assertEqual(repr(C(0)), 'x')
-
-    def test_overwriting_cmp(self):
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __eq__ '
-                                    'in C'):
-            # This will generate the comparison functions, make sure we can't
-            #  overwrite them.
-            @dataclass(hash=False, frozen=False)
-            class C:
-                x: int
-                def __eq__(self):
-                    pass
-
-        @dataclass(order=False, eq=False)
+    def test_overwriting_hash(self):
+        @dataclass(frozen=True)
         class C:
             x: int
-            def __eq__(self, other):
-                return True
-        self.assertEqual(C(0), 'x')
-
-    def test_overwriting_hash(self):
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __hash__ '
-                                    'in C'):
-            @dataclass(frozen=True)
-            class C:
-                x: int
-                def __hash__(self):
-                    pass
+            def __hash__(self):
+                pass
 
         @dataclass(frozen=True,hash=False)
         class C:
@@ -152,14 +97,11 @@ def __hash__(self):
                 return 600
         self.assertEqual(hash(C(0)), 600)
 
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __hash__ '
-                                    'in C'):
-            @dataclass(frozen=True)
-            class C:
-                x: int
-                def __hash__(self):
-                    pass
+        @dataclass(frozen=True)
+        class C:
+            x: int
+            def __hash__(self):
+                pass
 
         @dataclass(frozen=True, hash=False)
         class C:
@@ -168,33 +110,6 @@ def __hash__(self):
                 return 600
         self.assertEqual(hash(C(0)), 600)
 
-    def test_overwriting_frozen(self):
-        # frozen uses __setattr__ and __delattr__
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __setattr__ '
-                                    'in C'):
-            @dataclass(frozen=True)
-            class C:
-                x: int
-                def __setattr__(self):
-                    pass
-
-        with self.assertRaisesRegex(TypeError,
-                                    'Cannot overwrite attribute __delattr__ '
-                                    'in C'):
-            @dataclass(frozen=True)
-            class C:
-                x: int
-                def __delattr__(self):
-                    pass
-
-        @dataclass(frozen=False)
-        class C:
-            x: int
-            def __setattr__(self, name, value):
-                self.__dict__['x'] = value * 2
-        self.assertEqual(C(10).x, 20)
-
     def test_overwrite_fields_in_derived_class(self):
         # Note that x from C1 replaces x in Base, but the order remains
         #  the same as defined in Base.
@@ -239,34 +154,6 @@ class C:
         first = next(iter(sig.parameters))
         self.assertEqual('self', first)
 
-    def test_repr(self):
-        @dataclass
-        class B:
-            x: int
-
-        @dataclass
-        class C(B):
-            y: int = 10
-
-        o = C(4)
-        self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
-
-        @dataclass
-        class D(C):
-            x: int = 20
-        self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
-
-        @dataclass
-        class C:
-            @dataclass
-            class D:
-                i: int
-            @dataclass
-            class E:
-                pass
-        self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
-        self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
-
     def test_0_field_compare(self):
         # Ensure that order=False is the default.
         @dataclass
@@ -420,80 +307,8 @@ class C:
         self.assertEqual(hash(C(4)), hash((4,)))
         self.assertEqual(hash(C(42)), hash((42,)))
 
-    def test_hash(self):
-        @dataclass(hash=True)
-        class C:
-            x: int
-            y: str
-        self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
-
-    def test_no_hash(self):
-        @dataclass(hash=None)
-        class C:
-            x: int
-        with self.assertRaisesRegex(TypeError,
-                                    "unhashable type: 'C'"):
-            hash(C(1))
-
-    def test_hash_rules(self):
-        # There are 24 cases of:
-        #  hash=True/False/None
-        #  eq=True/False
-        #  order=True/False
-        #  frozen=True/False
-        for (hash,  eq,    order, frozen, result  ) in [
-            (False, False, False, False,  'absent'),
-            (False, False, False, True,   'absent'),
-            (False, False, True,  False,  'exception'),
-            (False, False, True,  True,   'exception'),
-            (False, True,  False, False,  'absent'),
-            (False, True,  False, True,   'absent'),
-            (False, True,  True,  False,  'absent'),
-            (False, True,  True,  True,   'absent'),
-            (True,  False, False, False,  'fn'),
-            (True,  False, False, True,   'fn'),
-            (True,  False, True,  False,  'exception'),
-            (True,  False, True,  True,   'exception'),
-            (True,  True,  False, False,  'fn'),
-            (True,  True,  False, True,   'fn'),
-            (True,  True,  True,  False,  'fn'),
-            (True,  True,  True,  True,   'fn'),
-            (None,  False, False, False,  'absent'),
-            (None,  False, False, True,   'absent'),
-            (None,  False, True,  False,  'exception'),
-            (None,  False, True,  True,   'exception'),
-            (None,  True,  False, False,  'none'),
-            (None,  True,  False, True,   'fn'),
-            (None,  True,  True,  False,  'none'),
-            (None,  True,  True,  True,   'fn'),
-        ]:
-            with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
-                if result == 'exception':
-                    with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
-                        @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
-                        class C:
-                            pass
-                else:
-                    @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
-                    class C:
-                        pass
-
-                    # See if the result matches what's expected.
-                    if result == 'fn':
-                        # __hash__ contains the function we generated.
-                        self.assertIn('__hash__', C.__dict__)
-                        self.assertIsNotNone(C.__dict__['__hash__'])
-                    elif result == 'absent':
-                        # __hash__ is not present in our class.
-                        self.assertNotIn('__hash__', C.__dict__)
-                    elif result == 'none':
-                        # __hash__ is set to None.
-                        self.assertIn('__hash__', C.__dict__)
-                        self.assertIsNone(C.__dict__['__hash__'])
-                    else:
-                        assert False, f'unknown result {result!r}'
-
     def test_eq_order(self):
+        # Test combining eq and order.
         for (eq,    order, result   ) in [
             (False, False, 'neither'),
             (False, True,  'exception'),
@@ -513,21 +328,18 @@ class C:
 
                     if result == 'neither':
                         self.assertNotIn('__eq__', C.__dict__)
-                        self.assertNotIn('__ne__', C.__dict__)
                         self.assertNotIn('__lt__', C.__dict__)
                         self.assertNotIn('__le__', C.__dict__)
                         self.assertNotIn('__gt__', C.__dict__)
                         self.assertNotIn('__ge__', C.__dict__)
                     elif result == 'both':
                         self.assertIn('__eq__', C.__dict__)
-                        self.assertIn('__ne__', C.__dict__)
                         self.assertIn('__lt__', C.__dict__)
                         self.assertIn('__le__', C.__dict__)
                         self.assertIn('__gt__', C.__dict__)
                         self.assertIn('__ge__', C.__dict__)
                     elif result == 'eq_only':
                         self.assertIn('__eq__', C.__dict__)
-                        self.assertIn('__ne__', C.__dict__)
                         self.assertNotIn('__lt__', C.__dict__)
                         self.assertNotIn('__le__', C.__dict__)
                         self.assertNotIn('__gt__', C.__dict__)
@@ -811,19 +623,6 @@ class C:
             y: int
         self.assertNotEqual(Point(1, 3), C(1, 3))
 
-    def test_base_has_init(self):
-        class B:
-            def __init__(self):
-                pass
-
-        # Make sure that declaring this class doesn't raise an error.
-        #  The issue is that we can't override __init__ in our class,
-        #  but it should be okay to add __init__ to us if our base has
-        #  an __init__.
-        @dataclass
-        class C(B):
-            x: int = 0
-
     def test_frozen(self):
         @dataclass(frozen=True)
         class C:
@@ -2065,6 +1864,7 @@ def test_helper_make_dataclass_no_types(self):
                                              'y': int,
                                              'z': 'typing.Any'})
 
+
 class TestDocString(unittest.TestCase):
     def assertDocStrEqual(self, a, b):
         # Because 3.6 and 3.7 differ in how inspect.signature work
@@ -2154,5 +1954,445 @@ class C:
         self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
 
 
+class TestInit(unittest.TestCase):
+    def test_base_has_init(self):
+        class B:
+            def __init__(self):
+                self.z = 100
+                pass
+
+        # Make sure that declaring this class doesn't raise an error.
+        #  The issue is that we can't override __init__ in our class,
+        #  but it should be okay to add __init__ to us if our base has
+        #  an __init__.
+        @dataclass
+        class C(B):
+            x: int = 0
+        c = C(10)
+        self.assertEqual(c.x, 10)
+        self.assertNotIn('z', vars(c))
+
+        # Make sure that if we don't add an init, the base __init__
+        #  gets called.
+        @dataclass(init=False)
+        class C(B):
+            x: int = 10
+        c = C()
+        self.assertEqual(c.x, 10)
+        self.assertEqual(c.z, 100)
+
+    def test_no_init(self):
+        dataclass(init=False)
+        class C:
+            i: int = 0
+        self.assertEqual(C().i, 0)
+
+        dataclass(init=False)
+        class C:
+            i: int = 2
+            def __init__(self):
+                self.i = 3
+        self.assertEqual(C().i, 3)
+
+    def test_overwriting_init(self):
+        # If the class has __init__, use it no matter the value of
+        #  init=.
+
+        @dataclass
+        class C:
+            x: int
+            def __init__(self, x):
+                self.x = 2 * x
+        self.assertEqual(C(3).x, 6)
+
+        @dataclass(init=True)
+        class C:
+            x: int
+            def __init__(self, x):
+                self.x = 2 * x
+        self.assertEqual(C(4).x, 8)
+
+        @dataclass(init=False)
+        class C:
+            x: int
+            def __init__(self, x):
+                self.x = 2 * x
+        self.assertEqual(C(5).x, 10)
+
+
+class TestRepr(unittest.TestCase):
+    def test_repr(self):
+        @dataclass
+        class B:
+            x: int
+
+        @dataclass
+        class C(B):
+            y: int = 10
+
+        o = C(4)
+        self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
+
+        @dataclass
+        class D(C):
+            x: int = 20
+        self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
+
+        @dataclass
+        class C:
+            @dataclass
+            class D:
+                i: int
+            @dataclass
+            class E:
+                pass
+        self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
+        self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
+
+    def test_no_repr(self):
+        # Test a class with no __repr__ and repr=False.
+        @dataclass(repr=False)
+        class C:
+            x: int
+        self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
+                      repr(C(3)))
+
+        # Test a class with a __repr__ and repr=False.
+        @dataclass(repr=False)
+        class C:
+            x: int
+            def __repr__(self):
+                return 'C-class'
+        self.assertEqual(repr(C(3)), 'C-class')
+
+    def test_overwriting_repr(self):
+        # If the class has __repr__, use it no matter the value of
+        #  repr=.
+
+        @dataclass
+        class C:
+            x: int
+            def __repr__(self):
+                return 'x'
+        self.assertEqual(repr(C(0)), 'x')
+
+        @dataclass(repr=True)
+        class C:
+            x: int
+            def __repr__(self):
+                return 'x'
+        self.assertEqual(repr(C(0)), 'x')
+
+        @dataclass(repr=False)
+        class C:
+            x: int
+            def __repr__(self):
+                return 'x'
+        self.assertEqual(repr(C(0)), 'x')
+
+
+class TestFrozen(unittest.TestCase):
+    def test_overwriting_frozen(self):
+        # frozen uses __setattr__ and __delattr__
+        with self.assertRaisesRegex(TypeError,
+                                    'Cannot overwrite attribute __setattr__'):
+            @dataclass(frozen=True)
+            class C:
+                x: int
+                def __setattr__(self):
+                    pass
+
+        with self.assertRaisesRegex(TypeError,
+                                    'Cannot overwrite attribute __delattr__'):
+            @dataclass(frozen=True)
+            class C:
+                x: int
+                def __delattr__(self):
+                    pass
+
+        @dataclass(frozen=False)
+        class C:
+            x: int
+            def __setattr__(self, name, value):
+                self.__dict__['x'] = value * 2
+        self.assertEqual(C(10).x, 20)
+
+
+class TestEq(unittest.TestCase):
+    def test_no_eq(self):
+        # Test a class with no __eq__ and eq=False.
+        @dataclass(eq=False)
+        class C:
+            x: int
+        self.assertNotEqual(C(0), C(0))
+        c = C(3)
+        self.assertEqual(c, c)
+
+        # Test a class with an __eq__ and eq=False.
+        @dataclass(eq=False)
+        class C:
+            x: int
+            def __eq__(self, other):
+                return other == 10
+        self.assertEqual(C(3), 10)
+
+    def test_overwriting_eq(self):
+        # If the class has __eq__, use it no matter the value of
+        #  eq=.
+
+        @dataclass
+        class C:
+            x: int
+            def __eq__(self, other):
+                return other == 3
+        self.assertEqual(C(1), 3)
+        self.assertNotEqual(C(1), 1)
+
+        @dataclass(eq=True)
+        class C:
+            x: int
+            def __eq__(self, other):
+                return other == 4
+        self.assertEqual(C(1), 4)
+        self.assertNotEqual(C(1), 1)
+
+        @dataclass(eq=False)
+        class C:
+            x: int
+            def __eq__(self, other):
+                return other == 5
+        self.assertEqual(C(1), 5)
+        self.assertNotEqual(C(1), 1)
+
+
+class TestOrdering(unittest.TestCase):
+    def test_functools_total_ordering(self):
+        # Test that functools.total_ordering works with this class.
+        @total_ordering
+        @dataclass
+        class C:
+            x: int
+            def __lt__(self, other):
+                # Perform the test "backward", just to make
+                #  sure this is being called.
+                return self.x >= other
+
+        self.assertLess(C(0), -1)
+        self.assertLessEqual(C(0), -1)
+        self.assertGreater(C(0), 1)
+        self.assertGreaterEqual(C(0), 1)
+
+    def test_no_order(self):
+        # Test that no ordering functions are added by default.
+        @dataclass(order=False)
+        class C:
+            x: int
+        # Make sure no order methods are added.
+        self.assertNotIn('__le__', C.__dict__)
+        self.assertNotIn('__lt__', C.__dict__)
+        self.assertNotIn('__ge__', C.__dict__)
+        self.assertNotIn('__gt__', C.__dict__)
+
+        # Test that __lt__ is still called
+        @dataclass(order=False)
+        class C:
+            x: int
+            def __lt__(self, other):
+                return False
+        # Make sure other methods aren't added.
+        self.assertNotIn('__le__', C.__dict__)
+        self.assertNotIn('__ge__', C.__dict__)
+        self.assertNotIn('__gt__', C.__dict__)
+
+    def test_overwriting_order(self):
+        with self.assertRaisesRegex(TypeError,
+                                    'Cannot overwrite attribute __lt__'
+                                    '.*using functools.total_ordering'):
+            @dataclass(order=True)
+            class C:
+                x: int
+                def __lt__(self):
+                    pass
+
+        with self.assertRaisesRegex(TypeError,
+                                    'Cannot overwrite attribute __le__'
+                                    '.*using functools.total_ordering'):
+            @dataclass(order=True)
+            class C:
+                x: int
+                def __le__(self):
+                    pass
+
+        with self.assertRaisesRegex(TypeError,
+                                    'Cannot overwrite attribute __gt__'
+                                    '.*using functools.total_ordering'):
+            @dataclass(order=True)
+            class C:
+                x: int
+                def __gt__(self):
+                    pass
+
+        with self.assertRaisesRegex(TypeError,
+                                    'Cannot overwrite attribute __ge__'
+                                    '.*using functools.total_ordering'):
+            @dataclass(order=True)
+            class C:
+                x: int
+                def __ge__(self):
+                    pass
+
+class TestHash(unittest.TestCase):
+    def test_hash(self):
+        @dataclass(hash=True)
+        class C:
+            x: int
+            y: str
+        self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+    def test_hash_false(self):
+        @dataclass(hash=False)
+        class C:
+            x: int
+            y: str
+        self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+    def test_hash_none(self):
+        @dataclass(hash=None)
+        class C:
+            x: int
+        with self.assertRaisesRegex(TypeError,
+                                    "unhashable type: 'C'"):
+            hash(C(1))
+
+    def test_hash_rules(self):
+        def non_bool(value):
+            # Map to something else that's True, but not a bool.
+            if value is None:
+                return None
+            if value:
+                return (3,)
+            return 0
+
+        def test(case, hash, eq, frozen, with_hash, result):
+            with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
+                if with_hash:
+                    @dataclass(hash=hash, eq=eq, frozen=frozen)
+                    class C:
+                        def __hash__(self):
+                            return 0
+                else:
+                    @dataclass(hash=hash, eq=eq, frozen=frozen)
+                    class C:
+                        pass
+
+                # See if the result matches what's expected.
+                if result in ('fn', 'fn-x'):
+                    # __hash__ contains the function we generated.
+                    self.assertIn('__hash__', C.__dict__)
+                    self.assertIsNotNone(C.__dict__['__hash__'])
+
+                    if result == 'fn-x':
+                        # This is the "auto-hash test" case.  We
+                        #  should overwrite __hash__ iff there's an
+                        #  __eq__ and if __hash__=None.
+
+                        # There are two ways of getting __hash__=None:
+                        #  explicitely, and by defining __eq__.  If
+                        #  __eq__ is defined, python will add __hash__
+                        #  when the class is created.
+                        @dataclass(hash=hash, eq=eq, frozen=frozen)
+                        class C:
+                            def __eq__(self, other): pass
+                            __hash__ = None
+
+                        # Hash should be overwritten (non-None).
+                        self.assertIsNotNone(C.__dict__['__hash__'])
+
+                        # Same test as above, but we don't provide
+                        #  __hash__, it will implicitely set to None.
+                        @dataclass(hash=hash, eq=eq, frozen=frozen)
+                        class C:
+                            def __eq__(self, other): pass
+
+                        # Hash should be overwritten (non-None).
+                        self.assertIsNotNone(C.__dict__['__hash__'])
+
+                elif result == '':
+                    # __hash__ is not present in our class.
+                    if not with_hash:
+                        self.assertNotIn('__hash__', C.__dict__)
+                elif result == 'none':
+                    # __hash__ is set to None.
+                    self.assertIn('__hash__', C.__dict__)
+                    self.assertIsNone(C.__dict__['__hash__'])
+                else:
+                    assert False, f'unknown result {result!r}'
+
+        # There are 12 cases of:
+        #  hash=True/False/None
+        #  eq=True/False
+        #  frozen=True/False
+        # And for each of these, a different result if
+        #  __hash__ is defined or not.
+        for case, (hash,  eq,    frozen, result_no, result_yes) in enumerate([
+                  (None,  False, False,  '',        ''),
+                  (None,  False, True,   '',        ''),
+                  (None,  True,  False,  'none',    ''),
+                  (None,  True,  True,   'fn',      'fn-x'),
+                  (False, False, False,  '',        ''),
+                  (False, False, True,   '',        ''),
+                  (False, True,  False,  '',        ''),
+                  (False, True,  True,   '',        ''),
+                  (True,  False, False,  'fn',      'fn-x'),
+                  (True,  False, True,   'fn',      'fn-x'),
+                  (True,  True,  False,  'fn',      'fn-x'),
+                  (True,  True,  True,   'fn',      'fn-x'),
+        ], 1):
+            test(case, hash, eq, frozen, False, result_no)
+            test(case, hash, eq, frozen, True,  result_yes)
+
+            # Test non-bool truth values, too.  This is just to
+            #  make sure the data-driven table in the decorator
+            #  handles non-bool values.
+            test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
+            test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True,  result_yes)
+
+
+    def test_eq_only(self):
+        # If a class defines __eq__, __hash__ is automatically added
+        #  and set to None.  This is normal Python behavior, not
+        #  related to dataclasses.  Make sure we don't interfere with
+        #  that (see bpo=32546).
+
+        @dataclass
+        class C:
+            i: int
+            def __eq__(self, other):
+                return self.i == other.i
+        self.assertEqual(C(1), C(1))
+        self.assertNotEqual(C(1), C(4))
+
+        # And make sure things work in this case if we specify
+        #  hash=True.
+        @dataclass(hash=True)
+        class C:
+            i: int
+            def __eq__(self, other):
+                return self.i == other.i
+        self.assertEqual(C(1), C(1.0))
+        self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+        # And check that the classes __eq__ is being used, despite
+        #  specifying eq=True.
+        @dataclass(hash=True, eq=True)
+        class C:
+            i: int
+            def __eq__(self, other):
+                return self.i == 3 and self.i == other.i
+        self.assertEqual(C(3), C(3))
+        self.assertNotEqual(C(1), C(1))
+        self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst b/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst
new file mode 100644
index 000000000000..48072417f771
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst
@@ -0,0 +1,2 @@
+In dataclasses, allow easier overriding of dunder methods without specifying
+decorator parameters.



More information about the Python-checkins mailing list