This is an automated email from the ASF dual-hosted git repository.

weizhong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 582c7b5  [FLINK-21242][python] Support state access API for the 
map/flat_map/filter/reduce operations of Python KeyedStream
582c7b5 is described below

commit 582c7b5596738ab2c8fa73fbbc2e09189b863fa0
Author: Wei Zhong <[email protected]>
AuthorDate: Thu Mar 11 11:16:48 2021 +0800

    [FLINK-21242][python] Support state access API for the 
map/flat_map/filter/reduce operations of Python KeyedStream
    
    This closes #15083
---
 flink-python/pyflink/common/typeinfo.py            |   3 +-
 flink-python/pyflink/datastream/data_stream.py     | 117 +++++++++++++----
 flink-python/pyflink/datastream/state.py           |   6 +-
 .../pyflink/datastream/tests/test_data_stream.py   | 146 ++++++++++++++++++++-
 .../tests/test_stream_execution_environment.py     |  17 ++-
 .../pyflink/fn_execution/coder_impl_fast.pxd       |   5 -
 .../pyflink/fn_execution/coder_impl_fast.pyx       | 100 +++++---------
 .../pyflink/fn_execution/flink_fn_execution_pb2.py | 134 +++++++++----------
 .../pyflink/fn_execution/operation_utils.py        |   6 -
 .../pyflink/proto/flink-fn-execution.proto         |  11 +-
 .../api/functions/python/KeyByKeySelector.java     |  26 +---
 .../api/operators/python/PythonReduceOperator.java | 106 ---------------
 .../python/beam/BeamPythonFunctionRunner.java      |   8 ++
 13 files changed, 372 insertions(+), 313 deletions(-)

diff --git a/flink-python/pyflink/common/typeinfo.py 
b/flink-python/pyflink/common/typeinfo.py
index 548a81c..8467241 100644
--- a/flink-python/pyflink/common/typeinfo.py
+++ b/flink-python/pyflink/common/typeinfo.py
@@ -886,7 +886,8 @@ def _from_java_type(j_type_info: JavaObject) -> 
TypeInformation:
         j_row_field_types = j_type_info.getFieldTypes()
         row_field_types = [_from_java_type(j_row_field_type) for 
j_row_field_type in
                            j_row_field_types]
-        return Types.ROW_NAMED(j_row_field_names, row_field_types)
+        row_field_names = [field_name for field_name in j_row_field_names]
+        return Types.ROW_NAMED(row_field_names, row_field_types)
 
     JTupleTypeInfo = 
gateway.jvm.org.apache.flink.api.java.typeutils.TupleTypeInfo
     if _is_instance_of(j_type_info, JTupleTypeInfo):
diff --git a/flink-python/pyflink/datastream/data_stream.py 
b/flink-python/pyflink/datastream/data_stream.py
index 3c32c47..9c297cf 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -15,16 +15,18 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
################################################################################
+import uuid
 from typing import Callable, Union, List
 
 from pyflink.common import typeinfo, ExecutionConfig, Row
-from pyflink.common.typeinfo import RowTypeInfo, Types, TypeInformation
+from pyflink.common.typeinfo import RowTypeInfo, Types, TypeInformation, 
_from_java_type
 from pyflink.common.watermark_strategy import WatermarkStrategy
 from pyflink.datastream.functions import _get_python_env, 
FlatMapFunctionWrapper, FlatMapFunction, \
     MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, 
FilterFunction, \
     FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, 
ReduceFunction, \
     ReduceFunctionWrapper, CoMapFunction, CoFlatMapFunction, Partitioner, \
     PartitionerFunctionWrapper, RuntimeContext, ProcessFunction, 
KeyedProcessFunction
+from pyflink.datastream.state import ValueStateDescriptor, ValueState
 from pyflink.datastream.utils import convert_to_python_obj
 from pyflink.java_gateway import get_gateway
 
@@ -289,10 +291,8 @@ class DataStream(object):
 
         output_type_info = typeinfo._from_java_type(
             self._j_data_stream.getTransformation().getOutputType())
-        is_key_pickled_byte_array = False
         if key_type_info is None:
             key_type_info = Types.PICKLED_BYTE_ARRAY()
-            is_key_pickled_byte_array = True
 
         intermediate_map_stream = self.map(
             lambda x: Row(key_selector.get_key(x), x),  # type: ignore
@@ -303,8 +303,8 @@ class DataStream(object):
                                      .STREAM_KEY_BY_MAP_OPERATOR_NAME)
         key_stream = KeyedStream(
             intermediate_map_stream._j_data_stream.keyBy(
-                JKeyByKeySelector(is_key_pickled_byte_array),
-                key_type_info.get_java_type_info()), output_type_info,
+                JKeyByKeySelector(),
+                Types.ROW([key_type_info]).get_java_type_info()), 
output_type_info,
             self)
         return key_stream
 
@@ -785,11 +785,49 @@ class KeyedStream(DataStream):
 
     def map(self, func: Union[Callable, MapFunction], output_type: 
TypeInformation = None) \
             -> 'DataStream':
-        return self._values().map(func, output_type)
+        if not isinstance(func, MapFunction):
+            if callable(func):
+                func = MapFunctionWrapper(func)  # type: ignore
+            else:
+                raise TypeError("The input func must be a MapFunction or a 
callable function.")
+
+        class KeyedMapFunctionWrapper(KeyedProcessFunction):
+
+            def __init__(self, underlying: MapFunction):
+                self._underlying = underlying
+
+            def open(self, runtime_context: RuntimeContext):
+                self._underlying.open(runtime_context)
+
+            def close(self):
+                self._underlying.close()
+
+            def process_element(self, value, ctx: 
'KeyedProcessFunction.Context'):
+                return [self._underlying.map(value)]
+        return self.process(KeyedMapFunctionWrapper(func), output_type)  # 
type: ignore
 
     def flat_map(self, func: Union[Callable, FlatMapFunction], result_type: 
TypeInformation = None)\
             -> 'DataStream':
-        return self._values().flat_map(func, result_type)
+        if not isinstance(func, FlatMapFunction):
+            if callable(func):
+                func = FlatMapFunctionWrapper(func)  # type: ignore
+            else:
+                raise TypeError("The input func must be a FlatMapFunction or a 
callable function.")
+
+        class KeyedFlatMapFunctionWrapper(KeyedProcessFunction):
+
+            def __init__(self, underlying: FlatMapFunction):
+                self._underlying = underlying
+
+            def open(self, runtime_context: RuntimeContext):
+                self._underlying.open(runtime_context)
+
+            def close(self):
+                self._underlying.close()
+
+            def process_element(self, value, ctx: 
'KeyedProcessFunction.Context'):
+                yield from self._underlying.flat_map(value)
+        return self.process(KeyedFlatMapFunctionWrapper(func), result_type)  # 
type: ignore
 
     def reduce(self, func: Union[Callable, ReduceFunction]) -> 'DataStream':
         """
@@ -810,20 +848,56 @@ class KeyedStream(DataStream):
             if callable(func):
                 func = ReduceFunctionWrapper(func)  # type: ignore
             else:
-                raise TypeError("The input must be a ReduceFunction or a 
callable function!")
+                raise TypeError("The input func must be a ReduceFunction or a 
callable function.")
+        output_type = 
_from_java_type(self._original_data_type_info.get_java_type_info())
 
-        from pyflink.fn_execution.flink_fn_execution_pb2 import 
UserDefinedDataStreamFunction
-        j_operator, j_output_type_info = \
-            _get_one_input_stream_operator(
-                self, func, UserDefinedDataStreamFunction.REDUCE)  # type: 
ignore
-        return DataStream(self._j_data_stream.transform(
-            "Keyed Reduce",
-            j_output_type_info,
-            j_operator
-        ))
+        class KeyedReduceFunctionWrapper(KeyedProcessFunction):
+
+            def __init__(self, underlying: ReduceFunction):
+                self._underlying = underlying
+                self._reduce_state_name = "_reduce_state" + str(uuid.uuid4())
+                self._reduce_value_state = None  # type: ValueState
+
+            def open(self, runtime_context: RuntimeContext):
+                self._reduce_value_state = runtime_context.get_state(
+                    ValueStateDescriptor(self._reduce_state_name, output_type))
+                self._underlying.open(runtime_context)
+
+            def process_element(self, value, ctx: 
'KeyedProcessFunction.Context'):
+                reduce_value = self._reduce_value_state.value()
+                if reduce_value is not None:
+                    reduce_value = self._underlying.reduce(reduce_value, value)
+                else:
+                    reduce_value = value
+                self._reduce_value_state.update(reduce_value)
+                return [reduce_value]
+
+        return self.process(KeyedReduceFunctionWrapper(func), output_type)  # 
type: ignore
 
     def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
-        return self._values().filter(func)
+        if callable(func):
+            func = FilterFunctionWrapper(func)  # type: ignore
+        elif not isinstance(func, FilterFunction):
+            raise TypeError("The input func must be a FilterFunction or a 
callable function.")
+
+        class KeyedFilterFunctionWrapper(KeyedProcessFunction):
+
+            def __init__(self, underlying: FilterFunction):
+                self._underlying = underlying
+
+            def open(self, runtime_context: RuntimeContext):
+                self._underlying.open(runtime_context)
+
+            def close(self):
+                self._underlying.close()
+
+            def process_element(self, value, ctx: 
'KeyedProcessFunction.Context'):
+                if self._underlying.filter(value):
+                    return [value]
+                else:
+                    return []
+        return self.process(
+            KeyedFilterFunctionWrapper(func), self._original_data_type_info)  
# type: ignore
 
     def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink':
         return self._values().add_sink(sink_func)
@@ -1067,12 +1141,7 @@ def _get_one_input_stream_operator(data_stream: 
DataStream,
     j_conf = gateway.jvm.org.apache.flink.configuration.Configuration()
 
     from pyflink.fn_execution.flink_fn_execution_pb2 import 
UserDefinedDataStreamFunction
-    if func_type == UserDefinedDataStreamFunction.REDUCE:  # type: ignore
-        # set max bundle size to 1 to force synchronize process for reduce 
function.
-        
j_conf.setInteger(gateway.jvm.org.apache.flink.python.PythonOptions.MAX_BUNDLE_SIZE,
 1)
-        j_output_type_info = j_input_types.getTypeAt(1)
-        JDataStreamPythonFunctionOperator = gateway.jvm.PythonReduceOperator
-    elif func_type == UserDefinedDataStreamFunction.MAP:  # type: ignore
+    if func_type == UserDefinedDataStreamFunction.MAP:  # type: ignore
         if str(func) == '_Flink_PartitionCustomMapFunction':
             JDataStreamPythonFunctionOperator = 
gateway.jvm.PythonPartitionCustomOperator
         else:
diff --git a/flink-python/pyflink/datastream/state.py 
b/flink-python/pyflink/datastream/state.py
index 4c11b29..a576d41 100644
--- a/flink-python/pyflink/datastream/state.py
+++ b/flink-python/pyflink/datastream/state.py
@@ -29,7 +29,9 @@ __all__ = [
     'MapStateDescriptor',
     'MapState',
     'ReducingStateDescriptor',
-    'ReducingState'
+    'ReducingState',
+    'AggregatingStateDescriptor',
+    'AggregatingState'
 ]
 
 T = TypeVar('T')
@@ -395,7 +397,7 @@ class AggregatingStateDescriptor(StateDescriptor):
         super(AggregatingStateDescriptor, self).__init__(name, state_type_info)
         from pyflink.datastream.functions import AggregateFunction
         if not isinstance(agg_function, AggregateFunction):
-            raise TypeError("The input must be a AggregateFunction!")
+            raise TypeError("The input must be a 
pyflink.datastream.functions.AggregateFunction!")
         self._agg_function = agg_function
 
     def get_agg_function(self):
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 3d7a0da..82b07d8 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -25,7 +25,8 @@ from pyflink.common.typeinfo import Types
 from pyflink.common.watermark_strategy import WatermarkStrategy, 
TimestampAssigner
 from pyflink.datastream import StreamExecutionEnvironment, TimeCharacteristic, 
RuntimeContext
 from pyflink.datastream.data_stream import DataStream
-from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction, 
AggregateFunction
+from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction, 
AggregateFunction, \
+    ReduceFunction
 from pyflink.datastream.functions import FilterFunction, ProcessFunction, 
KeyedProcessFunction
 from pyflink.datastream.functions import KeySelector
 from pyflink.datastream.functions import MapFunction, FlatMapFunction
@@ -383,13 +384,28 @@ class DataStreamTests(PyFlinkTestCase):
         class AssertKeyMapFunction(MapFunction):
             def __init__(self):
                 self.pre = None
+                self.state = None
+
+            def open(self, runtime_context: RuntimeContext):
+                self.state = runtime_context.get_state(
+                    ValueStateDescriptor("test_state", Types.INT()))
 
             def map(self, value):
+                state_value = self.state.value()
+                if state_value is None:
+                    state_value = 1
+                else:
+                    state_value += 1
                 if value[0] == 'b':
                     assert self.pre == 'a'
+                    assert state_value == 2
                 if value[0] == 'd':
                     assert self.pre == 'c'
+                    assert state_value == 2
+                if value[0] == 'e':
+                    assert state_value == 1
                 self.pre = value[0]
+                self.state.update(state_value)
                 return value
 
         keyed_stream.map(AssertKeyMapFunction()).add_sink(self.test_sink)
@@ -401,6 +417,134 @@ class DataStreamTests(PyFlinkTestCase):
         expected.sort()
         self.assertEqual(expected, results)
 
+    def test_key_by_flat_map(self):
+        ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), 
('e', 2)],
+                                      type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
+        keyed_stream = ds.key_by(MyKeySelector(), key_type_info=Types.INT())
+
+        with self.assertRaises(Exception):
+            keyed_stream.name("keyed stream")
+
+        class AssertKeyMapFunction(FlatMapFunction):
+            def __init__(self):
+                self.pre = None
+                self.state = None
+
+            def open(self, runtime_context: RuntimeContext):
+                self.state = runtime_context.get_state(
+                    ValueStateDescriptor("test_state", Types.INT()))
+
+            def flat_map(self, value):
+                state_value = self.state.value()
+                if state_value is None:
+                    state_value = 1
+                else:
+                    state_value += 1
+                if value[0] == 'b':
+                    assert self.pre == 'a'
+                    assert state_value == 2
+                if value[0] == 'd':
+                    assert self.pre == 'c'
+                    assert state_value == 2
+                if value[0] == 'e':
+                    assert state_value == 1
+                self.pre = value[0]
+                self.state.update(state_value)
+                yield value
+
+        keyed_stream.flat_map(AssertKeyMapFunction()).add_sink(self.test_sink)
+        self.env.execute('key_by_test')
+        results = self.test_sink.get_results(True)
+        expected = ["Row(f0='e', f1=2)", "Row(f0='a', f1=0)", "Row(f0='b', 
f1=0)",
+                    "Row(f0='c', f1=1)", "Row(f0='d', f1=1)"]
+        results.sort()
+        expected.sort()
+        self.assertEqual(expected, results)
+
+    def test_key_by_filter(self):
+        ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), 
('e', 2)],
+                                      type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
+        keyed_stream = ds.key_by(MyKeySelector())
+
+        with self.assertRaises(Exception):
+            keyed_stream.name("keyed stream")
+
+        class AssertKeyFilterFunction(FilterFunction):
+            def __init__(self):
+                self.pre = None
+                self.state = None
+
+            def open(self, runtime_context: RuntimeContext):
+                self.state = runtime_context.get_state(
+                    ValueStateDescriptor("test_state", Types.INT()))
+
+            def filter(self, value):
+                state_value = self.state.value()
+                if state_value is None:
+                    state_value = 1
+                else:
+                    state_value += 1
+                if value[0] == 'b':
+                    assert self.pre == 'a'
+                    assert state_value == 2
+                    return False
+                if value[0] == 'd':
+                    assert self.pre == 'c'
+                    assert state_value == 2
+                    return False
+                if value[0] == 'e':
+                    assert state_value == 1
+                self.pre = value[0]
+                self.state.update(state_value)
+                return True
+
+        keyed_stream.filter(AssertKeyFilterFunction()).add_sink(self.test_sink)
+        self.env.execute('key_by_test')
+        results = self.test_sink.get_results(False)
+        expected = ['+I[a, 0]', '+I[c, 1]', '+I[e, 2]']
+        results.sort()
+        expected.sort()
+        self.assertEqual(expected, results)
+
+    def test_reduce_with_state(self):
+        ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), 
('e', 1)],
+                                      type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
+        keyed_stream = ds.key_by(MyKeySelector(), key_type_info=Types.INT())
+
+        with self.assertRaises(Exception):
+            keyed_stream.name("keyed stream")
+
+        class AssertKeyReduceFunction(ReduceFunction):
+
+            def __init__(self):
+                self.state = None
+
+            def open(self, runtime_context: RuntimeContext):
+                self.state = runtime_context.get_state(
+                    ValueStateDescriptor("test_state", Types.INT()))
+
+            def reduce(self, value1, value2):
+                state_value = self.state.value()
+                if state_value is None:
+                    state_value = 2
+                else:
+                    state_value += 1
+                result_value = Row(value1[0] + value2[0], value1[1])
+                if result_value[0] == 'ab':
+                    assert state_value == 2
+                if result_value[0] == 'cde':
+                    assert state_value == 3
+                self.state.update(state_value)
+                return result_value
+
+        keyed_stream.reduce(AssertKeyReduceFunction()).add_sink(self.test_sink)
+        self.env.execute('key_by_test')
+        results = self.test_sink.get_results(False)
+        expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde, 
1]']
+        results.sort()
+        expected.sort()
+        self.assertEqual(expected, results)
+
     def test_multi_key_by(self):
         ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), 
('e', 2)],
                                       type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
diff --git 
a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py 
b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
index ca5780b..e5efb91 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
@@ -541,26 +541,25 @@ class StreamExecutionEnvironmentTests(PyFlinkTestCase):
         nodes = eval(self.env.get_execution_plan())['nodes']
 
         # The StreamGraph should be as bellow:
-        # Source: From Collection -> _stream_key_by_map_operator -> 
_keyed_stream_values_operator ->
+        # Source: From Collection -> _stream_key_by_map_operator ->
         # Plus Two Map -> Add From File Map -> Sink: Test Sink.
 
         # Source: From Collection and _stream_key_by_map_operator should have 
same parallelism.
         self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism'])
 
-        # _keyed_stream_values_operator and Plus Two Map should have same 
parallisim.
-        self.assertEqual(nodes[3]['parallelism'], 3)
-        self.assertEqual(nodes[2]['parallelism'], nodes[3]['parallelism'])
+        # The parallelism of Plus Two Map should be 3
+        self.assertEqual(nodes[2]['parallelism'], 3)
 
-        # The ship_strategy for Source: From Collection and 
_stream_key_by_map_operator shoule be
+        # The ship_strategy for Source: From Collection and 
_stream_key_by_map_operator should be
         # FORWARD
         self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'], 
"FORWARD")
 
-        # The ship_strategy for _keyed_stream_values_operator and Plus Two Map 
shoule be
-        # FORWARD
-        self.assertEqual(nodes[3]['predecessors'][0]['ship_strategy'], 
"FORWARD")
+        # The ship_strategy for _keyed_stream_values_operator and Plus Two Map 
should be
+        # HASH
+        self.assertEqual(nodes[2]['predecessors'][0]['ship_strategy'], "HASH")
 
         # The parallelism of Sink: Test Sink should be 4
-        self.assertEqual(nodes[5]['parallelism'], 4)
+        self.assertEqual(nodes[4]['parallelism'], 4)
 
         env_config_with_dependencies = 
dict(get_gateway().jvm.org.apache.flink.python.util
                                             
.PythonConfigUtil.getEnvConfigWithDependencies(
diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pxd 
b/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
index 4ecad43..e4512ae 100644
--- a/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
+++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
@@ -114,11 +114,6 @@ cdef class TableFunctionRowCoderImpl(FlattenRowCoderImpl):
 
 cdef class DataStreamMapCoderImpl(FlattenRowCoderImpl):
     cdef readonly FieldCoder _single_field_coder
-    cdef object _decode_data_stream_field_simple(self, TypeName field_type)
-    cdef object _decode_data_stream_field_complex(self, TypeName field_type, 
FieldCoder field_coder)
-    cdef void _encode_data_stream_field_simple(self, TypeName field_type, item)
-    cdef void _encode_data_stream_field_complex(self, TypeName field_type, 
FieldCoder field_coder,
-                                                item)
 
 cdef class DataStreamFlatMapCoderImpl(BaseCoderImpl):
     cdef readonly object _single_field_coder
diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx 
b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
index 6d98c35..e3143c5 100644
--- a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
+++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
@@ -25,6 +25,7 @@ from libc.string cimport memcpy
 
 import datetime
 import decimal
+import pickle
 
 from pyflink.fn_execution.window import TimeWindow, CountWindow
 from pyflink.table import Row
@@ -188,26 +189,6 @@ cdef class DataStreamMapCoderImpl(FlattenRowCoderImpl):
         output_stream.write(self._tmp_output_data, self._tmp_output_pos)
         self._tmp_output_pos = 0
 
-    cdef void _encode_field(self, CoderType coder_type, TypeName field_type, 
FieldCoder field_coder,
-                        item):
-        if coder_type == SIMPLE:
-            self._encode_field_simple(field_type, item)
-            self._encode_data_stream_field_simple(field_type, item)
-        else:
-            self._encode_field_complex(field_type, field_coder, item)
-            self._encode_data_stream_field_complex(field_type, field_coder, 
item)
-
-    cdef object _decode_field(self, CoderType coder_type, TypeName field_type,
-                        FieldCoder field_coder):
-        if coder_type == SIMPLE:
-            decoded_obj = self._decode_field_simple(field_type)
-            return decoded_obj if decoded_obj is not None \
-                else self._decode_data_stream_field_simple(field_type)
-        else:
-            decoded_obj = self._decode_field_complex(field_type, field_coder)
-            return decoded_obj if decoded_obj is not None \
-                else self._decode_data_stream_field_complex(field_type, 
field_coder)
-
     cpdef object decode_from_stream(self, LengthPrefixInputStream 
input_stream):
         input_stream.read(&self._input_data)
         self._input_pos = 0
@@ -216,50 +197,6 @@ cdef class DataStreamMapCoderImpl(FlattenRowCoderImpl):
         decoded_obj = self._decode_field(coder_type, type_name, 
self._single_field_coder)
         return decoded_obj
 
-    cdef void _encode_data_stream_field_simple(self, TypeName field_type, 
item):
-        if field_type == PICKLED_BYTES:
-            import pickle
-            pickled_bytes = pickle.dumps(item)
-            self._encode_bytes(pickled_bytes, len(pickled_bytes))
-        elif field_type == BIG_DEC:
-            item_bytes = str(item).encode('utf-8')
-            self._encode_bytes(item_bytes, len(item_bytes))
-
-    cdef void _encode_data_stream_field_complex(self, TypeName field_type, 
FieldCoder field_coder,
-                                           item):
-        if field_type == TUPLE:
-            tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
-            tuple_field_count = len(tuple_field_coders)
-            tuple_value = list(item)
-            for i in range(tuple_field_count):
-                field_item = tuple_value[i]
-                tuple_field_coder = tuple_field_coders[i]
-                if field_item is not None:
-                    self._encode_field(tuple_field_coder.coder_type(),
-                                       tuple_field_coder.type_name(),
-                                       tuple_field_coder,
-                                       field_item)
-    cdef object _decode_data_stream_field_simple(self, TypeName field_type):
-        if field_type == PICKLED_BYTES:
-            decoded_bytes = self._decode_bytes()
-            import pickle
-            return pickle.loads(decoded_bytes)
-        elif field_type == BIG_DEC:
-            return decimal.Decimal(self._decode_bytes().decode("utf-8"))
-
-    cdef object _decode_data_stream_field_complex(self, TypeName field_type, 
FieldCoder field_coder):
-        if field_type == TUPLE:
-            tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
-            tuple_field_count = len(tuple_field_coders)
-            decoded_list = []
-            for i in range(tuple_field_count):
-                decoded_list.append(self._decode_field(
-                    tuple_field_coders[i].coder_type(),
-                    tuple_field_coders[i].type_name(),
-                    tuple_field_coders[i]
-                ))
-            return (*decoded_list,)
-
 ROW_KIND_BIT_SIZE = 2
 
 cdef class FlattenRowCoderImpl(BaseCoderImpl):
@@ -425,6 +362,11 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
             hours = minutes // 60
             minutes %= 60
             return datetime.time(hours, minutes, seconds, milliseconds * 1000)
+        elif field_type == PICKLED_BYTES:
+            decoded_bytes = self._decode_bytes()
+            return pickle.loads(decoded_bytes)
+        elif field_type == BIG_DEC:
+            return decimal.Decimal(self._decode_bytes().decode("utf-8"))
 
     cdef object _decode_field_complex(self, TypeName field_type, FieldCoder 
field_coder):
         cdef libc.stdint.int32_t nanoseconds, microseconds, seconds, length
@@ -501,6 +443,17 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
         elif field_type == ROW:
             # Row
             return self._decode_field_row(field_coder)
+        elif field_type == TUPLE:
+            tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
+            tuple_field_count = len(tuple_field_coders)
+            decoded_list = []
+            for i in range(tuple_field_count):
+                decoded_list.append(self._decode_field(
+                    tuple_field_coders[i].coder_type(),
+                    tuple_field_coders[i].type_name(),
+                    tuple_field_coders[i]
+                ))
+            return (*decoded_list,)
 
     cdef object _decode_field_row(self, RowCoderImpl field_coder):
         cdef list row_field_coders, row_field_names
@@ -621,6 +574,13 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
             microsecond = item.microsecond
             milliseconds = hour * 3600000 + minute * 60000 + seconds * 1000 + 
microsecond // 1000
             self._encode_int(milliseconds)
+        elif field_type == PICKLED_BYTES:
+            # pickled object
+            pickled_bytes = pickle.dumps(item)
+            self._encode_bytes(pickled_bytes, len(pickled_bytes))
+        elif field_type == BIG_DEC:
+            item_bytes = str(item).encode('utf-8')
+            self._encode_bytes(item_bytes, len(item_bytes))
 
     cdef void _encode_field_complex(self, TypeName field_type, FieldCoder 
field_coder, item):
         cdef libc.stdint.int32_t nanoseconds, microseconds_of_second, length, 
row_field_count
@@ -711,6 +671,18 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
                 if field_item is not None:
                     self._encode_field(row_field_coder.coder_type(), 
row_field_coder.type_name(),
                                        row_field_coder, field_item)
+        elif field_type == TUPLE:
+            tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
+            tuple_field_count = len(tuple_field_coders)
+            tuple_value = list(item)
+            for i in range(tuple_field_count):
+                field_item = tuple_value[i]
+                tuple_field_coder = tuple_field_coders[i]
+                if field_item is not None:
+                    self._encode_field(tuple_field_coder.coder_type(),
+                                       tuple_field_coder.type_name(),
+                                       tuple_field_coder,
+                                       field_item)
 
     cdef void _extend(self, size_t missing):
         while self._tmp_output_buffer_size < self._tmp_output_pos + missing:
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py 
b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index 0d3aee9..a9aba93 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
   name='flink-fn-execution.proto',
   package='org.apache.flink.fn_execution.v1',
   syntax='proto3',
-  serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 
org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 
\x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02
 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 
\x01(\x0cH\x00\x42\x07\n\x05input\"\x91\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01
 \x01(\x0c\x12\x37\n\x06inputs\x18\x02 
\x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...]
+  serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 
org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 
\x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02
 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 
\x01(\x0cH\x00\x42\x07\n\x05input\"\x91\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01
 \x01(\x0c\x12\x37\n\x06inputs\x18\x02 
\x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...]
 )
 
 
@@ -102,34 +102,30 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = 
_descriptor.EnumDescriptor(
       options=None,
       type=None),
     _descriptor.EnumValueDescriptor(
-      name='REDUCE', index=2, number=2,
+      name='CO_MAP', index=2, number=2,
       options=None,
       type=None),
     _descriptor.EnumValueDescriptor(
-      name='CO_MAP', index=3, number=3,
+      name='CO_FLAT_MAP', index=3, number=3,
       options=None,
       type=None),
     _descriptor.EnumValueDescriptor(
-      name='CO_FLAT_MAP', index=4, number=4,
+      name='PROCESS', index=4, number=4,
       options=None,
       type=None),
     _descriptor.EnumValueDescriptor(
-      name='PROCESS', index=5, number=5,
+      name='KEYED_PROCESS', index=5, number=5,
       options=None,
       type=None),
     _descriptor.EnumValueDescriptor(
-      name='KEYED_PROCESS', index=6, number=6,
-      options=None,
-      type=None),
-    _descriptor.EnumValueDescriptor(
-      name='TIMESTAMP_ASSIGNER', index=7, number=7,
+      name='TIMESTAMP_ASSIGNER', index=6, number=6,
       options=None,
       type=None),
   ],
   containing_type=None,
   options=None,
-  serialized_start=1579,
-  serialized_end=1713,
+  serialized_start=1578,
+  serialized_end=1700,
 )
 _sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE)
 
@@ -154,8 +150,8 @@ _GROUPWINDOW_WINDOWTYPE = _descriptor.EnumDescriptor(
   ],
   containing_type=None,
   options=None,
-  serialized_start=2838,
-  serialized_end=2929,
+  serialized_start=2825,
+  serialized_end=2916,
 )
 _sym_db.RegisterEnumDescriptor(_GROUPWINDOW_WINDOWTYPE)
 
@@ -184,8 +180,8 @@ _GROUPWINDOW_WINDOWPROPERTY = _descriptor.EnumDescriptor(
   ],
   containing_type=None,
   options=None,
-  serialized_start=2931,
-  serialized_end=3030,
+  serialized_start=2918,
+  serialized_end=3017,
 )
 _sym_db.RegisterEnumDescriptor(_GROUPWINDOW_WINDOWPROPERTY)
 
@@ -282,8 +278,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor(
   ],
   containing_type=None,
   options=None,
-  serialized_start=5284,
-  serialized_end=5573,
+  serialized_start=5271,
+  serialized_end=5560,
 )
 _sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME)
 
@@ -380,8 +376,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor(
   ],
   containing_type=None,
   options=None,
-  serialized_start=6398,
-  serialized_end=6675,
+  serialized_start=6385,
+  serialized_end=6662,
 )
 _sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME)
 
@@ -742,7 +738,7 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
   oneofs=[
   ],
   serialized_start=881,
-  serialized_end=1713,
+  serialized_end=1700,
 )
 
 
@@ -772,8 +768,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_LISTVIEW = 
_descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=2244,
-  serialized_end=2328,
+  serialized_start=2231,
+  serialized_end=2315,
 )
 
 _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_MAPVIEW = _descriptor.Descriptor(
@@ -809,8 +805,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_MAPVIEW = 
_descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=2331,
-  serialized_end=2482,
+  serialized_start=2318,
+  serialized_end=2469,
 )
 
 _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC = _descriptor.Descriptor(
@@ -863,8 +859,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC = 
_descriptor.Descriptor(
       name='data_view', 
full_name='org.apache.flink.fn_execution.v1.UserDefinedAggregateFunction.DataViewSpec.data_view',
       index=0, containing_type=None, fields=[]),
   ],
-  serialized_start=1981,
-  serialized_end=2495,
+  serialized_start=1968,
+  serialized_end=2482,
 )
 
 _USERDEFINEDAGGREGATEFUNCTION = _descriptor.Descriptor(
@@ -928,8 +924,8 @@ _USERDEFINEDAGGREGATEFUNCTION = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1716,
-  serialized_end=2495,
+  serialized_start=1703,
+  serialized_end=2482,
 )
 
 
@@ -1017,8 +1013,8 @@ _GROUPWINDOW = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=2498,
-  serialized_end=3030,
+  serialized_start=2485,
+  serialized_end=3017,
 )
 
 
@@ -1125,8 +1121,8 @@ _USERDEFINEDAGGREGATEFUNCTIONS = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3033,
-  serialized_end=3542,
+  serialized_start=3020,
+  serialized_end=3529,
 )
 
 
@@ -1163,8 +1159,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3620,
-  serialized_end=3771,
+  serialized_start=3607,
+  serialized_end=3758,
 )
 
 _SCHEMA_TIMEINFO = _descriptor.Descriptor(
@@ -1193,8 +1189,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3773,
-  serialized_end=3802,
+  serialized_start=3760,
+  serialized_end=3789,
 )
 
 _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
@@ -1223,8 +1219,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3804,
-  serialized_end=3838,
+  serialized_start=3791,
+  serialized_end=3825,
 )
 
 _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -1253,8 +1249,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3840,
-  serialized_end=3884,
+  serialized_start=3827,
+  serialized_end=3871,
 )
 
 _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -1283,8 +1279,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3886,
-  serialized_end=3925,
+  serialized_start=3873,
+  serialized_end=3912,
 )
 
 _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
@@ -1320,8 +1316,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3927,
-  serialized_end=3974,
+  serialized_start=3914,
+  serialized_end=3961,
 )
 
 _SCHEMA_BINARYINFO = _descriptor.Descriptor(
@@ -1350,8 +1346,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3976,
-  serialized_end=4004,
+  serialized_start=3963,
+  serialized_end=3991,
 )
 
 _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
@@ -1380,8 +1376,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=4006,
-  serialized_end=4037,
+  serialized_start=3993,
+  serialized_end=4024,
 )
 
 _SCHEMA_CHARINFO = _descriptor.Descriptor(
@@ -1410,8 +1406,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=4039,
-  serialized_end=4065,
+  serialized_start=4026,
+  serialized_end=4052,
 )
 
 _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
@@ -1440,8 +1436,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=4067,
-  serialized_end=4096,
+  serialized_start=4054,
+  serialized_end=4083,
 )
 
 _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
@@ -1564,8 +1560,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
       name='type_info', 
full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info',
       index=0, containing_type=None, fields=[]),
   ],
-  serialized_start=4099,
-  serialized_end=5171,
+  serialized_start=4086,
+  serialized_end=5158,
 )
 
 _SCHEMA_FIELD = _descriptor.Descriptor(
@@ -1608,8 +1604,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=5173,
-  serialized_end=5281,
+  serialized_start=5160,
+  serialized_end=5268,
 )
 
 _SCHEMA = _descriptor.Descriptor(
@@ -1639,8 +1635,8 @@ _SCHEMA = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3545,
-  serialized_end=5573,
+  serialized_start=3532,
+  serialized_end=5560,
 )
 
 
@@ -1677,8 +1673,8 @@ _TYPEINFO_MAPTYPEINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=5987,
-  serialized_end=6126,
+  serialized_start=5974,
+  serialized_end=6113,
 )
 
 _TYPEINFO_ROWTYPEINFO_FIELD = _descriptor.Descriptor(
@@ -1714,8 +1710,8 @@ _TYPEINFO_ROWTYPEINFO_FIELD = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=6222,
-  serialized_end=6313,
+  serialized_start=6209,
+  serialized_end=6300,
 )
 
 _TYPEINFO_ROWTYPEINFO = _descriptor.Descriptor(
@@ -1744,8 +1740,8 @@ _TYPEINFO_ROWTYPEINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=6129,
-  serialized_end=6313,
+  serialized_start=6116,
+  serialized_end=6300,
 )
 
 _TYPEINFO_TUPLETYPEINFO = _descriptor.Descriptor(
@@ -1774,8 +1770,8 @@ _TYPEINFO_TUPLETYPEINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=6315,
-  serialized_end=6395,
+  serialized_start=6302,
+  serialized_end=6382,
 )
 
 _TYPEINFO = _descriptor.Descriptor(
@@ -1836,8 +1832,8 @@ _TYPEINFO = _descriptor.Descriptor(
       name='type_info', 
full_name='org.apache.flink.fn_execution.v1.TypeInfo.type_info',
       index=0, containing_type=None, fields=[]),
   ],
-  serialized_start=5576,
-  serialized_end=6688,
+  serialized_start=5563,
+  serialized_end=6675,
 )
 
 _INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py 
b/flink-python/pyflink/fn_execution/operation_utils.py
index aaa2add..8b65782 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -237,12 +237,6 @@ def extract_data_stream_stateless_function(udf_proto):
         func = user_defined_func.map
     elif func_type == UserDefinedDataStreamFunction.FLAT_MAP:
         func = user_defined_func.flat_map
-    elif func_type == UserDefinedDataStreamFunction.REDUCE:
-        reduce_func = user_defined_func.reduce
-
-        def wrapped_func(value):
-            return reduce_func(value[0], value[1])
-        func = wrapped_func
     elif func_type == UserDefinedDataStreamFunction.CO_MAP:
         co_map_func = user_defined_func
 
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto 
b/flink-python/pyflink/proto/flink-fn-execution.proto
index c34dede..3667997 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -81,12 +81,11 @@ message UserDefinedDataStreamFunction {
   enum FunctionType {
     MAP = 0;
     FLAT_MAP = 1;
-    REDUCE = 2;
-    CO_MAP = 3;
-    CO_FLAT_MAP = 4;
-    PROCESS = 5;
-    KEYED_PROCESS = 6;
-    TIMESTAMP_ASSIGNER = 7;
+    CO_MAP = 2;
+    CO_FLAT_MAP = 3;
+    PROCESS = 4;
+    KEYED_PROCESS = 5;
+    TIMESTAMP_ASSIGNER = 6;
   }
 
   message JobParameter {
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
index b4d7b41..4cacdc9 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
@@ -21,34 +21,20 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.types.Row;
 
-import net.razorvine.pickle.Unpickler;
-
 /**
  * {@link KeyByKeySelector} is responsible for extracting the first field of 
the input row as key.
  * The input row is generated by python DataStream map function in the format 
of
  * (key_selector.get_key(value), value) tuple2.
  */
 @Internal
-public class KeyByKeySelector implements KeySelector<Row, Object> {
+public class KeyByKeySelector implements KeySelector<Row, Row> {
     private static final long serialVersionUID = 1L;
 
-    private final boolean isKeyOfPickledByteArray;
-    private transient Unpickler unpickler = null;
-
-    public KeyByKeySelector(boolean isKeyPickledByteArray) {
-        this.isKeyOfPickledByteArray = isKeyPickledByteArray;
-    }
-
     @Override
-    public Object getKey(Row value) throws Exception {
-        Object key = value.getField(0);
-        if (!isKeyOfPickledByteArray) {
-            return key;
-        } else {
-            if (this.unpickler == null) {
-                unpickler = new Unpickler();
-            }
-            return this.unpickler.loads((byte[]) key);
-        }
+    public Row getKey(Row value) {
+        Object realKey = value.getField(0);
+        Row wrapper = new Row(1);
+        wrapper.setField(0, realKey);
+        return wrapper;
     }
 }
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/PythonReduceOperator.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/PythonReduceOperator.java
deleted file mode 100644
index 9c57811..0000000
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/PythonReduceOperator.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.operators.python;
-
-import org.apache.flink.annotation.Internal;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.configuration.Configuration;
-import 
org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.types.Row;
-
-/**
- * {@link PythonReduceOperator} is responsible for launching beam runner which 
will start a python
- * harness to execute user defined python ReduceFunction.
- */
-@Internal
-public class PythonReduceOperator<OUT> extends 
OneInputPythonFunctionOperator<Row, OUT, Row, OUT> {
-
-    private static final long serialVersionUID = 1L;
-
-    private static final String MAP_CODER_URN = "flink:coder:map:v1";
-
-    private static final String STATE_NAME = "_python_reduce_state";
-
-    /** This state is used to store the currently reduce value. */
-    private transient ValueState<OUT> valueState;
-
-    private transient Row reuseRow;
-
-    public PythonReduceOperator(
-            Configuration config,
-            TypeInformation<Row> inputTypeInfo,
-            TypeInformation<OUT> outputTypeInfo,
-            DataStreamPythonFunctionInfo pythonFunctionInfo) {
-        super(
-                config,
-                new RowTypeInfo(outputTypeInfo, outputTypeInfo),
-                outputTypeInfo,
-                pythonFunctionInfo);
-    }
-
-    @Override
-    public void open() throws Exception {
-        super.open();
-
-        // create state
-        ValueStateDescriptor<OUT> stateId =
-                new ValueStateDescriptor<>(STATE_NAME, runnerOutputTypeInfo);
-        valueState = getPartitionedState(stateId);
-
-        reuseRow = new Row(2);
-    }
-
-    @Override
-    public void processElement(StreamRecord<Row> element) throws Exception {
-        OUT inputData = (OUT) element.getValue().getField(1);
-        OUT currentValue = valueState.value();
-        if (currentValue == null) {
-            // emit directly for the first element.
-            valueState.update(inputData);
-            collector.setAbsoluteTimestamp(element.getTimestamp());
-            collector.collect(inputData);
-        } else {
-            reuseRow.setField(0, currentValue);
-            reuseRow.setField(1, inputData);
-            element.replace(reuseRow);
-            super.processElement(element);
-        }
-    }
-
-    @Override
-    public void emitResult(Tuple2<byte[], Integer> resultTuple) throws 
Exception {
-        byte[] rawResult = resultTuple.f0;
-        int length = resultTuple.f1;
-        bais.setBuffer(rawResult, 0, length);
-        OUT result = runnerOutputTypeSerializer.deserialize(baisWrapper);
-        valueState.update(result);
-        collector.setAbsoluteTimestamp(bufferedTimestamp.poll());
-        collector.collect(result);
-    }
-
-    @Override
-    public String getCoderUrn() {
-        return MAP_CODER_URN;
-    }
-}
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
index 92ba04e..def9057 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
@@ -28,6 +28,7 @@ import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.runtime.RowSerializer;
 import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
 import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -47,6 +48,7 @@ import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.utils.ByteArrayWrapper;
 import org.apache.flink.streaming.api.utils.ByteArrayWrapperSerializer;
 import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
 import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.LongFunctionWithException;
@@ -657,6 +659,12 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                 TypeSerializer keySerializer,
                 Map<String, String> config) {
             this.keyedStateBackend = keyedStateBackend;
+            TypeSerializer frameworkKeySerializer = 
keyedStateBackend.getKeySerializer();
+            if (!(frameworkKeySerializer instanceof AbstractRowDataSerializer
+                    || frameworkKeySerializer instanceof RowSerializer)) {
+                throw new RuntimeException(
+                        "Currently SimpleStateRequestHandler only support row 
key!");
+            }
             this.keySerializer = keySerializer;
             this.valueSerializer =
                     
PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.createSerializer(

Reply via email to