This is an automated email from the ASF dual-hosted git repository. dianfu pushed a commit to branch release-1.17 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.17 by this push: new ec5a09b3ce5 [FLINK-31478][python] Fix ds.execute_and_collect to support nested tuple ec5a09b3ce5 is described below commit ec5a09b3ce56426d1bdc8eeac4bf52cac9be015b Author: Dian Fu <dia...@apache.org> AuthorDate: Thu Mar 16 14:31:23 2023 +0800 [FLINK-31478][python] Fix ds.execute_and_collect to support nested tuple This closes #22190. --- flink-python/pyflink/common/types.py | 2 +- .../pyflink/datastream/tests/test_data_stream.py | 16 ++++- flink-python/pyflink/datastream/utils.py | 81 +++++++++++----------- .../table/tests/test_table_environment_api.py | 3 +- 4 files changed, 55 insertions(+), 47 deletions(-) diff --git a/flink-python/pyflink/common/types.py b/flink-python/pyflink/common/types.py index eeb68272402..d3151bdafe3 100644 --- a/flink-python/pyflink/common/types.py +++ b/flink-python/pyflink/common/types.py @@ -251,7 +251,7 @@ class Row(object): return "Row(%s)" % ", ".join("%s=%r" % (k, v) for k, v in zip(self._fields, tuple(self))) else: - return "<Row(%s)>" % ", ".join("%r" % field for field in self) + return "<Row(%s)>" % ", ".join(repr(field) for field in self) def __eq__(self, other): if not isinstance(other, Row): diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 8d4e0bea1b8..64f18f5cc83 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -1545,20 +1545,28 @@ class CommonDataStreamTests(PyFlinkTestCase): test_data = ['pyflink', 'datastream', 'execute', 'collect'] ds = self.env.from_collection(test_data) + # test collect with limit expected = test_data[:3] actual = [] for result in ds.execute_and_collect(limit=3): actual.append(result) self.assertEqual(expected, actual) - expected = test_data - ds = self.env.from_collection(collection=test_data, type_info=Types.STRING()) - with ds.execute_and_collect() as results: + # test collect KeyedStream + test_data = [('pyflink', 1), ('datastream', 2), ('pyflink', 1), ('collect', 2)] + expected = [Row(f0='pyflink', f1=('pyflink', 1)), + Row(f0='datastream', f1=('datastream', 2)), + Row(f0='pyflink', f1=('pyflink', 1)), + Row(f0='collect', f1=('collect', 2))] + ds = self.env.from_collection(collection=test_data, + type_info=Types.TUPLE([Types.STRING(), Types.INT()])) + with ds.key_by(lambda i: i[0], Types.STRING()).execute_and_collect() as results: actual = [] for result in results: actual.append(result) self.assertEqual(expected, actual) + # test all kinds of data types test_data = [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932, bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13), @@ -1585,6 +1593,7 @@ class CommonDataStreamTests(PyFlinkTestCase): actual = [result for result in results] self.assert_equals_sorted(expected, actual) + # test primitive array test_data = [[1, 2, 3], [4, 5]] expected = test_data ds = self.env.from_collection(test_data, type_info=Types.PRIMITIVE_ARRAY(Types.INT())) @@ -1597,6 +1606,7 @@ class CommonDataStreamTests(PyFlinkTestCase): ([None, ], [0.0, 0.0]) ] + # test object array ds = self.env.from_collection( test_data, type_info=Types.TUPLE( diff --git a/flink-python/pyflink/datastream/utils.py b/flink-python/pyflink/datastream/utils.py index a88170996ee..2579e761e39 100644 --- a/flink-python/pyflink/datastream/utils.py +++ b/flink-python/pyflink/datastream/utils.py @@ -57,68 +57,65 @@ def convert_to_python_obj(data, type_info): return convert_to_python_obj(data, type_info._type_info) else: gateway = get_gateway() - pickle_bytes = gateway.jvm.PythonBridgeUtils. \ + pickled_bytes = gateway.jvm.PythonBridgeUtils. \ getPickledBytesFromJavaObject(data, type_info.get_java_type_info()) - if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo): - if isinstance(type_info, RowTypeInfo): - field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types()) - else: - field_data = zip(pickle_bytes, type_info.get_field_types()) - fields = [] - for data, field_type in field_data: - if len(data) == 0: - fields.append(None) - else: - fields.append(pickled_bytes_to_python_converter(data, field_type)) - if isinstance(type_info, RowTypeInfo): - return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields) - else: - return tuple(fields) - else: - return pickled_bytes_to_python_converter(pickle_bytes, type_info) + return pickled_bytes_to_python_obj(pickled_bytes, type_info) -def pickled_bytes_to_python_converter(data, field_type): - if isinstance(field_type, RowTypeInfo): +def pickled_bytes_to_python_obj(data, type_info): + if isinstance(type_info, RowTypeInfo): row_kind = RowKind(int.from_bytes(data[0], 'little')) - data = zip(list(data[1:]), field_type.get_field_types()) + field_data_with_types = zip(list(data[1:]), type_info.get_field_types()) fields = [] - for d, d_type in data: - fields.append(pickled_bytes_to_python_converter(d, d_type)) + for field_data, field_type in field_data_with_types: + if len(field_data) == 0: + fields.append(None) + else: + fields.append(pickled_bytes_to_python_obj(field_data, field_type)) row = Row.of_kind(row_kind, *fields) + row.set_field_names(type_info.get_field_names()) return row + elif isinstance(type_info, TupleTypeInfo): + field_data_with_types = zip(data, type_info.get_field_types()) + fields = [] + for field_data, field_type in field_data_with_types: + if len(field_data) == 0: + fields.append(None) + else: + fields.append(pickled_bytes_to_python_obj(field_data, field_type)) + return tuple(fields) else: data = pickle.loads(data) - if field_type == Types.SQL_TIME(): + if type_info == Types.SQL_TIME(): seconds, microseconds = divmod(data, 10 ** 6) minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60) return datetime.time(hours, minutes, seconds, microseconds) - elif field_type == Types.SQL_DATE(): - return field_type.from_internal_type(data) - elif field_type == Types.SQL_TIMESTAMP(): - return field_type.from_internal_type(int(data.timestamp() * 10 ** 6)) - elif field_type == Types.FLOAT(): - return field_type.from_internal_type(ast.literal_eval(data)) - elif isinstance(field_type, + elif type_info == Types.SQL_DATE(): + return type_info.from_internal_type(data) + elif type_info == Types.SQL_TIMESTAMP(): + return type_info.from_internal_type(int(data.timestamp() * 10 ** 6)) + elif type_info == Types.FLOAT(): + return type_info.from_internal_type(ast.literal_eval(data)) + elif isinstance(type_info, (BasicArrayTypeInfo, PrimitiveArrayTypeInfo, ObjectArrayTypeInfo)): - element_type = field_type._element_type + element_type = type_info._element_type elements = [] for element_bytes in data: - elements.append(pickled_bytes_to_python_converter(element_bytes, element_type)) + elements.append(pickled_bytes_to_python_obj(element_bytes, element_type)) return elements - elif isinstance(field_type, MapTypeInfo): - key_type = field_type._key_type_info - value_type = field_type._value_type_info + elif isinstance(type_info, MapTypeInfo): + key_type = type_info._key_type_info + value_type = type_info._value_type_info zip_kv = zip(data[0], data[1]) - return dict((pickled_bytes_to_python_converter(k, key_type), - pickled_bytes_to_python_converter(v, value_type)) + return dict((pickled_bytes_to_python_obj(k, key_type), + pickled_bytes_to_python_obj(v, value_type)) for k, v in zip_kv) - elif isinstance(field_type, ListTypeInfo): - element_type = field_type.elem_type + elif isinstance(type_info, ListTypeInfo): + element_type = type_info.elem_type elements = [] for element_bytes in data: - elements.append(pickled_bytes_to_python_converter(element_bytes, element_type)) + elements.append(pickled_bytes_to_python_obj(element_bytes, element_type)) return elements else: - return field_type.from_internal_type(data) + return type_info.from_internal_type(data) diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py b/flink-python/pyflink/table/tests/test_table_environment_api.py index 7b23d9535f6..de3c4187724 100644 --- a/flink-python/pyflink/table/tests/test_table_environment_api.py +++ b/flink-python/pyflink/table/tests/test_table_environment_api.py @@ -352,7 +352,8 @@ class DataStreamConversionTestCases(PyFlinkUTTestCase): result._j_table_result.getResolvedSchema().toString()) with result.collect() as result: collected_result = [str(item) for item in result] - expected_result = [item for item in map(str, [Row(1), Row(2), Row(3), Row(4), Row(5)])] + expected_result = [item for item + in map(str, [Row((1,)), Row((2,)), Row((3,)), Row((4,)), Row((5,))])] expected_result.sort() collected_result.sort() self.assertEqual(expected_result, collected_result)