[Python-checkins] bpo-45126: Harden `sqlite3` connection initialisation (GH-28227)

encukou webhook-mailer at python.org
Tue Nov 16 09:53:44 EST 2021


https://github.com/python/cpython/commit/9d6215a54c177a5e359c37ecd1c50b594b194f41
commit: 9d6215a54c177a5e359c37ecd1c50b594b194f41
branch: main
author: Erlend Egeberg Aasland <erlend.aasland at innova.no>
committer: encukou <encukou at gmail.com>
date: 2021-11-16T15:53:35+01:00
summary:

bpo-45126: Harden `sqlite3` connection initialisation (GH-28227)

files:
M Lib/test/test_sqlite3/test_dbapi.py
M Modules/_sqlite/clinic/connection.c.h
M Modules/_sqlite/connection.c

diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index 802a6919371f9..18359e1a5e2ab 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -523,6 +523,44 @@ def test_connection_init_good_isolation_levels(self):
                 with memory_database(isolation_level=level) as cx:
                     cx.execute("select 'ok'")
 
+    def test_connection_reinit(self):
+        db = ":memory:"
+        cx = sqlite.connect(db)
+        cx.text_factory = bytes
+        cx.row_factory = sqlite.Row
+        cu = cx.cursor()
+        cu.execute("create table foo (bar)")
+        cu.executemany("insert into foo (bar) values (?)",
+                       ((str(v),) for v in range(4)))
+        cu.execute("select bar from foo")
+
+        rows = [r for r in cu.fetchmany(2)]
+        self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
+        self.assertEqual([r[0] for r in rows], [b"0", b"1"])
+
+        cx.__init__(db)
+        cx.execute("create table foo (bar)")
+        cx.executemany("insert into foo (bar) values (?)",
+                       ((v,) for v in ("a", "b", "c", "d")))
+
+        # This uses the old database, old row factory, but new text factory
+        rows = [r for r in cu.fetchall()]
+        self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
+        self.assertEqual([r[0] for r in rows], ["2", "3"])
+
+    def test_connection_bad_reinit(self):
+        cx = sqlite.connect(":memory:")
+        with cx:
+            cx.execute("create table t(t)")
+        with temp_dir() as db:
+            self.assertRaisesRegex(sqlite.OperationalError,
+                                   "unable to open database file",
+                                   cx.__init__, db)
+            self.assertRaisesRegex(sqlite.ProgrammingError,
+                                   "Base Connection.__init__ not called",
+                                   cx.executemany, "insert into t values(?)",
+                                   ((v,) for v in range(3)))
+
 
 class UninitialisedConnectionTests(unittest.TestCase):
     def setUp(self):
diff --git a/Modules/_sqlite/clinic/connection.c.h b/Modules/_sqlite/clinic/connection.c.h
index 5bfc589aeb149..3a3ae04e8a193 100644
--- a/Modules/_sqlite/clinic/connection.c.h
+++ b/Modules/_sqlite/clinic/connection.c.h
@@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
                               const char *database, double timeout,
                               int detect_types, const char *isolation_level,
                               int check_same_thread, PyObject *factory,
-                              int cached_statements, int uri);
+                              int cache_size, int uri);
 
 static int
 pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
@@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
     const char *isolation_level = "";
     int check_same_thread = 1;
     PyObject *factory = (PyObject*)clinic_state()->ConnectionType;
-    int cached_statements = 128;
+    int cache_size = 128;
     int uri = 0;
 
     fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf);
@@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
         }
     }
     if (fastargs[6]) {
-        cached_statements = _PyLong_AsInt(fastargs[6]);
-        if (cached_statements == -1 && PyErr_Occurred()) {
+        cache_size = _PyLong_AsInt(fastargs[6]);
+        if (cache_size == -1 && PyErr_Occurred()) {
             goto exit;
         }
         if (!--noptargs) {
@@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
         goto exit;
     }
 skip_optional_pos:
-    return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri);
+    return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri);
 
 exit:
     /* Cleanup for database */
@@ -851,4 +851,4 @@ getlimit(pysqlite_Connection *self, PyObject *arg)
 #ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
     #define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
 #endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
-/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index b902dc845c618..e7947671db354 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
 static void free_callback_context(callback_context *ctx);
 static void set_callback_context(callback_context **ctx_pp,
                                  callback_context *ctx);
+static void connection_close(pysqlite_Connection *self);
 
 static PyObject *
-new_statement_cache(pysqlite_Connection *self, int maxsize)
+new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
+                    int maxsize)
 {
     PyObject *args[] = { NULL, PyLong_FromLong(maxsize), };
     if (args[1] == NULL) {
         return NULL;
     }
-    PyObject *lru_cache = self->state->lru_cache;
+    PyObject *lru_cache = state->lru_cache;
     size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
     PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL);
     Py_DECREF(args[1]);
@@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init
     isolation_level: str(accept={str, NoneType}) = ""
     check_same_thread: bool(accept={int}) = True
     factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType
-    cached_statements: int = 128
+    cached_statements as cache_size: int = 128
     uri: bool = False
 [clinic start generated code]*/
 
@@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
                               const char *database, double timeout,
                               int detect_types, const char *isolation_level,
                               int check_same_thread, PyObject *factory,
-                              int cached_statements, int uri)
-/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/
+                              int cache_size, int uri)
+/*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/
 {
-    int rc;
-
     if (PySys_Audit("sqlite3.connect", "s", database) < 0) {
         return -1;
     }
 
-    pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
-    self->state = state;
-
-    Py_CLEAR(self->statement_cache);
-    Py_CLEAR(self->cursors);
-
-    Py_INCREF(Py_None);
-    Py_XSETREF(self->row_factory, Py_None);
-
-    Py_INCREF(&PyUnicode_Type);
-    Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type);
+    if (self->initialized) {
+        PyTypeObject *tp = Py_TYPE(self);
+        tp->tp_clear((PyObject *)self);
+        connection_close(self);
+        self->initialized = 0;
+    }
 
+    // Create and configure SQLite database object.
+    sqlite3 *db;
+    int rc;
     Py_BEGIN_ALLOW_THREADS
-    rc = sqlite3_open_v2(database, &self->db,
+    rc = sqlite3_open_v2(database, &db,
                          SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE |
                          (uri ? SQLITE_OPEN_URI : 0), NULL);
+    if (rc == SQLITE_OK) {
+        (void)sqlite3_busy_timeout(db, (int)(timeout*1000));
+    }
     Py_END_ALLOW_THREADS
 
-    if (self->db == NULL && rc == SQLITE_NOMEM) {
+    if (db == NULL && rc == SQLITE_NOMEM) {
         PyErr_NoMemory();
         return -1;
     }
+
+    pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
     if (rc != SQLITE_OK) {
-        _pysqlite_seterror(state, self->db);
+        _pysqlite_seterror(state, db);
         return -1;
     }
 
-    if (isolation_level) {
-        const char *stmt = get_begin_statement(isolation_level);
-        if (stmt == NULL) {
+    // Convert isolation level to begin statement.
+    const char *begin_statement = NULL;
+    if (isolation_level != NULL) {
+        begin_statement = get_begin_statement(isolation_level);
+        if (begin_statement == NULL) {
             return -1;
         }
-        self->begin_statement = stmt;
-    }
-    else {
-        self->begin_statement = NULL;
     }
 
-    self->statement_cache = new_statement_cache(self, cached_statements);
-    if (self->statement_cache == NULL) {
-        return -1;
-    }
-    if (PyErr_Occurred()) {
+    // Create LRU statement cache; returns a new reference.
+    PyObject *statement_cache = new_statement_cache(self, state, cache_size);
+    if (statement_cache == NULL) {
         return -1;
     }
 
-    self->created_cursors = 0;
-
-    /* Create list of weak references to cursors */
-    self->cursors = PyList_New(0);
-    if (self->cursors == NULL) {
+    // Create list of weak references to cursors.
+    PyObject *cursors = PyList_New(0);
+    if (cursors == NULL) {
+        Py_DECREF(statement_cache);
         return -1;
     }
 
+    // Init connection state members.
+    self->db = db;
+    self->state = state;
     self->detect_types = detect_types;
-    (void)sqlite3_busy_timeout(self->db, (int)(timeout*1000));
-    self->thread_ident = PyThread_get_thread_ident();
+    self->begin_statement = begin_statement;
     self->check_same_thread = check_same_thread;
+    self->thread_ident = PyThread_get_thread_ident();
+    self->statement_cache = statement_cache;
+    self->cursors = cursors;
+    self->created_cursors = 0;
+    self->row_factory = Py_NewRef(Py_None);
+    self->text_factory = Py_NewRef(&PyUnicode_Type);
+    self->trace_ctx = NULL;
+    self->progress_ctx = NULL;
+    self->authorizer_ctx = NULL;
 
-    set_callback_context(&self->trace_ctx, NULL);
-    set_callback_context(&self->progress_ctx, NULL);
-    set_callback_context(&self->authorizer_ctx, NULL);
-
+    // Borrowed refs
     self->Warning               = state->Warning;
     self->Error                 = state->Error;
     self->InterfaceError        = state->InterfaceError;
@@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
     }
 
     self->initialized = 1;
-
     return 0;
 }
 
@@ -321,16 +326,6 @@ connection_clear(pysqlite_Connection *self)
     return 0;
 }
 
-static void
-connection_close(pysqlite_Connection *self)
-{
-    if (self->db) {
-        int rc = sqlite3_close_v2(self->db);
-        assert(rc == SQLITE_OK), (void)rc;
-        self->db = NULL;
-    }
-}
-
 static void
 free_callback_contexts(pysqlite_Connection *self)
 {
@@ -339,6 +334,22 @@ free_callback_contexts(pysqlite_Connection *self)
     set_callback_context(&self->authorizer_ctx, NULL);
 }
 
+static void
+connection_close(pysqlite_Connection *self)
+{
+    if (self->db) {
+        free_callback_contexts(self);
+
+        sqlite3 *db = self->db;
+        self->db = NULL;
+
+        Py_BEGIN_ALLOW_THREADS
+        int rc = sqlite3_close_v2(db);
+        assert(rc == SQLITE_OK), (void)rc;
+        Py_END_ALLOW_THREADS
+    }
+}
+
 static void
 connection_dealloc(pysqlite_Connection *self)
 {
@@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self)
 
     /* Clean up if user has not called .close() explicitly. */
     connection_close(self);
-    free_callback_contexts(self);
 
     tp->tp_free(self);
     Py_DECREF(tp);



More information about the Python-checkins mailing list