[Python-checkins] bpo-42064: Pass module state to `sqlite3` UDF callbacks (GH-27456)

miss-islington webhook-mailer at python.org
Tue Aug 24 08:24:17 EDT 2021


https://github.com/python/cpython/commit/9ed523159c7ba840dbf403e02498eeae1b5d3ed9
commit: 9ed523159c7ba840dbf403e02498eeae1b5d3ed9
branch: main
author: Erlend Egeberg Aasland <erlend.aasland at innova.no>
committer: miss-islington <31488909+miss-islington at users.noreply.github.com>
date: 2021-08-24T05:24:09-07:00
summary:

bpo-42064: Pass module state to `sqlite3` UDF callbacks (GH-27456)



- Establish common callback context struct
- Convert UDF callbacks to fetch module state from callback context

files:
M Modules/_sqlite/connection.c
M Modules/_sqlite/connection.h

diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 0645367988e63..8ad5f5f061da5 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -612,8 +612,10 @@ set_sqlite_error(sqlite3_context *context, const char *msg)
     else {
         sqlite3_result_error(context, msg, -1);
     }
-    pysqlite_state *state = pysqlite_get_state(NULL);
-    if (state->enable_callback_tracebacks) {
+    callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+    assert(ctx != NULL);
+    assert(ctx->state != NULL);
+    if (ctx->state->enable_callback_tracebacks) {
         PyErr_Print();
     }
     else {
@@ -625,7 +627,6 @@ static void
 _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
 {
     PyObject* args;
-    PyObject* py_func;
     PyObject* py_retval = NULL;
     int ok;
 
@@ -633,11 +634,11 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
 
     threadstate = PyGILState_Ensure();
 
-    py_func = (PyObject*)sqlite3_user_data(context);
-
     args = _pysqlite_build_py_params(context, argc, argv);
     if (args) {
-        py_retval = PyObject_CallObject(py_func, args);
+        callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+        assert(ctx != NULL);
+        py_retval = PyObject_CallObject(ctx->callable, args);
         Py_DECREF(args);
     }
 
@@ -657,7 +658,6 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
 {
     PyObject* args;
     PyObject* function_result = NULL;
-    PyObject* aggregate_class;
     PyObject** aggregate_instance;
     PyObject* stepmethod = NULL;
 
@@ -665,12 +665,12 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
 
     threadstate = PyGILState_Ensure();
 
-    aggregate_class = (PyObject*)sqlite3_user_data(context);
-
     aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
 
     if (*aggregate_instance == NULL) {
-        *aggregate_instance = _PyObject_CallNoArg(aggregate_class);
+        callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+        assert(ctx != NULL);
+        *aggregate_instance = _PyObject_CallNoArg(ctx->callable);
         if (!*aggregate_instance) {
             set_sqlite_error(context,
                     "user-defined aggregate's '__init__' method raised error");
@@ -784,14 +784,35 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
     Py_SETREF(self->cursors, new_list);
 }
 
-static void _destructor(void* args)
+static callback_context *
+create_callback_context(pysqlite_state *state, PyObject *callable)
 {
-    // This function may be called without the GIL held, so we need to ensure
-    // that we destroy 'args' with the GIL
-    PyGILState_STATE gstate;
-    gstate = PyGILState_Ensure();
-    Py_DECREF((PyObject*)args);
+    PyGILState_STATE gstate = PyGILState_Ensure();
+    callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
+    if (ctx != NULL) {
+        ctx->callable = Py_NewRef(callable);
+        ctx->state = state;
+    }
     PyGILState_Release(gstate);
+    return ctx;
+}
+
+static void
+free_callback_context(callback_context *ctx)
+{
+    if (ctx != NULL) {
+        // This function may be called without the GIL held, so we need to
+        // ensure that we destroy 'ctx' with the GIL held.
+        PyGILState_STATE gstate = PyGILState_Ensure();
+        Py_DECREF(ctx->callable);
+        PyMem_Free(ctx);
+        PyGILState_Release(gstate);
+    }
+}
+
+static void _destructor(void* args)
+{
+    free_callback_context((callback_context *)args);
 }
 
 /*[clinic input]
@@ -833,11 +854,11 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
         flags |= SQLITE_DETERMINISTIC;
 #endif
     }
-    rc = sqlite3_create_function_v2(self->db,
-                                    name,
-                                    narg,
-                                    flags,
-                                    (void*)Py_NewRef(func),
+    callback_context *ctx = create_callback_context(self->state, func);
+    if (ctx == NULL) {
+        return NULL;
+    }
+    rc = sqlite3_create_function_v2(self->db, name, narg, flags, ctx,
                                     _pysqlite_func_callback,
                                     NULL,
                                     NULL,
@@ -873,11 +894,12 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
         return NULL;
     }
 
-    rc = sqlite3_create_function_v2(self->db,
-                                    name,
-                                    n_arg,
-                                    SQLITE_UTF8,
-                                    (void*)Py_NewRef(aggregate_class),
+    callback_context *ctx = create_callback_context(self->state,
+                                                    aggregate_class);
+    if (ctx == NULL) {
+        return NULL;
+    }
+    rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx,
                                     0,
                                     &_pysqlite_step_callback,
                                     &_pysqlite_final_callback,
@@ -1439,7 +1461,6 @@ pysqlite_collation_callback(
         int text1_length, const void* text1_data,
         int text2_length, const void* text2_data)
 {
-    PyObject* callback = (PyObject*)context;
     PyObject* string1 = 0;
     PyObject* string2 = 0;
     PyGILState_STATE gilstate;
@@ -1459,8 +1480,10 @@ pysqlite_collation_callback(
         goto finally; /* failed to allocate strings */
     }
 
+    callback_context *ctx = (callback_context *)context;
+    assert(ctx != NULL);
     PyObject *args[] = { string1, string2 };  // Borrowed refs.
-    retval = PyObject_Vectorcall(callback, args, 2, NULL);
+    retval = PyObject_Vectorcall(ctx->callable, args, 2, NULL);
     if (retval == NULL) {
         /* execution failed */
         goto finally;
@@ -1690,6 +1713,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
         return NULL;
     }
 
+    callback_context *ctx = NULL;
     int rc;
     int flags = SQLITE_UTF8;
     if (callable == Py_None) {
@@ -1701,8 +1725,11 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
             PyErr_SetString(PyExc_TypeError, "parameter must be callable");
             return NULL;
         }
-        rc = sqlite3_create_collation_v2(self->db, name, flags,
-                                         Py_NewRef(callable),
+        ctx = create_callback_context(self->state, callable);
+        if (ctx == NULL) {
+            return NULL;
+        }
+        rc = sqlite3_create_collation_v2(self->db, name, flags, ctx,
                                          &pysqlite_collation_callback,
                                          &_destructor);
     }
@@ -1713,7 +1740,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
          * the context before returning.
          */
         if (callable != Py_None) {
-            Py_DECREF(callable);
+            free_callback_context(ctx);
         }
         _pysqlite_seterror(self->state, self->db);
         return NULL;
diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h
index 4f08a6d5f7b21..11b3a8005e1f9 100644
--- a/Modules/_sqlite/connection.h
+++ b/Modules/_sqlite/connection.h
@@ -32,6 +32,12 @@
 
 #include "sqlite3.h"
 
+typedef struct _callback_context
+{
+    PyObject *callable;
+    pysqlite_state *state;
+} callback_context;
+
 typedef struct
 {
     PyObject_HEAD



More information about the Python-checkins mailing list