This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.17 by this push:
     new ec5a09b3ce5 [FLINK-31478][python] Fix ds.execute_and_collect to 
support nested tuple
ec5a09b3ce5 is described below

commit ec5a09b3ce56426d1bdc8eeac4bf52cac9be015b
Author: Dian Fu <dia...@apache.org>
AuthorDate: Thu Mar 16 14:31:23 2023 +0800

    [FLINK-31478][python] Fix ds.execute_and_collect to support nested tuple
    
    This closes #22190.
---
 flink-python/pyflink/common/types.py               |  2 +-
 .../pyflink/datastream/tests/test_data_stream.py   | 16 ++++-
 flink-python/pyflink/datastream/utils.py           | 81 +++++++++++-----------
 .../table/tests/test_table_environment_api.py      |  3 +-
 4 files changed, 55 insertions(+), 47 deletions(-)

diff --git a/flink-python/pyflink/common/types.py 
b/flink-python/pyflink/common/types.py
index eeb68272402..d3151bdafe3 100644
--- a/flink-python/pyflink/common/types.py
+++ b/flink-python/pyflink/common/types.py
@@ -251,7 +251,7 @@ class Row(object):
             return "Row(%s)" % ", ".join("%s=%r" % (k, v)
                                          for k, v in zip(self._fields, 
tuple(self)))
         else:
-            return "<Row(%s)>" % ", ".join("%r" % field for field in self)
+            return "<Row(%s)>" % ", ".join(repr(field) for field in self)
 
     def __eq__(self, other):
         if not isinstance(other, Row):
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 8d4e0bea1b8..64f18f5cc83 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -1545,20 +1545,28 @@ class CommonDataStreamTests(PyFlinkTestCase):
         test_data = ['pyflink', 'datastream', 'execute', 'collect']
         ds = self.env.from_collection(test_data)
 
+        # test collect with limit
         expected = test_data[:3]
         actual = []
         for result in ds.execute_and_collect(limit=3):
             actual.append(result)
         self.assertEqual(expected, actual)
 
-        expected = test_data
-        ds = self.env.from_collection(collection=test_data, 
type_info=Types.STRING())
-        with ds.execute_and_collect() as results:
+        # test collect KeyedStream
+        test_data = [('pyflink', 1), ('datastream', 2), ('pyflink', 1), 
('collect', 2)]
+        expected = [Row(f0='pyflink', f1=('pyflink', 1)),
+                    Row(f0='datastream', f1=('datastream', 2)),
+                    Row(f0='pyflink', f1=('pyflink', 1)),
+                    Row(f0='collect', f1=('collect', 2))]
+        ds = self.env.from_collection(collection=test_data,
+                                      type_info=Types.TUPLE([Types.STRING(), 
Types.INT()]))
+        with ds.key_by(lambda i: i[0], Types.STRING()).execute_and_collect() 
as results:
             actual = []
             for result in results:
                 actual.append(result)
             self.assertEqual(expected, actual)
 
+        # test all kinds of data types
         test_data = [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
                       bytearray(b'flink'), 'pyflink',
                       datetime.date(2014, 9, 13),
@@ -1585,6 +1593,7 @@ class CommonDataStreamTests(PyFlinkTestCase):
             actual = [result for result in results]
             self.assert_equals_sorted(expected, actual)
 
+        # test primitive array
         test_data = [[1, 2, 3], [4, 5]]
         expected = test_data
         ds = self.env.from_collection(test_data, 
type_info=Types.PRIMITIVE_ARRAY(Types.INT()))
@@ -1597,6 +1606,7 @@ class CommonDataStreamTests(PyFlinkTestCase):
             ([None, ], [0.0, 0.0])
         ]
 
+        # test object array
         ds = self.env.from_collection(
             test_data,
             type_info=Types.TUPLE(
diff --git a/flink-python/pyflink/datastream/utils.py 
b/flink-python/pyflink/datastream/utils.py
index a88170996ee..2579e761e39 100644
--- a/flink-python/pyflink/datastream/utils.py
+++ b/flink-python/pyflink/datastream/utils.py
@@ -57,68 +57,65 @@ def convert_to_python_obj(data, type_info):
         return convert_to_python_obj(data, type_info._type_info)
     else:
         gateway = get_gateway()
-        pickle_bytes = gateway.jvm.PythonBridgeUtils. \
+        pickled_bytes = gateway.jvm.PythonBridgeUtils. \
             getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
-        if isinstance(type_info, RowTypeInfo) or isinstance(type_info, 
TupleTypeInfo):
-            if isinstance(type_info, RowTypeInfo):
-                field_data = zip(list(pickle_bytes[1:]), 
type_info.get_field_types())
-            else:
-                field_data = zip(pickle_bytes, type_info.get_field_types())
-            fields = []
-            for data, field_type in field_data:
-                if len(data) == 0:
-                    fields.append(None)
-                else:
-                    fields.append(pickled_bytes_to_python_converter(data, 
field_type))
-            if isinstance(type_info, RowTypeInfo):
-                return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 
'little')), *fields)
-            else:
-                return tuple(fields)
-        else:
-            return pickled_bytes_to_python_converter(pickle_bytes, type_info)
+        return pickled_bytes_to_python_obj(pickled_bytes, type_info)
 
 
-def pickled_bytes_to_python_converter(data, field_type):
-    if isinstance(field_type, RowTypeInfo):
+def pickled_bytes_to_python_obj(data, type_info):
+    if isinstance(type_info, RowTypeInfo):
         row_kind = RowKind(int.from_bytes(data[0], 'little'))
-        data = zip(list(data[1:]), field_type.get_field_types())
+        field_data_with_types = zip(list(data[1:]), 
type_info.get_field_types())
         fields = []
-        for d, d_type in data:
-            fields.append(pickled_bytes_to_python_converter(d, d_type))
+        for field_data, field_type in field_data_with_types:
+            if len(field_data) == 0:
+                fields.append(None)
+            else:
+                fields.append(pickled_bytes_to_python_obj(field_data, 
field_type))
         row = Row.of_kind(row_kind, *fields)
+        row.set_field_names(type_info.get_field_names())
         return row
+    elif isinstance(type_info, TupleTypeInfo):
+        field_data_with_types = zip(data, type_info.get_field_types())
+        fields = []
+        for field_data, field_type in field_data_with_types:
+            if len(field_data) == 0:
+                fields.append(None)
+            else:
+                fields.append(pickled_bytes_to_python_obj(field_data, 
field_type))
+        return tuple(fields)
     else:
         data = pickle.loads(data)
-        if field_type == Types.SQL_TIME():
+        if type_info == Types.SQL_TIME():
             seconds, microseconds = divmod(data, 10 ** 6)
             minutes, seconds = divmod(seconds, 60)
             hours, minutes = divmod(minutes, 60)
             return datetime.time(hours, minutes, seconds, microseconds)
-        elif field_type == Types.SQL_DATE():
-            return field_type.from_internal_type(data)
-        elif field_type == Types.SQL_TIMESTAMP():
-            return field_type.from_internal_type(int(data.timestamp() * 10 ** 
6))
-        elif field_type == Types.FLOAT():
-            return field_type.from_internal_type(ast.literal_eval(data))
-        elif isinstance(field_type,
+        elif type_info == Types.SQL_DATE():
+            return type_info.from_internal_type(data)
+        elif type_info == Types.SQL_TIMESTAMP():
+            return type_info.from_internal_type(int(data.timestamp() * 10 ** 
6))
+        elif type_info == Types.FLOAT():
+            return type_info.from_internal_type(ast.literal_eval(data))
+        elif isinstance(type_info,
                         (BasicArrayTypeInfo, PrimitiveArrayTypeInfo, 
ObjectArrayTypeInfo)):
-            element_type = field_type._element_type
+            element_type = type_info._element_type
             elements = []
             for element_bytes in data:
-                
elements.append(pickled_bytes_to_python_converter(element_bytes, element_type))
+                elements.append(pickled_bytes_to_python_obj(element_bytes, 
element_type))
             return elements
-        elif isinstance(field_type, MapTypeInfo):
-            key_type = field_type._key_type_info
-            value_type = field_type._value_type_info
+        elif isinstance(type_info, MapTypeInfo):
+            key_type = type_info._key_type_info
+            value_type = type_info._value_type_info
             zip_kv = zip(data[0], data[1])
-            return dict((pickled_bytes_to_python_converter(k, key_type),
-                         pickled_bytes_to_python_converter(v, value_type))
+            return dict((pickled_bytes_to_python_obj(k, key_type),
+                         pickled_bytes_to_python_obj(v, value_type))
                         for k, v in zip_kv)
-        elif isinstance(field_type, ListTypeInfo):
-            element_type = field_type.elem_type
+        elif isinstance(type_info, ListTypeInfo):
+            element_type = type_info.elem_type
             elements = []
             for element_bytes in data:
-                
elements.append(pickled_bytes_to_python_converter(element_bytes, element_type))
+                elements.append(pickled_bytes_to_python_obj(element_bytes, 
element_type))
             return elements
         else:
-            return field_type.from_internal_type(data)
+            return type_info.from_internal_type(data)
diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py 
b/flink-python/pyflink/table/tests/test_table_environment_api.py
index 7b23d9535f6..de3c4187724 100644
--- a/flink-python/pyflink/table/tests/test_table_environment_api.py
+++ b/flink-python/pyflink/table/tests/test_table_environment_api.py
@@ -352,7 +352,8 @@ class DataStreamConversionTestCases(PyFlinkUTTestCase):
                          result._j_table_result.getResolvedSchema().toString())
         with result.collect() as result:
             collected_result = [str(item) for item in result]
-            expected_result = [item for item in map(str, [Row(1), Row(2), 
Row(3), Row(4), Row(5)])]
+            expected_result = [item for item
+                               in map(str, [Row((1,)), Row((2,)), Row((3,)), 
Row((4,)), Row((5,))])]
             expected_result.sort()
             collected_result.sort()
             self.assertEqual(expected_result, collected_result)

Reply via email to