[pypy-commit] pypy default: test and fix for behavior of _sqlite3.Connection._check_{thread, closed}

bdkearns noreply at buildbot.pypy.org
Wed Mar 6 01:42:22 CET 2013


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r62101:df9670b5d959
Date: 2013-03-05 18:49 -0500
http://bitbucket.org/pypy/pypy/changeset/df9670b5d959/

Log:	test and fix for behavior of
	_sqlite3.Connection._check_{thread,closed}

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -27,6 +27,7 @@
 from ctypes import POINTER, byref, string_at, CFUNCTYPE, cast
 from ctypes import sizeof, c_ssize_t
 from collections import OrderedDict
+from functools import wraps
 import datetime
 import sys
 import weakref
@@ -392,6 +393,20 @@
                 "The object was created in thread id %d and this is thread id %d",
                 self.thread_ident, thread_get_ident())
 
+    def _check_thread_wrap(func):
+        @wraps(func)
+        def _check_thread_func(self, *args, **kwargs):
+            self._check_thread()
+            return func(self, *args, **kwargs)
+        return _check_thread_func
+
+    def _check_closed_wrap(func):
+        @wraps(func)
+        def _check_closed_func(self, *args, **kwargs):
+            self._check_closed()
+            return func(self, *args, **kwargs)
+        return _check_closed_func
+
     def _reset_cursors(self):
         for cursor_ref in self.cursors:
             cursor = cursor_ref()
@@ -429,8 +444,8 @@
             cur.row_factory = self.row_factory
         return cur.executescript(*args)
 
+    @_check_closed_wrap
     def __call__(self, sql):
-        self._check_closed()
         if not isinstance(sql, (str, unicode)):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
         statement = self.statement_cache.get(sql, self.row_factory)
@@ -548,9 +563,9 @@
                 raise self._get_exception(ret)
             self.db.value = 0
 
+    @_check_thread_wrap
+    @_check_closed_wrap
     def create_collation(self, name, callback):
-        self._check_thread()
-        self._check_closed()
         name = name.upper()
         if not name.replace('_', '').isalnum():
             raise ProgrammingError("invalid character in collation name")
@@ -578,9 +593,9 @@
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
+    @_check_thread_wrap
+    @_check_closed_wrap
     def set_progress_handler(self, callable, nsteps):
-        self._check_thread()
-        self._check_closed()
         if callable is None:
             c_progress_handler = cast(None, PROGRESS)
         else:
@@ -603,10 +618,9 @@
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
+    @_check_thread_wrap
+    @_check_closed_wrap
     def set_authorizer(self, callback):
-        self._check_thread()
-        self._check_closed()
-
         try:
             c_authorizer, _ = self.func_cache[callback]
         except KeyError:
@@ -625,9 +639,9 @@
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
+    @_check_thread_wrap
+    @_check_closed_wrap
     def create_function(self, name, num_args, callback):
-        self._check_thread()
-        self._check_closed()
         try:
             c_closure, _ = self.func_cache[callback]
         except KeyError:
@@ -643,10 +657,9 @@
         if ret != SQLITE_OK:
             raise self.OperationalError("Error creating function")
 
+    @_check_thread_wrap
+    @_check_closed_wrap
     def create_aggregate(self, name, num_args, cls):
-        self._check_thread()
-        self._check_closed()
-
         try:
             c_step_callback, c_final_callback, _, _ = self._aggregates[cls]
         except KeyError:
@@ -718,10 +731,9 @@
         return _iterdump(self)
 
     if HAS_LOAD_EXTENSION:
+        @_check_thread_wrap
+        @_check_closed_wrap
         def enable_load_extension(self, enabled):
-            self._check_thread()
-            self._check_closed()
-
             rc = sqlite.sqlite3_enable_load_extension(self.db, int(enabled))
             if rc != SQLITE_OK:
                 raise OperationalError("Error enabling load extension")
diff --git a/pypy/module/test_lib_pypy/test_sqlite3.py b/pypy/module/test_lib_pypy/test_sqlite3.py
--- a/pypy/module/test_lib_pypy/test_sqlite3.py
+++ b/pypy/module/test_lib_pypy/test_sqlite3.py
@@ -45,6 +45,13 @@
     e = pytest.raises(_sqlite3.ProgrammingError, "cur.execute('select 1')")
     assert '__init__' in e.value.message
 
+def test_connection_after_close():
+    con = _sqlite3.connect(':memory:')
+    pytest.raises(TypeError, "con()")
+    con.close()
+    # raises ProgrammingError because should check closed before check args
+    pytest.raises(_sqlite3.ProgrammingError, "con()")
+
 def test_cursor_after_close():
      con = _sqlite3.connect(':memory:')
      cur = con.execute('select 1')


More information about the pypy-commit mailing list