Author: Brian Kearns <bdkea...@gmail.com>
Branch: 
Changeset: r62133:45bbcccf99e5
Date: 2013-03-06 15:25 -0500
http://bitbucket.org/pypy/pypy/changeset/45bbcccf99e5/

Log:    cleanup _sqlite3.Statement

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -237,8 +237,6 @@
     sqlite.sqlite3_enable_load_extension.argtypes = [c_void_p, c_int]
     sqlite.sqlite3_enable_load_extension.restype = c_int
 
-_DML, _DQL, _DDL = range(3)
-
 ##########################################
 # END Wrapped SQLite C API and constants
 ##########################################
@@ -302,9 +300,9 @@
             if len(self.cache) > self.maxcount:
                 self.cache.popitem(0)
 
-        if stat.in_use:
+        if stat._in_use:
             stat = Statement(self.connection, sql)
-        stat.set_row_factory(row_factory)
+        stat._row_factory = row_factory
         return stat
 
 
@@ -363,7 +361,7 @@
         for statement in self.__statements:
             obj = statement()
             if obj is not None:
-                obj.finalize()
+                obj._finalize()
 
         if self._db:
             ret = sqlite.sqlite3_close(self._db)
@@ -496,7 +494,7 @@
         for statement in self.__statements:
             obj = statement()
             if obj is not None:
-                obj.reset()
+                obj._reset()
 
         statement = c_void_p()
         next_char = c_char_p()
@@ -521,7 +519,7 @@
         for statement in self.__statements:
             obj = statement()
             if obj is not None:
-                obj.reset()
+                obj._reset()
 
         for cursor_ref in self._cursors:
             cursor = cursor_ref()
@@ -771,13 +769,13 @@
             except ValueError:
                 pass
         if self.__statement:
-            self.__statement.reset()
+            self.__statement._reset()
 
     def close(self):
         self.__connection._check_thread()
         self.__connection._check_closed()
         if self.__statement:
-            self.__statement.reset()
+            self.__statement._reset()
             self.__statement = None
         self.__closed = True
 
@@ -808,35 +806,35 @@
                 sql, self.row_factory)
 
             if self.__connection._isolation_level is not None:
-                if self.__statement.kind == _DDL:
+                if self.__statement._kind == Statement._DDL:
                     if self.__connection._in_transaction:
                         self.__connection.commit()
-                elif self.__statement.kind == _DML:
+                elif self.__statement._kind == Statement._DML:
                     if not self.__connection._in_transaction:
                         self.__connection._begin()
 
-            self.__statement.set_params(params)
+            self.__statement._set_params(params)
 
             # Actually execute the SQL statement
-            ret = sqlite.sqlite3_step(self.__statement.statement)
+            ret = sqlite.sqlite3_step(self.__statement._statement)
             if ret not in (SQLITE_DONE, SQLITE_ROW):
-                self.__statement.reset()
+                self.__statement._reset()
                 self.__connection._in_transaction = \
                         not 
sqlite.sqlite3_get_autocommit(self.__connection._db)
                 raise self.__connection._get_exception(ret)
 
-            if self.__statement.kind == _DML:
-                self.__statement.reset()
+            if self.__statement._kind == Statement._DML:
+                self.__statement._reset()
 
-            if self.__statement.kind == _DQL and ret == SQLITE_ROW:
+            if self.__statement._kind == Statement._DQL and ret == SQLITE_ROW:
                 self.__statement._build_row_cast_map()
                 self.__statement._readahead(self)
             else:
-                self.__statement.item = None
-                self.__statement.exhausted = True
+                self.__statement._item = None
+                self.__statement._exhausted = True
 
             self.__rowcount = -1
-            if self.__statement.kind == _DML:
+            if self.__statement._kind == Statement._DML:
                 self.__rowcount = sqlite.sqlite3_changes(self.__connection._db)
         finally:
             self.__locked = False
@@ -852,7 +850,7 @@
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
 
-            if self.__statement.kind == _DML:
+            if self.__statement._kind == Statement._DML:
                 if self.__connection._isolation_level is not None:
                     if not self.__connection._in_transaction:
                         self.__connection._begin()
@@ -861,15 +859,15 @@
 
             self.__rowcount = 0
             for params in many_params:
-                self.__statement.set_params(params)
-                ret = sqlite.sqlite3_step(self.__statement.statement)
+                self.__statement._set_params(params)
+                ret = sqlite.sqlite3_step(self.__statement._statement)
                 if ret != SQLITE_DONE:
-                    self.__statement.reset()
+                    self.__statement._reset()
                     self.__connection._in_transaction = \
                             not 
sqlite.sqlite3_get_autocommit(self.__connection._db)
                     raise self.__connection._get_exception(ret)
                 self.__rowcount += 
sqlite.sqlite3_changes(self.__connection._db)
-            self.__statement.reset()
+            self.__statement._reset()
         finally:
             self.__locked = False
 
@@ -926,7 +924,7 @@
             return None
 
         try:
-            return self.__statement.next(self)
+            return self.__statement._next(self)
         except StopIteration:
             return None
 
@@ -980,55 +978,67 @@
 
 
 class Statement(object):
-    statement = None
+    _DML, _DQL, _DDL = range(3)
+
+    _statement = None
 
     def __init__(self, connection, sql):
-        if not isinstance(sql, str):
+        self.__con = connection
+
+        if isinstance(sql, unicode):
+            sql = sql.encode('utf-8')
+        elif not isinstance(sql, str):
             raise ValueError("sql must be a string")
-        self.con = connection
-        self.sql = sql  # DEBUG ONLY
         first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper()
         if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"):
-            self.kind = _DML
+            self._kind = Statement._DML
         elif first_word in ("SELECT", "PRAGMA"):
-            self.kind = _DQL
+            self._kind = Statement._DQL
         else:
-            self.kind = _DDL
-        self.exhausted = False
-        self.in_use = False
-        #
-        # set by set_row_factory
-        self.row_factory = None
+            self._kind = Statement._DDL
 
-        self.statement = c_void_p()
+        self._in_use = False
+        self._exhausted = False
+        self._row_factory = None
+
+        self._statement = c_void_p()
         next_char = c_char_p()
         sql_char = c_char_p(sql)
-        ret = sqlite.sqlite3_prepare_v2(self.con._db, sql_char, -1, 
byref(self.statement), byref(next_char))
-        if ret == SQLITE_OK and self.statement.value is None:
+        ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1, 
byref(self._statement), byref(next_char))
+        if ret == SQLITE_OK and self._statement.value is None:
             # an empty statement, we work around that, as it's the least 
trouble
-            ret = sqlite.sqlite3_prepare_v2(self.con._db, "select 42", -1, 
byref(self.statement), byref(next_char))
-            self.kind = _DQL
+            ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1, 
byref(self._statement), byref(next_char))
+            self._kind = Statement._DQL
 
         if ret != SQLITE_OK:
-            raise self.con._get_exception(ret)
-        self.con._remember_statement(self)
+            raise self.__con._get_exception(ret)
+        self.__con._remember_statement(self)
         if _check_remaining_sql(next_char.value):
             raise Warning("One and only one statement required: %r" %
                           (next_char.value,))
-        # sql_char should remain alive until here
 
-        self._build_row_cast_map()
+    def __del__(self):
+        if self._statement:
+            sqlite.sqlite3_finalize(self._statement)
 
-    def set_row_factory(self, row_factory):
-        self.row_factory = row_factory
+    def _finalize(self):
+        sqlite.sqlite3_finalize(self._statement)
+        self._statement = None
+        self._in_use = False
+
+    def _reset(self):
+        ret = sqlite.sqlite3_reset(self._statement)
+        self._in_use = False
+        self._exhausted = False
+        return ret
 
     def _build_row_cast_map(self):
-        self.row_cast_map = []
-        for i in xrange(sqlite.sqlite3_column_count(self.statement)):
+        self.__row_cast_map = []
+        for i in xrange(sqlite.sqlite3_column_count(self._statement)):
             converter = None
 
-            if self.con._detect_types & PARSE_COLNAMES:
-                colname = sqlite.sqlite3_column_name(self.statement, i)
+            if self.__con._detect_types & PARSE_COLNAMES:
+                colname = sqlite.sqlite3_column_name(self._statement, i)
                 if colname is not None:
                     type_start = -1
                     key = None
@@ -1039,28 +1049,28 @@
                             key = colname[type_start:pos]
                             converter = converters[key.upper()]
 
-            if converter is None and self.con._detect_types & PARSE_DECLTYPES:
-                decltype = sqlite.sqlite3_column_decltype(self.statement, i)
+            if converter is None and self.__con._detect_types & 
PARSE_DECLTYPES:
+                decltype = sqlite.sqlite3_column_decltype(self._statement, i)
                 if decltype is not None:
                     decltype = decltype.split()[0]      # if multiple words, 
use first, eg. "INTEGER NOT NULL" => "INTEGER"
                     if '(' in decltype:
                         decltype = decltype[:decltype.index('(')]
                     converter = converters.get(decltype.upper(), None)
 
-            self.row_cast_map.append(converter)
+            self.__row_cast_map.append(converter)
 
-    def _check_decodable(self, param):
-        if self.con.text_factory in (unicode, OptimizedUnicode, 
unicode_text_factory):
+    def __check_decodable(self, param):
+        if self.__con.text_factory in (unicode, OptimizedUnicode, 
unicode_text_factory):
             for c in param:
                 if ord(c) & 0x80 != 0:
-                    raise self.con.ProgrammingError(
+                    raise self.__con.ProgrammingError(
                         "You must not use 8-bit bytestrings unless "
                         "you use a text_factory that can interpret "
                         "8-bit bytestrings (like text_factory = str). "
                         "It is highly recommended that you instead "
                         "just switch your application to Unicode strings.")
 
-    def set_param(self, idx, param):
+    def __set_param(self, idx, param):
         cvt = converters.get(type(param))
         if cvt is not None:
             cvt = param = cvt(param)
@@ -1068,34 +1078,34 @@
         param = adapt(param)
 
         if param is None:
-            sqlite.sqlite3_bind_null(self.statement, idx)
+            sqlite.sqlite3_bind_null(self._statement, idx)
         elif type(param) in (bool, int, long):
             if -2147483648 <= param <= 2147483647:
-                sqlite.sqlite3_bind_int(self.statement, idx, param)
+                sqlite.sqlite3_bind_int(self._statement, idx, param)
             else:
-                sqlite.sqlite3_bind_int64(self.statement, idx, param)
+                sqlite.sqlite3_bind_int64(self._statement, idx, param)
         elif type(param) is float:
-            sqlite.sqlite3_bind_double(self.statement, idx, param)
+            sqlite.sqlite3_bind_double(self._statement, idx, param)
         elif isinstance(param, str):
-            self._check_decodable(param)
-            sqlite.sqlite3_bind_text(self.statement, idx, param, len(param), 
SQLITE_TRANSIENT)
+            self.__check_decodable(param)
+            sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), 
SQLITE_TRANSIENT)
         elif isinstance(param, unicode):
             param = param.encode("utf-8")
-            sqlite.sqlite3_bind_text(self.statement, idx, param, len(param), 
SQLITE_TRANSIENT)
+            sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), 
SQLITE_TRANSIENT)
         elif type(param) is buffer:
-            sqlite.sqlite3_bind_blob(self.statement, idx, str(param), 
len(param), SQLITE_TRANSIENT)
+            sqlite.sqlite3_bind_blob(self._statement, idx, str(param), 
len(param), SQLITE_TRANSIENT)
         else:
             raise InterfaceError("parameter type %s is not supported" %
                                  str(type(param)))
 
-    def set_params(self, params):
-        ret = sqlite.sqlite3_reset(self.statement)
+    def _set_params(self, params):
+        ret = sqlite.sqlite3_reset(self._statement)
         if ret != SQLITE_OK:
-            raise self.con._get_exception(ret)
-        self.mark_dirty()
+            raise self.__con._get_exception(ret)
+        self._in_use = True
 
         if params is None:
-            if sqlite.sqlite3_bind_parameter_count(self.statement) != 0:
+            if sqlite.sqlite3_bind_parameter_count(self._statement) != 0:
                 raise ProgrammingError("wrong number of arguments")
             return
 
@@ -1106,14 +1116,14 @@
             params_type = list
 
         if params_type == list:
-            if len(params) != 
sqlite.sqlite3_bind_parameter_count(self.statement):
+            if len(params) != 
sqlite.sqlite3_bind_parameter_count(self._statement):
                 raise ProgrammingError("wrong number of arguments")
 
             for i in range(len(params)):
-                self.set_param(i+1, params[i])
+                self.__set_param(i+1, params[i])
         else:
-            for idx in range(1, 
sqlite.sqlite3_bind_parameter_count(self.statement) + 1):
-                param_name = 
sqlite.sqlite3_bind_parameter_name(self.statement, idx)
+            for idx in range(1, 
sqlite.sqlite3_bind_parameter_count(self._statement) + 1):
+                param_name = 
sqlite.sqlite3_bind_parameter_name(self._statement, idx)
                 if param_name is None:
                     raise ProgrammingError("need named parameters")
                 param_name = param_name[1:]
@@ -1121,92 +1131,73 @@
                     param = params[param_name]
                 except KeyError:
                     raise ProgrammingError("missing parameter '%s'" % param)
-                self.set_param(idx, param)
+                self.__set_param(idx, param)
 
-    def next(self, cursor):
-        self.con._check_closed()
-        self.con._check_thread()
-        if self.exhausted:
+    def _next(self, cursor):
+        self.__con._check_closed()
+        self.__con._check_thread()
+        if self._exhausted:
             raise StopIteration
-        item = self.item
+        item = self._item
 
-        ret = sqlite.sqlite3_step(self.statement)
+        ret = sqlite.sqlite3_step(self._statement)
         if ret == SQLITE_DONE:
-            self.exhausted = True
-            self.item = None
+            self._exhausted = True
+            self._item = None
         elif ret != SQLITE_ROW:
-            exc = self.con._get_exception(ret)
-            sqlite.sqlite3_reset(self.statement)
+            exc = self.__con._get_exception(ret)
+            sqlite.sqlite3_reset(self._statement)
             raise exc
 
         self._readahead(cursor)
         return item
 
     def _readahead(self, cursor):
-        self.column_count = sqlite.sqlite3_column_count(self.statement)
+        self.column_count = sqlite.sqlite3_column_count(self._statement)
         row = []
         for i in xrange(self.column_count):
-            typ = sqlite.sqlite3_column_type(self.statement, i)
+            typ = sqlite.sqlite3_column_type(self._statement, i)
 
-            converter = self.row_cast_map[i]
+            converter = self.__row_cast_map[i]
             if converter is None:
                 if typ == SQLITE_INTEGER:
-                    val = sqlite.sqlite3_column_int64(self.statement, i)
+                    val = sqlite.sqlite3_column_int64(self._statement, i)
                     if -sys.maxint-1 <= val <= sys.maxint:
                         val = int(val)
                 elif typ == SQLITE_FLOAT:
-                    val = sqlite.sqlite3_column_double(self.statement, i)
+                    val = sqlite.sqlite3_column_double(self._statement, i)
                 elif typ == SQLITE_BLOB:
-                    blob_len = sqlite.sqlite3_column_bytes(self.statement, i)
-                    blob = sqlite.sqlite3_column_blob(self.statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    blob = sqlite.sqlite3_column_blob(self._statement, i)
                     val = buffer(string_at(blob, blob_len))
                 elif typ == SQLITE_NULL:
                     val = None
                 elif typ == SQLITE_TEXT:
-                    text_len = sqlite.sqlite3_column_bytes(self.statement, i)
-                    text = sqlite.sqlite3_column_text(self.statement, i)
+                    text_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    text = sqlite.sqlite3_column_text(self._statement, i)
                     val = string_at(text, text_len)
-                    val = self.con.text_factory(val)
+                    val = self.__con.text_factory(val)
             else:
-                blob = sqlite.sqlite3_column_blob(self.statement, i)
+                blob = sqlite.sqlite3_column_blob(self._statement, i)
                 if not blob:
                     val = None
                 else:
-                    blob_len = sqlite.sqlite3_column_bytes(self.statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
                     val = string_at(blob, blob_len)
                     val = converter(val)
             row.append(val)
 
         row = tuple(row)
-        if self.row_factory is not None:
-            row = self.row_factory(cursor, row)
-        self.item = row
-
-    def reset(self):
-        self.row_cast_map = None
-        ret = sqlite.sqlite3_reset(self.statement)
-        self.in_use = False
-        self.exhausted = False
-        return ret
-
-    def finalize(self):
-        sqlite.sqlite3_finalize(self.statement)
-        self.statement = None
-        self.in_use = False
-
-    def mark_dirty(self):
-        self.in_use = True
-
-    def __del__(self):
-        if self.statement:
-            sqlite.sqlite3_finalize(self.statement)
+        if self._row_factory is not None:
+            row = self._row_factory(cursor, row)
+        self._item = row
 
     def _get_description(self):
-        if self.kind == _DML:
+        if self._kind == Statement._DML:
             return None
         desc = []
-        for i in xrange(sqlite.sqlite3_column_count(self.statement)):
-            name = sqlite.sqlite3_column_name(self.statement, 
i).split("[")[0].strip()
+        for i in xrange(sqlite.sqlite3_column_count(self._statement)):
+            name = sqlite.sqlite3_column_name(self._statement, 
i).split("[")[0].strip()
             desc.append((name, None, None, None, None, None, None))
         return desc
 
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to