Author: Brian Kearns <[email protected]>
Branch: py3k
Changeset: r62161:254c727895bb
Date: 2013-03-07 03:22 -0500
http://bitbucket.org/pypy/pypy/changeset/254c727895bb/
Log: merge default
diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -29,6 +29,7 @@
from collections import OrderedDict
from functools import wraps
import datetime
+import string
import sys
import weakref
from threading import _get_ident as _thread_get_ident
@@ -226,7 +227,7 @@
sqlite.sqlite3_total_changes.argtypes = [c_void_p]
sqlite.sqlite3_total_changes.restype = c_int
-sqlite.sqlite3_result_blob.argtypes = [c_void_p, c_char_p, c_int, c_void_p]
+sqlite.sqlite3_result_blob.argtypes = [c_void_p, c_void_p, c_int, c_void_p]
sqlite.sqlite3_result_blob.restype = None
sqlite.sqlite3_result_int64.argtypes = [c_void_p, c_int64]
sqlite.sqlite3_result_int64.restype = None
@@ -319,6 +320,7 @@
self.__initialized = True
self._db = c_void_p()
+ database = database.encode('utf-8')
if sqlite.sqlite3_open(database, byref(self._db)) != SQLITE_OK:
raise OperationalError("Could not open database")
if timeout is not None:
@@ -408,8 +410,7 @@
def _get_exception(self, error_code=None):
if error_code is None:
error_code = sqlite.sqlite3_errcode(self._db)
- error_message = sqlite.sqlite3_errmsg(self._db)
- error_message = error_message.decode('utf-8')
+ error_message = sqlite.sqlite3_errmsg(self._db).decode('utf-8')
if error_code == SQLITE_OK:
raise ValueError("error signalled but got SQLITE_OK")
@@ -503,7 +504,7 @@
statement = c_void_p()
next_char = c_char_p()
- ret = sqlite.sqlite3_prepare_v2(self._db, "COMMIT", -1,
+ ret = sqlite.sqlite3_prepare_v2(self._db, b"COMMIT", -1,
byref(statement), next_char)
try:
if ret != SQLITE_OK:
@@ -533,7 +534,7 @@
statement = c_void_p()
next_char = c_char_p()
- ret = sqlite.sqlite3_prepare_v2(self._db, "ROLLBACK", -1,
+ ret = sqlite.sqlite3_prepare_v2(self._db, b"ROLLBACK", -1,
byref(statement), next_char)
try:
if ret != SQLITE_OK:
@@ -564,6 +565,8 @@
function_callback(callback, context, nargs, c_params)
c_closure = _FUNC(closure)
self.__func_cache[callback] = c_closure, closure
+
+ name = name.encode('utf-8')
ret = sqlite.sqlite3_create_function(self._db, name, num_args,
SQLITE_UTF8, None,
c_closure,
@@ -579,7 +582,6 @@
c_step_callback, c_final_callback, _, _ = self.__aggregates[cls]
except KeyError:
def step_callback(context, argc, c_params):
-
aggregate_ptr = cast(
sqlite.sqlite3_aggregate_context(
context, sizeof(c_ssize_t)),
@@ -589,8 +591,8 @@
try:
aggregate = cls()
except Exception:
- msg = ("user-defined aggregate's '__init__' "
- "method raised error")
+ msg = (b"user-defined aggregate's '__init__' "
+ b"method raised error")
sqlite.sqlite3_result_error(context, msg, len(msg))
return
aggregate_id = id(aggregate)
@@ -603,12 +605,11 @@
try:
aggregate.step(*params)
except Exception:
- msg = ("user-defined aggregate's 'step' "
- "method raised error")
+ msg = (b"user-defined aggregate's 'step' "
+ b"method raised error")
sqlite.sqlite3_result_error(context, msg, len(msg))
def final_callback(context):
-
aggregate_ptr = cast(
sqlite.sqlite3_aggregate_context(
context, sizeof(c_ssize_t)),
@@ -619,8 +620,8 @@
try:
val = aggregate.finalize()
except Exception:
- msg = ("user-defined aggregate's 'finalize' "
- "method raised error")
+ msg = (b"user-defined aggregate's 'finalize' "
+ b"method raised error")
sqlite.sqlite3_result_error(context, msg, len(msg))
else:
_convert_result(context, val)
@@ -633,6 +634,7 @@
self.__aggregates[cls] = (c_step_callback, c_final_callback,
step_callback, final_callback)
+ name = name.encode('utf-8')
ret = sqlite.sqlite3_create_function(self._db, name, num_args,
SQLITE_UTF8, None,
cast(None, _FUNC),
@@ -645,7 +647,7 @@
@_check_closed_wrap
def create_collation(self, name, callback):
name = name.upper()
- if not name.replace('_', '').isalnum():
+ if not all(c in string.ascii_uppercase + string.digits + '_' for c in
name):
raise ProgrammingError("invalid character in collation name")
if callback is None:
@@ -656,14 +658,15 @@
raise TypeError("parameter must be callable")
def collation_callback(context, len1, str1, len2, str2):
- text1 = string_at(str1, len1)
- text2 = string_at(str2, len2)
+ text1 = string_at(str1, len1).decode('utf-8')
+ text2 = string_at(str2, len2).decode('utf-8')
return callback(text1, text2)
c_collation_callback = _COLLATION(collation_callback)
self.__collations[name] = c_collation_callback
+ name = name.encode('utf-8')
ret = sqlite.sqlite3_create_collation(self._db, name,
SQLITE_UTF8,
None,
@@ -733,7 +736,7 @@
if val is None:
self.commit()
else:
- self.__begin_statement = 'BEGIN ' + val
+ self.__begin_statement = b"BEGIN " + val.encode('utf-8')
self._isolation_level = val
isolation_level = property(__get_isolation_level, __set_isolation_level)
@@ -748,7 +751,6 @@
class Cursor(object):
__initialized = False
- __connection = None
__statement = None
def __init__(self, con):
@@ -770,11 +772,10 @@
self.__rowcount = -1
def __del__(self):
- if self.__connection:
- try:
- self.__connection._cursors.remove(weakref.ref(self))
- except ValueError:
- pass
+ try:
+ self.__connection._cursors.remove(weakref.ref(self))
+ except (AttributeError, ValueError):
+ pass
if self.__statement:
self.__statement._reset()
@@ -873,8 +874,8 @@
self.__connection._in_transaction = \
not
sqlite.sqlite3_get_autocommit(self.__connection._db)
raise self.__connection._get_exception(ret)
+ self.__statement._reset()
self.__rowcount +=
sqlite.sqlite3_changes(self.__connection._db)
- self.__statement._reset()
finally:
self.__locked = False
@@ -883,10 +884,9 @@
def executescript(self, sql):
self.__description = None
self._reset = False
- if type(sql) is str:
- sql = sql.encode("utf-8")
self.__check_cursor()
statement = c_void_p()
+ sql = sql.encode('utf-8')
c_sql = c_char_p(sql)
self.__connection.commit()
@@ -1008,11 +1008,12 @@
self._statement = c_void_p()
next_char = c_char_p()
- sql_char = sql
- ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1,
byref(self._statement), byref(next_char))
+ sql = sql.encode('utf-8')
+
+ ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql, -1,
byref(self._statement), byref(next_char))
if ret == SQLITE_OK and self._statement.value is None:
# an empty statement, we work around that, as it's the least
trouble
- ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1,
byref(self._statement), byref(next_char))
+ ret = sqlite.sqlite3_prepare_v2(self.__con._db, b"select 42", -1,
byref(self._statement), byref(next_char))
self._kind = Statement._DQL
if ret != SQLITE_OK:
@@ -1021,22 +1022,23 @@
next_char = next_char.value.decode('utf-8')
if _check_remaining_sql(next_char):
raise Warning("One and only one statement required: %r" %
- (next_char,))
+ next_char)
def __del__(self):
if self._statement:
sqlite.sqlite3_finalize(self._statement)
def _finalize(self):
- sqlite.sqlite3_finalize(self._statement)
- self._statement = None
+ if self._statement:
+ sqlite.sqlite3_finalize(self._statement)
+ self._statement = None
self._in_use = False
def _reset(self):
- ret = sqlite.sqlite3_reset(self._statement)
- self._in_use = False
+ if self._in_use and self._statement:
+ ret = sqlite.sqlite3_reset(self._statement)
+ self._in_use = False
self._exhausted = False
- return ret
def _build_row_cast_map(self):
self.__row_cast_map = []
@@ -1059,8 +1061,8 @@
if converter is None and self.__con._detect_types &
PARSE_DECLTYPES:
decltype = sqlite.sqlite3_column_decltype(self._statement, i)
if decltype is not None:
+ decltype = decltype.decode('utf-8')
decltype = decltype.split()[0] # if multiple words,
use first, eg. "INTEGER NOT NULL" => "INTEGER"
- decltype = decltype.decode('utf-8')
if '(' in decltype:
decltype = decltype[:decltype.index('(')]
converter = converters.get(decltype.upper(), None)
@@ -1070,37 +1072,36 @@
def __set_param(self, idx, param):
cvt = converters.get(type(param))
if cvt is not None:
- cvt = param = cvt(param)
+ param = cvt(param)
param = adapt(param)
if param is None:
- sqlite.sqlite3_bind_null(self._statement, idx)
+ rc = sqlite.sqlite3_bind_null(self._statement, idx)
elif type(param) in (bool, int):
if -2147483648 <= param <= 2147483647:
- sqlite.sqlite3_bind_int(self._statement, idx, param)
+ rc = sqlite.sqlite3_bind_int(self._statement, idx, param)
else:
- sqlite.sqlite3_bind_int64(self._statement, idx, param)
+ rc = sqlite.sqlite3_bind_int64(self._statement, idx, param)
elif type(param) is float:
- sqlite.sqlite3_bind_double(self._statement, idx, param)
+ rc = sqlite.sqlite3_bind_double(self._statement, idx, param)
elif isinstance(param, str):
- param = param.encode('utf-8')
- sqlite.sqlite3_bind_text(self._statement, idx, param, len(param),
SQLITE_TRANSIENT)
+ param = param.encode("utf-8")
+ rc = sqlite.sqlite3_bind_text(self._statement, idx, param,
len(param), SQLITE_TRANSIENT)
elif type(param) in (bytes, memoryview):
param = bytes(param)
- sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param),
SQLITE_TRANSIENT)
+ rc = sqlite.sqlite3_bind_blob(self._statement, idx, param,
len(param), SQLITE_TRANSIENT)
else:
- raise InterfaceError("parameter type %s is not supported" %
- type(param))
+ rc = -1
+ return rc
def _set_params(self, params):
- ret = sqlite.sqlite3_reset(self._statement)
- if ret != SQLITE_OK:
- raise self.__con._get_exception(ret)
self._in_use = True
num_params_needed =
sqlite.sqlite3_bind_parameter_count(self._statement)
- if not isinstance(params, dict):
+ if isinstance(params, (tuple, list)) or \
+ not isinstance(params, dict) and \
+ hasattr(params, '__len__') and hasattr(params, '__getitem__'):
num_params = len(params)
if num_params != num_params_needed:
raise ProgrammingError("Incorrect number of bindings supplied.
"
@@ -1108,25 +1109,32 @@
"there are %d supplied." %
(num_params_needed, num_params))
for i in range(num_params):
- self.__set_param(i + 1, params[i])
- else:
+ rc = self.__set_param(i + 1, params[i])
+ if rc != SQLITE_OK:
+ raise InterfaceError("Error binding parameter %d - "
+ "probably unsupported type." % i)
+ elif isinstance(params, dict):
for i in range(1, num_params_needed + 1):
param_name =
sqlite.sqlite3_bind_parameter_name(self._statement, i)
if param_name is None:
raise ProgrammingError("Binding %d has no name, but you "
"supplied a dictionary (which has "
"only names)." % i)
- param_name = param_name[1:].decode('utf-8')
+ param_name = param_name.decode('utf-8')[1:]
try:
param = params[param_name]
except KeyError:
raise ProgrammingError("You did not supply a value for "
"binding %d." % i)
- self.__set_param(i, param)
+ rc = self.__set_param(i, param)
+ if rc != SQLITE_OK:
+ raise InterfaceError("Error binding parameter :%s - "
+ "probably unsupported type." %
+ param_name)
+ else:
+ raise ValueError("parameters are of unsupported type")
def _next(self, cursor):
- self.__con._check_closed()
- self.__con._check_thread()
if self._exhausted:
raise StopIteration
item = self._item
@@ -1158,14 +1166,14 @@
elif typ == SQLITE_FLOAT:
val = sqlite.sqlite3_column_double(self._statement, i)
elif typ == SQLITE_BLOB:
+ blob = sqlite.sqlite3_column_blob(self._statement, i)
blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
- blob = sqlite.sqlite3_column_blob(self._statement, i)
val = bytes(string_at(blob, blob_len))
elif typ == SQLITE_NULL:
val = None
elif typ == SQLITE_TEXT:
+ text = sqlite.sqlite3_column_text(self._statement, i)
text_len = sqlite.sqlite3_column_bytes(self._statement, i)
- text = sqlite.sqlite3_column_text(self._statement, i)
val = string_at(text, text_len)
val = self.__con.text_factory(val)
else:
@@ -1174,7 +1182,7 @@
val = None
else:
blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
- val = string_at(blob, blob_len)
+ val = bytes(string_at(blob, blob_len))
val = converter(val)
row.append(val)
@@ -1188,8 +1196,8 @@
return None
desc = []
for i in range(sqlite.sqlite3_column_count(self._statement)):
- col_name = sqlite.sqlite3_column_name(self._statement, i)
- name = col_name.decode('utf-8').split("[")[0].strip()
+ name = sqlite.sqlite3_column_name(self._statement, i)
+ name = name.decode('utf-8').split("[")[0].strip()
desc.append((name, None, None, None, None, None, None))
return desc
@@ -1282,15 +1290,14 @@
elif typ == SQLITE_FLOAT:
val = sqlite.sqlite3_value_double(params[i])
elif typ == SQLITE_BLOB:
+ blob = sqlite.sqlite3_value_blob(params[i])
blob_len = sqlite.sqlite3_value_bytes(params[i])
- blob = sqlite.sqlite3_value_blob(params[i])
val = bytes(string_at(blob, blob_len))
elif typ == SQLITE_NULL:
val = None
elif typ == SQLITE_TEXT:
val = sqlite.sqlite3_value_text(params[i])
- # XXX changed from con.text_factory
- val = str(val, 'utf-8')
+ val = val.decode('utf-8')
else:
raise NotImplementedError
_params.append(val)
@@ -1303,13 +1310,12 @@
elif isinstance(val, (bool, int)):
sqlite.sqlite3_result_int64(con, int(val))
elif isinstance(val, str):
- sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
- elif isinstance(val, bytes):
+ val = val.encode('utf-8')
sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
elif isinstance(val, float):
sqlite.sqlite3_result_double(con, val)
- elif isinstance(val, buffer):
- sqlite.sqlite3_result_blob(con, str(val), len(val), SQLITE_TRANSIENT)
+ elif isinstance(val, (bytes, memoryview)):
+ sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT)
else:
raise NotImplementedError
@@ -1319,7 +1325,7 @@
try:
val = real_cb(*params)
except Exception:
- msg = "user-defined function raised exception"
+ msg = b"user-defined function raised exception"
sqlite.sqlite3_result_error(context, msg, len(msg))
else:
_convert_result(context, val)
diff --git a/pypy/module/test_lib_pypy/test_sqlite3.py
b/pypy/module/test_lib_pypy/test_sqlite3.py
--- a/pypy/module/test_lib_pypy/test_sqlite3.py
+++ b/pypy/module/test_lib_pypy/test_sqlite3.py
@@ -126,3 +126,20 @@
con.commit()
except _sqlite3.OperationalError:
pytest.fail("_sqlite3 knew nothing about the implicit ROLLBACK")
+
+def test_statement_param_checking():
+ con = _sqlite3.connect(':memory:')
+ con.execute('create table foo(x)')
+ con.execute('insert into foo(x) values (?)', [2])
+ con.execute('insert into foo(x) values (?)', (2,))
+ class seq(object):
+ def __len__(self):
+ return 1
+ def __getitem__(self, key):
+ return 2
+ con.execute('insert into foo(x) values (?)', seq())
+ with pytest.raises(_sqlite3.ProgrammingError):
+ con.execute('insert into foo(x) values (?)', {2:2})
+ with pytest.raises(ValueError) as e:
+ con.execute('insert into foo(x) values (?)', 2)
+ assert str(e.value) == 'parameters are of unsupported type'
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit