[pypy-commit] pypy default: unify sqlite3 between default and py3k

bdkearns noreply at buildbot.pypy.org
Thu Mar 7 23:43:48 CET 2013


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r62191:71538a3eb8f3
Date: 2013-03-07 17:38 -0500
http://bitbucket.org/pypy/pypy/changeset/71538a3eb8f3/

Log:	unify sqlite3 between default and py3k

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -34,6 +34,16 @@
 import weakref
 from threading import _get_ident as _thread_get_ident
 
+if sys.version_info[0] >= 3:
+    StandardError = Exception
+    long = int
+    xrange = range
+    basestring = unicode = str
+    buffer = memoryview
+    BLOB_TYPE = bytes
+else:
+    BLOB_TYPE = buffer
+
 names = "sqlite3.dll libsqlite3.so.0 libsqlite3.so libsqlite3.dylib".split()
 for name in names:
     try:
@@ -243,7 +253,7 @@
 ##########################################
 
 # SQLite version information
-sqlite_version = sqlite.sqlite3_libversion()
+sqlite_version = str(sqlite.sqlite3_libversion().decode('ascii'))
 
 class Error(StandardError):
     pass
@@ -282,6 +292,16 @@
 def unicode_text_factory(x):
     return unicode(x, 'utf-8')
 
+if sys.version_info[0] < 3:
+    def OptimizedUnicode(s):
+        try:
+            val = unicode(s, "ascii").encode("ascii")
+        except UnicodeDecodeError:
+            val = unicode(s, "utf-8")
+        return val
+else:
+    OptimizedUnicode = unicode_text_factory
+
 
 class _StatementCache(object):
     def __init__(self, connection, maxcount):
@@ -440,7 +460,7 @@
     @_check_thread_wrap
     @_check_closed_wrap
     def __call__(self, sql):
-        if not isinstance(sql, (str, unicode)):
+        if not isinstance(sql, basestring):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
         return self._statement_cache.get(sql, self.row_factory)
 
@@ -640,7 +660,7 @@
     @_check_closed_wrap
     def create_collation(self, name, callback):
         name = name.upper()
-        if not all(c in string.uppercase + string.digits + '_' for c in name):
+        if not all(c in string.ascii_uppercase + string.digits + '_' for c in name):
             raise ProgrammingError("invalid character in collation name")
 
         if callback is None:
@@ -714,6 +734,11 @@
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
+    if sys.version_info[0] >= 3:
+        def __get_in_transaction(self):
+            return self._in_transaction
+        in_transaction = property(__get_in_transaction)
+
     def __get_total_changes(self):
         self._check_closed()
         return sqlite.sqlite3_total_changes(self._db)
@@ -726,7 +751,7 @@
         if val is None:
             self.commit()
         else:
-            self.__begin_statement = b"BEGIN " + val.encode('ascii')
+            self.__begin_statement = str("BEGIN " + val).encode('utf-8')
         self._isolation_level = val
     isolation_level = property(__get_isolation_level, __set_isolation_level)
 
@@ -800,7 +825,7 @@
         try:
             self.__description = None
             self._reset = False
-            if not isinstance(sql, (str, unicode)):
+            if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
@@ -847,7 +872,7 @@
         try:
             self.__description = None
             self._reset = False
-            if not isinstance(sql, (str, unicode)):
+            if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
@@ -982,7 +1007,7 @@
     def __init__(self, connection, sql):
         self.__con = connection
 
-        if not isinstance(sql, (str, unicode)):
+        if not isinstance(sql, basestring):
             raise ValueError("sql must be a string")
         first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper()
         if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"):
@@ -1060,16 +1085,18 @@
 
             self.__row_cast_map.append(converter)
 
-    def __check_decodable(self, param):
-        if self.__con.text_factory in (unicode, OptimizedUnicode, unicode_text_factory):
-            for c in param:
-                if ord(c) & 0x80 != 0:
-                    raise self.__con.ProgrammingError(
-                        "You must not use 8-bit bytestrings unless "
-                        "you use a text_factory that can interpret "
-                        "8-bit bytestrings (like text_factory = str). "
-                        "It is highly recommended that you instead "
-                        "just switch your application to Unicode strings.")
+    if sys.version_info[0] < 3:
+        def __check_decodable(self, param):
+            if self.__con.text_factory in (unicode, OptimizedUnicode,
+                                           unicode_text_factory):
+                for c in param:
+                    if ord(c) & 0x80 != 0:
+                        raise self.__con.ProgrammingError(
+                            "You must not use 8-bit bytestrings unless "
+                            "you use a text_factory that can interpret "
+                            "8-bit bytestrings (like text_factory = str). "
+                            "It is highly recommended that you instead "
+                            "just switch your application to Unicode strings.")
 
     def __set_param(self, idx, param):
         cvt = converters.get(type(param))
@@ -1080,20 +1107,20 @@
 
         if param is None:
             rc = sqlite.sqlite3_bind_null(self._statement, idx)
-        elif type(param) in (bool, int, long):
+        elif isinstance(param, (bool, int, long)):
             if -2147483648 <= param <= 2147483647:
                 rc = sqlite.sqlite3_bind_int(self._statement, idx, param)
             else:
                 rc = sqlite.sqlite3_bind_int64(self._statement, idx, param)
-        elif type(param) is float:
+        elif isinstance(param, float):
             rc = sqlite.sqlite3_bind_double(self._statement, idx, param)
+        elif isinstance(param, unicode):
+            param = param.encode("utf-8")
+            rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
         elif isinstance(param, str):
             self.__check_decodable(param)
             rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
-        elif isinstance(param, unicode):
-            param = param.encode("utf-8")
-            rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
-        elif type(param) is buffer:
+        elif isinstance(param, (buffer, bytes)):
             param = bytes(param)
             rc = sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
         else:
@@ -1167,23 +1194,21 @@
 
             converter = self.__row_cast_map[i]
             if converter is None:
-                if typ == SQLITE_INTEGER:
+                if typ == SQLITE_NULL:
+                    val = None
+                elif typ == SQLITE_INTEGER:
                     val = sqlite.sqlite3_column_int64(self._statement, i)
-                    if -sys.maxint-1 <= val <= sys.maxint:
-                        val = int(val)
                 elif typ == SQLITE_FLOAT:
                     val = sqlite.sqlite3_column_double(self._statement, i)
-                elif typ == SQLITE_BLOB:
-                    blob = sqlite.sqlite3_column_blob(self._statement, i)
-                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
-                    val = buffer(string_at(blob, blob_len))
-                elif typ == SQLITE_NULL:
-                    val = None
                 elif typ == SQLITE_TEXT:
                     text = sqlite.sqlite3_column_text(self._statement, i)
                     text_len = sqlite.sqlite3_column_bytes(self._statement, i)
                     val = string_at(text, text_len)
                     val = self.__con.text_factory(val)
+                elif typ == SQLITE_BLOB:
+                    blob = sqlite.sqlite3_column_blob(self._statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    val = BLOB_TYPE(string_at(blob, blob_len))
             else:
                 blob = sqlite.sqlite3_column_blob(self._statement, i)
                 if not blob:
@@ -1292,21 +1317,19 @@
     _params = []
     for i in range(nargs):
         typ = sqlite.sqlite3_value_type(params[i])
-        if typ == SQLITE_INTEGER:
+        if typ == SQLITE_NULL:
+            val = None
+        elif typ == SQLITE_INTEGER:
             val = sqlite.sqlite3_value_int64(params[i])
-            if -sys.maxint-1 <= val <= sys.maxint:
-                val = int(val)
         elif typ == SQLITE_FLOAT:
             val = sqlite.sqlite3_value_double(params[i])
+        elif typ == SQLITE_TEXT:
+            val = sqlite.sqlite3_value_text(params[i])
+            val = val.decode('utf-8')
         elif typ == SQLITE_BLOB:
             blob = sqlite.sqlite3_value_blob(params[i])
             blob_len = sqlite.sqlite3_value_bytes(params[i])
-            val = buffer(string_at(blob, blob_len))
-        elif typ == SQLITE_NULL:
-            val = None
-        elif typ == SQLITE_TEXT:
-            val = sqlite.sqlite3_value_text(params[i])
-            val = val.decode('utf-8')
+            val = BLOB_TYPE(string_at(blob, blob_len))
         else:
             raise NotImplementedError
         _params.append(val)
@@ -1318,14 +1341,14 @@
         sqlite.sqlite3_result_null(con)
     elif isinstance(val, (bool, int, long)):
         sqlite.sqlite3_result_int64(con, int(val))
-    elif isinstance(val, str):
-        sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
+    elif isinstance(val, float):
+        sqlite.sqlite3_result_double(con, val)
     elif isinstance(val, unicode):
         val = val.encode('utf-8')
         sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
-    elif isinstance(val, float):
-        sqlite.sqlite3_result_double(con, val)
-    elif isinstance(val, buffer):
+    elif isinstance(val, str):
+        sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
+    elif isinstance(val, (buffer, bytes)):
         sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT)
     else:
         raise NotImplementedError
@@ -1397,8 +1420,8 @@
             microseconds = int(timepart_full[1])
         else:
             microseconds = 0
-        return datetime.datetime(year, month, day,
-                                 hours, minutes, seconds, microseconds)
+        return datetime.datetime(year, month, day, hours, minutes, seconds,
+                                 microseconds)
 
     register_adapter(datetime.date, adapt_date)
     register_adapter(datetime.datetime, adapt_datetime)
@@ -1435,11 +1458,3 @@
     return val
 
 register_adapters_and_converters()
-
-
-def OptimizedUnicode(s):
-    try:
-        val = unicode(s, "ascii").encode("ascii")
-    except UnicodeDecodeError:
-        val = unicode(s, "utf-8")
-    return val


More information about the pypy-commit mailing list