This is an automated email from the ASF dual-hosted git repository.
weizhong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 582c7b5 [FLINK-21242][python] Support state access API for the
map/flat_map/filter/reduce operations of Python KeyedStream
582c7b5 is described below
commit 582c7b5596738ab2c8fa73fbbc2e09189b863fa0
Author: Wei Zhong <[email protected]>
AuthorDate: Thu Mar 11 11:16:48 2021 +0800
[FLINK-21242][python] Support state access API for the
map/flat_map/filter/reduce operations of Python KeyedStream
This closes #15083
---
flink-python/pyflink/common/typeinfo.py | 3 +-
flink-python/pyflink/datastream/data_stream.py | 117 +++++++++++++----
flink-python/pyflink/datastream/state.py | 6 +-
.../pyflink/datastream/tests/test_data_stream.py | 146 ++++++++++++++++++++-
.../tests/test_stream_execution_environment.py | 17 ++-
.../pyflink/fn_execution/coder_impl_fast.pxd | 5 -
.../pyflink/fn_execution/coder_impl_fast.pyx | 100 +++++---------
.../pyflink/fn_execution/flink_fn_execution_pb2.py | 134 +++++++++----------
.../pyflink/fn_execution/operation_utils.py | 6 -
.../pyflink/proto/flink-fn-execution.proto | 11 +-
.../api/functions/python/KeyByKeySelector.java | 26 +---
.../api/operators/python/PythonReduceOperator.java | 106 ---------------
.../python/beam/BeamPythonFunctionRunner.java | 8 ++
13 files changed, 372 insertions(+), 313 deletions(-)
diff --git a/flink-python/pyflink/common/typeinfo.py
b/flink-python/pyflink/common/typeinfo.py
index 548a81c..8467241 100644
--- a/flink-python/pyflink/common/typeinfo.py
+++ b/flink-python/pyflink/common/typeinfo.py
@@ -886,7 +886,8 @@ def _from_java_type(j_type_info: JavaObject) ->
TypeInformation:
j_row_field_types = j_type_info.getFieldTypes()
row_field_types = [_from_java_type(j_row_field_type) for
j_row_field_type in
j_row_field_types]
- return Types.ROW_NAMED(j_row_field_names, row_field_types)
+ row_field_names = [field_name for field_name in j_row_field_names]
+ return Types.ROW_NAMED(row_field_names, row_field_types)
JTupleTypeInfo =
gateway.jvm.org.apache.flink.api.java.typeutils.TupleTypeInfo
if _is_instance_of(j_type_info, JTupleTypeInfo):
diff --git a/flink-python/pyflink/datastream/data_stream.py
b/flink-python/pyflink/datastream/data_stream.py
index 3c32c47..9c297cf 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -15,16 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
+import uuid
from typing import Callable, Union, List
from pyflink.common import typeinfo, ExecutionConfig, Row
-from pyflink.common.typeinfo import RowTypeInfo, Types, TypeInformation
+from pyflink.common.typeinfo import RowTypeInfo, Types, TypeInformation,
_from_java_type
from pyflink.common.watermark_strategy import WatermarkStrategy
from pyflink.datastream.functions import _get_python_env,
FlatMapFunctionWrapper, FlatMapFunction, \
MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction,
FilterFunction, \
FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector,
ReduceFunction, \
ReduceFunctionWrapper, CoMapFunction, CoFlatMapFunction, Partitioner, \
PartitionerFunctionWrapper, RuntimeContext, ProcessFunction,
KeyedProcessFunction
+from pyflink.datastream.state import ValueStateDescriptor, ValueState
from pyflink.datastream.utils import convert_to_python_obj
from pyflink.java_gateway import get_gateway
@@ -289,10 +291,8 @@ class DataStream(object):
output_type_info = typeinfo._from_java_type(
self._j_data_stream.getTransformation().getOutputType())
- is_key_pickled_byte_array = False
if key_type_info is None:
key_type_info = Types.PICKLED_BYTE_ARRAY()
- is_key_pickled_byte_array = True
intermediate_map_stream = self.map(
lambda x: Row(key_selector.get_key(x), x), # type: ignore
@@ -303,8 +303,8 @@ class DataStream(object):
.STREAM_KEY_BY_MAP_OPERATOR_NAME)
key_stream = KeyedStream(
intermediate_map_stream._j_data_stream.keyBy(
- JKeyByKeySelector(is_key_pickled_byte_array),
- key_type_info.get_java_type_info()), output_type_info,
+ JKeyByKeySelector(),
+ Types.ROW([key_type_info]).get_java_type_info()),
output_type_info,
self)
return key_stream
@@ -785,11 +785,49 @@ class KeyedStream(DataStream):
def map(self, func: Union[Callable, MapFunction], output_type:
TypeInformation = None) \
-> 'DataStream':
- return self._values().map(func, output_type)
+ if not isinstance(func, MapFunction):
+ if callable(func):
+ func = MapFunctionWrapper(func) # type: ignore
+ else:
+ raise TypeError("The input func must be a MapFunction or a
callable function.")
+
+ class KeyedMapFunctionWrapper(KeyedProcessFunction):
+
+ def __init__(self, underlying: MapFunction):
+ self._underlying = underlying
+
+ def open(self, runtime_context: RuntimeContext):
+ self._underlying.open(runtime_context)
+
+ def close(self):
+ self._underlying.close()
+
+ def process_element(self, value, ctx:
'KeyedProcessFunction.Context'):
+ return [self._underlying.map(value)]
+ return self.process(KeyedMapFunctionWrapper(func), output_type) #
type: ignore
def flat_map(self, func: Union[Callable, FlatMapFunction], result_type:
TypeInformation = None)\
-> 'DataStream':
- return self._values().flat_map(func, result_type)
+ if not isinstance(func, FlatMapFunction):
+ if callable(func):
+ func = FlatMapFunctionWrapper(func) # type: ignore
+ else:
+ raise TypeError("The input func must be a FlatMapFunction or a
callable function.")
+
+ class KeyedFlatMapFunctionWrapper(KeyedProcessFunction):
+
+ def __init__(self, underlying: FlatMapFunction):
+ self._underlying = underlying
+
+ def open(self, runtime_context: RuntimeContext):
+ self._underlying.open(runtime_context)
+
+ def close(self):
+ self._underlying.close()
+
+ def process_element(self, value, ctx:
'KeyedProcessFunction.Context'):
+ yield from self._underlying.flat_map(value)
+ return self.process(KeyedFlatMapFunctionWrapper(func), result_type) #
type: ignore
def reduce(self, func: Union[Callable, ReduceFunction]) -> 'DataStream':
"""
@@ -810,20 +848,56 @@ class KeyedStream(DataStream):
if callable(func):
func = ReduceFunctionWrapper(func) # type: ignore
else:
- raise TypeError("The input must be a ReduceFunction or a
callable function!")
+ raise TypeError("The input func must be a ReduceFunction or a
callable function.")
+ output_type =
_from_java_type(self._original_data_type_info.get_java_type_info())
- from pyflink.fn_execution.flink_fn_execution_pb2 import
UserDefinedDataStreamFunction
- j_operator, j_output_type_info = \
- _get_one_input_stream_operator(
- self, func, UserDefinedDataStreamFunction.REDUCE) # type:
ignore
- return DataStream(self._j_data_stream.transform(
- "Keyed Reduce",
- j_output_type_info,
- j_operator
- ))
+ class KeyedReduceFunctionWrapper(KeyedProcessFunction):
+
+ def __init__(self, underlying: ReduceFunction):
+ self._underlying = underlying
+ self._reduce_state_name = "_reduce_state" + str(uuid.uuid4())
+ self._reduce_value_state = None # type: ValueState
+
+ def open(self, runtime_context: RuntimeContext):
+ self._reduce_value_state = runtime_context.get_state(
+ ValueStateDescriptor(self._reduce_state_name, output_type))
+ self._underlying.open(runtime_context)
+
+ def process_element(self, value, ctx:
'KeyedProcessFunction.Context'):
+ reduce_value = self._reduce_value_state.value()
+ if reduce_value is not None:
+ reduce_value = self._underlying.reduce(reduce_value, value)
+ else:
+ reduce_value = value
+ self._reduce_value_state.update(reduce_value)
+ return [reduce_value]
+
+ return self.process(KeyedReduceFunctionWrapper(func), output_type) #
type: ignore
def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
- return self._values().filter(func)
+ if callable(func):
+ func = FilterFunctionWrapper(func) # type: ignore
+ elif not isinstance(func, FilterFunction):
+ raise TypeError("The input func must be a FilterFunction or a
callable function.")
+
+ class KeyedFilterFunctionWrapper(KeyedProcessFunction):
+
+ def __init__(self, underlying: FilterFunction):
+ self._underlying = underlying
+
+ def open(self, runtime_context: RuntimeContext):
+ self._underlying.open(runtime_context)
+
+ def close(self):
+ self._underlying.close()
+
+ def process_element(self, value, ctx:
'KeyedProcessFunction.Context'):
+ if self._underlying.filter(value):
+ return [value]
+ else:
+ return []
+ return self.process(
+ KeyedFilterFunctionWrapper(func), self._original_data_type_info)
# type: ignore
def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink':
return self._values().add_sink(sink_func)
@@ -1067,12 +1141,7 @@ def _get_one_input_stream_operator(data_stream:
DataStream,
j_conf = gateway.jvm.org.apache.flink.configuration.Configuration()
from pyflink.fn_execution.flink_fn_execution_pb2 import
UserDefinedDataStreamFunction
- if func_type == UserDefinedDataStreamFunction.REDUCE: # type: ignore
- # set max bundle size to 1 to force synchronize process for reduce
function.
-
j_conf.setInteger(gateway.jvm.org.apache.flink.python.PythonOptions.MAX_BUNDLE_SIZE,
1)
- j_output_type_info = j_input_types.getTypeAt(1)
- JDataStreamPythonFunctionOperator = gateway.jvm.PythonReduceOperator
- elif func_type == UserDefinedDataStreamFunction.MAP: # type: ignore
+ if func_type == UserDefinedDataStreamFunction.MAP: # type: ignore
if str(func) == '_Flink_PartitionCustomMapFunction':
JDataStreamPythonFunctionOperator =
gateway.jvm.PythonPartitionCustomOperator
else:
diff --git a/flink-python/pyflink/datastream/state.py
b/flink-python/pyflink/datastream/state.py
index 4c11b29..a576d41 100644
--- a/flink-python/pyflink/datastream/state.py
+++ b/flink-python/pyflink/datastream/state.py
@@ -29,7 +29,9 @@ __all__ = [
'MapStateDescriptor',
'MapState',
'ReducingStateDescriptor',
- 'ReducingState'
+ 'ReducingState',
+ 'AggregatingStateDescriptor',
+ 'AggregatingState'
]
T = TypeVar('T')
@@ -395,7 +397,7 @@ class AggregatingStateDescriptor(StateDescriptor):
super(AggregatingStateDescriptor, self).__init__(name, state_type_info)
from pyflink.datastream.functions import AggregateFunction
if not isinstance(agg_function, AggregateFunction):
- raise TypeError("The input must be a AggregateFunction!")
+ raise TypeError("The input must be a
pyflink.datastream.functions.AggregateFunction!")
self._agg_function = agg_function
def get_agg_function(self):
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 3d7a0da..82b07d8 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -25,7 +25,8 @@ from pyflink.common.typeinfo import Types
from pyflink.common.watermark_strategy import WatermarkStrategy,
TimestampAssigner
from pyflink.datastream import StreamExecutionEnvironment, TimeCharacteristic,
RuntimeContext
from pyflink.datastream.data_stream import DataStream
-from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction,
AggregateFunction
+from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction,
AggregateFunction, \
+ ReduceFunction
from pyflink.datastream.functions import FilterFunction, ProcessFunction,
KeyedProcessFunction
from pyflink.datastream.functions import KeySelector
from pyflink.datastream.functions import MapFunction, FlatMapFunction
@@ -383,13 +384,28 @@ class DataStreamTests(PyFlinkTestCase):
class AssertKeyMapFunction(MapFunction):
def __init__(self):
self.pre = None
+ self.state = None
+
+ def open(self, runtime_context: RuntimeContext):
+ self.state = runtime_context.get_state(
+ ValueStateDescriptor("test_state", Types.INT()))
def map(self, value):
+ state_value = self.state.value()
+ if state_value is None:
+ state_value = 1
+ else:
+ state_value += 1
if value[0] == 'b':
assert self.pre == 'a'
+ assert state_value == 2
if value[0] == 'd':
assert self.pre == 'c'
+ assert state_value == 2
+ if value[0] == 'e':
+ assert state_value == 1
self.pre = value[0]
+ self.state.update(state_value)
return value
keyed_stream.map(AssertKeyMapFunction()).add_sink(self.test_sink)
@@ -401,6 +417,134 @@ class DataStreamTests(PyFlinkTestCase):
expected.sort()
self.assertEqual(expected, results)
+ def test_key_by_flat_map(self):
+ ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1),
('e', 2)],
+ type_info=Types.ROW([Types.STRING(),
Types.INT()]))
+ keyed_stream = ds.key_by(MyKeySelector(), key_type_info=Types.INT())
+
+ with self.assertRaises(Exception):
+ keyed_stream.name("keyed stream")
+
+ class AssertKeyMapFunction(FlatMapFunction):
+ def __init__(self):
+ self.pre = None
+ self.state = None
+
+ def open(self, runtime_context: RuntimeContext):
+ self.state = runtime_context.get_state(
+ ValueStateDescriptor("test_state", Types.INT()))
+
+ def flat_map(self, value):
+ state_value = self.state.value()
+ if state_value is None:
+ state_value = 1
+ else:
+ state_value += 1
+ if value[0] == 'b':
+ assert self.pre == 'a'
+ assert state_value == 2
+ if value[0] == 'd':
+ assert self.pre == 'c'
+ assert state_value == 2
+ if value[0] == 'e':
+ assert state_value == 1
+ self.pre = value[0]
+ self.state.update(state_value)
+ yield value
+
+ keyed_stream.flat_map(AssertKeyMapFunction()).add_sink(self.test_sink)
+ self.env.execute('key_by_test')
+ results = self.test_sink.get_results(True)
+ expected = ["Row(f0='e', f1=2)", "Row(f0='a', f1=0)", "Row(f0='b',
f1=0)",
+ "Row(f0='c', f1=1)", "Row(f0='d', f1=1)"]
+ results.sort()
+ expected.sort()
+ self.assertEqual(expected, results)
+
+ def test_key_by_filter(self):
+ ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1),
('e', 2)],
+ type_info=Types.ROW([Types.STRING(),
Types.INT()]))
+ keyed_stream = ds.key_by(MyKeySelector())
+
+ with self.assertRaises(Exception):
+ keyed_stream.name("keyed stream")
+
+ class AssertKeyFilterFunction(FilterFunction):
+ def __init__(self):
+ self.pre = None
+ self.state = None
+
+ def open(self, runtime_context: RuntimeContext):
+ self.state = runtime_context.get_state(
+ ValueStateDescriptor("test_state", Types.INT()))
+
+ def filter(self, value):
+ state_value = self.state.value()
+ if state_value is None:
+ state_value = 1
+ else:
+ state_value += 1
+ if value[0] == 'b':
+ assert self.pre == 'a'
+ assert state_value == 2
+ return False
+ if value[0] == 'd':
+ assert self.pre == 'c'
+ assert state_value == 2
+ return False
+ if value[0] == 'e':
+ assert state_value == 1
+ self.pre = value[0]
+ self.state.update(state_value)
+ return True
+
+ keyed_stream.filter(AssertKeyFilterFunction()).add_sink(self.test_sink)
+ self.env.execute('key_by_test')
+ results = self.test_sink.get_results(False)
+ expected = ['+I[a, 0]', '+I[c, 1]', '+I[e, 2]']
+ results.sort()
+ expected.sort()
+ self.assertEqual(expected, results)
+
+ def test_reduce_with_state(self):
+ ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1),
('e', 1)],
+ type_info=Types.ROW([Types.STRING(),
Types.INT()]))
+ keyed_stream = ds.key_by(MyKeySelector(), key_type_info=Types.INT())
+
+ with self.assertRaises(Exception):
+ keyed_stream.name("keyed stream")
+
+ class AssertKeyReduceFunction(ReduceFunction):
+
+ def __init__(self):
+ self.state = None
+
+ def open(self, runtime_context: RuntimeContext):
+ self.state = runtime_context.get_state(
+ ValueStateDescriptor("test_state", Types.INT()))
+
+ def reduce(self, value1, value2):
+ state_value = self.state.value()
+ if state_value is None:
+ state_value = 2
+ else:
+ state_value += 1
+ result_value = Row(value1[0] + value2[0], value1[1])
+ if result_value[0] == 'ab':
+ assert state_value == 2
+ if result_value[0] == 'cde':
+ assert state_value == 3
+ self.state.update(state_value)
+ return result_value
+
+ keyed_stream.reduce(AssertKeyReduceFunction()).add_sink(self.test_sink)
+ self.env.execute('key_by_test')
+ results = self.test_sink.get_results(False)
+ expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde,
1]']
+ results.sort()
+ expected.sort()
+ self.assertEqual(expected, results)
+
def test_multi_key_by(self):
ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1),
('e', 2)],
type_info=Types.ROW([Types.STRING(),
Types.INT()]))
diff --git
a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
index ca5780b..e5efb91 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
@@ -541,26 +541,25 @@ class StreamExecutionEnvironmentTests(PyFlinkTestCase):
nodes = eval(self.env.get_execution_plan())['nodes']
# The StreamGraph should be as bellow:
- # Source: From Collection -> _stream_key_by_map_operator ->
_keyed_stream_values_operator ->
+ # Source: From Collection -> _stream_key_by_map_operator ->
# Plus Two Map -> Add From File Map -> Sink: Test Sink.
# Source: From Collection and _stream_key_by_map_operator should have
same parallelism.
self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism'])
- # _keyed_stream_values_operator and Plus Two Map should have same
parallisim.
- self.assertEqual(nodes[3]['parallelism'], 3)
- self.assertEqual(nodes[2]['parallelism'], nodes[3]['parallelism'])
+ # The parallelism of Plus Two Map should be 3
+ self.assertEqual(nodes[2]['parallelism'], 3)
- # The ship_strategy for Source: From Collection and
_stream_key_by_map_operator shoule be
+ # The ship_strategy for Source: From Collection and
_stream_key_by_map_operator should be
# FORWARD
self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'],
"FORWARD")
- # The ship_strategy for _keyed_stream_values_operator and Plus Two Map
shoule be
- # FORWARD
- self.assertEqual(nodes[3]['predecessors'][0]['ship_strategy'],
"FORWARD")
+ # The ship_strategy for _keyed_stream_values_operator and Plus Two Map
should be
+ # HASH
+ self.assertEqual(nodes[2]['predecessors'][0]['ship_strategy'], "HASH")
# The parallelism of Sink: Test Sink should be 4
- self.assertEqual(nodes[5]['parallelism'], 4)
+ self.assertEqual(nodes[4]['parallelism'], 4)
env_config_with_dependencies =
dict(get_gateway().jvm.org.apache.flink.python.util
.PythonConfigUtil.getEnvConfigWithDependencies(
diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
b/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
index 4ecad43..e4512ae 100644
--- a/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
+++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pxd
@@ -114,11 +114,6 @@ cdef class TableFunctionRowCoderImpl(FlattenRowCoderImpl):
cdef class DataStreamMapCoderImpl(FlattenRowCoderImpl):
cdef readonly FieldCoder _single_field_coder
- cdef object _decode_data_stream_field_simple(self, TypeName field_type)
- cdef object _decode_data_stream_field_complex(self, TypeName field_type,
FieldCoder field_coder)
- cdef void _encode_data_stream_field_simple(self, TypeName field_type, item)
- cdef void _encode_data_stream_field_complex(self, TypeName field_type,
FieldCoder field_coder,
- item)
cdef class DataStreamFlatMapCoderImpl(BaseCoderImpl):
cdef readonly object _single_field_coder
diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
index 6d98c35..e3143c5 100644
--- a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
+++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx
@@ -25,6 +25,7 @@ from libc.string cimport memcpy
import datetime
import decimal
+import pickle
from pyflink.fn_execution.window import TimeWindow, CountWindow
from pyflink.table import Row
@@ -188,26 +189,6 @@ cdef class DataStreamMapCoderImpl(FlattenRowCoderImpl):
output_stream.write(self._tmp_output_data, self._tmp_output_pos)
self._tmp_output_pos = 0
- cdef void _encode_field(self, CoderType coder_type, TypeName field_type,
FieldCoder field_coder,
- item):
- if coder_type == SIMPLE:
- self._encode_field_simple(field_type, item)
- self._encode_data_stream_field_simple(field_type, item)
- else:
- self._encode_field_complex(field_type, field_coder, item)
- self._encode_data_stream_field_complex(field_type, field_coder,
item)
-
- cdef object _decode_field(self, CoderType coder_type, TypeName field_type,
- FieldCoder field_coder):
- if coder_type == SIMPLE:
- decoded_obj = self._decode_field_simple(field_type)
- return decoded_obj if decoded_obj is not None \
- else self._decode_data_stream_field_simple(field_type)
- else:
- decoded_obj = self._decode_field_complex(field_type, field_coder)
- return decoded_obj if decoded_obj is not None \
- else self._decode_data_stream_field_complex(field_type,
field_coder)
-
cpdef object decode_from_stream(self, LengthPrefixInputStream
input_stream):
input_stream.read(&self._input_data)
self._input_pos = 0
@@ -216,50 +197,6 @@ cdef class DataStreamMapCoderImpl(FlattenRowCoderImpl):
decoded_obj = self._decode_field(coder_type, type_name,
self._single_field_coder)
return decoded_obj
- cdef void _encode_data_stream_field_simple(self, TypeName field_type,
item):
- if field_type == PICKLED_BYTES:
- import pickle
- pickled_bytes = pickle.dumps(item)
- self._encode_bytes(pickled_bytes, len(pickled_bytes))
- elif field_type == BIG_DEC:
- item_bytes = str(item).encode('utf-8')
- self._encode_bytes(item_bytes, len(item_bytes))
-
- cdef void _encode_data_stream_field_complex(self, TypeName field_type,
FieldCoder field_coder,
- item):
- if field_type == TUPLE:
- tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
- tuple_field_count = len(tuple_field_coders)
- tuple_value = list(item)
- for i in range(tuple_field_count):
- field_item = tuple_value[i]
- tuple_field_coder = tuple_field_coders[i]
- if field_item is not None:
- self._encode_field(tuple_field_coder.coder_type(),
- tuple_field_coder.type_name(),
- tuple_field_coder,
- field_item)
- cdef object _decode_data_stream_field_simple(self, TypeName field_type):
- if field_type == PICKLED_BYTES:
- decoded_bytes = self._decode_bytes()
- import pickle
- return pickle.loads(decoded_bytes)
- elif field_type == BIG_DEC:
- return decimal.Decimal(self._decode_bytes().decode("utf-8"))
-
- cdef object _decode_data_stream_field_complex(self, TypeName field_type,
FieldCoder field_coder):
- if field_type == TUPLE:
- tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
- tuple_field_count = len(tuple_field_coders)
- decoded_list = []
- for i in range(tuple_field_count):
- decoded_list.append(self._decode_field(
- tuple_field_coders[i].coder_type(),
- tuple_field_coders[i].type_name(),
- tuple_field_coders[i]
- ))
- return (*decoded_list,)
-
ROW_KIND_BIT_SIZE = 2
cdef class FlattenRowCoderImpl(BaseCoderImpl):
@@ -425,6 +362,11 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
hours = minutes // 60
minutes %= 60
return datetime.time(hours, minutes, seconds, milliseconds * 1000)
+ elif field_type == PICKLED_BYTES:
+ decoded_bytes = self._decode_bytes()
+ return pickle.loads(decoded_bytes)
+ elif field_type == BIG_DEC:
+ return decimal.Decimal(self._decode_bytes().decode("utf-8"))
cdef object _decode_field_complex(self, TypeName field_type, FieldCoder
field_coder):
cdef libc.stdint.int32_t nanoseconds, microseconds, seconds, length
@@ -501,6 +443,17 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
elif field_type == ROW:
# Row
return self._decode_field_row(field_coder)
+ elif field_type == TUPLE:
+ tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
+ tuple_field_count = len(tuple_field_coders)
+ decoded_list = []
+ for i in range(tuple_field_count):
+ decoded_list.append(self._decode_field(
+ tuple_field_coders[i].coder_type(),
+ tuple_field_coders[i].type_name(),
+ tuple_field_coders[i]
+ ))
+ return (*decoded_list,)
cdef object _decode_field_row(self, RowCoderImpl field_coder):
cdef list row_field_coders, row_field_names
@@ -621,6 +574,13 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
microsecond = item.microsecond
milliseconds = hour * 3600000 + minute * 60000 + seconds * 1000 +
microsecond // 1000
self._encode_int(milliseconds)
+ elif field_type == PICKLED_BYTES:
+ # pickled object
+ pickled_bytes = pickle.dumps(item)
+ self._encode_bytes(pickled_bytes, len(pickled_bytes))
+ elif field_type == BIG_DEC:
+ item_bytes = str(item).encode('utf-8')
+ self._encode_bytes(item_bytes, len(item_bytes))
cdef void _encode_field_complex(self, TypeName field_type, FieldCoder
field_coder, item):
cdef libc.stdint.int32_t nanoseconds, microseconds_of_second, length,
row_field_count
@@ -711,6 +671,18 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
if field_item is not None:
self._encode_field(row_field_coder.coder_type(),
row_field_coder.type_name(),
row_field_coder, field_item)
+ elif field_type == TUPLE:
+ tuple_field_coders = (<TupleCoderImpl> field_coder).field_coders
+ tuple_field_count = len(tuple_field_coders)
+ tuple_value = list(item)
+ for i in range(tuple_field_count):
+ field_item = tuple_value[i]
+ tuple_field_coder = tuple_field_coders[i]
+ if field_item is not None:
+ self._encode_field(tuple_field_coder.coder_type(),
+ tuple_field_coder.type_name(),
+ tuple_field_coder,
+ field_item)
cdef void _extend(self, size_t missing):
while self._tmp_output_buffer_size < self._tmp_output_pos + missing:
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index 0d3aee9..a9aba93 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='flink-fn-execution.proto',
package='org.apache.flink.fn_execution.v1',
syntax='proto3',
- serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12
org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01
\x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02
\x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03
\x01(\x0cH\x00\x42\x07\n\x05input\"\x91\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01
\x01(\x0c\x12\x37\n\x06inputs\x18\x02
\x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...]
+ serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12
org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01
\x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02
\x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03
\x01(\x0cH\x00\x42\x07\n\x05input\"\x91\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01
\x01(\x0c\x12\x37\n\x06inputs\x18\x02
\x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...]
)
@@ -102,34 +102,30 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE =
_descriptor.EnumDescriptor(
options=None,
type=None),
_descriptor.EnumValueDescriptor(
- name='REDUCE', index=2, number=2,
+ name='CO_MAP', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
- name='CO_MAP', index=3, number=3,
+ name='CO_FLAT_MAP', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
- name='CO_FLAT_MAP', index=4, number=4,
+ name='PROCESS', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
- name='PROCESS', index=5, number=5,
+ name='KEYED_PROCESS', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
- name='KEYED_PROCESS', index=6, number=6,
- options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TIMESTAMP_ASSIGNER', index=7, number=7,
+ name='TIMESTAMP_ASSIGNER', index=6, number=6,
options=None,
type=None),
],
containing_type=None,
options=None,
- serialized_start=1579,
- serialized_end=1713,
+ serialized_start=1578,
+ serialized_end=1700,
)
_sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE)
@@ -154,8 +150,8 @@ _GROUPWINDOW_WINDOWTYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=2838,
- serialized_end=2929,
+ serialized_start=2825,
+ serialized_end=2916,
)
_sym_db.RegisterEnumDescriptor(_GROUPWINDOW_WINDOWTYPE)
@@ -184,8 +180,8 @@ _GROUPWINDOW_WINDOWPROPERTY = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=2931,
- serialized_end=3030,
+ serialized_start=2918,
+ serialized_end=3017,
)
_sym_db.RegisterEnumDescriptor(_GROUPWINDOW_WINDOWPROPERTY)
@@ -282,8 +278,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=5284,
- serialized_end=5573,
+ serialized_start=5271,
+ serialized_end=5560,
)
_sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME)
@@ -380,8 +376,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=6398,
- serialized_end=6675,
+ serialized_start=6385,
+ serialized_end=6662,
)
_sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME)
@@ -742,7 +738,7 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
oneofs=[
],
serialized_start=881,
- serialized_end=1713,
+ serialized_end=1700,
)
@@ -772,8 +768,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_LISTVIEW =
_descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2244,
- serialized_end=2328,
+ serialized_start=2231,
+ serialized_end=2315,
)
_USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_MAPVIEW = _descriptor.Descriptor(
@@ -809,8 +805,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_MAPVIEW =
_descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2331,
- serialized_end=2482,
+ serialized_start=2318,
+ serialized_end=2469,
)
_USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC = _descriptor.Descriptor(
@@ -863,8 +859,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC =
_descriptor.Descriptor(
name='data_view',
full_name='org.apache.flink.fn_execution.v1.UserDefinedAggregateFunction.DataViewSpec.data_view',
index=0, containing_type=None, fields=[]),
],
- serialized_start=1981,
- serialized_end=2495,
+ serialized_start=1968,
+ serialized_end=2482,
)
_USERDEFINEDAGGREGATEFUNCTION = _descriptor.Descriptor(
@@ -928,8 +924,8 @@ _USERDEFINEDAGGREGATEFUNCTION = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1716,
- serialized_end=2495,
+ serialized_start=1703,
+ serialized_end=2482,
)
@@ -1017,8 +1013,8 @@ _GROUPWINDOW = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2498,
- serialized_end=3030,
+ serialized_start=2485,
+ serialized_end=3017,
)
@@ -1125,8 +1121,8 @@ _USERDEFINEDAGGREGATEFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3033,
- serialized_end=3542,
+ serialized_start=3020,
+ serialized_end=3529,
)
@@ -1163,8 +1159,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3620,
- serialized_end=3771,
+ serialized_start=3607,
+ serialized_end=3758,
)
_SCHEMA_TIMEINFO = _descriptor.Descriptor(
@@ -1193,8 +1189,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3773,
- serialized_end=3802,
+ serialized_start=3760,
+ serialized_end=3789,
)
_SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
@@ -1223,8 +1219,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3804,
- serialized_end=3838,
+ serialized_start=3791,
+ serialized_end=3825,
)
_SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -1253,8 +1249,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3840,
- serialized_end=3884,
+ serialized_start=3827,
+ serialized_end=3871,
)
_SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -1283,8 +1279,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3886,
- serialized_end=3925,
+ serialized_start=3873,
+ serialized_end=3912,
)
_SCHEMA_DECIMALINFO = _descriptor.Descriptor(
@@ -1320,8 +1316,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3927,
- serialized_end=3974,
+ serialized_start=3914,
+ serialized_end=3961,
)
_SCHEMA_BINARYINFO = _descriptor.Descriptor(
@@ -1350,8 +1346,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3976,
- serialized_end=4004,
+ serialized_start=3963,
+ serialized_end=3991,
)
_SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
@@ -1380,8 +1376,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4006,
- serialized_end=4037,
+ serialized_start=3993,
+ serialized_end=4024,
)
_SCHEMA_CHARINFO = _descriptor.Descriptor(
@@ -1410,8 +1406,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4039,
- serialized_end=4065,
+ serialized_start=4026,
+ serialized_end=4052,
)
_SCHEMA_VARCHARINFO = _descriptor.Descriptor(
@@ -1440,8 +1436,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4067,
- serialized_end=4096,
+ serialized_start=4054,
+ serialized_end=4083,
)
_SCHEMA_FIELDTYPE = _descriptor.Descriptor(
@@ -1564,8 +1560,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
name='type_info',
full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=4099,
- serialized_end=5171,
+ serialized_start=4086,
+ serialized_end=5158,
)
_SCHEMA_FIELD = _descriptor.Descriptor(
@@ -1608,8 +1604,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=5173,
- serialized_end=5281,
+ serialized_start=5160,
+ serialized_end=5268,
)
_SCHEMA = _descriptor.Descriptor(
@@ -1639,8 +1635,8 @@ _SCHEMA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3545,
- serialized_end=5573,
+ serialized_start=3532,
+ serialized_end=5560,
)
@@ -1677,8 +1673,8 @@ _TYPEINFO_MAPTYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=5987,
- serialized_end=6126,
+ serialized_start=5974,
+ serialized_end=6113,
)
_TYPEINFO_ROWTYPEINFO_FIELD = _descriptor.Descriptor(
@@ -1714,8 +1710,8 @@ _TYPEINFO_ROWTYPEINFO_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6222,
- serialized_end=6313,
+ serialized_start=6209,
+ serialized_end=6300,
)
_TYPEINFO_ROWTYPEINFO = _descriptor.Descriptor(
@@ -1744,8 +1740,8 @@ _TYPEINFO_ROWTYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6129,
- serialized_end=6313,
+ serialized_start=6116,
+ serialized_end=6300,
)
_TYPEINFO_TUPLETYPEINFO = _descriptor.Descriptor(
@@ -1774,8 +1770,8 @@ _TYPEINFO_TUPLETYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6315,
- serialized_end=6395,
+ serialized_start=6302,
+ serialized_end=6382,
)
_TYPEINFO = _descriptor.Descriptor(
@@ -1836,8 +1832,8 @@ _TYPEINFO = _descriptor.Descriptor(
name='type_info',
full_name='org.apache.flink.fn_execution.v1.TypeInfo.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=5576,
- serialized_end=6688,
+ serialized_start=5563,
+ serialized_end=6675,
)
_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py
b/flink-python/pyflink/fn_execution/operation_utils.py
index aaa2add..8b65782 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -237,12 +237,6 @@ def extract_data_stream_stateless_function(udf_proto):
func = user_defined_func.map
elif func_type == UserDefinedDataStreamFunction.FLAT_MAP:
func = user_defined_func.flat_map
- elif func_type == UserDefinedDataStreamFunction.REDUCE:
- reduce_func = user_defined_func.reduce
-
- def wrapped_func(value):
- return reduce_func(value[0], value[1])
- func = wrapped_func
elif func_type == UserDefinedDataStreamFunction.CO_MAP:
co_map_func = user_defined_func
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto
b/flink-python/pyflink/proto/flink-fn-execution.proto
index c34dede..3667997 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -81,12 +81,11 @@ message UserDefinedDataStreamFunction {
enum FunctionType {
MAP = 0;
FLAT_MAP = 1;
- REDUCE = 2;
- CO_MAP = 3;
- CO_FLAT_MAP = 4;
- PROCESS = 5;
- KEYED_PROCESS = 6;
- TIMESTAMP_ASSIGNER = 7;
+ CO_MAP = 2;
+ CO_FLAT_MAP = 3;
+ PROCESS = 4;
+ KEYED_PROCESS = 5;
+ TIMESTAMP_ASSIGNER = 6;
}
message JobParameter {
diff --git
a/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
b/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
index b4d7b41..4cacdc9 100644
---
a/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
+++
b/flink-python/src/main/java/org/apache/flink/streaming/api/functions/python/KeyByKeySelector.java
@@ -21,34 +21,20 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.types.Row;
-import net.razorvine.pickle.Unpickler;
-
/**
* {@link KeyByKeySelector} is responsible for extracting the first field of
the input row as key.
* The input row is generated by python DataStream map function in the format
of
* (key_selector.get_key(value), value) tuple2.
*/
@Internal
-public class KeyByKeySelector implements KeySelector<Row, Object> {
+public class KeyByKeySelector implements KeySelector<Row, Row> {
private static final long serialVersionUID = 1L;
- private final boolean isKeyOfPickledByteArray;
- private transient Unpickler unpickler = null;
-
- public KeyByKeySelector(boolean isKeyPickledByteArray) {
- this.isKeyOfPickledByteArray = isKeyPickledByteArray;
- }
-
@Override
- public Object getKey(Row value) throws Exception {
- Object key = value.getField(0);
- if (!isKeyOfPickledByteArray) {
- return key;
- } else {
- if (this.unpickler == null) {
- unpickler = new Unpickler();
- }
- return this.unpickler.loads((byte[]) key);
- }
+ public Row getKey(Row value) {
+ Object realKey = value.getField(0);
+ Row wrapper = new Row(1);
+ wrapper.setField(0, realKey);
+ return wrapper;
}
}
diff --git
a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/PythonReduceOperator.java
b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/PythonReduceOperator.java
deleted file mode 100644
index 9c57811..0000000
---
a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/PythonReduceOperator.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.operators.python;
-
-import org.apache.flink.annotation.Internal;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.configuration.Configuration;
-import
org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.types.Row;
-
-/**
- * {@link PythonReduceOperator} is responsible for launching beam runner which
will start a python
- * harness to execute user defined python ReduceFunction.
- */
-@Internal
-public class PythonReduceOperator<OUT> extends
OneInputPythonFunctionOperator<Row, OUT, Row, OUT> {
-
- private static final long serialVersionUID = 1L;
-
- private static final String MAP_CODER_URN = "flink:coder:map:v1";
-
- private static final String STATE_NAME = "_python_reduce_state";
-
- /** This state is used to store the currently reduce value. */
- private transient ValueState<OUT> valueState;
-
- private transient Row reuseRow;
-
- public PythonReduceOperator(
- Configuration config,
- TypeInformation<Row> inputTypeInfo,
- TypeInformation<OUT> outputTypeInfo,
- DataStreamPythonFunctionInfo pythonFunctionInfo) {
- super(
- config,
- new RowTypeInfo(outputTypeInfo, outputTypeInfo),
- outputTypeInfo,
- pythonFunctionInfo);
- }
-
- @Override
- public void open() throws Exception {
- super.open();
-
- // create state
- ValueStateDescriptor<OUT> stateId =
- new ValueStateDescriptor<>(STATE_NAME, runnerOutputTypeInfo);
- valueState = getPartitionedState(stateId);
-
- reuseRow = new Row(2);
- }
-
- @Override
- public void processElement(StreamRecord<Row> element) throws Exception {
- OUT inputData = (OUT) element.getValue().getField(1);
- OUT currentValue = valueState.value();
- if (currentValue == null) {
- // emit directly for the first element.
- valueState.update(inputData);
- collector.setAbsoluteTimestamp(element.getTimestamp());
- collector.collect(inputData);
- } else {
- reuseRow.setField(0, currentValue);
- reuseRow.setField(1, inputData);
- element.replace(reuseRow);
- super.processElement(element);
- }
- }
-
- @Override
- public void emitResult(Tuple2<byte[], Integer> resultTuple) throws
Exception {
- byte[] rawResult = resultTuple.f0;
- int length = resultTuple.f1;
- bais.setBuffer(rawResult, 0, length);
- OUT result = runnerOutputTypeSerializer.deserialize(baisWrapper);
- valueState.update(result);
- collector.setAbsoluteTimestamp(bufferedTimestamp.poll());
- collector.collect(result);
- }
-
- @Override
- public String getCoderUrn() {
- return MAP_CODER_URN;
- }
-}
diff --git
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
index 92ba04e..def9057 100644
---
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
+++
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
@@ -28,6 +28,7 @@ import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.runtime.RowSerializer;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -47,6 +48,7 @@ import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.streaming.api.utils.ByteArrayWrapper;
import org.apache.flink.streaming.api.utils.ByteArrayWrapperSerializer;
import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.LongFunctionWithException;
@@ -657,6 +659,12 @@ public abstract class BeamPythonFunctionRunner implements
PythonFunctionRunner {
TypeSerializer keySerializer,
Map<String, String> config) {
this.keyedStateBackend = keyedStateBackend;
+ TypeSerializer frameworkKeySerializer =
keyedStateBackend.getKeySerializer();
+ if (!(frameworkKeySerializer instanceof AbstractRowDataSerializer
+ || frameworkKeySerializer instanceof RowSerializer)) {
+ throw new RuntimeException(
+ "Currently SimpleStateRequestHandler only support row
key!");
+ }
this.keySerializer = keySerializer;
this.valueSerializer =
PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.createSerializer(