[Python-checkins] r61141 - in python/trunk: Lib/sqlite3/test/dbapi.py Lib/sqlite3/test/hooks.py Lib/sqlite3/test/py25tests.py Lib/sqlite3/test/regression.py Lib/sqlite3/test/transactions.py Lib/sqlite3/test/types.py Lib/test/test_sqlite.py Modules/_sqlite/connection.c Modules/_sqlite/connection.h Modules/_sqlite/cursor.c Modules/_sqlite/cursor.h Modules/_sqlite/microprotocols.h Modules/_sqlite/module.c Modules/_sqlite/module.h Modules/_sqlite/statement.c Modules/_sqlite/util.c Modules/_sqlite/util.h

gerhard.haering python-checkins at python.org
Fri Feb 29 23:08:41 CET 2008


Author: gerhard.haering
Date: Fri Feb 29 23:08:41 2008
New Revision: 61141

Added:
   python/trunk/Lib/sqlite3/test/py25tests.py
Modified:
   python/trunk/Lib/sqlite3/test/dbapi.py
   python/trunk/Lib/sqlite3/test/hooks.py
   python/trunk/Lib/sqlite3/test/regression.py
   python/trunk/Lib/sqlite3/test/transactions.py
   python/trunk/Lib/sqlite3/test/types.py
   python/trunk/Lib/test/test_sqlite.py
   python/trunk/Modules/_sqlite/connection.c
   python/trunk/Modules/_sqlite/connection.h
   python/trunk/Modules/_sqlite/cursor.c
   python/trunk/Modules/_sqlite/cursor.h
   python/trunk/Modules/_sqlite/microprotocols.h
   python/trunk/Modules/_sqlite/module.c
   python/trunk/Modules/_sqlite/module.h
   python/trunk/Modules/_sqlite/statement.c
   python/trunk/Modules/_sqlite/util.c
   python/trunk/Modules/_sqlite/util.h
Log:
Updated to pysqlite 2.4.1. Documentation additions will come later.


Modified: python/trunk/Lib/sqlite3/test/dbapi.py
==============================================================================
--- python/trunk/Lib/sqlite3/test/dbapi.py	(original)
+++ python/trunk/Lib/sqlite3/test/dbapi.py	Fri Feb 29 23:08:41 2008
@@ -1,7 +1,7 @@
 #-*- coding: ISO-8859-1 -*-
 # pysqlite2/test/dbapi.py: tests for DB-API compliance
 #
-# Copyright (C) 2004-2005 Gerhard Häring <gh at ghaering.de>
+# Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
 #
 # This file is part of pysqlite.
 #
@@ -22,6 +22,7 @@
 # 3. This notice may not be removed or altered from any source distribution.
 
 import unittest
+import sys
 import threading
 import sqlite3 as sqlite
 
@@ -223,12 +224,45 @@
         except sqlite.ProgrammingError:
             pass
 
+    def CheckExecuteParamList(self):
+        self.cu.execute("insert into test(name) values ('foo')")
+        self.cu.execute("select name from test where name=?", ["foo"])
+        row = self.cu.fetchone()
+        self.failUnlessEqual(row[0], "foo")
+
+    def CheckExecuteParamSequence(self):
+        class L(object):
+            def __len__(self):
+                return 1
+            def __getitem__(self, x):
+                assert x == 0
+                return "foo"
+
+        self.cu.execute("insert into test(name) values ('foo')")
+        self.cu.execute("select name from test where name=?", L())
+        row = self.cu.fetchone()
+        self.failUnlessEqual(row[0], "foo")
+
     def CheckExecuteDictMapping(self):
         self.cu.execute("insert into test(name) values ('foo')")
         self.cu.execute("select name from test where name=:name", {"name": "foo"})
         row = self.cu.fetchone()
         self.failUnlessEqual(row[0], "foo")
 
+    def CheckExecuteDictMapping_Mapping(self):
+        # Test only works with Python 2.5 or later
+        if sys.version_info < (2, 5, 0):
+            return
+
+        class D(dict):
+            def __missing__(self, key):
+                return "foo"
+
+        self.cu.execute("insert into test(name) values ('foo')")
+        self.cu.execute("select name from test where name=:name", D())
+        row = self.cu.fetchone()
+        self.failUnlessEqual(row[0], "foo")
+
     def CheckExecuteDictMappingTooLittleArgs(self):
         self.cu.execute("insert into test(name) values ('foo')")
         try:
@@ -378,6 +412,12 @@
         res = self.cu.fetchmany(100)
         self.failUnlessEqual(res, [])
 
+    def CheckFetchmanyKwArg(self):
+        """Checks if fetchmany works with keyword arguments"""
+        self.cu.execute("select name from test")
+        res = self.cu.fetchmany(size=100)
+        self.failUnlessEqual(len(res), 1)
+
     def CheckFetchall(self):
         self.cu.execute("select name from test")
         res = self.cu.fetchall()

Modified: python/trunk/Lib/sqlite3/test/hooks.py
==============================================================================
--- python/trunk/Lib/sqlite3/test/hooks.py	(original)
+++ python/trunk/Lib/sqlite3/test/hooks.py	Fri Feb 29 23:08:41 2008
@@ -1,7 +1,7 @@
 #-*- coding: ISO-8859-1 -*-
 # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
 #
-# Copyright (C) 2006 Gerhard Häring <gh at ghaering.de>
+# Copyright (C) 2006-2007 Gerhard Häring <gh at ghaering.de>
 #
 # This file is part of pysqlite.
 #
@@ -21,7 +21,7 @@
 #    misrepresented as being the original software.
 # 3. This notice may not be removed or altered from any source distribution.
 
-import unittest
+import os, unittest
 import sqlite3 as sqlite
 
 class CollationTests(unittest.TestCase):
@@ -105,9 +105,80 @@
             if not e.args[0].startswith("no such collation sequence"):
                 self.fail("wrong OperationalError raised")
 
+class ProgressTests(unittest.TestCase):
+    def CheckProgressHandlerUsed(self):
+        """
+        Test that the progress handler is invoked once it is set.
+        """
+        con = sqlite.connect(":memory:")
+        progress_calls = []
+        def progress():
+            progress_calls.append(None)
+            return 0
+        con.set_progress_handler(progress, 1)
+        con.execute("""
+            create table foo(a, b)
+            """)
+        self.failUnless(progress_calls)
+
+
+    def CheckOpcodeCount(self):
+        """
+        Test that the opcode argument is respected.
+        """
+        con = sqlite.connect(":memory:")
+        progress_calls = []
+        def progress():
+            progress_calls.append(None)
+            return 0
+        con.set_progress_handler(progress, 1)
+        curs = con.cursor()
+        curs.execute("""
+            create table foo (a, b)
+            """)
+        first_count = len(progress_calls)
+        progress_calls = []
+        con.set_progress_handler(progress, 2)
+        curs.execute("""
+            create table bar (a, b)
+            """)
+        second_count = len(progress_calls)
+        self.failUnless(first_count > second_count)
+
+    def CheckCancelOperation(self):
+        """
+        Test that returning a non-zero value stops the operation in progress.
+        """
+        con = sqlite.connect(":memory:")
+        progress_calls = []
+        def progress():
+            progress_calls.append(None)
+            return 1
+        con.set_progress_handler(progress, 1)
+        curs = con.cursor()
+        self.assertRaises(
+            sqlite.OperationalError,
+            curs.execute,
+            "create table bar (a, b)")
+
+    def CheckClearHandler(self):
+        """
+        Test that setting the progress handler to None clears the previously set handler.
+        """
+        con = sqlite.connect(":memory:")
+        action = 0
+        def progress():
+            action = 1
+            return 0
+        con.set_progress_handler(progress, 1)
+        con.set_progress_handler(None, 1)
+        con.execute("select 1 union select 2 union select 3").fetchall()
+        self.failUnlessEqual(action, 0, "progress handler was not cleared")
+
 def suite():
     collation_suite = unittest.makeSuite(CollationTests, "Check")
-    return unittest.TestSuite((collation_suite,))
+    progress_suite = unittest.makeSuite(ProgressTests, "Check")
+    return unittest.TestSuite((collation_suite, progress_suite))
 
 def test():
     runner = unittest.TextTestRunner()

Added: python/trunk/Lib/sqlite3/test/py25tests.py
==============================================================================
--- (empty file)
+++ python/trunk/Lib/sqlite3/test/py25tests.py	Fri Feb 29 23:08:41 2008
@@ -0,0 +1,80 @@
+#-*- coding: ISO-8859-1 -*-
+# pysqlite2/test/regression.py: pysqlite regression tests
+#
+# Copyright (C) 2007 Gerhard Häring <gh at ghaering.de>
+#
+# This file is part of pysqlite.
+#
+# This software is provided 'as-is', without any express or implied
+# warranty.  In no event will the authors be held liable for any damages
+# arising from the use of this software.
+#
+# Permission is granted to anyone to use this software for any purpose,
+# including commercial applications, and to alter it and redistribute it
+# freely, subject to the following restrictions:
+#
+# 1. The origin of this software must not be misrepresented; you must not
+#    claim that you wrote the original software. If you use this software
+#    in a product, an acknowledgment in the product documentation would be
+#    appreciated but is not required.
+# 2. Altered source versions must be plainly marked as such, and must not be
+#    misrepresented as being the original software.
+# 3. This notice may not be removed or altered from any source distribution.
+
+from __future__ import with_statement
+import unittest
+import sqlite3 as sqlite
+
+did_rollback = False
+
+class MyConnection(sqlite.Connection):
+    def rollback(self):
+        global did_rollback
+        did_rollback = True
+        sqlite.Connection.rollback(self)
+
+class ContextTests(unittest.TestCase):
+    def setUp(self):
+        global did_rollback
+        self.con = sqlite.connect(":memory:", factory=MyConnection)
+        self.con.execute("create table test(c unique)")
+        did_rollback = False
+
+    def tearDown(self):
+        self.con.close()
+
+    def CheckContextManager(self):
+        """Can the connection be used as a context manager at all?"""
+        with self.con:
+            pass
+
+    def CheckContextManagerCommit(self):
+        """Is a commit called in the context manager?"""
+        with self.con:
+            self.con.execute("insert into test(c) values ('foo')")
+        self.con.rollback()
+        count = self.con.execute("select count(*) from test").fetchone()[0]
+        self.failUnlessEqual(count, 1)
+
+    def CheckContextManagerRollback(self):
+        """Is a rollback called in the context manager?"""
+        global did_rollback
+        self.failUnlessEqual(did_rollback, False)
+        try:
+            with self.con:
+                self.con.execute("insert into test(c) values (4)")
+                self.con.execute("insert into test(c) values (4)")
+        except sqlite.IntegrityError:
+            pass
+        self.failUnlessEqual(did_rollback, True)
+
+def suite():
+    ctx_suite = unittest.makeSuite(ContextTests, "Check")
+    return unittest.TestSuite((ctx_suite,))
+
+def test():
+    runner = unittest.TextTestRunner()
+    runner.run(suite())
+
+if __name__ == "__main__":
+    test()

Modified: python/trunk/Lib/sqlite3/test/regression.py
==============================================================================
--- python/trunk/Lib/sqlite3/test/regression.py	(original)
+++ python/trunk/Lib/sqlite3/test/regression.py	Fri Feb 29 23:08:41 2008
@@ -1,7 +1,7 @@
 #-*- coding: ISO-8859-1 -*-
 # pysqlite2/test/regression.py: pysqlite regression tests
 #
-# Copyright (C) 2006 Gerhard Häring <gh at ghaering.de>
+# Copyright (C) 2006-2007 Gerhard Häring <gh at ghaering.de>
 #
 # This file is part of pysqlite.
 #
@@ -21,6 +21,7 @@
 #    misrepresented as being the original software.
 # 3. This notice may not be removed or altered from any source distribution.
 
+import datetime
 import unittest
 import sqlite3 as sqlite
 
@@ -79,6 +80,79 @@
         cur.fetchone()
         cur.fetchone()
 
+    def CheckStatementFinalizationOnCloseDb(self):
+        # pysqlite versions <= 2.3.3 only finalized statements in the statement
+        # cache when closing the database. statements that were still
+        # referenced in cursors weren't closed an could provoke "
+        # "OperationalError: Unable to close due to unfinalised statements".
+        con = sqlite.connect(":memory:")
+        cursors = []
+        # default statement cache size is 100
+        for i in range(105):
+            cur = con.cursor()
+            cursors.append(cur)
+            cur.execute("select 1 x union select " + str(i))
+        con.close()
+
+    def CheckOnConflictRollback(self):
+        if sqlite.sqlite_version_info < (3, 2, 2):
+            return
+        con = sqlite.connect(":memory:")
+        con.execute("create table foo(x, unique(x) on conflict rollback)")
+        con.execute("insert into foo(x) values (1)")
+        try:
+            con.execute("insert into foo(x) values (1)")
+        except sqlite.DatabaseError:
+            pass
+        con.execute("insert into foo(x) values (2)")
+        try:
+            con.commit()
+        except sqlite.OperationalError:
+            self.fail("pysqlite knew nothing about the implicit ROLLBACK")
+
+    def CheckWorkaroundForBuggySqliteTransferBindings(self):
+        """
+        pysqlite would crash with older SQLite versions unless
+        a workaround is implemented.
+        """
+        self.con.execute("create table if not exists foo(bar)")
+        self.con.execute("create table if not exists foo(bar)")
+
+    def CheckEmptyStatement(self):
+        """
+        pysqlite used to segfault with SQLite versions 3.5.x. These return NULL
+        for "no-operation" statements
+        """
+        self.con.execute("")
+
+    def CheckUnicodeConnect(self):
+        """
+        With pysqlite 2.4.0 you needed to use a string or a APSW connection
+        object for opening database connections.
+
+        Formerly, both bytestrings and unicode strings used to work.
+
+        Let's make sure unicode strings work in the future.
+        """
+        con = sqlite.connect(u":memory:")
+        con.close()
+
+    def CheckTypeMapUsage(self):
+        """
+        pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling
+        a statement. This test exhibits the problem.
+        """
+        SELECT = "select * from foo"
+        con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
+        con.execute("create table foo(bar timestamp)")
+        con.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
+        con.execute(SELECT)
+        con.execute("drop table foo")
+        con.execute("create table foo(bar integer)")
+        con.execute("insert into foo(bar) values (5)")
+        con.execute(SELECT)
+
+
 def suite():
     regression_suite = unittest.makeSuite(RegressionTests, "Check")
     return unittest.TestSuite((regression_suite,))

Modified: python/trunk/Lib/sqlite3/test/transactions.py
==============================================================================
--- python/trunk/Lib/sqlite3/test/transactions.py	(original)
+++ python/trunk/Lib/sqlite3/test/transactions.py	Fri Feb 29 23:08:41 2008
@@ -1,7 +1,7 @@
 #-*- coding: ISO-8859-1 -*-
 # pysqlite2/test/transactions.py: tests transactions
 #
-# Copyright (C) 2005 Gerhard Häring <gh at ghaering.de>
+# Copyright (C) 2005-2007 Gerhard Häring <gh at ghaering.de>
 #
 # This file is part of pysqlite.
 #
@@ -21,6 +21,7 @@
 #    misrepresented as being the original software.
 # 3. This notice may not be removed or altered from any source distribution.
 
+import sys
 import os, unittest
 import sqlite3 as sqlite
 
@@ -119,6 +120,23 @@
         except:
             self.fail("should have raised an OperationalError")
 
+    def CheckLocking(self):
+        """
+        This tests the improved concurrency with pysqlite 2.3.4. You needed
+        to roll back con2 before you could commit con1.
+        """
+        self.cur1.execute("create table test(i)")
+        self.cur1.execute("insert into test(i) values (5)")
+        try:
+            self.cur2.execute("insert into test(i) values (5)")
+            self.fail("should have raised an OperationalError")
+        except sqlite.OperationalError:
+            pass
+        except:
+            self.fail("should have raised an OperationalError")
+        # NO self.con2.rollback() HERE!!!
+        self.con1.commit()
+
 class SpecialCommandTests(unittest.TestCase):
     def setUp(self):
         self.con = sqlite.connect(":memory:")

Modified: python/trunk/Lib/sqlite3/test/types.py
==============================================================================
--- python/trunk/Lib/sqlite3/test/types.py	(original)
+++ python/trunk/Lib/sqlite3/test/types.py	Fri Feb 29 23:08:41 2008
@@ -1,7 +1,7 @@
 #-*- coding: ISO-8859-1 -*-
 # pysqlite2/test/types.py: tests for type conversion and detection
 #
-# Copyright (C) 2005 Gerhard Häring <gh at ghaering.de>
+# Copyright (C) 2005-2007 Gerhard Häring <gh at ghaering.de>
 #
 # This file is part of pysqlite.
 #
@@ -21,7 +21,7 @@
 #    misrepresented as being the original software.
 # 3. This notice may not be removed or altered from any source distribution.
 
-import bz2, datetime
+import zlib, datetime
 import unittest
 import sqlite3 as sqlite
 
@@ -287,7 +287,7 @@
 
 class BinaryConverterTests(unittest.TestCase):
     def convert(s):
-        return bz2.decompress(s)
+        return zlib.decompress(s)
     convert = staticmethod(convert)
 
     def setUp(self):
@@ -299,7 +299,7 @@
 
     def CheckBinaryInputForConverter(self):
         testdata = "abcdefg" * 10
-        result = self.con.execute('select ? as "x [bin]"', (buffer(bz2.compress(testdata)),)).fetchone()[0]
+        result = self.con.execute('select ? as "x [bin]"', (buffer(zlib.compress(testdata)),)).fetchone()[0]
         self.failUnlessEqual(testdata, result)
 
 class DateTimeTests(unittest.TestCase):
@@ -331,7 +331,8 @@
         if sqlite.sqlite_version_info < (3, 1):
             return
 
-        now = datetime.datetime.utcnow()
+        # SQLite's current_timestamp uses UTC time, while datetime.datetime.now() uses local time.
+        now = datetime.datetime.now()
         self.cur.execute("insert into test(ts) values (current_timestamp)")
         self.cur.execute("select ts from test")
         ts = self.cur.fetchone()[0]

Modified: python/trunk/Lib/test/test_sqlite.py
==============================================================================
--- python/trunk/Lib/test/test_sqlite.py	(original)
+++ python/trunk/Lib/test/test_sqlite.py	Fri Feb 29 23:08:41 2008
@@ -4,13 +4,13 @@
     import _sqlite3
 except ImportError:
     raise TestSkipped('no sqlite available')
-from sqlite3.test import (dbapi, types, userfunctions,
+from sqlite3.test import (dbapi, types, userfunctions, py25tests,
                                 factory, transactions, hooks, regression)
 
 def test_main():
     run_unittest(dbapi.suite(), types.suite(), userfunctions.suite(),
-                 factory.suite(), transactions.suite(), hooks.suite(),
-                 regression.suite())
+                 py25tests.suite(), factory.suite(), transactions.suite(),
+                 hooks.suite(), regression.suite())
 
 if __name__ == "__main__":
     test_main()

Modified: python/trunk/Modules/_sqlite/connection.c
==============================================================================
--- python/trunk/Modules/_sqlite/connection.c	(original)
+++ python/trunk/Modules/_sqlite/connection.c	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* connection.c - the connection type
  *
- * Copyright (C) 2004-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  * 
@@ -32,6 +32,9 @@
 
 #include "pythread.h"
 
+#define ACTION_FINALIZE 1
+#define ACTION_RESET 2
+
 static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level);
 
 
@@ -51,7 +54,7 @@
 {
     static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL};
 
-    char* database;
+    PyObject* database;
     int detect_types = 0;
     PyObject* isolation_level = NULL;
     PyObject* factory = NULL;
@@ -59,11 +62,15 @@
     int cached_statements = 100;
     double timeout = 5.0;
     int rc;
+    PyObject* class_attr = NULL;
+    PyObject* class_attr_str = NULL;
+    int is_apsw_connection = 0;
+    PyObject* database_utf8;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist,
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist,
                                      &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements))
     {
-        return -1; 
+        return -1;
     }
 
     self->begin_statement = NULL;
@@ -77,13 +84,53 @@
     Py_INCREF(&PyUnicode_Type);
     self->text_factory = (PyObject*)&PyUnicode_Type;
 
-    Py_BEGIN_ALLOW_THREADS
-    rc = sqlite3_open(database, &self->db);
-    Py_END_ALLOW_THREADS
+    if (PyString_Check(database) || PyUnicode_Check(database)) {
+        if (PyString_Check(database)) {
+            database_utf8 = database;
+            Py_INCREF(database_utf8);
+        } else {
+            database_utf8 = PyUnicode_AsUTF8String(database);
+            if (!database_utf8) {
+                return -1;
+            }
+        }
 
-    if (rc != SQLITE_OK) {
-        _pysqlite_seterror(self->db);
-        return -1;
+        Py_BEGIN_ALLOW_THREADS
+        rc = sqlite3_open(PyString_AsString(database_utf8), &self->db);
+        Py_END_ALLOW_THREADS
+
+        Py_DECREF(database_utf8);
+
+        if (rc != SQLITE_OK) {
+            _pysqlite_seterror(self->db, NULL);
+            return -1;
+        }
+    } else {
+        /* Create a pysqlite connection from a APSW connection */
+        class_attr = PyObject_GetAttrString(database, "__class__");
+        if (class_attr) {
+            class_attr_str = PyObject_Str(class_attr);
+            if (class_attr_str) {
+                if (strcmp(PyString_AsString(class_attr_str), "<type 'apsw.Connection'>") == 0) {
+                    /* In the APSW Connection object, the first entry after
+                     * PyObject_HEAD is the sqlite3* we want to get hold of.
+                     * Luckily, this is the same layout as we have in our
+                     * pysqlite_Connection */
+                    self->db = ((pysqlite_Connection*)database)->db;
+
+                    Py_INCREF(database);
+                    self->apsw_connection = database;
+                    is_apsw_connection = 1;
+                }
+            }
+        }
+        Py_XDECREF(class_attr_str);
+        Py_XDECREF(class_attr);
+
+        if (!is_apsw_connection) {
+            PyErr_SetString(PyExc_ValueError, "database parameter must be string or APSW Connection object");
+            return -1;
+        }
     }
 
     if (!isolation_level) {
@@ -169,7 +216,8 @@
     self->statement_cache->decref_factory = 0;
 }
 
-void pysqlite_reset_all_statements(pysqlite_Connection* self)
+/* action in (ACTION_RESET, ACTION_FINALIZE) */
+void pysqlite_do_all_statements(pysqlite_Connection* self, int action)
 {
     int i;
     PyObject* weakref;
@@ -179,13 +227,19 @@
         weakref = PyList_GetItem(self->statements, i);
         statement = PyWeakref_GetObject(weakref);
         if (statement != Py_None) {
-            (void)pysqlite_statement_reset((pysqlite_Statement*)statement);
+            if (action == ACTION_RESET) {
+                (void)pysqlite_statement_reset((pysqlite_Statement*)statement);
+            } else {
+                (void)pysqlite_statement_finalize((pysqlite_Statement*)statement);
+            }
         }
     }
 }
 
 void pysqlite_connection_dealloc(pysqlite_Connection* self)
 {
+    PyObject* ret = NULL;
+
     Py_XDECREF(self->statement_cache);
 
     /* Clean up if user has not called .close() explicitly. */
@@ -193,6 +247,10 @@
         Py_BEGIN_ALLOW_THREADS
         sqlite3_close(self->db);
         Py_END_ALLOW_THREADS
+    } else if (self->apsw_connection) {
+        ret = PyObject_CallMethod(self->apsw_connection, "close", "");
+        Py_XDECREF(ret);
+        Py_XDECREF(self->apsw_connection);
     }
 
     if (self->begin_statement) {
@@ -205,7 +263,7 @@
     Py_XDECREF(self->collations);
     Py_XDECREF(self->statements);
 
-    Py_TYPE(self)->tp_free((PyObject*)self);
+    self->ob_type->tp_free((PyObject*)self);
 }
 
 PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
@@ -241,24 +299,33 @@
 
 PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
 {
+    PyObject* ret;
     int rc;
 
     if (!pysqlite_check_thread(self)) {
         return NULL;
     }
 
-    pysqlite_flush_statement_cache(self);
+    pysqlite_do_all_statements(self, ACTION_FINALIZE);
 
     if (self->db) {
-        Py_BEGIN_ALLOW_THREADS
-        rc = sqlite3_close(self->db);
-        Py_END_ALLOW_THREADS
-
-        if (rc != SQLITE_OK) {
-            _pysqlite_seterror(self->db);
-            return NULL;
-        } else {
+        if (self->apsw_connection) {
+            ret = PyObject_CallMethod(self->apsw_connection, "close", "");
+            Py_XDECREF(ret);
+            Py_XDECREF(self->apsw_connection);
+            self->apsw_connection = NULL;
             self->db = NULL;
+        } else {
+            Py_BEGIN_ALLOW_THREADS
+            rc = sqlite3_close(self->db);
+            Py_END_ALLOW_THREADS
+
+            if (rc != SQLITE_OK) {
+                _pysqlite_seterror(self->db, NULL);
+                return NULL;
+            } else {
+                self->db = NULL;
+            }
         }
     }
 
@@ -292,7 +359,7 @@
     Py_END_ALLOW_THREADS
 
     if (rc != SQLITE_OK) {
-        _pysqlite_seterror(self->db);
+        _pysqlite_seterror(self->db, statement);
         goto error;
     }
 
@@ -300,7 +367,7 @@
     if (rc == SQLITE_DONE) {
         self->inTransaction = 1;
     } else {
-        _pysqlite_seterror(self->db);
+        _pysqlite_seterror(self->db, statement);
     }
 
     Py_BEGIN_ALLOW_THREADS
@@ -308,7 +375,7 @@
     Py_END_ALLOW_THREADS
 
     if (rc != SQLITE_OK && !PyErr_Occurred()) {
-        _pysqlite_seterror(self->db);
+        _pysqlite_seterror(self->db, NULL);
     }
 
 error:
@@ -335,7 +402,7 @@
         rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail);
         Py_END_ALLOW_THREADS
         if (rc != SQLITE_OK) {
-            _pysqlite_seterror(self->db);
+            _pysqlite_seterror(self->db, NULL);
             goto error;
         }
 
@@ -343,14 +410,14 @@
         if (rc == SQLITE_DONE) {
             self->inTransaction = 0;
         } else {
-            _pysqlite_seterror(self->db);
+            _pysqlite_seterror(self->db, statement);
         }
 
         Py_BEGIN_ALLOW_THREADS
         rc = sqlite3_finalize(statement);
         Py_END_ALLOW_THREADS
         if (rc != SQLITE_OK && !PyErr_Occurred()) {
-            _pysqlite_seterror(self->db);
+            _pysqlite_seterror(self->db, NULL);
         }
 
     }
@@ -375,13 +442,13 @@
     }
 
     if (self->inTransaction) {
-        pysqlite_reset_all_statements(self);
+        pysqlite_do_all_statements(self, ACTION_RESET);
 
         Py_BEGIN_ALLOW_THREADS
         rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail);
         Py_END_ALLOW_THREADS
         if (rc != SQLITE_OK) {
-            _pysqlite_seterror(self->db);
+            _pysqlite_seterror(self->db, NULL);
             goto error;
         }
 
@@ -389,14 +456,14 @@
         if (rc == SQLITE_DONE) {
             self->inTransaction = 0;
         } else {
-            _pysqlite_seterror(self->db);
+            _pysqlite_seterror(self->db, statement);
         }
 
         Py_BEGIN_ALLOW_THREADS
         rc = sqlite3_finalize(statement);
         Py_END_ALLOW_THREADS
         if (rc != SQLITE_OK && !PyErr_Occurred()) {
-            _pysqlite_seterror(self->db);
+            _pysqlite_seterror(self->db, NULL);
         }
 
     }
@@ -762,6 +829,33 @@
     return rc;
 }
 
+static int _progress_handler(void* user_arg)
+{
+    int rc;
+    PyObject *ret;
+    PyGILState_STATE gilstate;
+
+    gilstate = PyGILState_Ensure();
+    ret = PyObject_CallFunction((PyObject*)user_arg, "");
+
+    if (!ret) {
+        if (_enable_callback_tracebacks) {
+            PyErr_Print();
+        } else {
+            PyErr_Clear();
+        }
+
+        /* abort query if error occured */
+        rc = 1; 
+    } else {
+        rc = (int)PyObject_IsTrue(ret);
+    }
+
+    Py_DECREF(ret);
+    PyGILState_Release(gilstate);
+    return rc;
+}
+
 PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
 {
     PyObject* authorizer_cb;
@@ -787,6 +881,30 @@
     }
 }
 
+PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
+{
+    PyObject* progress_handler;
+    int n;
+
+    static char *kwlist[] = { "progress_handler", "n", NULL };
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler",
+                                      kwlist, &progress_handler, &n)) {
+        return NULL;
+    }
+
+    if (progress_handler == Py_None) {
+        /* None clears the progress handler previously set */
+        sqlite3_progress_handler(self->db, 0, 0, (void*)0);
+    } else {
+        sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
+        PyDict_SetItem(self->function_pinboard, progress_handler, Py_None);
+    }
+
+    Py_INCREF(Py_None);
+    return Py_None;
+}
+
 int pysqlite_check_thread(pysqlite_Connection* self)
 {
     if (self->check_same_thread) {
@@ -892,7 +1010,8 @@
         } else if (rc == PYSQLITE_SQL_WRONG_TYPE) {
             PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode.");
         } else {
-            _pysqlite_seterror(self->db);
+            (void)pysqlite_statement_reset(statement);
+            _pysqlite_seterror(self->db, NULL);
         }
 
         Py_DECREF(statement);
@@ -1134,7 +1253,7 @@
                                   (callable != Py_None) ? pysqlite_collation_callback : NULL);
     if (rc != SQLITE_OK) {
         PyDict_DelItem(self->collations, uppercase_name);
-        _pysqlite_seterror(self->db);
+        _pysqlite_seterror(self->db, NULL);
         goto finally;
     }
 
@@ -1151,6 +1270,44 @@
     return retval;
 }
 
+/* Called when the connection is used as a context manager. Returns itself as a
+ * convenience to the caller. */
+static PyObject *
+pysqlite_connection_enter(pysqlite_Connection* self, PyObject* args)
+{
+    Py_INCREF(self);
+    return (PyObject*)self;
+}
+
+/** Called when the connection is used as a context manager. If there was any
+ * exception, a rollback takes place; otherwise we commit. */
+static PyObject *
+pysqlite_connection_exit(pysqlite_Connection* self, PyObject* args)
+{
+    PyObject* exc_type, *exc_value, *exc_tb;
+    char* method_name;
+    PyObject* result;
+
+    if (!PyArg_ParseTuple(args, "OOO", &exc_type, &exc_value, &exc_tb)) {
+        return NULL;
+    }
+
+    if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) {
+        method_name = "commit";
+    } else {
+        method_name = "rollback";
+    }
+
+    result = PyObject_CallMethod((PyObject*)self, method_name, "");
+    if (!result) {
+        return NULL;
+    }
+    Py_DECREF(result);
+
+    Py_INCREF(Py_False);
+    return Py_False;
+}
+
 static char connection_doc[] =
 PyDoc_STR("SQLite database connection object.");
 
@@ -1175,6 +1332,8 @@
         PyDoc_STR("Creates a new aggregate. Non-standard.")},
     {"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS,
         PyDoc_STR("Sets authorizer callback. Non-standard.")},
+    {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS,
+        PyDoc_STR("Sets progress handler callback. Non-standard.")},
     {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS,
         PyDoc_STR("Executes a SQL statement. Non-standard.")},
     {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS,
@@ -1185,6 +1344,10 @@
         PyDoc_STR("Creates a collation function. Non-standard.")},
     {"interrupt", (PyCFunction)pysqlite_connection_interrupt, METH_NOARGS,
         PyDoc_STR("Abort any pending database operation. Non-standard.")},
+    {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS,
+        PyDoc_STR("For context manager. Non-standard.")},
+    {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS,
+        PyDoc_STR("For context manager. Non-standard.")},
     {NULL, NULL}
 };
 

Modified: python/trunk/Modules/_sqlite/connection.h
==============================================================================
--- python/trunk/Modules/_sqlite/connection.h	(original)
+++ python/trunk/Modules/_sqlite/connection.h	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* connection.h - definitions for the connection type
  *
- * Copyright (C) 2004-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  *
@@ -95,6 +95,11 @@
     /* a dictionary of registered collation name => collation callable mappings */
     PyObject* collations;
 
+    /* if our connection was created from a APSW connection, we keep a
+     * reference to the APSW connection around and get rid of it in our
+     * destructor */
+    PyObject* apsw_connection;
+
     /* Exception objects */
     PyObject* Warning;
     PyObject* Error;

Modified: python/trunk/Modules/_sqlite/cursor.c
==============================================================================
--- python/trunk/Modules/_sqlite/cursor.c	(original)
+++ python/trunk/Modules/_sqlite/cursor.c	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* cursor.c - the cursor type
  *
- * Copyright (C) 2004-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  *
@@ -80,7 +80,7 @@
 
     if (!PyArg_ParseTuple(args, "O!", &pysqlite_ConnectionType, &connection))
     {
-        return -1; 
+        return -1;
     }
 
     Py_INCREF(connection);
@@ -435,7 +435,7 @@
     if (multiple) {
         /* executemany() */
         if (!PyArg_ParseTuple(args, "OO", &operation, &second_argument)) {
-            return NULL; 
+            return NULL;
         }
 
         if (!PyString_Check(operation) && !PyUnicode_Check(operation)) {
@@ -457,7 +457,7 @@
     } else {
         /* execute() */
         if (!PyArg_ParseTuple(args, "O|O", &operation, &second_argument)) {
-            return NULL; 
+            return NULL;
         }
 
         if (!PyString_Check(operation) && !PyUnicode_Check(operation)) {
@@ -506,16 +506,47 @@
         operation_cstr = PyString_AsString(operation_bytestr);
     }
 
-    /* reset description and rowcount */
+    /* reset description */
     Py_DECREF(self->description);
     Py_INCREF(Py_None);
     self->description = Py_None;
 
-    Py_DECREF(self->rowcount);
-    self->rowcount = PyInt_FromLong(-1L);
-    if (!self->rowcount) {
+    func_args = PyTuple_New(1);
+    if (!func_args) {
         goto error;
     }
+    Py_INCREF(operation);
+    if (PyTuple_SetItem(func_args, 0, operation) != 0) {
+        goto error;
+    }
+
+    if (self->statement) {
+        (void)pysqlite_statement_reset(self->statement);
+        Py_DECREF(self->statement);
+    }
+
+    self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args);
+    Py_DECREF(func_args);
+
+    if (!self->statement) {
+        goto error;
+    }
+
+    if (self->statement->in_use) {
+        Py_DECREF(self->statement);
+        self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType);
+        if (!self->statement) {
+            goto error;
+        }
+        rc = pysqlite_statement_create(self->statement, self->connection, operation);
+        if (rc != SQLITE_OK) {
+            self->statement = 0;
+            goto error;
+        }
+    }
+
+    pysqlite_statement_reset(self->statement);
+    pysqlite_statement_mark_dirty(self->statement);
 
     statement_type = detect_statement_type(operation_cstr);
     if (self->connection->begin_statement) {
@@ -553,43 +584,6 @@
         }
     }
 
-    func_args = PyTuple_New(1);
-    if (!func_args) {
-        goto error;
-    }
-    Py_INCREF(operation);
-    if (PyTuple_SetItem(func_args, 0, operation) != 0) {
-        goto error;
-    }
-
-    if (self->statement) {
-        (void)pysqlite_statement_reset(self->statement);
-        Py_DECREF(self->statement);
-    }
-
-    self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args);
-    Py_DECREF(func_args);
-
-    if (!self->statement) {
-        goto error;
-    }
-
-    if (self->statement->in_use) {
-        Py_DECREF(self->statement);
-        self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType);
-        if (!self->statement) {
-            goto error;
-        }
-        rc = pysqlite_statement_create(self->statement, self->connection, operation);
-        if (rc != SQLITE_OK) {
-            self->statement = 0;
-            goto error;
-        }
-    }
-
-    pysqlite_statement_reset(self->statement);
-    pysqlite_statement_mark_dirty(self->statement);
-
     while (1) {
         parameters = PyIter_Next(parameters_iter);
         if (!parameters) {
@@ -603,11 +597,6 @@
             goto error;
         }
 
-        if (pysqlite_build_row_cast_map(self) != 0) {
-            PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map");
-            goto error;
-        }
-
         /* Keep trying the SQL statement until the schema stops changing. */
         while (1) {
             /* Actually execute the SQL statement. */
@@ -626,7 +615,8 @@
                     continue;
                 } else {
                     /* If the database gave us an error, promote it to Python. */
-                    _pysqlite_seterror(self->connection->db);
+                    (void)pysqlite_statement_reset(self->statement);
+                    _pysqlite_seterror(self->connection->db, NULL);
                     goto error;
                 }
             } else {
@@ -638,17 +628,23 @@
                         PyErr_Clear();
                     }
                 }
-                _pysqlite_seterror(self->connection->db);
+                (void)pysqlite_statement_reset(self->statement);
+                _pysqlite_seterror(self->connection->db, NULL);
                 goto error;
             }
         }
 
-        if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) {
-            Py_BEGIN_ALLOW_THREADS
-            numcols = sqlite3_column_count(self->statement->st);
-            Py_END_ALLOW_THREADS
+        if (pysqlite_build_row_cast_map(self) != 0) {
+            PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map");
+            goto error;
+        }
 
+        if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) {
             if (self->description == Py_None) {
+                Py_BEGIN_ALLOW_THREADS
+                numcols = sqlite3_column_count(self->statement->st);
+                Py_END_ALLOW_THREADS
+
                 Py_DECREF(self->description);
                 self->description = PyTuple_New(numcols);
                 if (!self->description) {
@@ -689,15 +685,11 @@
             case STATEMENT_DELETE:
             case STATEMENT_INSERT:
             case STATEMENT_REPLACE:
-                Py_BEGIN_ALLOW_THREADS
                 rowcount += (long)sqlite3_changes(self->connection->db);
-                Py_END_ALLOW_THREADS
-                Py_DECREF(self->rowcount);
-                self->rowcount = PyInt_FromLong(rowcount);
         }
 
         Py_DECREF(self->lastrowid);
-        if (statement_type == STATEMENT_INSERT) {
+        if (!multiple && statement_type == STATEMENT_INSERT) {
             Py_BEGIN_ALLOW_THREADS
             lastrowid = sqlite3_last_insert_rowid(self->connection->db);
             Py_END_ALLOW_THREADS
@@ -714,14 +706,27 @@
     }
 
 error:
+    /* just to be sure (implicit ROLLBACKs with ON CONFLICT ROLLBACK/OR
+     * ROLLBACK could have happened */
+    #ifdef SQLITE_VERSION_NUMBER
+    #if SQLITE_VERSION_NUMBER >= 3002002
+    self->connection->inTransaction = !sqlite3_get_autocommit(self->connection->db);
+    #endif
+    #endif
+
     Py_XDECREF(operation_bytestr);
     Py_XDECREF(parameters);
     Py_XDECREF(parameters_iter);
     Py_XDECREF(parameters_list);
 
     if (PyErr_Occurred()) {
+        Py_DECREF(self->rowcount);
+        self->rowcount = PyInt_FromLong(-1L);
         return NULL;
     } else {
+        Py_DECREF(self->rowcount);
+        self->rowcount = PyInt_FromLong(rowcount);
+
         Py_INCREF(self);
         return (PyObject*)self;
     }
@@ -748,7 +753,7 @@
     int statement_completed = 0;
 
     if (!PyArg_ParseTuple(args, "O", &script_obj)) {
-        return NULL; 
+        return NULL;
     }
 
     if (!pysqlite_check_thread(self->connection) || !pysqlite_check_connection(self->connection)) {
@@ -788,7 +793,7 @@
                              &statement,
                              &script_cstr);
         if (rc != SQLITE_OK) {
-            _pysqlite_seterror(self->connection->db);
+            _pysqlite_seterror(self->connection->db, NULL);
             goto error;
         }
 
@@ -796,17 +801,18 @@
         rc = SQLITE_ROW;
         while (rc == SQLITE_ROW) {
             rc = _sqlite_step_with_busyhandler(statement, self->connection);
+            /* TODO: we probably need more error handling here */
         }
 
         if (rc != SQLITE_DONE) {
             (void)sqlite3_finalize(statement);
-            _pysqlite_seterror(self->connection->db);
+            _pysqlite_seterror(self->connection->db, NULL);
             goto error;
         }
 
         rc = sqlite3_finalize(statement);
         if (rc != SQLITE_OK) {
-            _pysqlite_seterror(self->connection->db);
+            _pysqlite_seterror(self->connection->db, NULL);
             goto error;
         }
     }
@@ -864,8 +870,9 @@
     if (self->statement) {
         rc = _sqlite_step_with_busyhandler(self->statement->st, self->connection);
         if (rc != SQLITE_DONE && rc != SQLITE_ROW) {
+            (void)pysqlite_statement_reset(self->statement);
             Py_DECREF(next_row);
-            _pysqlite_seterror(self->connection->db);
+            _pysqlite_seterror(self->connection->db, NULL);
             return NULL;
         }
 
@@ -890,15 +897,17 @@
     return row;
 }
 
-PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args)
+PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs)
 {
+    static char *kwlist[] = {"size", NULL, NULL};
+
     PyObject* row;
     PyObject* list;
     int maxrows = self->arraysize;
     int counter = 0;
 
-    if (!PyArg_ParseTuple(args, "|i", &maxrows)) {
-        return NULL; 
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i:fetchmany", kwlist, &maxrows)) {
+        return NULL;
     }
 
     list = PyList_New(0);
@@ -992,7 +1001,7 @@
         PyDoc_STR("Executes a multiple SQL statements at once. Non-standard.")},
     {"fetchone", (PyCFunction)pysqlite_cursor_fetchone, METH_NOARGS,
         PyDoc_STR("Fetches one row from the resultset.")},
-    {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS,
+    {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS|METH_KEYWORDS,
         PyDoc_STR("Fetches several rows from the resultset.")},
     {"fetchall", (PyCFunction)pysqlite_cursor_fetchall, METH_NOARGS,
         PyDoc_STR("Fetches all rows from the resultset.")},

Modified: python/trunk/Modules/_sqlite/cursor.h
==============================================================================
--- python/trunk/Modules/_sqlite/cursor.h	(original)
+++ python/trunk/Modules/_sqlite/cursor.h	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* cursor.h - definitions for the cursor type
  *
- * Copyright (C) 2004-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  *
@@ -60,7 +60,7 @@
 PyObject* pysqlite_cursor_getiter(pysqlite_Cursor *self);
 PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self);
 PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args);
-PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args);
+PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs);
 PyObject* pysqlite_cursor_fetchall(pysqlite_Cursor* self, PyObject* args);
 PyObject* pysqlite_noop(pysqlite_Connection* self, PyObject* args);
 PyObject* pysqlite_cursor_close(pysqlite_Cursor* self, PyObject* args);

Modified: python/trunk/Modules/_sqlite/microprotocols.h
==============================================================================
--- python/trunk/Modules/_sqlite/microprotocols.h	(original)
+++ python/trunk/Modules/_sqlite/microprotocols.h	Fri Feb 29 23:08:41 2008
@@ -28,10 +28,6 @@
 
 #include <Python.h>
 
-#ifdef __cplusplus
-extern "C" {
-#endif
-
 /** adapters registry **/
 
 extern PyObject *psyco_adapters;

Modified: python/trunk/Modules/_sqlite/module.c
==============================================================================
--- python/trunk/Modules/_sqlite/module.c	(original)
+++ python/trunk/Modules/_sqlite/module.c	Fri Feb 29 23:08:41 2008
@@ -1,25 +1,25 @@
-    /* module.c - the module itself
-     *
-     * Copyright (C) 2004-2006 Gerhard Häring <gh at ghaering.de>
-     *
-     * This file is part of pysqlite.
-     *
-     * This software is provided 'as-is', without any express or implied
-     * warranty.  In no event will the authors be held liable for any damages
-     * arising from the use of this software.
-     *
-     * Permission is granted to anyone to use this software for any purpose,
-     * including commercial applications, and to alter it and redistribute it
-     * freely, subject to the following restrictions:
-     *
-     * 1. The origin of this software must not be misrepresented; you must not
-     *    claim that you wrote the original software. If you use this software
-     *    in a product, an acknowledgment in the product documentation would be
-     *    appreciated but is not required.
-     * 2. Altered source versions must be plainly marked as such, and must not be
-     *    misrepresented as being the original software.
-     * 3. This notice may not be removed or altered from any source distribution.
-     */
+/* module.c - the module itself
+ *
+ * Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
+ *
+ * This file is part of pysqlite.
+ *
+ * This software is provided 'as-is', without any express or implied
+ * warranty.  In no event will the authors be held liable for any damages
+ * arising from the use of this software.
+ *
+ * Permission is granted to anyone to use this software for any purpose,
+ * including commercial applications, and to alter it and redistribute it
+ * freely, subject to the following restrictions:
+ *
+ * 1. The origin of this software must not be misrepresented; you must not
+ *    claim that you wrote the original software. If you use this software
+ *    in a product, an acknowledgment in the product documentation would be
+ *    appreciated but is not required.
+ * 2. Altered source versions must be plainly marked as such, and must not be
+ *    misrepresented as being the original software.
+ * 3. This notice may not be removed or altered from any source distribution.
+ */
 
 #include "connection.h"
 #include "statement.h"
@@ -41,6 +41,7 @@
 
 PyObject* converters;
 int _enable_callback_tracebacks;
+int pysqlite_BaseTypeAdapted;
 
 static PyObject* module_connect(PyObject* self, PyObject* args, PyObject*
         kwargs)
@@ -50,7 +51,7 @@
      * connection.c and must always be copied from there ... */
 
     static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL};
-    char* database;
+    PyObject* database;
     int detect_types = 0;
     PyObject* isolation_level;
     PyObject* factory = NULL;
@@ -60,7 +61,7 @@
 
     PyObject* result;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist,
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist,
                                      &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements))
     {
         return NULL; 
@@ -133,6 +134,13 @@
         return NULL;
     }
 
+    /* a basic type is adapted; there's a performance optimization if that's not the case
+     * (99 % of all usages) */
+    if (type == &PyInt_Type || type == &PyLong_Type || type == &PyFloat_Type
+            || type == &PyString_Type || type == &PyUnicode_Type || type == &PyBuffer_Type) {
+        pysqlite_BaseTypeAdapted = 1;
+    }
+
     microprotocols_add(type, (PyObject*)&pysqlite_PrepareProtocolType, caster);
 
     Py_INCREF(Py_None);
@@ -379,6 +387,8 @@
 
     _enable_callback_tracebacks = 0;
 
+    pysqlite_BaseTypeAdapted = 0;
+
     /* Original comment form _bsddb.c in the Python core. This is also still
      * needed nowadays for Python 2.3/2.4.
      * 

Modified: python/trunk/Modules/_sqlite/module.h
==============================================================================
--- python/trunk/Modules/_sqlite/module.h	(original)
+++ python/trunk/Modules/_sqlite/module.h	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* module.h - definitions for the module
  *
- * Copyright (C) 2004-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2004-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  *
@@ -25,7 +25,7 @@
 #define PYSQLITE_MODULE_H
 #include "Python.h"
 
-#define PYSQLITE_VERSION "2.3.3"
+#define PYSQLITE_VERSION "2.4.1"
 
 extern PyObject* pysqlite_Error;
 extern PyObject* pysqlite_Warning;
@@ -51,6 +51,7 @@
 extern PyObject* converters;
 
 extern int _enable_callback_tracebacks;
+extern int pysqlite_BaseTypeAdapted;
 
 #define PARSE_DECLTYPES 1
 #define PARSE_COLNAMES 2

Modified: python/trunk/Modules/_sqlite/statement.c
==============================================================================
--- python/trunk/Modules/_sqlite/statement.c	(original)
+++ python/trunk/Modules/_sqlite/statement.c	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* statement.c - the statement type
  *
- * Copyright (C) 2005-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2005-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  *
@@ -40,6 +40,16 @@
     NORMAL
 } parse_remaining_sql_state;
 
+typedef enum {
+    TYPE_INT,
+    TYPE_LONG,
+    TYPE_FLOAT,
+    TYPE_STRING,
+    TYPE_UNICODE,
+    TYPE_BUFFER,
+    TYPE_UNKNOWN
+} parameter_type;
+
 int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql)
 {
     const char* tail;
@@ -97,42 +107,96 @@
     char* string;
     Py_ssize_t buflen;
     PyObject* stringval;
+    parameter_type paramtype;
 
     if (parameter == Py_None) {
         rc = sqlite3_bind_null(self->st, pos);
+        goto final;
+    }
+
+    if (PyInt_CheckExact(parameter)) {
+        paramtype = TYPE_INT;
+    } else if (PyLong_CheckExact(parameter)) {
+        paramtype = TYPE_LONG;
+    } else if (PyFloat_CheckExact(parameter)) {
+        paramtype = TYPE_FLOAT;
+    } else if (PyString_CheckExact(parameter)) {
+        paramtype = TYPE_STRING;
+    } else if (PyUnicode_CheckExact(parameter)) {
+        paramtype = TYPE_UNICODE;
+    } else if (PyBuffer_Check(parameter)) {
+        paramtype = TYPE_BUFFER;
     } else if (PyInt_Check(parameter)) {
-        longval = PyInt_AsLong(parameter);
-        rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval);
-#ifdef HAVE_LONG_LONG
+        paramtype = TYPE_INT;
     } else if (PyLong_Check(parameter)) {
-        longlongval = PyLong_AsLongLong(parameter);
-        /* in the overflow error case, longlongval is -1, and an exception is set */
-        rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval);
-#endif
+        paramtype = TYPE_LONG;
     } else if (PyFloat_Check(parameter)) {
-        rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter));
-    } else if (PyBuffer_Check(parameter)) {
-        if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) {
-            rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT);
-        } else {
-            PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer");
-            rc = -1;
-        }
-    } else if PyString_Check(parameter) {
-        string = PyString_AsString(parameter);
-        rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT);
-    } else if PyUnicode_Check(parameter) {
-        stringval = PyUnicode_AsUTF8String(parameter);
-        string = PyString_AsString(stringval);
-        rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT);
-        Py_DECREF(stringval);
+        paramtype = TYPE_FLOAT;
+    } else if (PyString_Check(parameter)) {
+        paramtype = TYPE_STRING;
+    } else if (PyUnicode_Check(parameter)) {
+        paramtype = TYPE_UNICODE;
     } else {
-        rc = -1;
+        paramtype = TYPE_UNKNOWN;
     }
 
+    switch (paramtype) {
+        case TYPE_INT:
+            longval = PyInt_AsLong(parameter);
+            rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval);
+            break;
+#ifdef HAVE_LONG_LONG
+        case TYPE_LONG:
+            longlongval = PyLong_AsLongLong(parameter);
+            /* in the overflow error case, longlongval is -1, and an exception is set */
+            rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval);
+            break;
+#endif
+        case TYPE_FLOAT:
+            rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter));
+            break;
+        case TYPE_STRING:
+            string = PyString_AS_STRING(parameter);
+            rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT);
+            break;
+        case TYPE_UNICODE:
+            stringval = PyUnicode_AsUTF8String(parameter);
+            string = PyString_AsString(stringval);
+            rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT);
+            Py_DECREF(stringval);
+            break;
+        case TYPE_BUFFER:
+            if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) {
+                rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT);
+            } else {
+                PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer");
+                rc = -1;
+            }
+            break;
+        case TYPE_UNKNOWN:
+            rc = -1;
+    }
+
+final:
     return rc;
 }
 
+/* returns 0 if the object is one of Python's internal ones that don't need to be adapted */
+static int _need_adapt(PyObject* obj)
+{
+    if (pysqlite_BaseTypeAdapted) {
+        return 1;
+    }
+
+    if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj) 
+            || PyFloat_CheckExact(obj) || PyString_CheckExact(obj)
+            || PyUnicode_CheckExact(obj) || PyBuffer_Check(obj)) {
+        return 0;
+    } else {
+        return 1;
+    }
+}
+
 void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters)
 {
     PyObject* current_param;
@@ -147,7 +211,55 @@
     num_params_needed = sqlite3_bind_parameter_count(self->st);
     Py_END_ALLOW_THREADS
 
-    if (PyDict_Check(parameters)) {
+    if (PyTuple_CheckExact(parameters) || PyList_CheckExact(parameters) || (!PyDict_Check(parameters) && PySequence_Check(parameters))) {
+        /* parameters passed as sequence */
+        if (PyTuple_CheckExact(parameters)) {
+            num_params = PyTuple_GET_SIZE(parameters);
+        } else if (PyList_CheckExact(parameters)) {
+            num_params = PyList_GET_SIZE(parameters);
+        } else {
+            num_params = PySequence_Size(parameters);
+        }
+        if (num_params != num_params_needed) {
+            PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.",
+                         num_params_needed, num_params);
+            return;
+        }
+        for (i = 0; i < num_params; i++) {
+            if (PyTuple_CheckExact(parameters)) {
+                current_param = PyTuple_GET_ITEM(parameters, i);
+                Py_XINCREF(current_param);
+            } else if (PyList_CheckExact(parameters)) {
+                current_param = PyList_GET_ITEM(parameters, i);
+                Py_XINCREF(current_param);
+            } else {
+                current_param = PySequence_GetItem(parameters, i);
+            }
+            if (!current_param) {
+                return;
+            }
+
+            if (!_need_adapt(current_param)) {
+                adapted = current_param;
+            } else {
+                adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL);
+                if (adapted) {
+                    Py_DECREF(current_param);
+                } else {
+                    PyErr_Clear();
+                    adapted = current_param;
+                }
+            }
+
+            rc = pysqlite_statement_bind_parameter(self, i + 1, adapted);
+            Py_DECREF(adapted);
+
+            if (rc != SQLITE_OK) {
+                PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i);
+                return;
+            }
+        }
+    } else if (PyDict_Check(parameters)) {
         /* parameters passed as dictionary */
         for (i = 1; i <= num_params_needed; i++) {
             Py_BEGIN_ALLOW_THREADS
@@ -159,19 +271,27 @@
             }
 
             binding_name++; /* skip first char (the colon) */
-            current_param = PyDict_GetItemString(parameters, binding_name);
+            if (PyDict_CheckExact(parameters)) {
+                current_param = PyDict_GetItemString(parameters, binding_name);
+                Py_XINCREF(current_param);
+            } else {
+                current_param = PyMapping_GetItemString(parameters, (char*)binding_name);
+            }
             if (!current_param) {
                 PyErr_Format(pysqlite_ProgrammingError, "You did not supply a value for binding %d.", i);
                 return;
             }
 
-            Py_INCREF(current_param);
-            adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL);
-            if (adapted) {
-                Py_DECREF(current_param);
-            } else {
-                PyErr_Clear();
+            if (!_need_adapt(current_param)) {
                 adapted = current_param;
+            } else {
+                adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL);
+                if (adapted) {
+                    Py_DECREF(current_param);
+                } else {
+                    PyErr_Clear();
+                    adapted = current_param;
+                }
             }
 
             rc = pysqlite_statement_bind_parameter(self, i, adapted);
@@ -183,35 +303,7 @@
            }
         }
     } else {
-        /* parameters passed as sequence */
-        num_params = PySequence_Length(parameters);
-        if (num_params != num_params_needed) {
-            PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.",
-                         num_params_needed, num_params);
-            return;
-        }
-        for (i = 0; i < num_params; i++) {
-            current_param = PySequence_GetItem(parameters, i);
-            if (!current_param) {
-                return;
-            }
-            adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL);
-
-            if (adapted) {
-                Py_DECREF(current_param);
-            } else {
-                PyErr_Clear();
-                adapted = current_param;
-            }
-
-            rc = pysqlite_statement_bind_parameter(self, i + 1, adapted);
-            Py_DECREF(adapted);
-
-            if (rc != SQLITE_OK) {
-                PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i);
-                return;
-            }
-        }
+        PyErr_SetString(PyExc_ValueError, "parameters are of unsupported type");
     }
 }
 

Modified: python/trunk/Modules/_sqlite/util.c
==============================================================================
--- python/trunk/Modules/_sqlite/util.c	(original)
+++ python/trunk/Modules/_sqlite/util.c	Fri Feb 29 23:08:41 2008
@@ -1,6 +1,6 @@
 /* util.c - various utility functions
  *
- * Copyright (C) 2005-2006 Gerhard Häring <gh at ghaering.de>
+ * Copyright (C) 2005-2007 Gerhard Häring <gh at ghaering.de>
  *
  * This file is part of pysqlite.
  *
@@ -45,10 +45,15 @@
  * Checks the SQLite error code and sets the appropriate DB-API exception.
  * Returns the error code (0 means no error occurred).
  */
-int _pysqlite_seterror(sqlite3* db)
+int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st)
 {
     int errorcode;
 
+    /* SQLite often doesn't report anything useful, unless you reset the statement first */
+    if (st != NULL) {
+        (void)sqlite3_reset(st);
+    }
+
     errorcode = sqlite3_errcode(db);
 
     switch (errorcode)

Modified: python/trunk/Modules/_sqlite/util.h
==============================================================================
--- python/trunk/Modules/_sqlite/util.h	(original)
+++ python/trunk/Modules/_sqlite/util.h	Fri Feb 29 23:08:41 2008
@@ -34,5 +34,5 @@
  * Checks the SQLite error code and sets the appropriate DB-API exception.
  * Returns the error code (0 means no error occurred).
  */
-int _pysqlite_seterror(sqlite3* db);
+int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st);
 #endif


More information about the Python-checkins mailing list