[Python-checkins] [3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) (GH-23335)

miss-islington webhook-mailer at python.org
Tue Nov 17 10:23:44 EST 2020


https://github.com/python/cpython/commit/ac472b316cbb22ab8b750a474e991b46d1e92e15
commit: ac472b316cbb22ab8b750a474e991b46d1e92e15
branch: 3.9
author: Yurii Karabas <1998uriyyo at gmail.com>
committer: miss-islington <31488909+miss-islington at users.noreply.github.com>
date: 2020-11-17T07:23:36-08:00
summary:

[3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) (GH-23335)



Literal equality no longer depends on the order of arguments.

Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function.

Add deduplication of `typing.Literal` arguments.

(cherry picked from commit f03d318ca42578e45405717aedd4ac26ea52aaed)

Co-authored-by: Yurii Karabas <1998uriyyo at gmail.com>

files:
A Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst
M Lib/test/test_typing.py
M Lib/typing.py
M Misc/ACKS

diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 67cadc37e4fbe..9d82eec3f5376 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -532,6 +532,7 @@ def test_repr(self):
         self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
         self.assertEqual(repr(Literal), "typing.Literal")
         self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
+        self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
 
     def test_cannot_init(self):
         with self.assertRaises(TypeError):
@@ -563,6 +564,30 @@ def test_no_multiple_subscripts(self):
         with self.assertRaises(TypeError):
             Literal[1][1]
 
+    def test_equal(self):
+        self.assertNotEqual(Literal[0], Literal[False])
+        self.assertNotEqual(Literal[True], Literal[1])
+        self.assertNotEqual(Literal[1], Literal[2])
+        self.assertNotEqual(Literal[1, True], Literal[1])
+        self.assertEqual(Literal[1], Literal[1])
+        self.assertEqual(Literal[1, 2], Literal[2, 1])
+        self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])
+
+    def test_args(self):
+        self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
+        self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
+        self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
+        # Mutable arguments will not be deduplicated
+        self.assertEqual(Literal[[], []].__args__, ([], []))
+
+    def test_flatten(self):
+        l1 = Literal[Literal[1], Literal[2], Literal[3]]
+        l2 = Literal[Literal[1, 2], 3]
+        l3 = Literal[Literal[1, 2, 3]]
+        for l in l1, l2, l3:
+            self.assertEqual(l, Literal[1, 2, 3])
+            self.assertEqual(l.__args__, (1, 2, 3))
+
 
 XK = TypeVar('XK', str, bytes)
 XV = TypeVar('XV')
diff --git a/Lib/typing.py b/Lib/typing.py
index 6fd67b038834e..14952ec6cc695 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
                         f" actual {alen}, expected {elen}")
 
 
+def _deduplicate(params):
+    # Weed out strict duplicates, preserving the first of each occurrence.
+    all_params = set(params)
+    if len(all_params) < len(params):
+        new_params = []
+        for t in params:
+            if t in all_params:
+                new_params.append(t)
+                all_params.remove(t)
+        params = new_params
+        assert not all_params, all_params
+    return params
+
+
 def _remove_dups_flatten(parameters):
     """An internal helper for Union creation and substitution: flatten Unions
     among parameters, then remove duplicates.
@@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
             params.extend(p[1:])
         else:
             params.append(p)
-    # Weed out strict duplicates, preserving the first of each occurrence.
-    all_params = set(params)
-    if len(all_params) < len(params):
-        new_params = []
-        for t in params:
-            if t in all_params:
-                new_params.append(t)
-                all_params.remove(t)
-        params = new_params
-        assert not all_params, all_params
+
+    return tuple(_deduplicate(params))
+
+
+def _flatten_literal_params(parameters):
+    """An internal helper for Literal creation: flatten Literals among parameters"""
+    params = []
+    for p in parameters:
+        if isinstance(p, _LiteralGenericAlias):
+            params.extend(p.__args__)
+        else:
+            params.append(p)
     return tuple(params)
 
 
 _cleanups = []
 
 
-def _tp_cache(func):
+def _tp_cache(func=None, /, *, typed=False):
     """Internal wrapper caching __getitem__ of generic types with a fallback to
     original function for non-hashable arguments.
     """
-    cached = functools.lru_cache()(func)
-    _cleanups.append(cached.cache_clear)
+    def decorator(func):
+        cached = functools.lru_cache(typed=typed)(func)
+        _cleanups.append(cached.cache_clear)
 
-    @functools.wraps(func)
-    def inner(*args, **kwds):
-        try:
-            return cached(*args, **kwds)
-        except TypeError:
-            pass  # All real errors (not unhashable args) are raised below.
-        return func(*args, **kwds)
-    return inner
+        @functools.wraps(func)
+        def inner(*args, **kwds):
+            try:
+                return cached(*args, **kwds)
+            except TypeError:
+                pass  # All real errors (not unhashable args) are raised below.
+            return func(*args, **kwds)
+        return inner
+
+    if func is not None:
+        return decorator(func)
 
+    return decorator
 
 def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
     """Evaluate all forward references in the given type t.
@@ -317,6 +338,13 @@ def __subclasscheck__(self, cls):
     def __getitem__(self, parameters):
         return self._getitem(self, parameters)
 
+
+class _LiteralSpecialForm(_SpecialForm, _root=True):
+    @_tp_cache(typed=True)
+    def __getitem__(self, parameters):
+        return self._getitem(self, parameters)
+
+
 @_SpecialForm
 def Any(self, parameters):
     """Special type indicating an unconstrained type.
@@ -434,7 +462,7 @@ def Optional(self, parameters):
     arg = _type_check(parameters, f"{self} requires a single type.")
     return Union[arg, type(None)]
 
- at _SpecialForm
+ at _LiteralSpecialForm
 def Literal(self, parameters):
     """Special typing form to define literal types (a.k.a. value types).
 
@@ -458,7 +486,17 @@ def open_helper(file: str, mode: MODE) -> str:
     """
     # There is no '_type_check' call because arguments to Literal[...] are
     # values, not types.
-    return _GenericAlias(self, parameters)
+    if not isinstance(parameters, tuple):
+        parameters = (parameters,)
+
+    parameters = _flatten_literal_params(parameters)
+
+    try:
+        parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
+    except TypeError:  # unhashable parameters
+        pass
+
+    return _LiteralGenericAlias(self, parameters)
 
 
 class ForwardRef(_Final, _root=True):
@@ -881,6 +919,22 @@ def __repr__(self):
         return super().__repr__()
 
 
+def _value_and_type_iter(parameters):
+    return ((p, type(p)) for p in parameters)
+
+
+class _LiteralGenericAlias(_GenericAlias, _root=True):
+
+    def __eq__(self, other):
+        if not isinstance(other, _LiteralGenericAlias):
+            return NotImplemented
+
+        return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
+
+    def __hash__(self):
+        return hash(tuple(_value_and_type_iter(self.__args__)))
+
+
 class Generic:
     """Abstract base class for generic types.
 
diff --git a/Misc/ACKS b/Misc/ACKS
index 9ad9dffe22aea..12a5ac1410a77 100644
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -855,6 +855,7 @@ Jan Kanis
 Rafe Kaplan
 Jacob Kaplan-Moss
 Allison Kaptur
+Yurii Karabas
 Janne Karila
 Per Øyvind Karlsen
 Anton Kasyanov
diff --git a/Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst b/Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst
new file mode 100644
index 0000000000000..6339182c3ae72
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst
@@ -0,0 +1,2 @@
+Fix various issues with ``typing.Literal`` parameter handling (flatten,
+deduplicate, use type to cache key). Patch provided by Yurii Karabas.



More information about the Python-checkins mailing list