[Python-checkins] bpo-44796: Unify TypeVar and ParamSpec substitution (GH-31143)

serhiy-storchaka webhook-mailer at python.org
Fri Mar 11 03:47:56 EST 2022


https://github.com/python/cpython/commit/b6a5d8590c4bfe4553d796b36af03bda8c0d5af5
commit: b6a5d8590c4bfe4553d796b36af03bda8c0d5af5
branch: main
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2022-03-11T10:47:26+02:00
summary:

bpo-44796: Unify TypeVar and ParamSpec substitution (GH-31143)

Add methods __typing_subst__() in TypeVar and ParamSpec.
Simplify code by using more object-oriented approach, especially
the C code for types.GenericAlias and the Python code for
collections.abc.Callable.

files:
M Include/internal/pycore_global_strings.h
M Include/internal/pycore_runtime_init.h
M Lib/_collections_abc.py
M Lib/test/test_typing.py
M Lib/typing.py
M Objects/genericaliasobject.c

diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h
index 755d69a873cdc..35bffa7aff949 100644
--- a/Include/internal/pycore_global_strings.h
+++ b/Include/internal/pycore_global_strings.h
@@ -199,6 +199,7 @@ struct _Py_global_strings {
         STRUCT_FOR_ID(__subclasshook__)
         STRUCT_FOR_ID(__truediv__)
         STRUCT_FOR_ID(__trunc__)
+        STRUCT_FOR_ID(__typing_subst__)
         STRUCT_FOR_ID(__warningregistry__)
         STRUCT_FOR_ID(__weakref__)
         STRUCT_FOR_ID(__xor__)
diff --git a/Include/internal/pycore_runtime_init.h b/Include/internal/pycore_runtime_init.h
index 5ba18267aeb34..20d543a8cbc56 100644
--- a/Include/internal/pycore_runtime_init.h
+++ b/Include/internal/pycore_runtime_init.h
@@ -822,6 +822,7 @@ extern "C" {
                 INIT_ID(__subclasshook__), \
                 INIT_ID(__truediv__), \
                 INIT_ID(__trunc__), \
+                INIT_ID(__typing_subst__), \
                 INIT_ID(__warningregistry__), \
                 INIT_ID(__weakref__), \
                 INIT_ID(__xor__), \
diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py
index 40417dc1d3133..86eb042e3a75a 100644
--- a/Lib/_collections_abc.py
+++ b/Lib/_collections_abc.py
@@ -430,25 +430,13 @@ def __new__(cls, origin, args):
             raise TypeError(
                 "Callable must be used as Callable[[arg, ...], result].")
         t_args, t_result = args
-        if isinstance(t_args, list):
+        if isinstance(t_args, (tuple, list)):
             args = (*t_args, t_result)
         elif not _is_param_expr(t_args):
             raise TypeError(f"Expected a list of types, an ellipsis, "
                             f"ParamSpec, or Concatenate. Got {t_args}")
         return super().__new__(cls, origin, args)
 
-    @property
-    def __parameters__(self):
-        params = []
-        for arg in self.__args__:
-            # Looks like a genericalias
-            if hasattr(arg, "__parameters__") and isinstance(arg.__parameters__, tuple):
-                params.extend(arg.__parameters__)
-            else:
-                if _is_typevarlike(arg):
-                    params.append(arg)
-        return tuple(dict.fromkeys(params))
-
     def __repr__(self):
         if len(self.__args__) == 2 and _is_param_expr(self.__args__[0]):
             return super().__repr__()
@@ -468,57 +456,24 @@ def __getitem__(self, item):
         # code is copied from typing's _GenericAlias and the builtin
         # types.GenericAlias.
 
-        # A special case in PEP 612 where if X = Callable[P, int],
-        # then X[int, str] == X[[int, str]].
-        param_len = len(self.__parameters__)
-        if param_len == 0:
-            raise TypeError(f'{self} is not a generic class')
         if not isinstance(item, tuple):
             item = (item,)
-        if (param_len == 1 and _is_param_expr(self.__parameters__[0])
+        # A special case in PEP 612 where if X = Callable[P, int],
+        # then X[int, str] == X[[int, str]].
+        if (len(self.__parameters__) == 1
+                and _is_param_expr(self.__parameters__[0])
                 and item and not _is_param_expr(item[0])):
-            item = (list(item),)
-        item_len = len(item)
-        if item_len != param_len:
-            raise TypeError(f'Too {"many" if item_len > param_len else "few"}'
-                            f' arguments for {self};'
-                            f' actual {item_len}, expected {param_len}')
-        subst = dict(zip(self.__parameters__, item))
-        new_args = []
-        for arg in self.__args__:
-            if _is_typevarlike(arg):
-                if _is_param_expr(arg):
-                    arg = subst[arg]
-                    if not _is_param_expr(arg):
-                        raise TypeError(f"Expected a list of types, an ellipsis, "
-                                        f"ParamSpec, or Concatenate. Got {arg}")
-                else:
-                    arg = subst[arg]
-            # Looks like a GenericAlias
-            elif hasattr(arg, '__parameters__') and isinstance(arg.__parameters__, tuple):
-                subparams = arg.__parameters__
-                if subparams:
-                    subargs = tuple(subst[x] for x in subparams)
-                    arg = arg[subargs]
-            if isinstance(arg, tuple):
-                new_args.extend(arg)
-            else:
-                new_args.append(arg)
+            item = (item,)
+
+        new_args = super().__getitem__(item).__args__
 
         # args[0] occurs due to things like Z[[int, str, bool]] from PEP 612
-        if not isinstance(new_args[0], list):
+        if not isinstance(new_args[0], (tuple, list)):
             t_result = new_args[-1]
             t_args = new_args[:-1]
             new_args = (t_args, t_result)
         return _CallableGenericAlias(Callable, tuple(new_args))
 
-
-def _is_typevarlike(arg):
-    obj = type(arg)
-    # looks like a TypeVar/ParamSpec
-    return (obj.__module__ == 'typing'
-            and obj.__name__ in {'ParamSpec', 'TypeVar'})
-
 def _is_param_expr(obj):
     """Checks if obj matches either a list of types, ``...``, ``ParamSpec`` or
     ``_ConcatenateGenericAlias`` from typing.py
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index fc596e4d90b21..91b2e77e97b5a 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -360,10 +360,31 @@ def test_no_bivariant(self):
         with self.assertRaises(ValueError):
             TypeVar('T', covariant=True, contravariant=True)
 
+    def test_var_substitution(self):
+        T = TypeVar('T')
+        subst = T.__typing_subst__
+        self.assertIs(subst(int), int)
+        self.assertEqual(subst(list[int]), list[int])
+        self.assertEqual(subst(List[int]), List[int])
+        self.assertEqual(subst(List), List)
+        self.assertIs(subst(Any), Any)
+        self.assertIs(subst(None), type(None))
+        self.assertIs(subst(T), T)
+        self.assertEqual(subst(int|str), int|str)
+        self.assertEqual(subst(Union[int, str]), Union[int, str])
+
     def test_bad_var_substitution(self):
         T = TypeVar('T')
-        for arg in (), (int, str):
+        P = ParamSpec("P")
+        bad_args = (
+            42, ..., [int], (), (int, str), Union,
+            Generic, Generic[T], Protocol, Protocol[T],
+            Final, Final[int], ClassVar, ClassVar[int],
+        )
+        for arg in bad_args:
             with self.subTest(arg=arg):
+                with self.assertRaises(TypeError):
+                    T.__typing_subst__(arg)
                 with self.assertRaises(TypeError):
                     List[T][arg]
                 with self.assertRaises(TypeError):
@@ -1110,8 +1131,7 @@ def test_var_substitution(self):
         C2 = Callable[[KT, T], VT]
         C3 = Callable[..., T]
         self.assertEqual(C1[str], Callable[[int, str], str])
-        if Callable is typing.Callable:
-            self.assertEqual(C1[None], Callable[[int, type(None)], type(None)])
+        self.assertEqual(C1[None], Callable[[int, type(None)], type(None)])
         self.assertEqual(C2[int, float, str], Callable[[int, float], str])
         self.assertEqual(C3[int], Callable[..., int])
         self.assertEqual(C3[NoReturn], Callable[..., NoReturn])
@@ -2696,7 +2716,10 @@ def test_all_repr_eq_any(self):
         for obj in objs:
             self.assertNotEqual(repr(obj), '')
             self.assertEqual(obj, obj)
-            if getattr(obj, '__parameters__', None) and len(obj.__parameters__) == 1:
+            if (getattr(obj, '__parameters__', None)
+                    and not isinstance(obj, typing.TypeVar)
+                    and isinstance(obj.__parameters__, tuple)
+                    and len(obj.__parameters__) == 1):
                 self.assertEqual(obj[Any].__args__, (Any,))
             if isinstance(obj, type):
                 for base in obj.__mro__:
@@ -5748,33 +5771,30 @@ class X(Generic[P, P2]):
         self.assertEqual(G1.__args__, ((int, str), (bytes,)))
         self.assertEqual(G2.__args__, ((int,), (str, bytes)))
 
+    def test_var_substitution(self):
+        T = TypeVar("T")
+        P = ParamSpec("P")
+        subst = P.__typing_subst__
+        self.assertEqual(subst((int, str)), (int, str))
+        self.assertEqual(subst([int, str]), (int, str))
+        self.assertEqual(subst([None]), (type(None),))
+        self.assertIs(subst(...), ...)
+        self.assertIs(subst(P), P)
+        self.assertEqual(subst(Concatenate[int, P]), Concatenate[int, P])
+
     def test_bad_var_substitution(self):
         T = TypeVar('T')
         P = ParamSpec('P')
         bad_args = (42, int, None, T, int|str, Union[int, str])
         for arg in bad_args:
             with self.subTest(arg=arg):
+                with self.assertRaises(TypeError):
+                    P.__typing_subst__(arg)
                 with self.assertRaises(TypeError):
                     typing.Callable[P, T][arg, str]
                 with self.assertRaises(TypeError):
                     collections.abc.Callable[P, T][arg, str]
 
-    def test_no_paramspec_in__parameters__(self):
-        # ParamSpec should not be found in __parameters__
-        # of generics. Usages outside Callable, Concatenate
-        # and Generic are invalid.
-        T = TypeVar("T")
-        P = ParamSpec("P")
-        self.assertNotIn(P, List[P].__parameters__)
-        self.assertIn(T, Tuple[T, P].__parameters__)
-
-        # Test for consistency with builtin generics.
-        self.assertNotIn(P, list[P].__parameters__)
-        self.assertIn(T, tuple[T, P].__parameters__)
-
-        self.assertNotIn(P, (list[P] | int).__parameters__)
-        self.assertIn(T, (tuple[T, P] | int).__parameters__)
-
     def test_paramspec_in_nested_generics(self):
         # Although ParamSpec should not be found in __parameters__ of most
         # generics, they probably should be found when nested in
diff --git a/Lib/typing.py b/Lib/typing.py
index e3015563b3e8c..062c01ef2a9b9 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -179,7 +179,9 @@ def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=
     if (isinstance(arg, _GenericAlias) and
             arg.__origin__ in invalid_generic_forms):
         raise TypeError(f"{arg} is not valid as type argument")
-    if arg in (Any, NoReturn, Never, Self, ClassVar, Final, TypeAlias):
+    if arg in (Any, NoReturn, Never, Self, TypeAlias):
+        return arg
+    if allow_special_forms and arg in (ClassVar, Final):
         return arg
     if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
         raise TypeError(f"Plain {arg} is not valid as type argument")
@@ -217,21 +219,22 @@ def _type_repr(obj):
     return repr(obj)
 
 
-def _collect_type_vars(types_, typevar_types=None):
-    """Collect all type variable contained
-    in types in order of first appearance (lexicographic order). For example::
+def _collect_parameters(args):
+    """Collect all type variables and parameter specifications in args
+    in order of first appearance (lexicographic order). For example::
 
-        _collect_type_vars((T, List[S, T])) == (T, S)
+        _collect_parameters((T, Callable[P, T])) == (T, P)
     """
-    if typevar_types is None:
-        typevar_types = TypeVar
-    tvars = []
-    for t in types_:
-        if isinstance(t, typevar_types) and t not in tvars:
-            tvars.append(t)
-        if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
-            tvars.extend([t for t in t.__parameters__ if t not in tvars])
-    return tuple(tvars)
+    parameters = []
+    for t in args:
+        if hasattr(t, '__typing_subst__'):
+            if t not in parameters:
+                parameters.append(t)
+        else:
+            for x in getattr(t, '__parameters__', ()):
+                if x not in parameters:
+                    parameters.append(x)
+    return tuple(parameters)
 
 
 def _check_generic(cls, parameters, elen):
@@ -671,7 +674,6 @@ def Concatenate(self, parameters):
     msg = "Concatenate[arg, ...]: each arg must be a type."
     parameters = (*(_type_check(p, msg) for p in parameters[:-1]), parameters[-1])
     return _ConcatenateGenericAlias(self, parameters,
-                                    _typevar_types=(TypeVar, ParamSpec),
                                     _paramspec_tvars=True)
 
 
@@ -909,6 +911,11 @@ def __init__(self, name, *constraints, bound=None,
         if def_mod != 'typing':
             self.__module__ = def_mod
 
+    def __typing_subst__(self, arg):
+        msg = "Parameters to generic types must be types."
+        arg = _type_check(arg, msg, is_argument=True)
+        return arg
+
 
 class TypeVarTuple(_Final, _Immutable, _root=True):
     """Type variable tuple.
@@ -942,6 +949,9 @@ def __iter__(self):
     def __repr__(self):
         return self._name
 
+    def __typing_subst__(self, arg):
+        raise AssertionError
+
 
 class ParamSpecArgs(_Final, _Immutable, _root=True):
     """The args for a ParamSpec object.
@@ -1052,6 +1062,14 @@ def __init__(self, name, *, bound=None, covariant=False, contravariant=False):
         if def_mod != 'typing':
             self.__module__ = def_mod
 
+    def __typing_subst__(self, arg):
+        if isinstance(arg, (list, tuple)):
+            arg = tuple(_type_check(a, "Expected a type.") for a in arg)
+        elif not _is_param_expr(arg):
+            raise TypeError(f"Expected a list of types, an ellipsis, "
+                            f"ParamSpec, or Concatenate. Got {arg}")
+        return arg
+
 
 def _is_dunder(attr):
     return attr.startswith('__') and attr.endswith('__')
@@ -1106,7 +1124,7 @@ def __getattr__(self, attr):
 
     def __setattr__(self, attr, val):
         if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams',
-                                        '_typevar_types', '_paramspec_tvars'}:
+                                        '_paramspec_tvars'}:
             super().__setattr__(attr, val)
         else:
             setattr(self.__origin__, attr, val)
@@ -1199,7 +1217,6 @@ class _GenericAlias(_BaseGenericAlias, _root=True):
     #     TypeVar[bool]
 
     def __init__(self, origin, args, *, inst=True, name=None,
-                 _typevar_types=(TypeVar, TypeVarTuple),
                  _paramspec_tvars=False):
         super().__init__(origin, inst=inst, name=name)
         if not isinstance(args, tuple):
@@ -1207,8 +1224,7 @@ def __init__(self, origin, args, *, inst=True, name=None,
         self.__args__ = tuple(... if a is _TypingEllipsis else
                               () if a is _TypingEmpty else
                               a for a in args)
-        self.__parameters__ = _collect_type_vars(args, typevar_types=_typevar_types)
-        self._typevar_types = _typevar_types
+        self.__parameters__ = _collect_parameters(args)
         self._paramspec_tvars = _paramspec_tvars
         if not name:
             self.__module__ = origin.__module__
@@ -1291,26 +1307,20 @@ def _determine_new_args(self, args):
         new_args = []
         for old_arg in self.__args__:
 
-            if isinstance(old_arg, ParamSpec):
-                new_arg = new_arg_by_param[old_arg]
-                if not _is_param_expr(new_arg):
-                    raise TypeError(f"Expected a list of types, an ellipsis, "
-                                    f"ParamSpec, or Concatenate. Got {new_arg}")
-            elif isinstance(old_arg, self._typevar_types):
-                new_arg = new_arg_by_param[old_arg]
-            elif (TypeVarTuple in self._typevar_types
-                  and _is_unpacked_typevartuple(old_arg)):
+            if _is_unpacked_typevartuple(old_arg):
                 original_typevartuple = old_arg.__parameters__[0]
                 new_arg = new_arg_by_param[original_typevartuple]
-            elif isinstance(old_arg, (_GenericAlias, GenericAlias, types.UnionType)):
-                subparams = old_arg.__parameters__
-                if not subparams:
-                    new_arg = old_arg
-                else:
-                    subargs = tuple(new_arg_by_param[x] for x in subparams)
-                    new_arg = old_arg[subargs]
             else:
-                new_arg = old_arg
+                substfunc = getattr(old_arg, '__typing_subst__', None)
+                if substfunc:
+                    new_arg = substfunc(new_arg_by_param[old_arg])
+                else:
+                    subparams = getattr(old_arg, '__parameters__', ())
+                    if not subparams:
+                        new_arg = old_arg
+                    else:
+                        subargs = tuple(new_arg_by_param[x] for x in subparams)
+                        new_arg = old_arg[subargs]
 
             if self.__origin__ == collections.abc.Callable and isinstance(new_arg, tuple):
                 # Consider the following `Callable`.
@@ -1342,7 +1352,6 @@ def _determine_new_args(self, args):
 
     def copy_with(self, args):
         return self.__class__(self.__origin__, args, name=self._name, inst=self._inst,
-                              _typevar_types=self._typevar_types,
                               _paramspec_tvars=self._paramspec_tvars)
 
     def __repr__(self):
@@ -1454,7 +1463,6 @@ class _CallableType(_SpecialGenericAlias, _root=True):
     def copy_with(self, params):
         return _CallableGenericAlias(self.__origin__, params,
                                      name=self._name, inst=self._inst,
-                                     _typevar_types=(TypeVar, ParamSpec),
                                      _paramspec_tvars=True)
 
     def __getitem__(self, params):
@@ -1675,11 +1683,8 @@ def __class_getitem__(cls, params):
                 # don't check variadic generic arity at runtime (to reduce
                 # complexity of typing.py).
                 _check_generic(cls, params, len(cls.__parameters__))
-        return _GenericAlias(
-            cls, params,
-            _typevar_types=(TypeVar, TypeVarTuple, ParamSpec),
-            _paramspec_tvars=True,
-        )
+        return _GenericAlias(cls, params,
+                             _paramspec_tvars=True)
 
     def __init_subclass__(cls, *args, **kwargs):
         super().__init_subclass__(*args, **kwargs)
@@ -1691,9 +1696,7 @@ def __init_subclass__(cls, *args, **kwargs):
         if error:
             raise TypeError("Cannot inherit from plain Generic")
         if '__orig_bases__' in cls.__dict__:
-            tvars = _collect_type_vars(
-                cls.__orig_bases__, (TypeVar, TypeVarTuple, ParamSpec)
-            )
+            tvars = _collect_parameters(cls.__orig_bases__)
             # Look for Generic[T1, ..., Tn].
             # If found, tvars must be a subset of it.
             # If not found, tvars is it.
diff --git a/Objects/genericaliasobject.c b/Objects/genericaliasobject.c
index b41644910f5d2..45caf2e2ee7db 100644
--- a/Objects/genericaliasobject.c
+++ b/Objects/genericaliasobject.c
@@ -152,25 +152,6 @@ ga_repr(PyObject *self)
     return NULL;
 }
 
-// isinstance(obj, TypeVar) without importing typing.py.
-// Returns -1 for errors.
-static int
-is_typevar(PyObject *obj)
-{
-    PyTypeObject *type = Py_TYPE(obj);
-    if (strcmp(type->tp_name, "TypeVar") != 0) {
-        return 0;
-    }
-    PyObject *module = PyObject_GetAttrString((PyObject *)type, "__module__");
-    if (module == NULL) {
-        return -1;
-    }
-    int res = PyUnicode_Check(module)
-        && _PyUnicode_EqualToASCIIString(module, "typing");
-    Py_DECREF(module);
-    return res;
-}
-
 // Index of item in self[:len], or -1 if not found (self is a tuple)
 static Py_ssize_t
 tuple_index(PyObject *self, Py_ssize_t len, PyObject *item)
@@ -205,13 +186,14 @@ _Py_make_parameters(PyObject *args)
     Py_ssize_t iparam = 0;
     for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
         PyObject *t = PyTuple_GET_ITEM(args, iarg);
-        int typevar = is_typevar(t);
-        if (typevar < 0) {
+        PyObject *subst;
+        if (_PyObject_LookupAttr(t, &_Py_ID(__typing_subst__), &subst) < 0) {
             Py_DECREF(parameters);
             return NULL;
         }
-        if (typevar) {
+        if (subst) {
             iparam += tuple_add(parameters, iparam, t);
+            Py_DECREF(subst);
         }
         else {
             PyObject *subparams;
@@ -295,7 +277,7 @@ _Py_subs_parameters(PyObject *self, PyObject *args, PyObject *parameters, PyObje
     Py_ssize_t nparams = PyTuple_GET_SIZE(parameters);
     if (nparams == 0) {
         return PyErr_Format(PyExc_TypeError,
-                            "There are no type variables left in %R",
+                            "%R is not a generic class",
                             self);
     }
     int is_tuple = PyTuple_Check(item);
@@ -320,23 +302,23 @@ _Py_subs_parameters(PyObject *self, PyObject *args, PyObject *parameters, PyObje
     }
     for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
         PyObject *arg = PyTuple_GET_ITEM(args, iarg);
-        int typevar = is_typevar(arg);
-        if (typevar < 0) {
+        PyObject *subst;
+        if (_PyObject_LookupAttr(arg, &_Py_ID(__typing_subst__), &subst) < 0) {
             Py_DECREF(newargs);
             return NULL;
         }
-        if (typevar) {
+        if (subst) {
             Py_ssize_t iparam = tuple_index(parameters, nparams, arg);
             assert(iparam >= 0);
-            arg = argitems[iparam];
-            Py_INCREF(arg);
+            arg = PyObject_CallOneArg(subst, argitems[iparam]);
+            Py_DECREF(subst);
         }
         else {
             arg = subs_tvars(arg, parameters, argitems);
-            if (arg == NULL) {
-                Py_DECREF(newargs);
-                return NULL;
-            }
+        }
+        if (arg == NULL) {
+            Py_DECREF(newargs);
+            return NULL;
         }
         PyTuple_SET_ITEM(newargs, iarg, arg);
     }



More information about the Python-checkins mailing list