[pypy-commit] pypy py3.6-sqlite: Follow CPython's behaviour more closely in sqlite3 and fix extra_tests to pass on CPython 3.6

rlamy pypy.commits at gmail.com
Fri Dec 27 15:04:38 EST 2019


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: py3.6-sqlite
Changeset: r98402:dc2537b84e7c
Date: 2019-12-27 21:03 +0100
http://bitbucket.org/pypy/pypy/changeset/dc2537b84e7c/

Log:	Follow CPython's behaviour more closely in sqlite3 and fix
	extra_tests to pass on CPython 3.6

diff --git a/extra_tests/test_sqlite3.py b/extra_tests/test_sqlite3.py
--- a/extra_tests/test_sqlite3.py
+++ b/extra_tests/test_sqlite3.py
@@ -88,8 +88,7 @@
     with pytest.raises(StopIteration):
         next(cur)
 
-    with pytest.raises(_sqlite3.ProgrammingError):
-        cur.executemany('select 1', [])
+    cur.executemany('select 1', [])
     with pytest.raises(StopIteration):
         next(cur)
 
@@ -201,8 +200,8 @@
 
 def test_explicit_begin(con):
     con.execute('BEGIN')
-    con.execute('BEGIN ')
-    con.execute('BEGIN')
+    with pytest.raises(_sqlite3.OperationalError):
+        con.execute('BEGIN ')
     con.commit()
     con.execute('BEGIN')
     con.commit()
@@ -228,15 +227,15 @@
     cur = con.cursor()
     cur.execute("create table test(a)")
     cur.executemany("insert into test values (?)", [[1], [2], [3]])
-    assert cur.lastrowid is None
+    assert cur.lastrowid == 0
     # issue 2682
     cur.execute('''insert
                 into test
                 values (?)
                 ''', (1, ))
-    assert cur.lastrowid is not None
+    assert cur.lastrowid
     cur.execute('''insert\t into test values (?) ''', (1, ))
-    assert cur.lastrowid is not None
+    assert cur.lastrowid
 
 def test_authorizer_bad_value(con):
     def authorizer_cb(action, arg1, arg2, dbname, source):
diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -233,7 +233,6 @@
         self.text_factory = _unicode_text_factory
 
         self._detect_types = detect_types
-        self._in_transaction = False
         self.isolation_level = isolation_level
 
         self.__cursors = []
@@ -433,7 +432,7 @@
 
     def _begin(self):
         statement_star = _ffi.new('sqlite3_stmt **')
-        ret = _lib.sqlite3_prepare_v2(self._db, self.__begin_statement, -1,
+        ret = _lib.sqlite3_prepare_v2(self._db, self._begin_statement, -1,
                                       statement_star, _ffi.NULL)
         try:
             if ret != _lib.SQLITE_OK:
@@ -441,14 +440,13 @@
             ret = _lib.sqlite3_step(statement_star[0])
             if ret != _lib.SQLITE_DONE:
                 raise self._get_exception(ret)
-            self._in_transaction = True
         finally:
             _lib.sqlite3_finalize(statement_star[0])
 
     def commit(self):
         self._check_thread()
         self._check_closed()
-        if not self._in_transaction:
+        if not self.in_transaction:
             return
 
         # PyPy fix for non-refcounting semantics: since 2.7.13 (and in
@@ -473,14 +471,13 @@
             ret = _lib.sqlite3_step(statement_star[0])
             if ret != _lib.SQLITE_DONE:
                 raise self._get_exception(ret)
-            self._in_transaction = False
         finally:
             _lib.sqlite3_finalize(statement_star[0])
 
     def rollback(self):
         self._check_thread()
         self._check_closed()
-        if not self._in_transaction:
+        if not self.in_transaction:
             return
 
         self.__do_all_statements(Statement._reset, True)
@@ -494,7 +491,6 @@
             ret = _lib.sqlite3_step(statement_star[0])
             if ret != _lib.SQLITE_DONE:
                 raise self._get_exception(ret)
-            self._in_transaction = False
         finally:
             _lib.sqlite3_finalize(statement_star[0])
 
@@ -687,10 +683,10 @@
                 self.__func_cache[callable] = trace_callback
         _lib.sqlite3_trace(self._db, trace_callback, _ffi.NULL)
 
-    if sys.version_info[0] >= 3:
-        def __get_in_transaction(self):
-            return self._in_transaction
-        in_transaction = property(__get_in_transaction)
+    @property
+    @_check_closed_wrap
+    def in_transaction(self):
+        return not _lib.sqlite3_get_autocommit(self._db)
 
     def __get_total_changes(self):
         self._check_closed()
@@ -710,7 +706,7 @@
             stmt = str("BEGIN " + val).upper()
             if stmt not in BEGIN_STATMENTS:
                 raise ValueError("invalid value for isolation_level")
-            self.__begin_statement = stmt.encode('utf-8')
+            self._begin_statement = stmt.encode('utf-8')
         self._isolation_level = val
     isolation_level = property(__get_isolation_level, __set_isolation_level)
 
@@ -878,22 +874,9 @@
             self.__rowcount = -1
             self.__statement = self.__connection._statement_cache.get(sql)
 
-            if self.__connection._isolation_level is not None:
-                if self.__statement._type in (
-                    _STMT_TYPE_UPDATE,
-                    _STMT_TYPE_DELETE,
-                    _STMT_TYPE_INSERT,
-                    _STMT_TYPE_REPLACE
-                ):
-                    if not self.__connection._in_transaction:
-                        self.__connection._begin()
-                elif self.__statement._type == _STMT_TYPE_OTHER:
-                    if self.__connection._in_transaction:
-                        self.__connection.commit()
-                elif self.__statement._type == _STMT_TYPE_SELECT:
-                    if multiple:
-                        raise ProgrammingError("You cannot execute SELECT "
-                                               "statements in executemany().")
+            if self.__connection._begin_statement and self.__statement._is_dml:
+                if _lib.sqlite3_get_autocommit(self.__connection._db):
+                    self.__connection._begin()
 
             for params in many_params:
                 self.__statement._set_params(params)
@@ -911,6 +894,16 @@
                     self.__connection._reset_already_committed_statements()
                     ret = _lib.sqlite3_step(self.__statement._statement)
 
+                if self.__statement._is_dml:
+                    if self.__rowcount == -1:
+                        self.__rowcount = 0
+                    self.__rowcount += _lib.sqlite3_changes(self.__connection._db)
+                else:
+                    self.__rowcount = -1
+
+                if not multiple:
+                    self.__lastrowid = _lib.sqlite3_last_insert_rowid(self.__connection._db)
+
                 if ret == _lib.SQLITE_ROW:
                     if multiple:
                         raise ProgrammingError("executemany() can only execute DML statements.")
@@ -923,28 +916,9 @@
                     self.__statement._reset()
                     raise self.__connection._get_exception(ret)
 
-                if self.__statement._type in (
-                    _STMT_TYPE_UPDATE,
-                    _STMT_TYPE_DELETE,
-                    _STMT_TYPE_INSERT,
-                    _STMT_TYPE_REPLACE
-                ):
-                    if self.__rowcount == -1:
-                        self.__rowcount = 0
-                    self.__rowcount += _lib.sqlite3_changes(self.__connection._db)
-
-                if not multiple and self.__statement._type in (
-                        # REPLACE is an alias for INSERT OR REPLACE
-                        _STMT_TYPE_INSERT, _STMT_TYPE_REPLACE):
-                    self.__lastrowid = _lib.sqlite3_last_insert_rowid(self.__connection._db)
-                else:
-                    self.__lastrowid = None
-
                 if multiple:
                     self.__statement._reset()
         finally:
-            self.__connection._in_transaction = \
-                not _lib.sqlite3_get_autocommit(self.__connection._db)
             self.__locked = False
         return self
 
@@ -1086,38 +1060,19 @@
         if '\0' in sql:
             raise ValueError("the query contains a null character")
 
-        
-        if sql:
-            first_word = sql.lstrip().split()[0].upper()
-            if first_word == '':
-                self._type = _STMT_TYPE_INVALID
-            if first_word == "SELECT":
-                self._type = _STMT_TYPE_SELECT
-            elif first_word == "INSERT":
-                self._type = _STMT_TYPE_INSERT
-            elif first_word == "UPDATE":
-                self._type = _STMT_TYPE_UPDATE
-            elif first_word == "DELETE":
-                self._type = _STMT_TYPE_DELETE
-            elif first_word == "REPLACE":
-                self._type = _STMT_TYPE_REPLACE
-            else:
-                self._type = _STMT_TYPE_OTHER
-        else:
-            self._type = _STMT_TYPE_INVALID
+        to_check = sql.lstrip().upper()
+        self._valid = bool(to_check)
+        self._is_dml = to_check.startswith(('INSERT', 'UPDATE', 'DELETE', 'REPLACE'))
 
-        if isinstance(sql, unicode):
-            sql = sql.encode('utf-8')
         statement_star = _ffi.new('sqlite3_stmt **')
         next_char = _ffi.new('char **')
-        c_sql = _ffi.new("char[]", sql)
+        c_sql = _ffi.new("char[]", sql.encode('utf-8'))
         ret = _lib.sqlite3_prepare_v2(self.__con._db, c_sql, -1,
                                       statement_star, next_char)
         self._statement = statement_star[0]
 
         if ret == _lib.SQLITE_OK and not self._statement:
             # an empty statement, work around that, as it's the least trouble
-            self._type = _STMT_TYPE_SELECT
             c_sql = _ffi.new("char[]", b"select 42")
             ret = _lib.sqlite3_prepare_v2(self.__con._db, c_sql, -1,
                                           statement_star, next_char)
@@ -1238,13 +1193,6 @@
             raise ValueError("parameters are of unsupported type")
 
     def _get_description(self):
-        if self._type in (
-            _STMT_TYPE_INSERT,
-            _STMT_TYPE_UPDATE,
-            _STMT_TYPE_DELETE,
-            _STMT_TYPE_REPLACE
-        ):
-            return None
         desc = []
         for i in xrange(_lib.sqlite3_column_count(self._statement)):
             name = _lib.sqlite3_column_name(self._statement, i)


More information about the pypy-commit mailing list