[Python-checkins] bpo-43977: Make sure that tp_flags for pattern matching are inherited correctly. (GH-25813)

markshannon webhook-mailer at python.org
Sun May 2 19:38:32 EDT 2021


https://github.com/python/cpython/commit/33ec88ac81f23668293d101b83367b086c795e5e
commit: 33ec88ac81f23668293d101b83367b086c795e5e
branch: master
author: Mark Shannon <mark at hotpy.org>
committer: markshannon <mark at hotpy.org>
date: 2021-05-03T00:38:22+01:00
summary:

bpo-43977: Make sure that tp_flags for pattern matching are inherited correctly. (GH-25813)

files:
A Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst
M Lib/test/test_collections.py
M Lib/test/test_patma.py
M Modules/_abc.c
M Objects/typeobject.c

diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index 98690d231e606..2ba1a19ead9d8 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -1967,6 +1967,12 @@ def insert(self, index, value):
         self.assertEqual(len(mss), len(mss2))
         self.assertEqual(list(mss), list(mss2))
 
+    def test_illegal_patma_flags(self):
+        with self.assertRaises(TypeError):
+            class Both(Collection):
+                __abc_tpflags__ = (Sequence.__flags__ | Mapping.__flags__)
+
+
 
 ################################################################################
 ### Counter
diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py
index 8a273be7250bb..084d0879f1764 100644
--- a/Lib/test/test_patma.py
+++ b/Lib/test/test_patma.py
@@ -2979,6 +2979,47 @@ def f(x):
         self.assertEqual(f((False, range(10, 20), True)), alts[4])
 
 
+class TestInheritance(unittest.TestCase):
+
+    def test_multiple_inheritance(self):
+        class C:
+            pass
+        class S1(collections.UserList, collections.abc.Mapping):
+            pass
+        class S2(C, collections.UserList, collections.abc.Mapping):
+            pass
+        class S3(list, C, collections.abc.Mapping):
+            pass
+        class S4(collections.UserList, dict, C):
+            pass
+        class M1(collections.UserDict, collections.abc.Sequence):
+            pass
+        class M2(C, collections.UserDict, collections.abc.Sequence):
+            pass
+        class M3(collections.UserDict, C, list):
+            pass
+        class M4(dict, collections.abc.Sequence, C):
+            pass
+        def f(x):
+            match x:
+                case []:
+                    return "seq"
+                case {}:
+                    return "map"
+        def g(x):
+            match x:
+                case {}:
+                    return "map"
+                case []:
+                    return "seq"
+        for Seq in (S1, S2, S3, S4):
+            self.assertEqual(f(Seq()), "seq")
+            self.assertEqual(g(Seq()), "seq")
+        for Map in (M1, M2, M3, M4):
+            self.assertEqual(f(Map()), "map")
+            self.assertEqual(g(Map()), "map")
+
+
 class PerfPatma(TestPatma):
 
     def assertEqual(*_, **__):
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst b/Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst
new file mode 100644
index 0000000000000..95aacaf5fa2c3
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst	
@@ -0,0 +1 @@
+Prevent classes being both a sequence and a mapping when pattern matching.
diff --git a/Modules/_abc.c b/Modules/_abc.c
index 39261dd3cd579..7720d4051fe9e 100644
--- a/Modules/_abc.c
+++ b/Modules/_abc.c
@@ -467,6 +467,10 @@ _abc__abc_init(PyObject *module, PyObject *self)
                 if (val == -1 && PyErr_Occurred()) {
                     return NULL;
                 }
+                if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
+                    PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
+                    return NULL;
+                }
                 ((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
             }
             if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
@@ -527,9 +531,12 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
     /* Invalidate negative cache */
     get_abc_state(module)->abc_invalidation_counter++;
 
-    if (PyType_Check(subclass) && PyType_Check(self) &&
-        !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE))
+    /* Set Py_TPFLAGS_SEQUENCE  or Py_TPFLAGS_MAPPING flag */
+    if (PyType_Check(self) &&
+        !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
+        ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
     {
+        ((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
         ((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
     }
     Py_INCREF(subclass);
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 621bb0ca93022..e511cf9ebfc7e 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -5713,12 +5713,6 @@ inherit_special(PyTypeObject *type, PyTypeObject *base)
     if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
         type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
     }
-    if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) {
-        type->tp_flags |= Py_TPFLAGS_SEQUENCE;
-    }
-    if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) {
-        type->tp_flags |= Py_TPFLAGS_MAPPING;
-    }
 }
 
 static int
@@ -5936,6 +5930,7 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
 static int add_operators(PyTypeObject *);
 static int add_tp_new_wrapper(PyTypeObject *type);
 
+#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
 
 static int
 type_ready_checks(PyTypeObject *type)
@@ -5962,6 +5957,10 @@ type_ready_checks(PyTypeObject *type)
         _PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL);
     }
 
+    /* Consistency checks for pattern matching
+     * Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING are mutually exclusive */
+    _PyObject_ASSERT((PyObject *)type, (type->tp_flags & COLLECTION_FLAGS) != COLLECTION_FLAGS);
+
     if (type->tp_name == NULL) {
         PyErr_Format(PyExc_SystemError,
                      "Type does not define the tp_name field.");
@@ -6156,6 +6155,12 @@ type_ready_inherit_as_structs(PyTypeObject *type, PyTypeObject *base)
     }
 }
 
+static void
+inherit_patma_flags(PyTypeObject *type, PyTypeObject *base) {
+    if ((type->tp_flags & COLLECTION_FLAGS) == 0) {
+        type->tp_flags |= base->tp_flags & COLLECTION_FLAGS;
+    }
+}
 
 static int
 type_ready_inherit(PyTypeObject *type)
@@ -6175,6 +6180,7 @@ type_ready_inherit(PyTypeObject *type)
             if (inherit_slots(type, (PyTypeObject *)b) < 0) {
                 return -1;
             }
+            inherit_patma_flags(type, (PyTypeObject *)b);
         }
     }
 



More information about the Python-checkins mailing list