This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 46757739cf5 [FLINK-30168][python] Fix DataStream.execute_and_collect 
to support None data and ObjectArray
46757739cf5 is described below

commit 46757739cf50c1e7b7305a4bc9cf779bb1945a1f
Author: Dian Fu <[email protected]>
AuthorDate: Fri Jan 13 16:56:51 2023 +0800

    [FLINK-30168][python] Fix DataStream.execute_and_collect to support None 
data and ObjectArray
    
    This closes #21664.
---
 .../pyflink/datastream/tests/test_data_stream.py   | 22 ++++++++++++++++++++--
 flink-python/pyflink/datastream/utils.py           |  5 ++++-
 .../flink/api/common/python/PythonBridgeUtils.java | 11 ++++++++---
 3 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 6013c9cf0d8..b8589f54fbf 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -1231,8 +1231,10 @@ class CommonDataStreamTests(PyFlinkTestCase):
         self.test_sink.clear()
 
     def assert_equals_sorted(self, expected, actual):
-        expected.sort()
-        actual.sort()
+        # otherwise, it may thrown exceptions such as the following:
+        # TypeError: '<' not supported between instances of 'NoneType' and 
'str'
+        expected.sort(key=lambda x: str(x))
+        actual.sort(key=lambda x: str(x))
         self.assertEqual(expected, actual)
 
     def test_data_stream_name(self):
@@ -1496,6 +1498,22 @@ class CommonDataStreamTests(PyFlinkTestCase):
             actual = [r for r in results]
             self.assert_equals_sorted(expected, actual)
 
+        test_data = [
+            (["test", "test"], [0.0, 0.0]),
+            ([None, ], [0.0, 0.0])
+        ]
+
+        ds = self.env.from_collection(
+            test_data,
+            type_info=Types.TUPLE(
+                [Types.OBJECT_ARRAY(Types.STRING()), 
Types.OBJECT_ARRAY(Types.DOUBLE())]
+            )
+        )
+        expected = test_data
+        with ds.execute_and_collect() as results:
+            actual = [result for result in results]
+            self.assert_equals_sorted(expected, actual)
+
     def test_function_with_error(self):
         ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), 
('e', 1)],
                                       type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
diff --git a/flink-python/pyflink/datastream/utils.py 
b/flink-python/pyflink/datastream/utils.py
index caad21f88f0..a88170996ee 100644
--- a/flink-python/pyflink/datastream/utils.py
+++ b/flink-python/pyflink/datastream/utils.py
@@ -60,7 +60,10 @@ def convert_to_python_obj(data, type_info):
         pickle_bytes = gateway.jvm.PythonBridgeUtils. \
             getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
         if isinstance(type_info, RowTypeInfo) or isinstance(type_info, 
TupleTypeInfo):
-            field_data = zip(list(pickle_bytes[1:]), 
type_info.get_field_types())
+            if isinstance(type_info, RowTypeInfo):
+                field_data = zip(list(pickle_bytes[1:]), 
type_info.get_field_types())
+            else:
+                field_data = zip(pickle_bytes, type_info.get_field_types())
             fields = []
             for data, field_type in field_data:
                 if len(data) == 0:
diff --git 
a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
 
b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
index ebf4f346de3..b0bd0ab45ee 100644
--- 
a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
+++ 
b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
@@ -28,6 +28,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple;
 import org.apache.flink.api.java.typeutils.ListTypeInfo;
 import org.apache.flink.api.java.typeutils.MapTypeInfo;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
@@ -241,7 +242,7 @@ public final class PythonBridgeUtils {
         Pickler pickler = new Pickler();
         initialize();
         if (obj == null) {
-            return new byte[0];
+            return pickler.dumps(null);
         } else {
             if (dataType instanceof SqlTimeTypeInfo) {
                 SqlTimeTypeInfo<?> sqlTimeTypeInfo =
@@ -270,15 +271,19 @@ public final class PythonBridgeUtils {
                 }
                 return fieldBytes;
             } else if (dataType instanceof BasicArrayTypeInfo
-                    || dataType instanceof PrimitiveArrayTypeInfo) {
+                    || dataType instanceof PrimitiveArrayTypeInfo
+                    || dataType instanceof ObjectArrayTypeInfo) {
                 Object[] objects;
                 TypeInformation<?> elementType;
                 if (dataType instanceof BasicArrayTypeInfo) {
                     objects = (Object[]) obj;
                     elementType = ((BasicArrayTypeInfo<?, ?>) 
dataType).getComponentInfo();
-                } else {
+                } else if (dataType instanceof PrimitiveArrayTypeInfo) {
                     objects = primitiveArrayConverter(obj, dataType);
                     elementType = ((PrimitiveArrayTypeInfo<?>) 
dataType).getComponentType();
+                } else {
+                    objects = (Object[]) obj;
+                    elementType = ((ObjectArrayTypeInfo<?, ?>) 
dataType).getComponentInfo();
                 }
                 List<Object> serializedElements = new 
ArrayList<>(objects.length);
                 for (Object object : objects) {

Reply via email to