[Python-checkins] gh-91051: allow setting a callback hook on PyType_Modified (GH-97875)

markshannon webhook-mailer at python.org
Fri Oct 21 09:42:11 EDT 2022


https://github.com/python/cpython/commit/82ccbf69a842db25d8117f1c41b47aa5b4ed96ab
commit: 82ccbf69a842db25d8117f1c41b47aa5b4ed96ab
branch: main
author: Carl Meyer <carl at oddbird.net>
committer: markshannon <mark at hotpy.org>
date: 2022-10-21T14:41:51+01:00
summary:

gh-91051: allow setting a callback hook on PyType_Modified (GH-97875)

files:
A Misc/NEWS.d/next/C API/2022-10-05-10-43-32.gh-issue-91051.ODDRsQ.rst
M Doc/c-api/type.rst
M Doc/whatsnew/3.12.rst
M Include/cpython/object.h
M Include/internal/pycore_interp.h
M Lib/test/test_capi.py
M Lib/test/test_sys.py
M Modules/_testcapimodule.c
M Objects/typeobject.c

diff --git a/Doc/c-api/type.rst b/Doc/c-api/type.rst
index 1dc05001adfa..7b5d1fac40ed 100644
--- a/Doc/c-api/type.rst
+++ b/Doc/c-api/type.rst
@@ -57,6 +57,55 @@ Type Objects
    modification of the attributes or base classes of the type.
 
 
+.. c:function:: int PyType_AddWatcher(PyType_WatchCallback callback)
+
+   Register *callback* as a type watcher. Return a non-negative integer ID
+   which must be passed to future calls to :c:func:`PyType_Watch`. In case of
+   error (e.g. no more watcher IDs available), return ``-1`` and set an
+   exception.
+
+   .. versionadded:: 3.12
+
+
+.. c:function:: int PyType_ClearWatcher(int watcher_id)
+
+   Clear watcher identified by *watcher_id* (previously returned from
+   :c:func:`PyType_AddWatcher`). Return ``0`` on success, ``-1`` on error (e.g.
+   if *watcher_id* was never registered.)
+
+   An extension should never call ``PyType_ClearWatcher`` with a *watcher_id*
+   that was not returned to it by a previous call to
+   :c:func:`PyType_AddWatcher`.
+
+   .. versionadded:: 3.12
+
+
+.. c:function:: int PyType_Watch(int watcher_id, PyObject *type)
+
+   Mark *type* as watched. The callback granted *watcher_id* by
+   :c:func:`PyType_AddWatcher` will be called whenever
+   :c:func:`PyType_Modified` reports a change to *type*. (The callback may be
+   called only once for a series of consecutive modifications to *type*, if
+   :c:func:`PyType_Lookup` is not called on *type* between the modifications;
+   this is an implementation detail and subject to change.)
+
+   An extension should never call ``PyType_Watch`` with a *watcher_id* that was
+   not returned to it by a previous call to :c:func:`PyType_AddWatcher`.
+
+   .. versionadded:: 3.12
+
+
+.. c:type:: int (*PyType_WatchCallback)(PyObject *type)
+
+   Type of a type-watcher callback function.
+
+   The callback must not modify *type* or cause :c:func:`PyType_Modified` to be
+   called on *type* or any type in its MRO; violating this rule could cause
+   infinite recursion.
+
+   .. versionadded:: 3.12
+
+
 .. c:function:: int PyType_HasFeature(PyTypeObject *o, int feature)
 
    Return non-zero if the type object *o* sets the feature *feature*.
diff --git a/Doc/whatsnew/3.12.rst b/Doc/whatsnew/3.12.rst
index 525efc405c85..3e0b106c4a04 100644
--- a/Doc/whatsnew/3.12.rst
+++ b/Doc/whatsnew/3.12.rst
@@ -587,6 +587,12 @@ New Features
   :c:func:`PyDict_AddWatch` and related APIs to be called whenever a dictionary
   is modified. This is intended for use by optimizing interpreters, JIT
   compilers, or debuggers.
+  (Contributed by Carl Meyer in :gh:`91052`.)
+
+* Added :c:func:`PyType_AddWatcher` and :c:func:`PyType_Watch` API to register
+  callbacks to receive notification on changes to a type.
+  (Contributed by Carl Meyer in :gh:`91051`.)
+
 
 Porting to Python 3.12
 ----------------------
diff --git a/Include/cpython/object.h b/Include/cpython/object.h
index c80fc1df0e0b..900b52321dff 100644
--- a/Include/cpython/object.h
+++ b/Include/cpython/object.h
@@ -224,6 +224,9 @@ struct _typeobject {
 
     destructor tp_finalize;
     vectorcallfunc tp_vectorcall;
+
+    /* bitset of which type-watchers care about this type */
+    char tp_watched;
 };
 
 /* This struct is used by the specializer
@@ -510,3 +513,11 @@ Py_DEPRECATED(3.11) typedef int UsingDeprecatedTrashcanMacro;
 
 PyAPI_FUNC(int) _PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg);
 PyAPI_FUNC(void) _PyObject_ClearManagedDict(PyObject *obj);
+
+#define TYPE_MAX_WATCHERS 8
+
+typedef int(*PyType_WatchCallback)(PyTypeObject *);
+PyAPI_FUNC(int) PyType_AddWatcher(PyType_WatchCallback callback);
+PyAPI_FUNC(int) PyType_ClearWatcher(int watcher_id);
+PyAPI_FUNC(int) PyType_Watch(int watcher_id, PyObject *type);
+PyAPI_FUNC(int) PyType_Unwatch(int watcher_id, PyObject *type);
diff --git a/Include/internal/pycore_interp.h b/Include/internal/pycore_interp.h
index e643c7e9e4ed..9017e84e15af 100644
--- a/Include/internal/pycore_interp.h
+++ b/Include/internal/pycore_interp.h
@@ -166,6 +166,7 @@ struct _is {
     struct atexit_state atexit;
 
     PyObject *audit_hooks;
+    PyType_WatchCallback type_watchers[TYPE_MAX_WATCHERS];
 
     struct _Py_unicode_state unicode;
     struct _Py_float_state float_state;
diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py
index a2183cfb0fdf..364c607b3c18 100644
--- a/Lib/test/test_capi.py
+++ b/Lib/test/test_capi.py
@@ -2,7 +2,7 @@
 # these are all functions _testcapi exports whose name begins with 'test_'.
 
 from collections import OrderedDict
-from contextlib import contextmanager
+from contextlib import contextmanager, ExitStack
 import _thread
 import importlib.machinery
 import importlib.util
@@ -1606,5 +1606,172 @@ def test_clear_unassigned_watcher_id(self):
             self.clear_watcher(1)
 
 
+class TestTypeWatchers(unittest.TestCase):
+    # types of watchers testcapimodule can add:
+    TYPES = 0    # appends modified types to global event list
+    ERROR = 1    # unconditionally sets and signals a RuntimeException
+    WRAP = 2     # appends modified type wrapped in list to global event list
+
+    # duplicating the C constant
+    TYPE_MAX_WATCHERS = 8
+
+    def add_watcher(self, kind=TYPES):
+        return _testcapi.add_type_watcher(kind)
+
+    def clear_watcher(self, watcher_id):
+        _testcapi.clear_type_watcher(watcher_id)
+
+    @contextmanager
+    def watcher(self, kind=TYPES):
+        wid = self.add_watcher(kind)
+        try:
+            yield wid
+        finally:
+            self.clear_watcher(wid)
+
+    def assert_events(self, expected):
+        actual = _testcapi.get_type_modified_events()
+        self.assertEqual(actual, expected)
+
+    def watch(self, wid, t):
+        _testcapi.watch_type(wid, t)
+
+    def unwatch(self, wid, t):
+        _testcapi.unwatch_type(wid, t)
+
+    def test_watch_type(self):
+        class C: pass
+        with self.watcher() as wid:
+            self.watch(wid, C)
+            C.foo = "bar"
+            self.assert_events([C])
+
+    def test_event_aggregation(self):
+        class C: pass
+        with self.watcher() as wid:
+            self.watch(wid, C)
+            C.foo = "bar"
+            C.bar = "baz"
+            # only one event registered for both modifications
+            self.assert_events([C])
+
+    def test_lookup_resets_aggregation(self):
+        class C: pass
+        with self.watcher() as wid:
+            self.watch(wid, C)
+            C.foo = "bar"
+            # lookup resets type version tag
+            self.assertEqual(C.foo, "bar")
+            C.bar = "baz"
+            # both events registered
+            self.assert_events([C, C])
+
+    def test_unwatch_type(self):
+        class C: pass
+        with self.watcher() as wid:
+            self.watch(wid, C)
+            C.foo = "bar"
+            self.assertEqual(C.foo, "bar")
+            self.assert_events([C])
+            self.unwatch(wid, C)
+            C.bar = "baz"
+            self.assert_events([C])
+
+    def test_clear_watcher(self):
+        class C: pass
+        # outer watcher is unused, it's just to keep events list alive
+        with self.watcher() as _:
+            with self.watcher() as wid:
+                self.watch(wid, C)
+                C.foo = "bar"
+                self.assertEqual(C.foo, "bar")
+                self.assert_events([C])
+            C.bar = "baz"
+            # Watcher on C has been cleared, no new event
+            self.assert_events([C])
+
+    def test_watch_type_subclass(self):
+        class C: pass
+        class D(C): pass
+        with self.watcher() as wid:
+            self.watch(wid, D)
+            C.foo = "bar"
+            self.assert_events([D])
+
+    def test_error(self):
+        class C: pass
+        with self.watcher(kind=self.ERROR) as wid:
+            self.watch(wid, C)
+            with catch_unraisable_exception() as cm:
+                C.foo = "bar"
+                self.assertIs(cm.unraisable.object, C)
+                self.assertEqual(str(cm.unraisable.exc_value), "boom!")
+            self.assert_events([])
+
+    def test_two_watchers(self):
+        class C1: pass
+        class C2: pass
+        with self.watcher() as wid1:
+            with self.watcher(kind=self.WRAP) as wid2:
+                self.assertNotEqual(wid1, wid2)
+                self.watch(wid1, C1)
+                self.watch(wid2, C2)
+                C1.foo = "bar"
+                C2.hmm = "baz"
+                self.assert_events([C1, [C2]])
+
+    def test_watch_non_type(self):
+        with self.watcher() as wid:
+            with self.assertRaisesRegex(ValueError, r"Cannot watch non-type"):
+                self.watch(wid, 1)
+
+    def test_watch_out_of_range_watcher_id(self):
+        class C: pass
+        with self.assertRaisesRegex(ValueError, r"Invalid type watcher ID -1"):
+            self.watch(-1, C)
+        with self.assertRaisesRegex(ValueError, r"Invalid type watcher ID 8"):
+            self.watch(self.TYPE_MAX_WATCHERS, C)
+
+    def test_watch_unassigned_watcher_id(self):
+        class C: pass
+        with self.assertRaisesRegex(ValueError, r"No type watcher set for ID 1"):
+            self.watch(1, C)
+
+    def test_unwatch_non_type(self):
+        with self.watcher() as wid:
+            with self.assertRaisesRegex(ValueError, r"Cannot watch non-type"):
+                self.unwatch(wid, 1)
+
+    def test_unwatch_out_of_range_watcher_id(self):
+        class C: pass
+        with self.assertRaisesRegex(ValueError, r"Invalid type watcher ID -1"):
+            self.unwatch(-1, C)
+        with self.assertRaisesRegex(ValueError, r"Invalid type watcher ID 8"):
+            self.unwatch(self.TYPE_MAX_WATCHERS, C)
+
+    def test_unwatch_unassigned_watcher_id(self):
+        class C: pass
+        with self.assertRaisesRegex(ValueError, r"No type watcher set for ID 1"):
+            self.unwatch(1, C)
+
+    def test_clear_out_of_range_watcher_id(self):
+        with self.assertRaisesRegex(ValueError, r"Invalid type watcher ID -1"):
+            self.clear_watcher(-1)
+        with self.assertRaisesRegex(ValueError, r"Invalid type watcher ID 8"):
+            self.clear_watcher(self.TYPE_MAX_WATCHERS)
+
+    def test_clear_unassigned_watcher_id(self):
+        with self.assertRaisesRegex(ValueError, r"No type watcher set for ID 1"):
+            self.clear_watcher(1)
+
+    def test_no_more_ids_available(self):
+        contexts = [self.watcher() for i in range(self.TYPE_MAX_WATCHERS)]
+        with ExitStack() as stack:
+            for ctx in contexts:
+                stack.enter_context(ctx)
+            with self.assertRaisesRegex(RuntimeError, r"no more type watcher IDs"):
+                self.add_watcher()
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 41482734872e..9184e9a42f19 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -1521,7 +1521,7 @@ def delx(self): del self.__x
         check((1,2,3), vsize('') + 3*self.P)
         # type
         # static type: PyTypeObject
-        fmt = 'P2nPI13Pl4Pn9Pn12PIP'
+        fmt = 'P2nPI13Pl4Pn9Pn12PIPc'
         s = vsize('2P' + fmt)
         check(int, s)
         # class
diff --git a/Misc/NEWS.d/next/C API/2022-10-05-10-43-32.gh-issue-91051.ODDRsQ.rst b/Misc/NEWS.d/next/C API/2022-10-05-10-43-32.gh-issue-91051.ODDRsQ.rst
new file mode 100644
index 000000000000..c18e2d61f397
--- /dev/null
+++ b/Misc/NEWS.d/next/C API/2022-10-05-10-43-32.gh-issue-91051.ODDRsQ.rst	
@@ -0,0 +1,2 @@
+Add :c:func:`PyType_Watch` and related APIs to allow callbacks on
+:c:func:`PyType_Modified`.
diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c
index 624e878b20d8..194b2de2f0d6 100644
--- a/Modules/_testcapimodule.c
+++ b/Modules/_testcapimodule.c
@@ -5695,6 +5695,128 @@ function_get_module(PyObject *self, PyObject *func)
 }
 
 
+// type watchers
+
+static PyObject *g_type_modified_events;
+static int g_type_watchers_installed;
+
+static int
+type_modified_callback(PyTypeObject *type)
+{
+    assert(PyList_Check(g_type_modified_events));
+    if(PyList_Append(g_type_modified_events, (PyObject *)type) < 0) {
+        return -1;
+    }
+    return 0;
+}
+
+static int
+type_modified_callback_wrap(PyTypeObject *type)
+{
+    assert(PyList_Check(g_type_modified_events));
+    PyObject *list = PyList_New(0);
+    if (!list) {
+        return -1;
+    }
+    if (PyList_Append(list, (PyObject *)type) < 0) {
+        Py_DECREF(list);
+        return -1;
+    }
+    if (PyList_Append(g_type_modified_events, list) < 0) {
+        Py_DECREF(list);
+        return -1;
+    }
+    Py_DECREF(list);
+    return 0;
+}
+
+static int
+type_modified_callback_error(PyTypeObject *type)
+{
+    PyErr_SetString(PyExc_RuntimeError, "boom!");
+    return -1;
+}
+
+static PyObject *
+add_type_watcher(PyObject *self, PyObject *kind)
+{
+    int watcher_id;
+    assert(PyLong_Check(kind));
+    long kind_l = PyLong_AsLong(kind);
+    if (kind_l == 2) {
+        watcher_id = PyType_AddWatcher(type_modified_callback_wrap);
+    } else if (kind_l == 1) {
+        watcher_id = PyType_AddWatcher(type_modified_callback_error);
+    } else {
+        watcher_id = PyType_AddWatcher(type_modified_callback);
+    }
+    if (watcher_id < 0) {
+        return NULL;
+    }
+    if (!g_type_watchers_installed) {
+        assert(!g_type_modified_events);
+        if (!(g_type_modified_events = PyList_New(0))) {
+            return NULL;
+        }
+    }
+    g_type_watchers_installed++;
+    return PyLong_FromLong(watcher_id);
+}
+
+static PyObject *
+clear_type_watcher(PyObject *self, PyObject *watcher_id)
+{
+    if (PyType_ClearWatcher(PyLong_AsLong(watcher_id))) {
+        return NULL;
+    }
+    g_type_watchers_installed--;
+    if (!g_type_watchers_installed) {
+        assert(g_type_modified_events);
+        Py_CLEAR(g_type_modified_events);
+    }
+    Py_RETURN_NONE;
+}
+
+static PyObject *
+get_type_modified_events(PyObject *self, PyObject *Py_UNUSED(args))
+{
+    if (!g_type_modified_events) {
+        PyErr_SetString(PyExc_RuntimeError, "no watchers active");
+        return NULL;
+    }
+    Py_INCREF(g_type_modified_events);
+    return g_type_modified_events;
+}
+
+static PyObject *
+watch_type(PyObject *self, PyObject *args)
+{
+    PyObject *type;
+    int watcher_id;
+    if (!PyArg_ParseTuple(args, "iO", &watcher_id, &type)) {
+        return NULL;
+    }
+    if (PyType_Watch(watcher_id, type)) {
+        return NULL;
+    }
+    Py_RETURN_NONE;
+}
+
+static PyObject *
+unwatch_type(PyObject *self, PyObject *args)
+{
+    PyObject *type;
+    int watcher_id;
+    if (!PyArg_ParseTuple(args, "iO", &watcher_id, &type)) {
+        return NULL;
+    }
+    if (PyType_Unwatch(watcher_id, type)) {
+        return NULL;
+    }
+    Py_RETURN_NONE;
+}
+
+
 static PyObject *test_buildvalue_issue38913(PyObject *, PyObject *);
 static PyObject *getargs_s_hash_int(PyObject *, PyObject *, PyObject*);
 static PyObject *getargs_s_hash_int2(PyObject *, PyObject *, PyObject*);
@@ -5981,6 +6103,11 @@ static PyMethodDef TestMethods[] = {
     {"function_get_code", function_get_code, METH_O, NULL},
     {"function_get_globals", function_get_globals, METH_O, NULL},
     {"function_get_module", function_get_module, METH_O, NULL},
+    {"add_type_watcher", add_type_watcher, METH_O, NULL},
+    {"clear_type_watcher", clear_type_watcher, METH_O, NULL},
+    {"watch_type", watch_type, METH_VARARGS, NULL},
+    {"unwatch_type", unwatch_type, METH_VARARGS, NULL},
+    {"get_type_modified_events", get_type_modified_events, METH_NOARGS, NULL},
     {NULL, NULL} /* sentinel */
 };
 
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 196a6aee4993..7f8f2c7846eb 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -372,6 +372,83 @@ _PyTypes_Fini(PyInterpreterState *interp)
 
 static PyObject * lookup_subclasses(PyTypeObject *);
 
+int
+PyType_AddWatcher(PyType_WatchCallback callback)
+{
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+
+    for (int i = 0; i < TYPE_MAX_WATCHERS; i++) {
+        if (!interp->type_watchers[i]) {
+            interp->type_watchers[i] = callback;
+            return i;
+        }
+    }
+
+    PyErr_SetString(PyExc_RuntimeError, "no more type watcher IDs available");
+    return -1;
+}
+
+static inline int
+validate_watcher_id(PyInterpreterState *interp, int watcher_id)
+{
+    if (watcher_id < 0 || watcher_id >= TYPE_MAX_WATCHERS) {
+        PyErr_Format(PyExc_ValueError, "Invalid type watcher ID %d", watcher_id);
+        return -1;
+    }
+    if (!interp->type_watchers[watcher_id]) {
+        PyErr_Format(PyExc_ValueError, "No type watcher set for ID %d", watcher_id);
+        return -1;
+    }
+    return 0;
+}
+
+int
+PyType_ClearWatcher(int watcher_id)
+{
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+    if (validate_watcher_id(interp, watcher_id) < 0) {
+        return -1;
+    }
+    interp->type_watchers[watcher_id] = NULL;
+    return 0;
+}
+
+static int assign_version_tag(PyTypeObject *type);
+
+int
+PyType_Watch(int watcher_id, PyObject* obj)
+{
+    if (!PyType_Check(obj)) {
+        PyErr_SetString(PyExc_ValueError, "Cannot watch non-type");
+        return -1;
+    }
+    PyTypeObject *type = (PyTypeObject *)obj;
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+    if (validate_watcher_id(interp, watcher_id) < 0) {
+        return -1;
+    }
+    // ensure we will get a callback on the next modification
+    assign_version_tag(type);
+    type->tp_watched |= (1 << watcher_id);
+    return 0;
+}
+
+int
+PyType_Unwatch(int watcher_id, PyObject* obj)
+{
+    if (!PyType_Check(obj)) {
+        PyErr_SetString(PyExc_ValueError, "Cannot watch non-type");
+        return -1;
+    }
+    PyTypeObject *type = (PyTypeObject *)obj;
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+    if (validate_watcher_id(interp, watcher_id)) {
+        return -1;
+    }
+    type->tp_watched &= ~(1 << watcher_id);
+    return 0;
+}
+
 void
 PyType_Modified(PyTypeObject *type)
 {
@@ -409,6 +486,23 @@ PyType_Modified(PyTypeObject *type)
         }
     }
 
+    if (type->tp_watched) {
+        PyInterpreterState *interp = _PyInterpreterState_GET();
+        int bits = type->tp_watched;
+        int i = 0;
+        while(bits && i < TYPE_MAX_WATCHERS) {
+            if (bits & 1) {
+                PyType_WatchCallback cb = interp->type_watchers[i];
+                if (cb && (cb(type) < 0)) {
+                    PyErr_WriteUnraisable((PyObject *)type);
+                }
+            }
+            i += 1;
+            bits >>= 1;
+        }
+    }
+
+
     type->tp_flags &= ~Py_TPFLAGS_VALID_VERSION_TAG;
     type->tp_version_tag = 0; /* 0 is not a valid version tag */
 }
@@ -467,7 +561,7 @@ type_mro_modified(PyTypeObject *type, PyObject *bases) {
 }
 
 static int
-assign_version_tag(struct type_cache *cache, PyTypeObject *type)
+assign_version_tag(PyTypeObject *type)
 {
     /* Ensure that the tp_version_tag is valid and set
        Py_TPFLAGS_VALID_VERSION_TAG.  To respect the invariant, this
@@ -492,7 +586,7 @@ assign_version_tag(struct type_cache *cache, PyTypeObject *type)
     Py_ssize_t n = PyTuple_GET_SIZE(bases);
     for (Py_ssize_t i = 0; i < n; i++) {
         PyObject *b = PyTuple_GET_ITEM(bases, i);
-        if (!assign_version_tag(cache, _PyType_CAST(b)))
+        if (!assign_version_tag(_PyType_CAST(b)))
             return 0;
     }
     type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
@@ -4111,7 +4205,7 @@ _PyType_Lookup(PyTypeObject *type, PyObject *name)
         return NULL;
     }
 
-    if (MCACHE_CACHEABLE_NAME(name) && assign_version_tag(cache, type)) {
+    if (MCACHE_CACHEABLE_NAME(name) && assign_version_tag(type)) {
         h = MCACHE_HASH_METHOD(type, name);
         struct type_cache_entry *entry = &cache->hashtable[h];
         entry->version = type->tp_version_tag;



More information about the Python-checkins mailing list