[Python-checkins] bpo-32953: Dataclasses: frozen should not be inherited for non-dataclass derived classes (#6147)

Eric V. Smith webhook-mailer at python.org
Sun Mar 18 20:40:38 EDT 2018


https://github.com/python/cpython/commit/f199bc655eb50c28e94010714629b376bbbd077b
commit: f199bc655eb50c28e94010714629b376bbbd077b
branch: master
author: Eric V. Smith <ericvsmith at users.noreply.github.com>
committer: GitHub <noreply at github.com>
date: 2018-03-18T20:40:34-04:00
summary:

bpo-32953: Dataclasses: frozen should not be inherited for non-dataclass derived classes (#6147)

If a non-dataclass derives from a frozen dataclass, allow attributes to be set.
Require either all of the dataclasses in a class hierarchy to be frozen, or all non-frozen.
Store `@dataclass` parameters on the class object under `__dataclass_params__`. This is needed to detect frozen base classes.

files:
A Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst
M Lib/dataclasses.py
M Lib/test/test_dataclasses.py

diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index b55a497db302..8ab04dd5b975 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -171,7 +171,11 @@ class _MISSING_TYPE:
 
 # The name of an attribute on the class where we store the Field
 #  objects. Also used to check if a class is a Data Class.
-_MARKER = '__dataclass_fields__'
+_FIELDS = '__dataclass_fields__'
+
+# The name of an attribute on the class that stores the parameters to
+# @dataclass.
+_PARAMS = '__dataclass_params__'
 
 # The name of the function, that if it exists, is called at the end of
 # __init__.
@@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta):
 # name and type are filled in after the fact, not in __init__. They're
 #  not known at the time this class is instantiated, but it's
 #  convenient if they're available later.
-# When cls._MARKER is filled in with a list of Field objects, the name
+# When cls._FIELDS is filled in with a list of Field objects, the name
 #  and type fields will have been populated.
 class Field:
     __slots__ = ('name',
@@ -236,6 +240,32 @@ def __repr__(self):
                 ')')
 
 
+class _DataclassParams:
+    __slots__ = ('init',
+                 'repr',
+                 'eq',
+                 'order',
+                 'unsafe_hash',
+                 'frozen',
+                 )
+    def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
+        self.init = init
+        self.repr = repr
+        self.eq = eq
+        self.order = order
+        self.unsafe_hash = unsafe_hash
+        self.frozen = frozen
+
+    def __repr__(self):
+        return ('_DataclassParams('
+                f'init={self.init},'
+                f'repr={self.repr},'
+                f'eq={self.eq},'
+                f'order={self.order},'
+                f'unsafe_hash={self.unsafe_hash},'
+                f'frozen={self.frozen}'
+                ')')
+
 # This function is used instead of exposing Field creation directly,
 #  so that a type checker can be told (via overloads) that this is a
 #  function whose type depends on its parameters.
@@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
     args = ','.join(args)
     body = '\n'.join(f' {b}' for b in body)
 
+    # Compute the text of the entire function.
     txt = f'def {name}({args}){return_annotation}:\n{body}'
 
     exec(txt, globals, locals)
@@ -432,12 +463,29 @@ def _repr_fn(fields):
                        ')"'])
 
 
-def _frozen_setattr(self, name, value):
-    raise FrozenInstanceError(f'cannot assign to field {name!r}')
-
-
-def _frozen_delattr(self, name):
-    raise FrozenInstanceError(f'cannot delete field {name!r}')
+def _frozen_get_del_attr(cls, fields):
+    # XXX: globals is modified on the first call to _create_fn, then the
+    #  modified version is used in the second call.  Is this okay?
+    globals = {'cls': cls,
+              'FrozenInstanceError': FrozenInstanceError}
+    if fields:
+        fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
+    else:
+        # Special case for the zero-length tuple.
+        fields_str = '()'
+    return (_create_fn('__setattr__',
+                      ('self', 'name', 'value'),
+                      (f'if type(self) is cls or name in {fields_str}:',
+                        ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
+                       f'super(cls, self).__setattr__(name, value)'),
+                       globals=globals),
+            _create_fn('__delattr__',
+                      ('self', 'name'),
+                      (f'if type(self) is cls or name in {fields_str}:',
+                        ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
+                       f'super(cls, self).__delattr__(name)'),
+                       globals=globals),
+            )
 
 
 def _cmp_fn(name, op, self_tuple, other_tuple):
@@ -583,23 +631,32 @@ def _set_new_attribute(cls, name, value):
 #  version of this table.
 
 
-def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
+def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
     # Now that dicts retain insertion order, there's no reason to use
     #  an ordered dict.  I am leveraging that ordering here, because
     #  derived class fields overwrite base class fields, but the order
     #  is defined by the base class, which is found first.
     fields = {}
 
+    setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
+                                           unsafe_hash, frozen))
+
     # Find our base classes in reverse MRO order, and exclude
     #  ourselves.  In reversed order so that more derived classes
     #  override earlier field definitions in base classes.
+    # As long as we're iterating over them, see if any are frozen.
+    any_frozen_base = False
+    has_dataclass_bases = False
     for b in cls.__mro__[-1:0:-1]:
         # Only process classes that have been processed by our
-        #  decorator.  That is, they have a _MARKER attribute.
-        base_fields = getattr(b, _MARKER, None)
+        #  decorator.  That is, they have a _FIELDS attribute.
+        base_fields = getattr(b, _FIELDS, None)
         if base_fields:
+            has_dataclass_bases = True
             for f in base_fields.values():
                 fields[f.name] = f
+            if getattr(b, _PARAMS).frozen:
+                any_frozen_base = True
 
     # Now find fields in our class.  While doing so, validate some
     #  things, and set the default values (as class attributes)
@@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
             else:
                 setattr(cls, f.name, f.default)
 
-    # We're inheriting from a frozen dataclass, but we're not frozen.
-    if cls.__setattr__ is _frozen_setattr and not frozen:
-        raise TypeError('cannot inherit non-frozen dataclass from a '
-                        'frozen one')
+    # Check rules that apply if we are derived from any dataclasses.
+    if has_dataclass_bases:
+        # Raise an exception if any of our bases are frozen, but we're not.
+        if any_frozen_base and not frozen:
+            raise TypeError('cannot inherit non-frozen dataclass from a '
+                            'frozen one')
 
-    # We're inheriting from a non-frozen dataclass, but we're frozen.
-    if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr
-        and frozen):
-        raise TypeError('cannot inherit frozen dataclass from a '
-                        'non-frozen one')
+        # Raise an exception if we're frozen, but none of our bases are.
+        if not any_frozen_base and frozen:
+            raise TypeError('cannot inherit frozen dataclass from a '
+                            'non-frozen one')
 
-    # Remember all of the fields on our class (including bases).  This
+    # Remember all of the fields on our class (including bases).  This also
     #  marks this class as being a dataclass.
-    setattr(cls, _MARKER, fields)
+    setattr(cls, _FIELDS, fields)
 
     # Was this class defined with an explicit __hash__?  Note that if
     #  __eq__ is defined in this class, then python will automatically
@@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
                                 'functools.total_ordering')
 
     if frozen:
-        for name, fn in [('__setattr__', _frozen_setattr),
-                         ('__delattr__', _frozen_delattr)]:
-            if _set_new_attribute(cls, name, fn):
-                raise TypeError(f'Cannot overwrite attribute {name} '
+        # XXX: Which fields are frozen? InitVar? ClassVar? hashed-only?
+        for fn in _frozen_get_del_attr(cls, field_list):
+            if _set_new_attribute(cls, fn.__name__, fn):
+                raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
                                 f'in class {cls.__name__}')
 
     # Decide if/how we're going to create a hash function.
@@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
     """
 
     def wrap(cls):
-        return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen)
+        return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
 
     # See if we're being called as @dataclass or @dataclass().
     if _cls is None:
@@ -779,7 +837,7 @@ def fields(class_or_instance):
 
     # Might it be worth caching this, per class?
     try:
-        fields =  getattr(class_or_instance, _MARKER)
+        fields =  getattr(class_or_instance, _FIELDS)
     except AttributeError:
         raise TypeError('must be called with a dataclass type or instance')
 
@@ -790,13 +848,13 @@ def fields(class_or_instance):
 
 def _is_dataclass_instance(obj):
     """Returns True if obj is an instance of a dataclass."""
-    return not isinstance(obj, type) and hasattr(obj, _MARKER)
+    return not isinstance(obj, type) and hasattr(obj, _FIELDS)
 
 
 def is_dataclass(obj):
     """Returns True if obj is a dataclass or an instance of a
     dataclass."""
-    return hasattr(obj, _MARKER)
+    return hasattr(obj, _FIELDS)
 
 
 def asdict(obj, *, dict_factory=dict):
@@ -953,7 +1011,7 @@ class C:
     # It's an error to have init=False fields in 'changes'.
     # If a field is not in 'changes', read its value from the provided obj.
 
-    for f in getattr(obj, _MARKER).values():
+    for f in getattr(obj, _FIELDS).values():
         if not f.init:
             # Error if this field is specified in changes.
             if f.name in changes:
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index 46d485c0157b..3e6726360940 100755
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -2476,41 +2476,92 @@ class D(C):
         d = D(0, 10)
         with self.assertRaises(FrozenInstanceError):
             d.i = 5
+        with self.assertRaises(FrozenInstanceError):
+            d.j = 6
         self.assertEqual(d.i, 0)
+        self.assertEqual(d.j, 10)
+
+    # Test both ways: with an intermediate normal (non-dataclass)
+    #  class and without an intermediate class.
+    def test_inherit_nonfrozen_from_frozen(self):
+        for intermediate_class in [True, False]:
+            with self.subTest(intermediate_class=intermediate_class):
+                @dataclass(frozen=True)
+                class C:
+                    i: int
 
-    def test_inherit_from_nonfrozen_from_frozen(self):
-        @dataclass(frozen=True)
-        class C:
-            i: int
+                if intermediate_class:
+                    class I(C): pass
+                else:
+                    I = C
 
-        with self.assertRaisesRegex(TypeError,
-                                    'cannot inherit non-frozen dataclass from a frozen one'):
-            @dataclass
-            class D(C):
-                pass
+                with self.assertRaisesRegex(TypeError,
+                                            'cannot inherit non-frozen dataclass from a frozen one'):
+                    @dataclass
+                    class D(I):
+                        pass
 
-    def test_inherit_from_frozen_from_nonfrozen(self):
-        @dataclass
-        class C:
-            i: int
+    def test_inherit_frozen_from_nonfrozen(self):
+        for intermediate_class in [True, False]:
+            with self.subTest(intermediate_class=intermediate_class):
+                @dataclass
+                class C:
+                    i: int
 
-        with self.assertRaisesRegex(TypeError,
-                                    'cannot inherit frozen dataclass from a non-frozen one'):
-            @dataclass(frozen=True)
-            class D(C):
-                pass
+                if intermediate_class:
+                    class I(C): pass
+                else:
+                    I = C
+
+                with self.assertRaisesRegex(TypeError,
+                                            'cannot inherit frozen dataclass from a non-frozen one'):
+                    @dataclass(frozen=True)
+                    class D(I):
+                        pass
 
     def test_inherit_from_normal_class(self):
-        class C:
-            pass
+        for intermediate_class in [True, False]:
+            with self.subTest(intermediate_class=intermediate_class):
+                class C:
+                    pass
+
+                if intermediate_class:
+                    class I(C): pass
+                else:
+                    I = C
+
+                @dataclass(frozen=True)
+                class D(I):
+                    i: int
+
+            d = D(10)
+            with self.assertRaises(FrozenInstanceError):
+                d.i = 5
+
+    def test_non_frozen_normal_derived(self):
+        # See bpo-32953.
 
         @dataclass(frozen=True)
-        class D(C):
-            i: int
+        class D:
+            x: int
+            y: int = 10
 
-        d = D(10)
+        class S(D):
+            pass
+
+        s = S(3)
+        self.assertEqual(s.x, 3)
+        self.assertEqual(s.y, 10)
+        s.cached = True
+
+        # But can't change the frozen attributes.
         with self.assertRaises(FrozenInstanceError):
-            d.i = 5
+            s.x = 5
+        with self.assertRaises(FrozenInstanceError):
+            s.y = 5
+        self.assertEqual(s.x, 3)
+        self.assertEqual(s.y, 10)
+        self.assertEqual(s.cached, True)
 
 
 if __name__ == '__main__':
diff --git a/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst b/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst
new file mode 100644
index 000000000000..fbea34aa9a2a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst
@@ -0,0 +1,4 @@
+If a non-dataclass inherits from a frozen dataclass, allow attributes to be
+added to the derived class.  Only attributes from from the frozen dataclass
+cannot be assigned to.  Require all dataclasses in a hierarchy to be either
+all frozen or all non-frozen.



More information about the Python-checkins mailing list