Author: Carl Friedrich Bolz <cfb...@gmx.de>
Branch: 
Changeset: r80539:8c1d8f7986a2
Date: 2015-11-04 23:36 +0100
http://bitbucket.org/pypy/pypy/changeset/8c1d8f7986a2/

Log:    a much faster implementation of enumerate

diff --git a/pypy/module/__builtin__/functional.py 
b/pypy/module/__builtin__/functional.py
--- a/pypy/module/__builtin__/functional.py
+++ b/pypy/module/__builtin__/functional.py
@@ -8,7 +8,7 @@
 from pypy.interpreter.error import OperationError
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
 from pypy.interpreter.typedef import TypeDef
-from rpython.rlib import jit
+from rpython.rlib import jit, rarithmetic
 from rpython.rlib.objectmodel import specialize
 from rpython.rlib.rarithmetic import r_uint, intmask
 from rpython.rlib.rbigint import rbigint
@@ -229,10 +229,22 @@
     return min_max(space, __args__, "min")
 
 
+
 class W_Enumerate(W_Root):
-    def __init__(self, w_iter, w_start):
-        self.w_iter = w_iter
-        self.w_index = w_start
+    def __init__(self, space, w_iterable, w_start):
+        from pypy.objspace.std.listobject import W_ListObject
+        w_iter = space.iter(w_iterable)
+        if space.is_w(space.type(w_start), space.w_int):
+            self.index = space.int_w(w_start)
+            self.w_index = None
+            if self.index == 0 and type(w_iterable) is W_ListObject:
+                w_iter = w_iterable
+        else:
+            self.index = -1
+            self.w_index = w_start
+        self.w_iter_or_list = w_iter
+        if self.w_index is not None:
+            assert not type(self.w_iter_or_list) is W_ListObject
 
     def descr___new__(space, w_subtype, w_iterable, w_start=None):
         self = space.allocate_instance(W_Enumerate, w_subtype)
@@ -240,16 +252,42 @@
             w_start = space.wrap(0)
         else:
             w_start = space.index(w_start)
-        self.__init__(space.iter(w_iterable), w_start)
+        self.__init__(space, w_iterable, w_start)
         return space.wrap(self)
 
     def descr___iter__(self, space):
         return space.wrap(self)
 
     def descr_next(self, space):
-        w_item = space.next(self.w_iter)
+        from pypy.objspace.std.listobject import W_ListObject
         w_index = self.w_index
-        self.w_index = space.add(w_index, space.wrap(1))
+        w_iter_or_list = self.w_iter_or_list
+        w_item = None
+        if w_index is None:
+            index = self.index
+            if type(w_iter_or_list) is W_ListObject:
+                try:
+                    w_item = w_iter_or_list.getitem(index)
+                except IndexError:
+                    self.w_iter_or_list = None
+                    raise OperationError(space.w_StopIteration, space.w_None)
+                self.index = index + 1
+            elif w_iter_or_list is None:
+                raise OperationError(space.w_StopIteration, space.w_None)
+            else:
+                try:
+                    newval = rarithmetic.ovfcheck(index + 1)
+                except OverflowError:
+                    w_index = space.wrap(index)
+                    self.w_index = space.add(w_index, space.wrap(1))
+                    self.index = -1
+                else:
+                    self.index = newval
+            w_index = space.wrap(index)
+        else:
+            self.w_index = space.add(w_index, space.wrap(1))
+        if w_item is None:
+            w_item = space.next(self.w_iter_or_list)
         return space.newtuple([w_index, w_item])
 
     def descr___reduce__(self, space):
@@ -257,12 +295,17 @@
         w_mod    = space.getbuiltinmodule('_pickle_support')
         mod      = space.interp_w(MixedModule, w_mod)
         w_new_inst = mod.get('enumerate_new')
-        w_info = space.newtuple([self.w_iter, self.w_index])
+        w_index = self.w_index
+        if w_index is None:
+            w_index = space.wrap(self.index)
+        else:
+            w_index = self.w_index
+        w_info = space.newtuple([self.w_iter_or_list, w_index])
         return space.newtuple([w_new_inst, w_info])
 
 # exported through _pickle_support
 def _make_enumerate(space, w_iter, w_index):
-    return space.wrap(W_Enumerate(w_iter, w_index))
+    return space.wrap(W_Enumerate(space, w_iter, w_index))
 
 W_Enumerate.typedef = TypeDef("enumerate",
     __new__=interp2app(W_Enumerate.descr___new__.im_func),
diff --git a/pypy/module/__builtin__/test/test_builtin.py 
b/pypy/module/__builtin__/test/test_builtin.py
--- a/pypy/module/__builtin__/test/test_builtin.py
+++ b/pypy/module/__builtin__/test/test_builtin.py
@@ -264,6 +264,7 @@
         raises(StopIteration,x.next)
 
     def test_enumerate(self):
+        import sys
         seq = range(2,4)
         enum = enumerate(seq)
         assert enum.next() == (0, 2)
@@ -274,6 +275,15 @@
         enum = enumerate(range(5), 2)
         assert list(enum) == zip(range(2, 7), range(5))
 
+        enum = enumerate(range(2), 2**100)
+        assert list(enum) == [(2**100, 0), (2**100+1, 1)]
+
+        enum = enumerate(range(2), sys.maxint)
+        assert list(enum) == [(sys.maxint, 0), (sys.maxint+1, 1)]
+
+        raises(TypeError, enumerate, range(2), 5.5)
+
+
     def test_next(self):
         x = iter(['a', 'b', 'c'])
         assert next(x) == 'a'
diff --git a/pypy/module/pypyjit/test_pypy_c/test_containers.py 
b/pypy/module/pypyjit/test_pypy_c/test_containers.py
--- a/pypy/module/pypyjit/test_pypy_c/test_containers.py
+++ b/pypy/module/pypyjit/test_pypy_c/test_containers.py
@@ -248,3 +248,23 @@
         loop, = log.loops_by_filename(self.filepath)
         ops = loop.ops_by_id('getitem', include_guard_not_invalidated=False)
         assert log.opnames(ops) == []
+
+    def test_enumerate_list(self):
+        def main(n):
+            for a, b in enumerate([1, 2] * 1000):
+                a + b
+
+        log = self.run(main, [1000])
+        loop, = log.loops_by_filename(self.filepath)
+        opnames = log.opnames(loop.allops())
+        assert opnames.count('new_with_vtable') == 0
+
+    def test_enumerate(self):
+        def main(n):
+            for a, b in enumerate("abc" * 1000):
+                a + ord(b)
+
+        log = self.run(main, [1000])
+        loop, = log.loops_by_filename(self.filepath)
+        opnames = log.opnames(loop.allops())
+        assert opnames.count('new_with_vtable') == 0
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to