Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r62133:45bbcccf99e5 Date: 2013-03-06 15:25 -0500 http://bitbucket.org/pypy/pypy/changeset/45bbcccf99e5/
Log: cleanup _sqlite3.Statement diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py --- a/lib_pypy/_sqlite3.py +++ b/lib_pypy/_sqlite3.py @@ -237,8 +237,6 @@ sqlite.sqlite3_enable_load_extension.argtypes = [c_void_p, c_int] sqlite.sqlite3_enable_load_extension.restype = c_int -_DML, _DQL, _DDL = range(3) - ########################################## # END Wrapped SQLite C API and constants ########################################## @@ -302,9 +300,9 @@ if len(self.cache) > self.maxcount: self.cache.popitem(0) - if stat.in_use: + if stat._in_use: stat = Statement(self.connection, sql) - stat.set_row_factory(row_factory) + stat._row_factory = row_factory return stat @@ -363,7 +361,7 @@ for statement in self.__statements: obj = statement() if obj is not None: - obj.finalize() + obj._finalize() if self._db: ret = sqlite.sqlite3_close(self._db) @@ -496,7 +494,7 @@ for statement in self.__statements: obj = statement() if obj is not None: - obj.reset() + obj._reset() statement = c_void_p() next_char = c_char_p() @@ -521,7 +519,7 @@ for statement in self.__statements: obj = statement() if obj is not None: - obj.reset() + obj._reset() for cursor_ref in self._cursors: cursor = cursor_ref() @@ -771,13 +769,13 @@ except ValueError: pass if self.__statement: - self.__statement.reset() + self.__statement._reset() def close(self): self.__connection._check_thread() self.__connection._check_closed() if self.__statement: - self.__statement.reset() + self.__statement._reset() self.__statement = None self.__closed = True @@ -808,35 +806,35 @@ sql, self.row_factory) if self.__connection._isolation_level is not None: - if self.__statement.kind == _DDL: + if self.__statement._kind == Statement._DDL: if self.__connection._in_transaction: self.__connection.commit() - elif self.__statement.kind == _DML: + elif self.__statement._kind == Statement._DML: if not self.__connection._in_transaction: self.__connection._begin() - self.__statement.set_params(params) + self.__statement._set_params(params) # Actually execute the SQL statement - ret = sqlite.sqlite3_step(self.__statement.statement) + ret = sqlite.sqlite3_step(self.__statement._statement) if ret not in (SQLITE_DONE, SQLITE_ROW): - self.__statement.reset() + self.__statement._reset() self.__connection._in_transaction = \ not sqlite.sqlite3_get_autocommit(self.__connection._db) raise self.__connection._get_exception(ret) - if self.__statement.kind == _DML: - self.__statement.reset() + if self.__statement._kind == Statement._DML: + self.__statement._reset() - if self.__statement.kind == _DQL and ret == SQLITE_ROW: + if self.__statement._kind == Statement._DQL and ret == SQLITE_ROW: self.__statement._build_row_cast_map() self.__statement._readahead(self) else: - self.__statement.item = None - self.__statement.exhausted = True + self.__statement._item = None + self.__statement._exhausted = True self.__rowcount = -1 - if self.__statement.kind == _DML: + if self.__statement._kind == Statement._DML: self.__rowcount = sqlite.sqlite3_changes(self.__connection._db) finally: self.__locked = False @@ -852,7 +850,7 @@ self.__statement = self.__connection._statement_cache.get( sql, self.row_factory) - if self.__statement.kind == _DML: + if self.__statement._kind == Statement._DML: if self.__connection._isolation_level is not None: if not self.__connection._in_transaction: self.__connection._begin() @@ -861,15 +859,15 @@ self.__rowcount = 0 for params in many_params: - self.__statement.set_params(params) - ret = sqlite.sqlite3_step(self.__statement.statement) + self.__statement._set_params(params) + ret = sqlite.sqlite3_step(self.__statement._statement) if ret != SQLITE_DONE: - self.__statement.reset() + self.__statement._reset() self.__connection._in_transaction = \ not sqlite.sqlite3_get_autocommit(self.__connection._db) raise self.__connection._get_exception(ret) self.__rowcount += sqlite.sqlite3_changes(self.__connection._db) - self.__statement.reset() + self.__statement._reset() finally: self.__locked = False @@ -926,7 +924,7 @@ return None try: - return self.__statement.next(self) + return self.__statement._next(self) except StopIteration: return None @@ -980,55 +978,67 @@ class Statement(object): - statement = None + _DML, _DQL, _DDL = range(3) + + _statement = None def __init__(self, connection, sql): - if not isinstance(sql, str): + self.__con = connection + + if isinstance(sql, unicode): + sql = sql.encode('utf-8') + elif not isinstance(sql, str): raise ValueError("sql must be a string") - self.con = connection - self.sql = sql # DEBUG ONLY first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper() if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"): - self.kind = _DML + self._kind = Statement._DML elif first_word in ("SELECT", "PRAGMA"): - self.kind = _DQL + self._kind = Statement._DQL else: - self.kind = _DDL - self.exhausted = False - self.in_use = False - # - # set by set_row_factory - self.row_factory = None + self._kind = Statement._DDL - self.statement = c_void_p() + self._in_use = False + self._exhausted = False + self._row_factory = None + + self._statement = c_void_p() next_char = c_char_p() sql_char = c_char_p(sql) - ret = sqlite.sqlite3_prepare_v2(self.con._db, sql_char, -1, byref(self.statement), byref(next_char)) - if ret == SQLITE_OK and self.statement.value is None: + ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1, byref(self._statement), byref(next_char)) + if ret == SQLITE_OK and self._statement.value is None: # an empty statement, we work around that, as it's the least trouble - ret = sqlite.sqlite3_prepare_v2(self.con._db, "select 42", -1, byref(self.statement), byref(next_char)) - self.kind = _DQL + ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1, byref(self._statement), byref(next_char)) + self._kind = Statement._DQL if ret != SQLITE_OK: - raise self.con._get_exception(ret) - self.con._remember_statement(self) + raise self.__con._get_exception(ret) + self.__con._remember_statement(self) if _check_remaining_sql(next_char.value): raise Warning("One and only one statement required: %r" % (next_char.value,)) - # sql_char should remain alive until here - self._build_row_cast_map() + def __del__(self): + if self._statement: + sqlite.sqlite3_finalize(self._statement) - def set_row_factory(self, row_factory): - self.row_factory = row_factory + def _finalize(self): + sqlite.sqlite3_finalize(self._statement) + self._statement = None + self._in_use = False + + def _reset(self): + ret = sqlite.sqlite3_reset(self._statement) + self._in_use = False + self._exhausted = False + return ret def _build_row_cast_map(self): - self.row_cast_map = [] - for i in xrange(sqlite.sqlite3_column_count(self.statement)): + self.__row_cast_map = [] + for i in xrange(sqlite.sqlite3_column_count(self._statement)): converter = None - if self.con._detect_types & PARSE_COLNAMES: - colname = sqlite.sqlite3_column_name(self.statement, i) + if self.__con._detect_types & PARSE_COLNAMES: + colname = sqlite.sqlite3_column_name(self._statement, i) if colname is not None: type_start = -1 key = None @@ -1039,28 +1049,28 @@ key = colname[type_start:pos] converter = converters[key.upper()] - if converter is None and self.con._detect_types & PARSE_DECLTYPES: - decltype = sqlite.sqlite3_column_decltype(self.statement, i) + if converter is None and self.__con._detect_types & PARSE_DECLTYPES: + decltype = sqlite.sqlite3_column_decltype(self._statement, i) if decltype is not None: decltype = decltype.split()[0] # if multiple words, use first, eg. "INTEGER NOT NULL" => "INTEGER" if '(' in decltype: decltype = decltype[:decltype.index('(')] converter = converters.get(decltype.upper(), None) - self.row_cast_map.append(converter) + self.__row_cast_map.append(converter) - def _check_decodable(self, param): - if self.con.text_factory in (unicode, OptimizedUnicode, unicode_text_factory): + def __check_decodable(self, param): + if self.__con.text_factory in (unicode, OptimizedUnicode, unicode_text_factory): for c in param: if ord(c) & 0x80 != 0: - raise self.con.ProgrammingError( + raise self.__con.ProgrammingError( "You must not use 8-bit bytestrings unless " "you use a text_factory that can interpret " "8-bit bytestrings (like text_factory = str). " "It is highly recommended that you instead " "just switch your application to Unicode strings.") - def set_param(self, idx, param): + def __set_param(self, idx, param): cvt = converters.get(type(param)) if cvt is not None: cvt = param = cvt(param) @@ -1068,34 +1078,34 @@ param = adapt(param) if param is None: - sqlite.sqlite3_bind_null(self.statement, idx) + sqlite.sqlite3_bind_null(self._statement, idx) elif type(param) in (bool, int, long): if -2147483648 <= param <= 2147483647: - sqlite.sqlite3_bind_int(self.statement, idx, param) + sqlite.sqlite3_bind_int(self._statement, idx, param) else: - sqlite.sqlite3_bind_int64(self.statement, idx, param) + sqlite.sqlite3_bind_int64(self._statement, idx, param) elif type(param) is float: - sqlite.sqlite3_bind_double(self.statement, idx, param) + sqlite.sqlite3_bind_double(self._statement, idx, param) elif isinstance(param, str): - self._check_decodable(param) - sqlite.sqlite3_bind_text(self.statement, idx, param, len(param), SQLITE_TRANSIENT) + self.__check_decodable(param) + sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) elif isinstance(param, unicode): param = param.encode("utf-8") - sqlite.sqlite3_bind_text(self.statement, idx, param, len(param), SQLITE_TRANSIENT) + sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) elif type(param) is buffer: - sqlite.sqlite3_bind_blob(self.statement, idx, str(param), len(param), SQLITE_TRANSIENT) + sqlite.sqlite3_bind_blob(self._statement, idx, str(param), len(param), SQLITE_TRANSIENT) else: raise InterfaceError("parameter type %s is not supported" % str(type(param))) - def set_params(self, params): - ret = sqlite.sqlite3_reset(self.statement) + def _set_params(self, params): + ret = sqlite.sqlite3_reset(self._statement) if ret != SQLITE_OK: - raise self.con._get_exception(ret) - self.mark_dirty() + raise self.__con._get_exception(ret) + self._in_use = True if params is None: - if sqlite.sqlite3_bind_parameter_count(self.statement) != 0: + if sqlite.sqlite3_bind_parameter_count(self._statement) != 0: raise ProgrammingError("wrong number of arguments") return @@ -1106,14 +1116,14 @@ params_type = list if params_type == list: - if len(params) != sqlite.sqlite3_bind_parameter_count(self.statement): + if len(params) != sqlite.sqlite3_bind_parameter_count(self._statement): raise ProgrammingError("wrong number of arguments") for i in range(len(params)): - self.set_param(i+1, params[i]) + self.__set_param(i+1, params[i]) else: - for idx in range(1, sqlite.sqlite3_bind_parameter_count(self.statement) + 1): - param_name = sqlite.sqlite3_bind_parameter_name(self.statement, idx) + for idx in range(1, sqlite.sqlite3_bind_parameter_count(self._statement) + 1): + param_name = sqlite.sqlite3_bind_parameter_name(self._statement, idx) if param_name is None: raise ProgrammingError("need named parameters") param_name = param_name[1:] @@ -1121,92 +1131,73 @@ param = params[param_name] except KeyError: raise ProgrammingError("missing parameter '%s'" % param) - self.set_param(idx, param) + self.__set_param(idx, param) - def next(self, cursor): - self.con._check_closed() - self.con._check_thread() - if self.exhausted: + def _next(self, cursor): + self.__con._check_closed() + self.__con._check_thread() + if self._exhausted: raise StopIteration - item = self.item + item = self._item - ret = sqlite.sqlite3_step(self.statement) + ret = sqlite.sqlite3_step(self._statement) if ret == SQLITE_DONE: - self.exhausted = True - self.item = None + self._exhausted = True + self._item = None elif ret != SQLITE_ROW: - exc = self.con._get_exception(ret) - sqlite.sqlite3_reset(self.statement) + exc = self.__con._get_exception(ret) + sqlite.sqlite3_reset(self._statement) raise exc self._readahead(cursor) return item def _readahead(self, cursor): - self.column_count = sqlite.sqlite3_column_count(self.statement) + self.column_count = sqlite.sqlite3_column_count(self._statement) row = [] for i in xrange(self.column_count): - typ = sqlite.sqlite3_column_type(self.statement, i) + typ = sqlite.sqlite3_column_type(self._statement, i) - converter = self.row_cast_map[i] + converter = self.__row_cast_map[i] if converter is None: if typ == SQLITE_INTEGER: - val = sqlite.sqlite3_column_int64(self.statement, i) + val = sqlite.sqlite3_column_int64(self._statement, i) if -sys.maxint-1 <= val <= sys.maxint: val = int(val) elif typ == SQLITE_FLOAT: - val = sqlite.sqlite3_column_double(self.statement, i) + val = sqlite.sqlite3_column_double(self._statement, i) elif typ == SQLITE_BLOB: - blob_len = sqlite.sqlite3_column_bytes(self.statement, i) - blob = sqlite.sqlite3_column_blob(self.statement, i) + blob_len = sqlite.sqlite3_column_bytes(self._statement, i) + blob = sqlite.sqlite3_column_blob(self._statement, i) val = buffer(string_at(blob, blob_len)) elif typ == SQLITE_NULL: val = None elif typ == SQLITE_TEXT: - text_len = sqlite.sqlite3_column_bytes(self.statement, i) - text = sqlite.sqlite3_column_text(self.statement, i) + text_len = sqlite.sqlite3_column_bytes(self._statement, i) + text = sqlite.sqlite3_column_text(self._statement, i) val = string_at(text, text_len) - val = self.con.text_factory(val) + val = self.__con.text_factory(val) else: - blob = sqlite.sqlite3_column_blob(self.statement, i) + blob = sqlite.sqlite3_column_blob(self._statement, i) if not blob: val = None else: - blob_len = sqlite.sqlite3_column_bytes(self.statement, i) + blob_len = sqlite.sqlite3_column_bytes(self._statement, i) val = string_at(blob, blob_len) val = converter(val) row.append(val) row = tuple(row) - if self.row_factory is not None: - row = self.row_factory(cursor, row) - self.item = row - - def reset(self): - self.row_cast_map = None - ret = sqlite.sqlite3_reset(self.statement) - self.in_use = False - self.exhausted = False - return ret - - def finalize(self): - sqlite.sqlite3_finalize(self.statement) - self.statement = None - self.in_use = False - - def mark_dirty(self): - self.in_use = True - - def __del__(self): - if self.statement: - sqlite.sqlite3_finalize(self.statement) + if self._row_factory is not None: + row = self._row_factory(cursor, row) + self._item = row def _get_description(self): - if self.kind == _DML: + if self._kind == Statement._DML: return None desc = [] - for i in xrange(sqlite.sqlite3_column_count(self.statement)): - name = sqlite.sqlite3_column_name(self.statement, i).split("[")[0].strip() + for i in xrange(sqlite.sqlite3_column_count(self._statement)): + name = sqlite.sqlite3_column_name(self._statement, i).split("[")[0].strip() desc.append((name, None, None, None, None, None, None)) return desc _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit