https://github.com/python/cpython/commit/909c6f718913e713c990d69e6d8a74c05f81e2c2
commit: 909c6f718913e713c990d69e6d8a74c05f81e2c2
branch: main
author: Raymond Hettinger <[email protected]>
committer: rhettinger <[email protected]>
date: 2024-09-25T13:38:05-07:00
summary:

gh-123884 Tee of tee was not producing n independent iterators (gh-124490)

files:
A Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
M Doc/library/itertools.rst
M Include/internal/pycore_global_objects_fini_generated.h
M Include/internal/pycore_global_strings.h
M Include/internal/pycore_runtime_init_generated.h
M Include/internal/pycore_unicodeobject_generated.h
M Lib/test/test_itertools.py
M Modules/itertoolsmodule.c

diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index 508c20f4df6f5e..c1299ebfe8d27a 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -691,25 +691,36 @@ loops that truncate the stream.
 
         def tee(iterable, n=2):
             if n < 0:
-                raise ValueError('n must be >= 0')
-            iterator = iter(iterable)
-            shared_link = [None, None]
-            return tuple(_tee(iterator, shared_link) for _ in range(n))
-
-        def _tee(iterator, link):
-            try:
-                while True:
-                    if link[1] is None:
-                        link[0] = next(iterator)
-                        link[1] = [None, None]
-                    value, link = link
-                    yield value
-            except StopIteration:
-                return
-
-   Once a :func:`tee` has been created, the original *iterable* should not be
-   used anywhere else; otherwise, the *iterable* could get advanced without
-   the tee objects being informed.
+                raise ValueError
+            if n == 0:
+                return ()
+            iterator = _tee(iterable)
+            result = [iterator]
+            for _ in range(n - 1):
+                result.append(_tee(iterator))
+            return tuple(result)
+
+        class _tee:
+
+            def __init__(self, iterable):
+                it = iter(iterable)
+                if isinstance(it, _tee):
+                    self.iterator = it.iterator
+                    self.link = it.link
+                else:
+                    self.iterator = it
+                    self.link = [None, None]
+
+            def __iter__(self):
+                return self
+
+            def __next__(self):
+                link = self.link
+                if link[1] is None:
+                    link[0] = next(self.iterator)
+                    link[1] = [None, None]
+                value, self.link = link
+                return value
 
    When the input *iterable* is already a tee iterator object, all
    members of the return tuple are constructed as if they had been
diff --git a/Include/internal/pycore_global_objects_fini_generated.h 
b/Include/internal/pycore_global_objects_fini_generated.h
index 6e948e16b7dbe8..a5f13692c627e4 100644
--- a/Include/internal/pycore_global_objects_fini_generated.h
+++ b/Include/internal/pycore_global_objects_fini_generated.h
@@ -604,7 +604,6 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) {
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__classdictcell__));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__complex__));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__contains__));
-    _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__copy__));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__ctypes_from_outparam__));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__del__));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__delattr__));
diff --git a/Include/internal/pycore_global_strings.h 
b/Include/internal/pycore_global_strings.h
index 5c63a6e519b93d..dd958dcd9fe3ac 100644
--- a/Include/internal/pycore_global_strings.h
+++ b/Include/internal/pycore_global_strings.h
@@ -93,7 +93,6 @@ struct _Py_global_strings {
         STRUCT_FOR_ID(__classdictcell__)
         STRUCT_FOR_ID(__complex__)
         STRUCT_FOR_ID(__contains__)
-        STRUCT_FOR_ID(__copy__)
         STRUCT_FOR_ID(__ctypes_from_outparam__)
         STRUCT_FOR_ID(__del__)
         STRUCT_FOR_ID(__delattr__)
diff --git a/Include/internal/pycore_runtime_init_generated.h 
b/Include/internal/pycore_runtime_init_generated.h
index bac6b5b8fcfd9d..8d7da8befc94c1 100644
--- a/Include/internal/pycore_runtime_init_generated.h
+++ b/Include/internal/pycore_runtime_init_generated.h
@@ -602,7 +602,6 @@ extern "C" {
     INIT_ID(__classdictcell__), \
     INIT_ID(__complex__), \
     INIT_ID(__contains__), \
-    INIT_ID(__copy__), \
     INIT_ID(__ctypes_from_outparam__), \
     INIT_ID(__del__), \
     INIT_ID(__delattr__), \
diff --git a/Include/internal/pycore_unicodeobject_generated.h 
b/Include/internal/pycore_unicodeobject_generated.h
index efdbde4c8ea3c6..d5ad61d4c62e22 100644
--- a/Include/internal/pycore_unicodeobject_generated.h
+++ b/Include/internal/pycore_unicodeobject_generated.h
@@ -172,10 +172,6 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) {
     _PyUnicode_InternStatic(interp, &string);
     assert(_PyUnicode_CheckConsistency(string, 1));
     assert(PyUnicode_GET_LENGTH(string) != 1);
-    string = &_Py_ID(__copy__);
-    _PyUnicode_InternStatic(interp, &string);
-    assert(_PyUnicode_CheckConsistency(string, 1));
-    assert(PyUnicode_GET_LENGTH(string) != 1);
     string = &_Py_ID(__ctypes_from_outparam__);
     _PyUnicode_InternStatic(interp, &string);
     assert(_PyUnicode_CheckConsistency(string, 1));
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 6820dce3f12620..8469de998ba014 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -1249,10 +1249,11 @@ def test_tee(self):
             self.assertEqual(len(result), n)
             self.assertEqual([list(x) for x in result], [list('abc')]*n)
 
-        # tee pass-through to copyable iterator
+        # tee objects are independent (see bug gh-123884)
         a, b = tee('abc')
         c, d = tee(a)
-        self.assertTrue(a is c)
+        e, f = tee(c)
+        self.assertTrue(len({a, b, c, d, e, f}) == 6)
 
         # test tee_new
         t1, t2 = tee('abc')
@@ -1759,21 +1760,36 @@ def test_tee_recipe(self):
 
         def tee(iterable, n=2):
             if n < 0:
-                raise ValueError('n must be >= 0')
-            iterator = iter(iterable)
-            shared_link = [None, None]
-            return tuple(_tee(iterator, shared_link) for _ in range(n))
+                raise ValueError
+            if n == 0:
+                return ()
+            iterator = _tee(iterable)
+            result = [iterator]
+            for _ in range(n - 1):
+                result.append(_tee(iterator))
+            return tuple(result)
+
+        class _tee:
+
+            def __init__(self, iterable):
+                it = iter(iterable)
+                if isinstance(it, _tee):
+                    self.iterator = it.iterator
+                    self.link = it.link
+                else:
+                    self.iterator = it
+                    self.link = [None, None]
 
-        def _tee(iterator, link):
-            try:
-                while True:
-                    if link[1] is None:
-                        link[0] = next(iterator)
-                        link[1] = [None, None]
-                    value, link = link
-                    yield value
-            except StopIteration:
-                return
+            def __iter__(self):
+                return self
+
+            def __next__(self):
+                link = self.link
+                if link[1] is None:
+                    link[0] = next(self.iterator)
+                    link[1] = [None, None]
+                value, self.link = link
+                return value
 
         # End tee() recipe #############################################
 
@@ -1819,12 +1835,10 @@ def _tee(iterator, link):
         self.assertRaises(TypeError, tee, [1,2], 'x')
         self.assertRaises(TypeError, tee, [1,2], 3, 'x')
 
-        # Tests not applicable to the tee() recipe
-        if False:
-            # tee object should be instantiable
-            a, b = tee('abc')
-            c = type(a)('def')
-            self.assertEqual(list(c), list('def'))
+        # tee object should be instantiable
+        a, b = tee('abc')
+        c = type(a)('def')
+        self.assertEqual(list(c), list('def'))
 
         # test long-lagged and multi-way split
         a, b, c = tee(range(2000), 3)
@@ -1845,21 +1859,19 @@ def _tee(iterator, link):
             self.assertEqual(len(result), n)
             self.assertEqual([list(x) for x in result], [list('abc')]*n)
 
+        # tee objects are independent (see bug gh-123884)
+        a, b = tee('abc')
+        c, d = tee(a)
+        e, f = tee(c)
+        self.assertTrue(len({a, b, c, d, e, f}) == 6)
 
-        # Tests not applicable to the tee() recipe
-        if False:
-            # tee pass-through to copyable iterator
-            a, b = tee('abc')
-            c, d = tee(a)
-            self.assertTrue(a is c)
-
-            # test tee_new
-            t1, t2 = tee('abc')
-            tnew = type(t1)
-            self.assertRaises(TypeError, tnew)
-            self.assertRaises(TypeError, tnew, 10)
-            t3 = tnew(t1)
-            self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
+        # test tee_new
+        t1, t2 = tee('abc')
+        tnew = type(t1)
+        self.assertRaises(TypeError, tnew)
+        self.assertRaises(TypeError, tnew, 10)
+        t3 = tnew(t1)
+        self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
 
         # test that tee objects are weak referencable
         a, b = tee(range(10))
diff --git 
a/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst 
b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
new file mode 100644
index 00000000000000..55f1d4b41125c3
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
@@ -0,0 +1,4 @@
+Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee).
+The output now has the promised *n* independent new iterators.  Formerly,
+the first iterator was identical (not independent) to the input iterator.
+This would sometimes give surprising results.
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index e740ec4d7625c3..1201fa094902d7 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1036,7 +1036,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, 
Py_ssize_t n)
 /*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/
 {
     Py_ssize_t i;
-    PyObject *it, *copyable, *copyfunc, *result;
+    PyObject *it, *to, *result;
 
     if (n < 0) {
         PyErr_SetString(PyExc_ValueError, "n must be >= 0");
@@ -1053,41 +1053,23 @@ itertools_tee_impl(PyObject *module, PyObject 
*iterable, Py_ssize_t n)
         return NULL;
     }
 
-    if (PyObject_GetOptionalAttr(it, &_Py_ID(__copy__), &copyfunc) < 0) {
-        Py_DECREF(it);
+    itertools_state *state = get_module_state(module);
+    to = tee_fromiterable(state, it);
+    Py_DECREF(it);
+    if (to == NULL) {
         Py_DECREF(result);
         return NULL;
     }
-    if (copyfunc != NULL) {
-        copyable = it;
-    }
-    else {
-        itertools_state *state = get_module_state(module);
-        copyable = tee_fromiterable(state, it);
-        Py_DECREF(it);
-        if (copyable == NULL) {
-            Py_DECREF(result);
-            return NULL;
-        }
-        copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__));
-        if (copyfunc == NULL) {
-            Py_DECREF(copyable);
-            Py_DECREF(result);
-            return NULL;
-        }
-    }
 
-    PyTuple_SET_ITEM(result, 0, copyable);
+    PyTuple_SET_ITEM(result, 0, to);
     for (i = 1; i < n; i++) {
-        copyable = _PyObject_CallNoArgs(copyfunc);
-        if (copyable == NULL) {
-            Py_DECREF(copyfunc);
+        to = tee_copy((teeobject *)to, NULL);
+        if (to == NULL) {
             Py_DECREF(result);
             return NULL;
         }
-        PyTuple_SET_ITEM(result, i, copyable);
+        PyTuple_SET_ITEM(result, i, to);
     }
-    Py_DECREF(copyfunc);
     return result;
 }
 

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to