[Python-checkins] bpo-42064: Add module backref to `sqlite3` callback context (GH-28242)

encukou webhook-mailer at python.org
Tue Oct 19 09:44:55 EDT 2021


https://github.com/python/cpython/commit/09c04e7f0d26f0006964554b6a0caa5ef7f0bd24
commit: 09c04e7f0d26f0006964554b6a0caa5ef7f0bd24
branch: main
author: Erlend Egeberg Aasland <erlend.aasland at innova.no>
committer: encukou <encukou at gmail.com>
date: 2021-10-19T15:44:45+02:00
summary:

bpo-42064: Add module backref to `sqlite3` callback context (GH-28242)

files:
M Modules/_sqlite/clinic/connection.c.h
M Modules/_sqlite/connection.c
M Modules/_sqlite/connection.h

diff --git a/Modules/_sqlite/clinic/connection.c.h b/Modules/_sqlite/clinic/connection.c.h
index bf5a58d7756c9..e9e3064ae0f89 100644
--- a/Modules/_sqlite/clinic/connection.c.h
+++ b/Modules/_sqlite/clinic/connection.c.h
@@ -204,57 +204,30 @@ PyDoc_STRVAR(pysqlite_connection_create_function__doc__,
 "Creates a new function. Non-standard.");
 
 #define PYSQLITE_CONNECTION_CREATE_FUNCTION_METHODDEF    \
-    {"create_function", (PyCFunction)(void(*)(void))pysqlite_connection_create_function, METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_create_function__doc__},
+    {"create_function", (PyCFunction)(void(*)(void))pysqlite_connection_create_function, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_create_function__doc__},
 
 static PyObject *
 pysqlite_connection_create_function_impl(pysqlite_Connection *self,
-                                         const char *name, int narg,
-                                         PyObject *func, int deterministic);
+                                         PyTypeObject *cls, const char *name,
+                                         int narg, PyObject *func,
+                                         int deterministic);
 
 static PyObject *
-pysqlite_connection_create_function(pysqlite_Connection *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+pysqlite_connection_create_function(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
     static const char * const _keywords[] = {"name", "narg", "func", "deterministic", NULL};
-    static _PyArg_Parser _parser = {NULL, _keywords, "create_function", 0};
-    PyObject *argsbuf[4];
-    Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 3;
+    static _PyArg_Parser _parser = {"siO|$p:create_function", _keywords, 0};
     const char *name;
     int narg;
     PyObject *func;
     int deterministic = 0;
 
-    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 3, 3, 0, argsbuf);
-    if (!args) {
-        goto exit;
-    }
-    if (!PyUnicode_Check(args[0])) {
-        _PyArg_BadArgument("create_function", "argument 'name'", "str", args[0]);
-        goto exit;
-    }
-    Py_ssize_t name_length;
-    name = PyUnicode_AsUTF8AndSize(args[0], &name_length);
-    if (name == NULL) {
+    if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+        &name, &narg, &func, &deterministic)) {
         goto exit;
     }
-    if (strlen(name) != (size_t)name_length) {
-        PyErr_SetString(PyExc_ValueError, "embedded null character");
-        goto exit;
-    }
-    narg = _PyLong_AsInt(args[1]);
-    if (narg == -1 && PyErr_Occurred()) {
-        goto exit;
-    }
-    func = args[2];
-    if (!noptargs) {
-        goto skip_optional_kwonly;
-    }
-    deterministic = PyObject_IsTrue(args[3]);
-    if (deterministic < 0) {
-        goto exit;
-    }
-skip_optional_kwonly:
-    return_value = pysqlite_connection_create_function_impl(self, name, narg, func, deterministic);
+    return_value = pysqlite_connection_create_function_impl(self, cls, name, narg, func, deterministic);
 
 exit:
     return return_value;
@@ -267,47 +240,29 @@ PyDoc_STRVAR(pysqlite_connection_create_aggregate__doc__,
 "Creates a new aggregate. Non-standard.");
 
 #define PYSQLITE_CONNECTION_CREATE_AGGREGATE_METHODDEF    \
-    {"create_aggregate", (PyCFunction)(void(*)(void))pysqlite_connection_create_aggregate, METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_create_aggregate__doc__},
+    {"create_aggregate", (PyCFunction)(void(*)(void))pysqlite_connection_create_aggregate, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_create_aggregate__doc__},
 
 static PyObject *
 pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
+                                          PyTypeObject *cls,
                                           const char *name, int n_arg,
                                           PyObject *aggregate_class);
 
 static PyObject *
-pysqlite_connection_create_aggregate(pysqlite_Connection *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+pysqlite_connection_create_aggregate(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
     static const char * const _keywords[] = {"name", "n_arg", "aggregate_class", NULL};
-    static _PyArg_Parser _parser = {NULL, _keywords, "create_aggregate", 0};
-    PyObject *argsbuf[3];
+    static _PyArg_Parser _parser = {"siO:create_aggregate", _keywords, 0};
     const char *name;
     int n_arg;
     PyObject *aggregate_class;
 
-    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 3, 3, 0, argsbuf);
-    if (!args) {
-        goto exit;
-    }
-    if (!PyUnicode_Check(args[0])) {
-        _PyArg_BadArgument("create_aggregate", "argument 'name'", "str", args[0]);
-        goto exit;
-    }
-    Py_ssize_t name_length;
-    name = PyUnicode_AsUTF8AndSize(args[0], &name_length);
-    if (name == NULL) {
+    if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+        &name, &n_arg, &aggregate_class)) {
         goto exit;
     }
-    if (strlen(name) != (size_t)name_length) {
-        PyErr_SetString(PyExc_ValueError, "embedded null character");
-        goto exit;
-    }
-    n_arg = _PyLong_AsInt(args[1]);
-    if (n_arg == -1 && PyErr_Occurred()) {
-        goto exit;
-    }
-    aggregate_class = args[2];
-    return_value = pysqlite_connection_create_aggregate_impl(self, name, n_arg, aggregate_class);
+    return_value = pysqlite_connection_create_aggregate_impl(self, cls, name, n_arg, aggregate_class);
 
 exit:
     return return_value;
@@ -320,27 +275,26 @@ PyDoc_STRVAR(pysqlite_connection_set_authorizer__doc__,
 "Sets authorizer callback. Non-standard.");
 
 #define PYSQLITE_CONNECTION_SET_AUTHORIZER_METHODDEF    \
-    {"set_authorizer", (PyCFunction)(void(*)(void))pysqlite_connection_set_authorizer, METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_set_authorizer__doc__},
+    {"set_authorizer", (PyCFunction)(void(*)(void))pysqlite_connection_set_authorizer, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_set_authorizer__doc__},
 
 static PyObject *
 pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self,
+                                        PyTypeObject *cls,
                                         PyObject *callable);
 
 static PyObject *
-pysqlite_connection_set_authorizer(pysqlite_Connection *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+pysqlite_connection_set_authorizer(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
     static const char * const _keywords[] = {"authorizer_callback", NULL};
-    static _PyArg_Parser _parser = {NULL, _keywords, "set_authorizer", 0};
-    PyObject *argsbuf[1];
+    static _PyArg_Parser _parser = {"O:set_authorizer", _keywords, 0};
     PyObject *callable;
 
-    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf);
-    if (!args) {
+    if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+        &callable)) {
         goto exit;
     }
-    callable = args[0];
-    return_value = pysqlite_connection_set_authorizer_impl(self, callable);
+    return_value = pysqlite_connection_set_authorizer_impl(self, cls, callable);
 
 exit:
     return return_value;
@@ -353,32 +307,27 @@ PyDoc_STRVAR(pysqlite_connection_set_progress_handler__doc__,
 "Sets progress handler callback. Non-standard.");
 
 #define PYSQLITE_CONNECTION_SET_PROGRESS_HANDLER_METHODDEF    \
-    {"set_progress_handler", (PyCFunction)(void(*)(void))pysqlite_connection_set_progress_handler, METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_set_progress_handler__doc__},
+    {"set_progress_handler", (PyCFunction)(void(*)(void))pysqlite_connection_set_progress_handler, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_set_progress_handler__doc__},
 
 static PyObject *
 pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self,
+                                              PyTypeObject *cls,
                                               PyObject *callable, int n);
 
 static PyObject *
-pysqlite_connection_set_progress_handler(pysqlite_Connection *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+pysqlite_connection_set_progress_handler(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
     static const char * const _keywords[] = {"progress_handler", "n", NULL};
-    static _PyArg_Parser _parser = {NULL, _keywords, "set_progress_handler", 0};
-    PyObject *argsbuf[2];
+    static _PyArg_Parser _parser = {"Oi:set_progress_handler", _keywords, 0};
     PyObject *callable;
     int n;
 
-    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 2, 0, argsbuf);
-    if (!args) {
-        goto exit;
-    }
-    callable = args[0];
-    n = _PyLong_AsInt(args[1]);
-    if (n == -1 && PyErr_Occurred()) {
+    if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+        &callable, &n)) {
         goto exit;
     }
-    return_value = pysqlite_connection_set_progress_handler_impl(self, callable, n);
+    return_value = pysqlite_connection_set_progress_handler_impl(self, cls, callable, n);
 
 exit:
     return return_value;
@@ -393,27 +342,26 @@ PyDoc_STRVAR(pysqlite_connection_set_trace_callback__doc__,
 "Non-standard.");
 
 #define PYSQLITE_CONNECTION_SET_TRACE_CALLBACK_METHODDEF    \
-    {"set_trace_callback", (PyCFunction)(void(*)(void))pysqlite_connection_set_trace_callback, METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_set_trace_callback__doc__},
+    {"set_trace_callback", (PyCFunction)(void(*)(void))pysqlite_connection_set_trace_callback, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_set_trace_callback__doc__},
 
 static PyObject *
 pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self,
+                                            PyTypeObject *cls,
                                             PyObject *callable);
 
 static PyObject *
-pysqlite_connection_set_trace_callback(pysqlite_Connection *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+pysqlite_connection_set_trace_callback(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
     static const char * const _keywords[] = {"trace_callback", NULL};
-    static _PyArg_Parser _parser = {NULL, _keywords, "set_trace_callback", 0};
-    PyObject *argsbuf[1];
+    static _PyArg_Parser _parser = {"O:set_trace_callback", _keywords, 0};
     PyObject *callable;
 
-    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf);
-    if (!args) {
+    if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+        &callable)) {
         goto exit;
     }
-    callable = args[0];
-    return_value = pysqlite_connection_set_trace_callback_impl(self, callable);
+    return_value = pysqlite_connection_set_trace_callback_impl(self, cls, callable);
 
 exit:
     return return_value;
@@ -720,38 +668,28 @@ PyDoc_STRVAR(pysqlite_connection_create_collation__doc__,
 "Creates a collation function. Non-standard.");
 
 #define PYSQLITE_CONNECTION_CREATE_COLLATION_METHODDEF    \
-    {"create_collation", (PyCFunction)(void(*)(void))pysqlite_connection_create_collation, METH_FASTCALL, pysqlite_connection_create_collation__doc__},
+    {"create_collation", (PyCFunction)(void(*)(void))pysqlite_connection_create_collation, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, pysqlite_connection_create_collation__doc__},
 
 static PyObject *
 pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
+                                          PyTypeObject *cls,
                                           const char *name,
                                           PyObject *callable);
 
 static PyObject *
-pysqlite_connection_create_collation(pysqlite_Connection *self, PyObject *const *args, Py_ssize_t nargs)
+pysqlite_connection_create_collation(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
+    static const char * const _keywords[] = {"", "", NULL};
+    static _PyArg_Parser _parser = {"sO:create_collation", _keywords, 0};
     const char *name;
     PyObject *callable;
 
-    if (!_PyArg_CheckPositional("create_collation", nargs, 2, 2)) {
-        goto exit;
-    }
-    if (!PyUnicode_Check(args[0])) {
-        _PyArg_BadArgument("create_collation", "argument 1", "str", args[0]);
-        goto exit;
-    }
-    Py_ssize_t name_length;
-    name = PyUnicode_AsUTF8AndSize(args[0], &name_length);
-    if (name == NULL) {
-        goto exit;
-    }
-    if (strlen(name) != (size_t)name_length) {
-        PyErr_SetString(PyExc_ValueError, "embedded null character");
+    if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+        &name, &callable)) {
         goto exit;
     }
-    callable = args[1];
-    return_value = pysqlite_connection_create_collation_impl(self, name, callable);
+    return_value = pysqlite_connection_create_collation_impl(self, cls, name, callable);
 
 exit:
     return return_value;
@@ -819,4 +757,4 @@ pysqlite_connection_exit(pysqlite_Connection *self, PyObject *const *args, Py_ss
 #ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
     #define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
 #endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
-/*[clinic end generated code: output=5b7268875f33c016 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=7567e5d716309258 input=a9049054013a1b77]*/
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index d6d1fa8bf2876..e94c4cbb4e8c3 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -256,6 +256,7 @@ pysqlite_do_all_statements(pysqlite_Connection *self)
 do {                                \
     if (ctx) {                      \
         Py_VISIT(ctx->callable);    \
+        Py_VISIT(ctx->module);      \
     }                               \
 } while (0)
 
@@ -280,6 +281,7 @@ clear_callback_context(callback_context *ctx)
 {
     if (ctx != NULL) {
         Py_CLEAR(ctx->callable);
+        Py_CLEAR(ctx->module);
     }
 }
 
@@ -822,13 +824,21 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
     Py_SETREF(self->cursors, new_list);
 }
 
+/* Allocate a UDF/callback context structure. In order to ensure that the state
+ * pointer always outlives the callback context, we make sure it owns a
+ * reference to the module itself. create_callback_context() is always called
+ * from connection methods, so we use the defining class to fetch the module
+ * pointer.
+ */
 static callback_context *
-create_callback_context(pysqlite_state *state, PyObject *callable)
+create_callback_context(PyTypeObject *cls, PyObject *callable)
 {
     callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
     if (ctx != NULL) {
+        PyObject *module = PyType_GetModule(cls);
         ctx->callable = Py_NewRef(callable);
-        ctx->state = state;
+        ctx->module = Py_NewRef(module);
+        ctx->state = pysqlite_get_state(module);
     }
     return ctx;
 }
@@ -838,6 +848,7 @@ free_callback_context(callback_context *ctx)
 {
     assert(ctx != NULL);
     Py_XDECREF(ctx->callable);
+    Py_XDECREF(ctx->module);
     PyMem_Free(ctx);
 }
 
@@ -867,6 +878,8 @@ destructor_callback(void *ctx)
 /*[clinic input]
 _sqlite3.Connection.create_function as pysqlite_connection_create_function
 
+    cls: defining_class
+    /
     name: str
     narg: int
     func: object
@@ -878,9 +891,10 @@ Creates a new function. Non-standard.
 
 static PyObject *
 pysqlite_connection_create_function_impl(pysqlite_Connection *self,
-                                         const char *name, int narg,
-                                         PyObject *func, int deterministic)
-/*[clinic end generated code: output=07d1877dd98c0308 input=f2edcf073e815beb]*/
+                                         PyTypeObject *cls, const char *name,
+                                         int narg, PyObject *func,
+                                         int deterministic)
+/*[clinic end generated code: output=8a811529287ad240 input=f0f99754bfeafd8d]*/
 {
     int rc;
     int flags = SQLITE_UTF8;
@@ -903,7 +917,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
         flags |= SQLITE_DETERMINISTIC;
 #endif
     }
-    callback_context *ctx = create_callback_context(self->state, func);
+    callback_context *ctx = create_callback_context(cls, func);
     if (ctx == NULL) {
         return NULL;
     }
@@ -924,6 +938,8 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
 /*[clinic input]
 _sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate
 
+    cls: defining_class
+    /
     name: str
     n_arg: int
     aggregate_class: object
@@ -933,9 +949,10 @@ Creates a new aggregate. Non-standard.
 
 static PyObject *
 pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
+                                          PyTypeObject *cls,
                                           const char *name, int n_arg,
                                           PyObject *aggregate_class)
-/*[clinic end generated code: output=fbb2f858cfa4d8db input=c2e13bbf234500a5]*/
+/*[clinic end generated code: output=1b02d0f0aec7ff96 input=bd527067e6c2e33f]*/
 {
     int rc;
 
@@ -943,8 +960,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
         return NULL;
     }
 
-    callback_context *ctx = create_callback_context(self->state,
-                                                    aggregate_class);
+    callback_context *ctx = create_callback_context(cls, aggregate_class);
     if (ctx == NULL) {
         return NULL;
     }
@@ -1075,6 +1091,8 @@ trace_callback(void *ctx, const char *statement_string)
 /*[clinic input]
 _sqlite3.Connection.set_authorizer as pysqlite_connection_set_authorizer
 
+    cls: defining_class
+    /
     authorizer_callback as callable: object
 
 Sets authorizer callback. Non-standard.
@@ -1082,8 +1100,9 @@ Sets authorizer callback. Non-standard.
 
 static PyObject *
 pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self,
+                                        PyTypeObject *cls,
                                         PyObject *callable)
-/*[clinic end generated code: output=c193601e9e8a5116 input=ec104f130b82050b]*/
+/*[clinic end generated code: output=75fa60114fc971c3 input=9f3e90d3d642c4a0]*/
 {
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
@@ -1095,7 +1114,7 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self,
         set_callback_context(&self->authorizer_ctx, NULL);
     }
     else {
-        callback_context *ctx = create_callback_context(self->state, callable);
+        callback_context *ctx = create_callback_context(cls, callable);
         if (ctx == NULL) {
             return NULL;
         }
@@ -1114,6 +1133,8 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self,
 /*[clinic input]
 _sqlite3.Connection.set_progress_handler as pysqlite_connection_set_progress_handler
 
+    cls: defining_class
+    /
     progress_handler as callable: object
     n: int
 
@@ -1122,8 +1143,9 @@ Sets progress handler callback. Non-standard.
 
 static PyObject *
 pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self,
+                                              PyTypeObject *cls,
                                               PyObject *callable, int n)
-/*[clinic end generated code: output=ba14008a483d7a53 input=3cf56d045f130a84]*/
+/*[clinic end generated code: output=0739957fd8034a50 input=83e8dcbb4ce183f7]*/
 {
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
@@ -1135,7 +1157,7 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self,
         set_callback_context(&self->progress_ctx, NULL);
     }
     else {
-        callback_context *ctx = create_callback_context(self->state, callable);
+        callback_context *ctx = create_callback_context(cls, callable);
         if (ctx == NULL) {
             return NULL;
         }
@@ -1148,6 +1170,8 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self,
 /*[clinic input]
 _sqlite3.Connection.set_trace_callback as pysqlite_connection_set_trace_callback
 
+    cls: defining_class
+    /
     trace_callback as callable: object
 
 Sets a trace callback called for each SQL statement (passed as unicode).
@@ -1157,8 +1181,9 @@ Non-standard.
 
 static PyObject *
 pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self,
+                                            PyTypeObject *cls,
                                             PyObject *callable)
-/*[clinic end generated code: output=c9fd551e359165d3 input=d76eabbb633057bc]*/
+/*[clinic end generated code: output=d91048c03bfcee05 input=96f03acec3ec8044]*/
 {
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
@@ -1180,7 +1205,7 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self,
         set_callback_context(&self->trace_ctx, NULL);
     }
     else {
-        callback_context *ctx = create_callback_context(self->state, callable);
+        callback_context *ctx = create_callback_context(cls, callable);
         if (ctx == NULL) {
             return NULL;
         }
@@ -1742,6 +1767,7 @@ pysqlite_connection_backup_impl(pysqlite_Connection *self,
 /*[clinic input]
 _sqlite3.Connection.create_collation as pysqlite_connection_create_collation
 
+    cls: defining_class
     name: str
     callback as callable: object
     /
@@ -1751,9 +1777,10 @@ Creates a collation function. Non-standard.
 
 static PyObject *
 pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
+                                          PyTypeObject *cls,
                                           const char *name,
                                           PyObject *callable)
-/*[clinic end generated code: output=a4ceaff957fdef9a input=301647aab0f2fb1d]*/
+/*[clinic end generated code: output=32d339e97869c378 input=fee2c8e5708602ad]*/
 {
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
@@ -1771,7 +1798,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
             PyErr_SetString(PyExc_TypeError, "parameter must be callable");
             return NULL;
         }
-        ctx = create_callback_context(self->state, callable);
+        ctx = create_callback_context(cls, callable);
         if (ctx == NULL) {
             return NULL;
         }
diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h
index c4cec857ddbfe..6a2aa1c8e080e 100644
--- a/Modules/_sqlite/connection.h
+++ b/Modules/_sqlite/connection.h
@@ -35,6 +35,7 @@
 typedef struct _callback_context
 {
     PyObject *callable;
+    PyObject *module;
     pysqlite_state *state;
 } callback_context;
 



More information about the Python-checkins mailing list