[Python-checkins] bpo-45081: Fix __init__ method generation when inheriting from Protocol (GH-28121)

ambv webhook-mailer at python.org
Thu Sep 2 12:17:22 EDT 2021


https://github.com/python/cpython/commit/0635e201beaf52373f776ff32702795e38f43ae3
commit: 0635e201beaf52373f776ff32702795e38f43ae3
branch: main
author: Yurii Karabas <1998uriyyo at gmail.com>
committer: ambv <lukasz at langa.pl>
date: 2021-09-02T18:17:13+02:00
summary:

bpo-45081: Fix __init__ method generation when inheriting from Protocol (GH-28121)

Co-authored-by: Ken Jin <28750310+Fidget-Spinner at users.noreply.github.com>

files:
A Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst
M Lib/test/test_dataclasses.py
M Lib/typing.py

diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index 8e645aeb4a750..33c9fcd165621 100644
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -10,7 +10,7 @@
 import builtins
 import unittest
 from unittest.mock import Mock
-from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
+from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
 from typing import get_type_hints
 from collections import deque, OrderedDict, namedtuple
 from functools import total_ordering
@@ -2150,6 +2150,26 @@ def __init__(self, x):
                 self.x = 2 * x
         self.assertEqual(C(5).x, 10)
 
+    def test_inherit_from_protocol(self):
+        # Dataclasses inheriting from protocol should preserve their own `__init__`.
+        # See bpo-45081.
+
+        class P(Protocol):
+            a: int
+
+        @dataclass
+        class C(P):
+            a: int
+
+        self.assertEqual(C(5).a, 5)
+
+        @dataclass
+        class D(P):
+            def __init__(self, a):
+                self.a = a * 2
+
+        self.assertEqual(D(5).a, 10)
+
 
 class TestRepr(unittest.TestCase):
     def test_repr(self):
diff --git a/Lib/typing.py b/Lib/typing.py
index 35c57c21b37c2..892f1b3506851 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -1400,8 +1400,29 @@ def _is_callable_members_only(cls):
     return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
 
 
-def _no_init(self, *args, **kwargs):
-    raise TypeError('Protocols cannot be instantiated')
+def _no_init_or_replace_init(self, *args, **kwargs):
+    cls = type(self)
+
+    if cls._is_protocol:
+        raise TypeError('Protocols cannot be instantiated')
+
+    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
+    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
+    # searches for a proper new `__init__` in the MRO. The new `__init__`
+    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
+    # instantiation of the protocol subclass will thus use the new
+    # `__init__` and no longer call `_no_init_or_replace_init`.
+    for base in cls.__mro__:
+        init = base.__dict__.get('__init__', _no_init_or_replace_init)
+        if init is not _no_init_or_replace_init:
+            cls.__init__ = init
+            break
+    else:
+        # should not happen
+        cls.__init__ = object.__init__
+
+    cls.__init__(self, *args, **kwargs)
+
 
 def _caller(depth=1, default='__main__'):
     try:
@@ -1541,15 +1562,6 @@ def _proto_hook(other):
 
         # We have nothing more to do for non-protocols...
         if not cls._is_protocol:
-            if cls.__init__ == _no_init:
-                for base in cls.__mro__:
-                    init = base.__dict__.get('__init__', _no_init)
-                    if init != _no_init:
-                        cls.__init__ = init
-                        break
-                else:
-                    # should not happen
-                    cls.__init__ = object.__init__
             return
 
         # ... otherwise check consistency of bases, and prohibit instantiation.
@@ -1560,7 +1572,7 @@ def _proto_hook(other):
                     issubclass(base, Generic) and base._is_protocol):
                 raise TypeError('Protocols can only inherit from other'
                                 ' protocols, got %r' % base)
-        cls.__init__ = _no_init
+        cls.__init__ = _no_init_or_replace_init
 
 
 class _AnnotatedAlias(_GenericAlias, _root=True):
diff --git a/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst b/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst
new file mode 100644
index 0000000000000..86d7182003bb9
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst
@@ -0,0 +1,2 @@
+Fix issue when dataclasses that inherit from ``typing.Protocol`` subclasses
+have wrong ``__init__``. Patch provided by Yurii Karabas.



More information about the Python-checkins mailing list