Author: Armin Rigo <[email protected]>
Branch: 
Changeset: r63024:8e8cbf7127f1
Date: 2013-04-04 23:09 +0200
http://bitbucket.org/pypy/pypy/changeset/8e8cbf7127f1/

Log:    merge heads

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -377,7 +377,7 @@
         self.maxcount = maxcount
         self.cache = OrderedDict()
 
-    def get(self, sql, row_factory):
+    def get(self, sql):
         try:
             stat = self.cache[sql]
         except KeyError:
@@ -389,7 +389,6 @@
             if stat._in_use:
                 stat = Statement(self.connection, sql)
                 self.cache[sql] = stat
-        stat._row_factory = row_factory
         return stat
 
 
@@ -552,7 +551,7 @@
     @_check_thread_wrap
     @_check_closed_wrap
     def __call__(self, sql):
-        return self._statement_cache.get(sql, self.row_factory)
+        return self._statement_cache.get(sql)
 
     def cursor(self, factory=None):
         self._check_thread()
@@ -881,16 +880,96 @@
             return func(self, *args, **kwargs)
         return wrapper
 
+    def __check_reset(self):
+        if self._reset:
+            raise InterfaceError(
+                    "Cursor needed to be reset because of commit/rollback "
+                    "and can no longer be fetched from.")
+
+    def __build_row_cast_map(self):
+        if not self.__connection._detect_types:
+            return
+        self.__row_cast_map = []
+        for i in 
xrange(_lib.sqlite3_column_count(self.__statement._statement)):
+            converter = None
+
+            if self.__connection._detect_types & PARSE_COLNAMES:
+                colname = 
_lib.sqlite3_column_name(self.__statement._statement, i)
+                if colname:
+                    colname = _ffi.string(colname).decode('utf-8')
+                    type_start = -1
+                    key = None
+                    for pos in range(len(colname)):
+                        if colname[pos] == '[':
+                            type_start = pos + 1
+                        elif colname[pos] == ']' and type_start != -1:
+                            key = colname[type_start:pos]
+                            converter = converters[key.upper()]
+
+            if converter is None and self.__connection._detect_types & 
PARSE_DECLTYPES:
+                decltype = 
_lib.sqlite3_column_decltype(self.__statement._statement, i)
+                if decltype:
+                    decltype = _ffi.string(decltype).decode('utf-8')
+                    # if multiple words, use first, eg.
+                    # "INTEGER NOT NULL" => "INTEGER"
+                    decltype = decltype.split()[0]
+                    if '(' in decltype:
+                        decltype = decltype[:decltype.index('(')]
+                    converter = converters.get(decltype.upper(), None)
+
+            self.__row_cast_map.append(converter)
+
+    def __fetch_one_row(self):
+        row = []
+        num_cols = _lib.sqlite3_data_count(self.__statement._statement)
+        for i in xrange(num_cols):
+            if self.__connection._detect_types:
+                converter = self.__row_cast_map[i]
+            else:
+                converter = None
+
+            if converter is not None:
+                blob = _lib.sqlite3_column_blob(self.__statement._statement, i)
+                if not blob:
+                    val = None
+                else:
+                    blob_len = 
_lib.sqlite3_column_bytes(self.__statement._statement, i)
+                    val = _ffi.buffer(blob, blob_len)[:]
+                    val = converter(val)
+            else:
+                typ = _lib.sqlite3_column_type(self.__statement._statement, i)
+                if typ == _lib.SQLITE_NULL:
+                    val = None
+                elif typ == _lib.SQLITE_INTEGER:
+                    val = 
_lib.sqlite3_column_int64(self.__statement._statement, i)
+                    val = int(val)
+                elif typ == _lib.SQLITE_FLOAT:
+                    val = 
_lib.sqlite3_column_double(self.__statement._statement, i)
+                elif typ == _lib.SQLITE_TEXT:
+                    text = 
_lib.sqlite3_column_text(self.__statement._statement, i)
+                    text_len = 
_lib.sqlite3_column_bytes(self.__statement._statement, i)
+                    val = _ffi.buffer(text, text_len)[:]
+                    val = self.__connection.text_factory(val)
+                elif typ == _lib.SQLITE_BLOB:
+                    blob = 
_lib.sqlite3_column_blob(self.__statement._statement, i)
+                    blob_len = 
_lib.sqlite3_column_bytes(self.__statement._statement, i)
+                    val = _BLOB_TYPE(_ffi.buffer(blob, blob_len))
+            row.append(val)
+        return tuple(row)
+
     def __execute(self, multiple, sql, many_params):
         self.__locked = True
+        self._reset = False
         try:
-            self._reset = False
+            del self.__next_row
+        except AttributeError:
+            pass
+        try:
             if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__description = None
             self.__rowcount = -1
-            self.__statement = self.__connection._statement_cache.get(
-                sql, self.row_factory)
+            self.__statement = self.__connection._statement_cache.get(sql)
 
             if self.__connection._isolation_level is not None:
                 if self.__statement._kind == Statement._DDL:
@@ -915,9 +994,9 @@
                 if self.__statement._kind == Statement._DML:
                     self.__statement._reset()
 
-                if self.__statement._kind == Statement._DQL and ret == 
_lib.SQLITE_ROW:
-                    self.__statement._build_row_cast_map()
-                    self.__statement._readahead(self)
+                if ret == _lib.SQLITE_ROW:
+                    self.__build_row_cast_map()
+                    self.__next_row = self.__fetch_one_row()
 
                 if self.__statement._kind == Statement._DML:
                     if self.__rowcount == -1:
@@ -978,12 +1057,6 @@
                 break
         return self
 
-    def __check_reset(self):
-        if self._reset:
-            raise self.__connection.InterfaceError(
-                    "Cursor needed to be reset because of commit/rollback "
-                    "and can no longer be fetched from.")
-
     def __iter__(self):
         return self
 
@@ -992,7 +1065,25 @@
         self.__check_reset()
         if not self.__statement:
             raise StopIteration
-        return self.__statement._next(self)
+
+        try:
+            next_row = self.__next_row
+        except AttributeError:
+            self.__statement._reset()
+            self.__statement = None
+            raise StopIteration
+        del self.__next_row
+
+        if self.row_factory is not None:
+            next_row = self.row_factory(self, next_row)
+
+        ret = _lib.sqlite3_step(self.__statement._statement)
+        if ret not in (_lib.SQLITE_DONE, _lib.SQLITE_ROW):
+            self.__statement._reset()
+            raise self.__connection._get_exception(ret)
+        elif ret == _lib.SQLITE_ROW:
+            self.__next_row = self.__fetch_one_row()
+        return next_row
 
     if sys.version_info[0] < 3:
         next = __next__
@@ -1049,7 +1140,6 @@
         self.__con._remember_statement(self)
 
         self._in_use = False
-        self._row_factory = None
 
         if not isinstance(sql, basestring):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
@@ -1186,98 +1276,6 @@
         else:
             raise ValueError("parameters are of unsupported type")
 
-    def _build_row_cast_map(self):
-        if not self.__con._detect_types:
-            return
-        self.__row_cast_map = []
-        for i in xrange(_lib.sqlite3_column_count(self._statement)):
-            converter = None
-
-            if self.__con._detect_types & PARSE_COLNAMES:
-                colname = _lib.sqlite3_column_name(self._statement, i)
-                if colname:
-                    colname = _ffi.string(colname).decode('utf-8')
-                    type_start = -1
-                    key = None
-                    for pos in range(len(colname)):
-                        if colname[pos] == '[':
-                            type_start = pos + 1
-                        elif colname[pos] == ']' and type_start != -1:
-                            key = colname[type_start:pos]
-                            converter = converters[key.upper()]
-
-            if converter is None and self.__con._detect_types & 
PARSE_DECLTYPES:
-                decltype = _lib.sqlite3_column_decltype(self._statement, i)
-                if decltype:
-                    decltype = _ffi.string(decltype).decode('utf-8')
-                    # if multiple words, use first, eg.
-                    # "INTEGER NOT NULL" => "INTEGER"
-                    decltype = decltype.split()[0]
-                    if '(' in decltype:
-                        decltype = decltype[:decltype.index('(')]
-                    converter = converters.get(decltype.upper(), None)
-
-            self.__row_cast_map.append(converter)
-
-    def _readahead(self, cursor):
-        row = []
-        num_cols = _lib.sqlite3_data_count(self._statement)
-        for i in xrange(num_cols):
-            if self.__con._detect_types:
-                converter = self.__row_cast_map[i]
-            else:
-                converter = None
-
-            if converter is not None:
-                blob = _lib.sqlite3_column_blob(self._statement, i)
-                if not blob:
-                    val = None
-                else:
-                    blob_len = _lib.sqlite3_column_bytes(self._statement, i)
-                    val = _ffi.buffer(blob, blob_len)[:]
-                    val = converter(val)
-            else:
-                typ = _lib.sqlite3_column_type(self._statement, i)
-                if typ == _lib.SQLITE_NULL:
-                    val = None
-                elif typ == _lib.SQLITE_INTEGER:
-                    val = _lib.sqlite3_column_int64(self._statement, i)
-                    val = int(val)
-                elif typ == _lib.SQLITE_FLOAT:
-                    val = _lib.sqlite3_column_double(self._statement, i)
-                elif typ == _lib.SQLITE_TEXT:
-                    text = _lib.sqlite3_column_text(self._statement, i)
-                    text_len = _lib.sqlite3_column_bytes(self._statement, i)
-                    val = _ffi.buffer(text, text_len)[:]
-                    val = self.__con.text_factory(val)
-                elif typ == _lib.SQLITE_BLOB:
-                    blob = _lib.sqlite3_column_blob(self._statement, i)
-                    blob_len = _lib.sqlite3_column_bytes(self._statement, i)
-                    val = _BLOB_TYPE(_ffi.buffer(blob, blob_len))
-            row.append(val)
-
-        row = tuple(row)
-        if self._row_factory is not None:
-            row = self._row_factory(cursor, row)
-        self._item = row
-
-    def _next(self, cursor):
-        try:
-            item = self._item
-        except AttributeError:
-            self._reset()
-            raise StopIteration
-        del self._item
-
-        ret = _lib.sqlite3_step(self._statement)
-        if ret not in (_lib.SQLITE_DONE, _lib.SQLITE_ROW):
-            _lib.sqlite3_reset(self._statement)
-            raise self.__con._get_exception(ret)
-        elif ret == _lib.SQLITE_ROW:
-            self._readahead(cursor)
-
-        return item
-
     def _get_description(self):
         if self._kind == Statement._DML:
             return None
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -281,7 +281,7 @@
 
     def astype(self, space, dtype):
         new_arr = W_NDimArray.from_shape(self.get_shape(), dtype)
-        if dtype.is_str_or_unicode():
+        if self.dtype.is_str_or_unicode() and not dtype.is_str_or_unicode():
             raise OperationError(space.w_NotImplementedError, space.wrap(
                 "astype(%s) not implemented yet" % self.dtype))
         else:
diff --git a/pypy/module/micronumpy/interp_support.py 
b/pypy/module/micronumpy/interp_support.py
--- a/pypy/module/micronumpy/interp_support.py
+++ b/pypy/module/micronumpy/interp_support.py
@@ -15,7 +15,7 @@
     items = []
     num_items = 0
     idx = 0
-    
+
     while (num_items < count or count == -1) and idx < len(s):
         nextidx = s.find(sep, idx)
         if nextidx < 0:
@@ -45,7 +45,7 @@
             items.append(val)
             num_items += 1
         idx = nextidx + 1
-    
+
     if count > num_items:
         raise OperationError(space.w_ValueError, space.wrap(
             "string is smaller than requested size"))
@@ -70,7 +70,7 @@
     if count * itemsize > length:
         raise OperationError(space.w_ValueError, space.wrap(
             "string is smaller than requested size"))
-        
+
     a = W_NDimArray.from_shape([count], dtype=dtype)
     loop.fromstring_loop(a, dtype, itemsize, s)
     return space.wrap(a)
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1487,14 +1487,14 @@
         a = concatenate((['abcdef'], ['abc']))
         assert a[0] == 'abcdef'
         assert str(a.dtype) == '|S6'
-    
+
     def test_record_concatenate(self):
         # only an exact match can succeed
         from numpypy import zeros, concatenate
         a = concatenate((zeros((2,),dtype=[('x', int), ('y', float)]),
                          zeros((2,),dtype=[('x', int), ('y', float)])))
         assert a.shape == (4,)
-        exc = raises(TypeError, concatenate, 
+        exc = raises(TypeError, concatenate,
                             (zeros((2,), dtype=[('x', int), ('y', float)]),
                             (zeros((2,), dtype=[('x', float), ('y', float)]))))
         assert str(exc.value).startswith('record type mismatch')
@@ -1677,11 +1677,15 @@
         a = array('x').astype('S3').dtype
         assert a.itemsize == 3
         # scalar vs. array
+        a = array([1, 2, 3.14156]).astype('S3').dtype
+        assert a.itemsize == 3
+        a = array(3.1415).astype('S3').dtype
+        assert a.itemsize == 3
         try:
-            a = array([1, 2, 3.14156]).astype('S3').dtype
-            assert a.itemsize == 3
+            a = array(['1', '2','3']).astype(float)
+            assert a[2] == 3.0
         except NotImplementedError:
-            skip('astype("S3") not implemented for numeric arrays')
+            skip('astype("float") not implemented for str arrays')
 
     def test_base(self):
         from numpypy import array
diff --git a/pypy/module/test_lib_pypy/test_sqlite3.py 
b/pypy/module/test_lib_pypy/test_sqlite3.py
--- a/pypy/module/test_lib_pypy/test_sqlite3.py
+++ b/pypy/module/test_lib_pypy/test_sqlite3.py
@@ -193,3 +193,8 @@
     con.commit()
     con.execute('BEGIN')
     con.commit()
+
+def test_row_factory_use():
+    con = _sqlite3.connect(':memory:')
+    con.row_factory = 42
+    con.execute('select 1')
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to