This is an automated email from the ASF dual-hosted git repository. hequn pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 175038b705f82655a4a082de64a73f989f6abea4 Author: hequn.chq <hequn....@alibaba-inc.com> AuthorDate: Sat Aug 15 12:39:41 2020 +0800 [FLINK-18943][python] Support CoMapFunction for Python DataStream API --- flink-python/dev/glibc_version_fix.h | 0 flink-python/pyflink/datastream/data_stream.py | 105 ++++++- flink-python/pyflink/datastream/functions.py | 35 +++ .../pyflink/datastream/tests/test_data_stream.py | 38 +++ .../pyflink/fn_execution/flink_fn_execution_pb2.py | 90 +++--- .../pyflink/fn_execution/operation_utils.py | 6 + .../pyflink/proto/flink-fn-execution.proto | 2 + ...eamTwoInputPythonStatelessFunctionOperator.java | 196 +++++++++++++ .../python/AbstractPythonFunctionOperator.java | 320 +-------------------- ...ava => AbstractPythonFunctionOperatorBase.java} | 18 +- .../AbstractTwoInputPythonFunctionOperator.java | 44 +++ 11 files changed, 482 insertions(+), 372 deletions(-) diff --git a/flink-python/dev/glibc_version_fix.h b/flink-python/dev/glibc_version_fix.h old mode 100644 new mode 100755 diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py index 67b2903..17abb70 100644 --- a/flink-python/pyflink/datastream/data_stream.py +++ b/flink-python/pyflink/datastream/data_stream.py @@ -24,7 +24,7 @@ from pyflink.common.typeinfo import TypeInformation from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \ MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, FilterFunction, \ FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, ReduceFunction, \ - ReduceFunctionWrapper + ReduceFunctionWrapper, CoMapFunction from pyflink.java_gateway import get_gateway @@ -359,6 +359,17 @@ class DataStream(object): j_united_stream = self._j_data_stream.union(j_data_stream_arr) return DataStream(j_data_stream=j_united_stream) + def connect(self, ds: 'DataStream') -> 'ConnectedStreams': + """ + Creates a new 'ConnectedStreams' by connecting 'DataStream' outputs of (possible) + different types with each other. The DataStreams connected using this operator can + be used with CoFunctions to apply joint transformations. + + :param ds: The DataStream with which this stream will be connected. + :return: The `ConnectedStreams`. + """ + return ConnectedStreams(self, ds) + def shuffle(self) -> 'DataStream': """ Sets the partitioning of the DataStream so that the output elements are shuffled uniformly @@ -687,6 +698,9 @@ class KeyedStream(DataStream): j_python_data_stream_scalar_function_operator )) + def connect(self, ds: 'KeyedStream') -> 'ConnectedStreams': + raise Exception('Connect on KeyedStream has not been supported yet.') + def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream': return self._values().filter(func) @@ -763,3 +777,92 @@ class KeyedStream(DataStream): def slot_sharing_group(self, slot_sharing_group: str) -> 'DataStream': raise Exception("Setting slot sharing group for KeyedStream is not supported.") + + +class ConnectedStreams(object): + """ + ConnectedStreams represent two connected streams of (possibly) different data types. + Connected streams are useful for cases where operations on one stream directly + affect the operations on the other stream, usually via shared state between the streams. + + An example for the use of connected streams would be to apply rules that change over time + onto another stream. One of the connected streams has the rules, the other stream the + elements to apply the rules to. The operation on the connected stream maintains the + current set of rules in the state. It may receive either a rule update and update the state + or a data element and apply the rules in the state to the element. + + The connected stream can be conceptually viewed as a union stream of an Either type, that + holds either the first stream's type or the second stream's type. + """ + + def __init__(self, stream1: DataStream, stream2: DataStream): + self.stream1 = stream1 + self.stream2 = stream2 + + def map(self, func: CoMapFunction, type_info: TypeInformation = None) \ + -> 'DataStream': + """ + Applies a CoMap transformation on a `ConnectedStreams` and maps the output to a common + type. The transformation calls a `CoMapFunction.map1` for each element of the first + input and `CoMapFunction.map2` for each element of the second input. Each CoMapFunction + call returns exactly one element. + + :param func: The CoMapFunction used to jointly transform the two input DataStreams + :param type_info: `TypeInformation` for the result type of the function. + :return: The transformed `DataStream` + """ + if not isinstance(func, CoMapFunction): + raise TypeError("The input function must be a CoMapFunction!") + func_name = str(func) + + # get connected stream + j_connected_stream = self.stream1._j_data_stream.connect(self.stream2._j_data_stream) + from pyflink.fn_execution import flink_fn_execution_pb2 + j_operator, j_output_type = self._get_connected_stream_operator( + func, type_info, func_name, flink_fn_execution_pb2.UserDefinedDataStreamFunction.CO_MAP) + return DataStream(j_connected_stream.transform("Co-Process", j_output_type, j_operator)) + + def _get_connected_stream_operator(self, func: Union[Function, FunctionWrapper], + type_info: TypeInformation, func_name: str, + func_type: int): + gateway = get_gateway() + import cloudpickle + serialized_func = cloudpickle.dumps(func) + + j_input_types1 = self.stream1._j_data_stream.getTransformation().getOutputType() + j_input_types2 = self.stream2._j_data_stream.getTransformation().getOutputType() + + if type_info is None: + output_type_info = PickledBytesTypeInfo.PICKLED_BYTE_ARRAY_TYPE_INFO() + else: + if isinstance(type_info, list): + output_type_info = RowTypeInfo(type_info) + else: + output_type_info = type_info + + DataStreamPythonFunction = gateway.jvm.org.apache.flink.datastream.runtime.functions \ + .python.DataStreamPythonFunction + j_python_data_stream_scalar_function = DataStreamPythonFunction( + func_name, + bytearray(serialized_func), + _get_python_env()) + + DataStreamPythonFunctionInfo = gateway.jvm. \ + org.apache.flink.datastream.runtime.functions.python \ + .DataStreamPythonFunctionInfo + + j_python_data_stream_function_info = DataStreamPythonFunctionInfo( + j_python_data_stream_scalar_function, + func_type) + + j_conf = gateway.jvm.org.apache.flink.configuration.Configuration() + DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \ + .operators.python.DataStreamTwoInputPythonStatelessFunctionOperator + j_python_data_stream_function_operator = DataStreamPythonFunctionOperator( + j_conf, + j_input_types1, + j_input_types2, + output_type_info.get_java_type_info(), + j_python_data_stream_function_info) + + return j_python_data_stream_function_operator, output_type_info.get_java_type_info() diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py index c5049e8..10433ee 100644 --- a/flink-python/pyflink/datastream/functions.py +++ b/flink-python/pyflink/datastream/functions.py @@ -56,6 +56,41 @@ class MapFunction(Function): pass +class CoMapFunction(Function): + """ + A CoMapFunction implements a map() transformation over two connected streams. + + The same instance of the transformation function is used to transform both of + the connected streams. That way, the stream transformations can share state. + + The basic syntax for using a MapFunction is as follows: + :: + >>> ds1 = ... + >>> ds2 = ... + >>> new_ds = ds1.connect(ds2).map(MyCoMapFunction()) + """ + + @abc.abstractmethod + def map1(self, value): + """ + This method is called for each element in the first of the connected streams. + + :param value: The stream element + :return: The resulting element + """ + pass + + @abc.abstractmethod + def map2(self, value): + """ + This method is called for each element in the second of the connected streams. + + :param value: The stream element + :return: The resulting element + """ + pass + + class FlatMapFunction(Function): """ Base class for flatMap functions. FLatMap functions take elements and transform them, into zero, diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 937d1db..d05b2e7 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -22,6 +22,7 @@ from pyflink.datastream import StreamExecutionEnvironment from pyflink.datastream.functions import FilterFunction from pyflink.datastream.functions import KeySelector from pyflink.datastream.functions import MapFunction, FlatMapFunction +from pyflink.datastream.functions import CoMapFunction from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction from pyflink.java_gateway import get_gateway from pyflink.testing.test_case_utils import PyFlinkTestCase @@ -105,6 +106,34 @@ class DataStreamTests(PyFlinkTestCase): results.sort() self.assertEqual(expected, results) + def test_co_map_function_without_data_types(self): + self.env.set_parallelism(1) + ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], + type_info=Types.ROW([Types.INT(), Types.INT()])) + ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], + type_info=Types.ROW([Types.STRING(), Types.STRING()])) + ds1.connect(ds2).map(MyCoMapFunction()).add_sink(self.test_sink) + self.env.execute('co_map_function_test') + results = self.test_sink.get_results(True) + expected = ['2', '3', '4', 'a', 'b', 'c'] + expected.sort() + results.sort() + self.assertEqual(expected, results) + + def test_co_map_function_with_data_types(self): + self.env.set_parallelism(1) + ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], + type_info=Types.ROW([Types.INT(), Types.INT()])) + ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], + type_info=Types.ROW([Types.STRING(), Types.STRING()])) + ds1.connect(ds2).map(MyCoMapFunction(), type_info=Types.STRING()).add_sink(self.test_sink) + self.env.execute('co_map_function_test') + results = self.test_sink.get_results(False) + expected = ['2', '3', '4', 'a', 'b', 'c'] + expected.sort() + results.sort() + self.assertEqual(expected, results) + def test_map_function_with_data_types_and_function_object(self): ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], type_info=Types.ROW([Types.STRING(), Types.INT()])) @@ -431,3 +460,12 @@ class MyFilterFunction(FilterFunction): def filter(self, value): return value[0] % 2 == 0 + + +class MyCoMapFunction(CoMapFunction): + + def map1(self, value): + return str(value[0] + 1) + + def map2(self, value): + return value[0] diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py index 6a7945e..f90071d 100644 --- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py +++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( name='flink-fn-execution.proto', package='org.apache.flink.fn_execution.v1', syntax='proto3', - serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...] + serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...] ) @@ -59,11 +59,19 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = _descriptor.EnumDescriptor( name='REDUCE', index=2, number=2, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='CO_MAP', index=3, number=3, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='CO_FLAT_MAP', index=4, number=4, + options=None, + type=None), ], containing_type=None, options=None, serialized_start=585, - serialized_end=634, + serialized_end=663, ) _sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE) @@ -160,8 +168,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=2514, - serialized_end=2797, + serialized_start=2543, + serialized_end=2826, ) _sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME) @@ -246,8 +254,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3318, - serialized_end=3549, + serialized_start=3347, + serialized_end=3578, ) _sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME) @@ -410,7 +418,7 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor( oneofs=[ ], serialized_start=435, - serialized_end=634, + serialized_end=663, ) @@ -447,8 +455,8 @@ _USERDEFINEDDATASTREAMFUNCTIONS = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=637, - serialized_end=772, + serialized_start=666, + serialized_end=801, ) @@ -485,8 +493,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=850, - serialized_end=1001, + serialized_start=879, + serialized_end=1030, ) _SCHEMA_TIMEINFO = _descriptor.Descriptor( @@ -515,8 +523,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1003, - serialized_end=1032, + serialized_start=1032, + serialized_end=1061, ) _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor( @@ -545,8 +553,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1034, - serialized_end=1068, + serialized_start=1063, + serialized_end=1097, ) _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor( @@ -575,8 +583,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1070, - serialized_end=1114, + serialized_start=1099, + serialized_end=1143, ) _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor( @@ -605,8 +613,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1116, - serialized_end=1155, + serialized_start=1145, + serialized_end=1184, ) _SCHEMA_DECIMALINFO = _descriptor.Descriptor( @@ -642,8 +650,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1157, - serialized_end=1204, + serialized_start=1186, + serialized_end=1233, ) _SCHEMA_BINARYINFO = _descriptor.Descriptor( @@ -672,8 +680,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1206, - serialized_end=1234, + serialized_start=1235, + serialized_end=1263, ) _SCHEMA_VARBINARYINFO = _descriptor.Descriptor( @@ -702,8 +710,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1236, - serialized_end=1267, + serialized_start=1265, + serialized_end=1296, ) _SCHEMA_CHARINFO = _descriptor.Descriptor( @@ -732,8 +740,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1269, - serialized_end=1295, + serialized_start=1298, + serialized_end=1324, ) _SCHEMA_VARCHARINFO = _descriptor.Descriptor( @@ -762,8 +770,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1297, - serialized_end=1326, + serialized_start=1326, + serialized_end=1355, ) _SCHEMA_FIELDTYPE = _descriptor.Descriptor( @@ -886,8 +894,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor( name='type_info', full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info', index=0, containing_type=None, fields=[]), ], - serialized_start=1329, - serialized_end=2401, + serialized_start=1358, + serialized_end=2430, ) _SCHEMA_FIELD = _descriptor.Descriptor( @@ -930,8 +938,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=2403, - serialized_end=2511, + serialized_start=2432, + serialized_end=2540, ) _SCHEMA = _descriptor.Descriptor( @@ -961,8 +969,8 @@ _SCHEMA = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=775, - serialized_end=2797, + serialized_start=804, + serialized_end=2826, ) @@ -1016,8 +1024,8 @@ _TYPEINFO_FIELDTYPE = _descriptor.Descriptor( name='type_info', full_name='org.apache.flink.fn_execution.v1.TypeInfo.FieldType.type_info', index=0, containing_type=None, fields=[]), ], - serialized_start=2878, - serialized_end=3203, + serialized_start=2907, + serialized_end=3232, ) _TYPEINFO_FIELD = _descriptor.Descriptor( @@ -1060,8 +1068,8 @@ _TYPEINFO_FIELD = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=3205, - serialized_end=3315, + serialized_start=3234, + serialized_end=3344, ) _TYPEINFO = _descriptor.Descriptor( @@ -1091,8 +1099,8 @@ _TYPEINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=2800, - serialized_end=3549, + serialized_start=2829, + serialized_end=3578, ) _USERDEFINEDFUNCTION_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION diff --git a/flink-python/pyflink/fn_execution/operation_utils.py b/flink-python/pyflink/fn_execution/operation_utils.py index 52426ce..c361681 100644 --- a/flink-python/pyflink/fn_execution/operation_utils.py +++ b/flink-python/pyflink/fn_execution/operation_utils.py @@ -108,6 +108,12 @@ def extract_data_stream_stateless_funcs(udfs): def wrap_func(value): return reduce_func(value[0], value[1]) func = wrap_func + elif func_type == udf.CO_MAP: + co_map_func = cloudpickle.loads(udfs[0].payload) + + def wrap_func(value): + return co_map_func.map1(value[1]) if value[0] else co_map_func.map2(value[2]) + func = wrap_func return func diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto index 9a1e64b..cd903f3 100644 --- a/flink-python/pyflink/proto/flink-fn-execution.proto +++ b/flink-python/pyflink/proto/flink-fn-execution.proto @@ -58,6 +58,8 @@ message UserDefinedDataStreamFunction { MAP = 0; FLAT_MAP = 1; REDUCE = 2; + CO_MAP = 3; + CO_FLAT_MAP = 4; } FunctionType functionType = 1; bytes payload = 2; diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java new file mode 100644 index 0000000..7383076 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.datastream.runtime.operators.python; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; +import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.datastream.runtime.functions.python.DataStreamPythonFunctionInfo; +import org.apache.flink.datastream.runtime.runners.python.beam.BeamDataStreamPythonStatelessFunctionRunner; +import org.apache.flink.datastream.runtime.typeutils.python.PythonTypeUtils; +import org.apache.flink.fnexecution.v1.FlinkFnApi; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.streaming.api.operators.python.AbstractTwoInputPythonFunctionOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.functions.python.PythonEnv; +import org.apache.flink.table.runtime.util.StreamRecordCollector; +import org.apache.flink.types.Row; + +import com.google.protobuf.ByteString; + +import java.util.Map; + +/** + * {@link DataStreamTwoInputPythonStatelessFunctionOperator} is responsible for launching beam + * runner which will start a python harness to execute two-input user defined python function. + */ +public class DataStreamTwoInputPythonStatelessFunctionOperator<IN1, IN2, OUT> + extends AbstractTwoInputPythonFunctionOperator<IN1, IN2, OUT> { + + private static final long serialVersionUID = 1L; + + private static final String DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN = + "flink:transform:datastream_stateless_function:v1"; + private static final String DATA_STREAM_MAP_FUNCTION_CODER_URN = "flink:coder:datastream:map_function:v1"; + private static final String DATA_STREAM_FLAT_MAP_FUNCTION_CODER_URN = "flink:coder:datastream:flatmap_function:v1"; + + + private final DataStreamPythonFunctionInfo pythonFunctionInfo; + + private final TypeInformation<OUT> outputTypeInfo; + + private final Map<String, String> jobOptions; + private transient TypeSerializer<OUT> outputTypeSerializer; + + private transient ByteArrayInputStreamWithPos bais; + + private transient DataInputViewStreamWrapper baisWrapper; + + private transient ByteArrayOutputStreamWithPos baos; + + private transient DataOutputViewStreamWrapper baosWrapper; + + private transient StreamRecordCollector streamRecordCollector; + + private transient TypeSerializer<Row> runnerInputTypeSerializer; + + private final TypeInformation<Row> runnerInputTypeInfo; + + private transient Row reuseRow; + + public DataStreamTwoInputPythonStatelessFunctionOperator( + Configuration config, + TypeInformation<IN1> inputTypeInfo1, + TypeInformation<IN2> inputTypeInfo2, + TypeInformation<OUT> outputTypeInfo, + DataStreamPythonFunctionInfo pythonFunctionInfo) { + super(config); + this.pythonFunctionInfo = pythonFunctionInfo; + jobOptions = config.toMap(); + this.outputTypeInfo = outputTypeInfo; + // The row contains three field. The first field indicate left input or right input + // The second field contains left input and the third field contains right input. + runnerInputTypeInfo = new RowTypeInfo(Types.BOOLEAN, inputTypeInfo1, inputTypeInfo2); + } + + @Override + public void open() throws Exception { + super.open(); + bais = new ByteArrayInputStreamWithPos(); + baisWrapper = new DataInputViewStreamWrapper(bais); + + baos = new ByteArrayOutputStreamWithPos(); + baosWrapper = new DataOutputViewStreamWrapper(baos); + this.outputTypeSerializer = PythonTypeUtils.TypeInfoToSerializerConverter + .typeInfoSerializerConverter(outputTypeInfo); + runnerInputTypeSerializer = PythonTypeUtils.TypeInfoToSerializerConverter + .typeInfoSerializerConverter(runnerInputTypeInfo); + + reuseRow = new Row(3); + this.streamRecordCollector = new StreamRecordCollector(output); + } + + @Override + public PythonFunctionRunner createPythonFunctionRunner() throws Exception { + + String coderUrn; + int functionType = this.pythonFunctionInfo.getFunctionType(); + if (functionType == FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.CO_MAP.getNumber()) { + coderUrn = DATA_STREAM_MAP_FUNCTION_CODER_URN; + } else if (functionType == FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.CO_FLAT_MAP.getNumber()) { + coderUrn = DATA_STREAM_FLAT_MAP_FUNCTION_CODER_URN; + } else { + throw new RuntimeException("Function Type for ConnectedStream should be Map or FlatMap"); + } + + return new BeamDataStreamPythonStatelessFunctionRunner( + getRuntimeContext().getTaskName(), + createPythonEnvironmentManager(), + runnerInputTypeInfo, + outputTypeInfo, + DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN, + getUserDefinedDataStreamFunctionsProto(), + coderUrn, + jobOptions, + getFlinkMetricContainer() + ); + } + + @Override + public PythonEnv getPythonEnv() { + return pythonFunctionInfo.getPythonFunction().getPythonEnv(); + } + + @Override + public void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception { + byte[] rawResult = resultTuple.f0; + int length = resultTuple.f1; + bais.setBuffer(rawResult, 0, length); + streamRecordCollector.collect(outputTypeSerializer.deserialize(baisWrapper)); + } + + @Override + public void processElement1(StreamRecord<IN1> element) throws Exception { + // construct combined row. + reuseRow.setField(0, true); + reuseRow.setField(1, element.getValue()); + reuseRow.setField(2, null); // need to set null since it is a reuse row. + processElement(); + } + + @Override + public void processElement2(StreamRecord<IN2> element) throws Exception { + // construct combined row. + reuseRow.setField(0, false); + reuseRow.setField(1, null); // need to set null since it is a reuse row. + reuseRow.setField(2, element.getValue()); + processElement(); + } + + private void processElement() throws Exception { + runnerInputTypeSerializer.serialize(reuseRow, baosWrapper); + pythonFunctionRunner.process(baos.toByteArray()); + baos.reset(); + checkInvokeFinishBundleByCount(); + emitResults(); + } + + protected FlinkFnApi.UserDefinedDataStreamFunctions getUserDefinedDataStreamFunctionsProto() { + FlinkFnApi.UserDefinedDataStreamFunctions.Builder builder = FlinkFnApi.UserDefinedDataStreamFunctions.newBuilder(); + builder.addUdfs(getUserDefinedDataStreamFunctionProto(pythonFunctionInfo)); + return builder.build(); + } + + private FlinkFnApi.UserDefinedDataStreamFunction getUserDefinedDataStreamFunctionProto( + DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo) { + FlinkFnApi.UserDefinedDataStreamFunction.Builder builder = + FlinkFnApi.UserDefinedDataStreamFunction.newBuilder(); + builder.setFunctionType(FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.forNumber( + dataStreamPythonFunctionInfo.getFunctionType())); + builder.setPayload(ByteString.copyFrom( + dataStreamPythonFunctionInfo.getPythonFunction().getSerializedPythonFunction())); + return builder.build(); + } +} diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java index 7361320..1295b8f 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java @@ -19,340 +19,26 @@ package org.apache.flink.streaming.api.operators.python; import org.apache.flink.annotation.Internal; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; -import org.apache.flink.configuration.MemorySize; -import org.apache.flink.python.PythonConfig; -import org.apache.flink.python.PythonFunctionRunner; -import org.apache.flink.python.PythonOptions; -import org.apache.flink.python.env.PythonDependencyInfo; -import org.apache.flink.python.env.PythonEnvironmentManager; -import org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager; -import org.apache.flink.python.metric.FlinkMetricContainer; -import org.apache.flink.runtime.memory.MemoryManager; -import org.apache.flink.runtime.memory.MemoryReservationException; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; -import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.table.functions.python.PythonEnv; -import org.apache.flink.util.Preconditions; - -import java.io.IOException; -import java.util.concurrent.ScheduledFuture; /** - * Base class for all stream operators to execute Python functions. + * Base class for all one input stream operators to execute Python functions. */ @Internal public abstract class AbstractPythonFunctionOperator<IN, OUT> - extends AbstractStreamOperator<OUT> + extends AbstractPythonFunctionOperatorBase<OUT> implements OneInputStreamOperator<IN, OUT>, BoundedOneInput { private static final long serialVersionUID = 1L; - /** - * The {@link PythonFunctionRunner} which is responsible for Python user-defined function execution. - */ - protected transient PythonFunctionRunner pythonFunctionRunner; - - /** - * Max number of elements to include in a bundle. - */ - protected transient int maxBundleSize; - - /** - * Number of processed elements in the current bundle. - */ - private transient int elementCount; - - /** - * Max duration of a bundle. - */ - private transient long maxBundleTimeMills; - - /** - * Time that the last bundle was finished. - */ - private transient long lastFinishBundleTime; - - /** - * A timer that finishes the current bundle after a fixed amount of time. - */ - private transient ScheduledFuture<?> checkFinishBundleTimer; - - /** - * Callback to be executed after the current bundle was finished. - */ - private transient Runnable bundleFinishedCallback; - - /** - * The size of the reserved memory from the MemoryManager. - */ - private transient long reservedMemory; - - /** - * The python config. - */ - private PythonConfig config; - public AbstractPythonFunctionOperator(Configuration config) { - this.config = new PythonConfig(Preconditions.checkNotNull(config)); - this.chainingStrategy = ChainingStrategy.ALWAYS; - } - - public PythonConfig getPythonConfig() { - return config; - } - - @Override - public void open() throws Exception { - try { - - if (config.isUsingManagedMemory()) { - reserveMemoryForPythonWorker(); - } - - this.maxBundleSize = config.getMaxBundleSize(); - if (this.maxBundleSize <= 0) { - this.maxBundleSize = PythonOptions.MAX_BUNDLE_SIZE.defaultValue(); - LOG.error("Invalid value for the maximum bundle size. Using default value of " + - this.maxBundleSize + '.'); - } else { - LOG.info("The maximum bundle size is configured to {}.", this.maxBundleSize); - } - - this.maxBundleTimeMills = config.getMaxBundleTimeMills(); - if (this.maxBundleTimeMills <= 0L) { - this.maxBundleTimeMills = PythonOptions.MAX_BUNDLE_TIME_MILLS.defaultValue(); - LOG.error("Invalid value for the maximum bundle time. Using default value of " + - this.maxBundleTimeMills + '.'); - } else { - LOG.info("The maximum bundle time is configured to {} milliseconds.", this.maxBundleTimeMills); - } - - this.pythonFunctionRunner = createPythonFunctionRunner(); - this.pythonFunctionRunner.open(config); - - this.elementCount = 0; - this.lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime(); - - // Schedule timer to check timeout of finish bundle. - long bundleCheckPeriod = Math.max(this.maxBundleTimeMills, 1); - this.checkFinishBundleTimer = - getProcessingTimeService() - .scheduleAtFixedRate( - // ProcessingTimeService callbacks are executed under the checkpointing lock - timestamp -> checkInvokeFinishBundleByTime(), bundleCheckPeriod, bundleCheckPeriod); - } finally { - super.open(); - } - } - - @Override - public void close() throws Exception { - try { - invokeFinishBundle(); - } finally { - super.close(); - } - } - - @Override - public void dispose() throws Exception { - try { - if (checkFinishBundleTimer != null) { - checkFinishBundleTimer.cancel(true); - checkFinishBundleTimer = null; - } - if (pythonFunctionRunner != null) { - pythonFunctionRunner.close(); - pythonFunctionRunner = null; - } - if (reservedMemory > 0) { - getContainingTask().getEnvironment().getMemoryManager().releaseMemory(this, reservedMemory); - reservedMemory = -1; - } - } finally { - super.dispose(); - } + super(config); } @Override public void endInput() throws Exception { invokeFinishBundle(); } - - @Override - public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { - try { - invokeFinishBundle(); - } finally { - super.prepareSnapshotPreBarrier(checkpointId); - } - } - - @Override - public void processWatermark(Watermark mark) throws Exception { - // Due to the asynchronous communication with the SDK harness, - // a bundle might still be in progress and not all items have - // yet been received from the SDK harness. If we just set this - // watermark as the new output watermark, we could violate the - // order of the records, i.e. pending items in the SDK harness - // could become "late" although they were "on time". - // - // We can solve this problem using one of the following options: - // - // 1) Finish the current bundle and emit this watermark as the - // new output watermark. Finishing the bundle ensures that - // all the items have been processed by the SDK harness and - // the execution results sent to the downstream operator. - // - // 2) Hold on the output watermark for as long as the current - // bundle has not been finished. We have to remember to manually - // finish the bundle in case we receive the final watermark. - // To avoid latency, we should process this watermark again as - // soon as the current bundle is finished. - // - // Approach 1) is the easiest and gives better latency, yet 2) - // gives better throughput due to the bundle not getting cut on - // every watermark. So we have implemented 2) below. - if (mark.getTimestamp() == Long.MAX_VALUE) { - invokeFinishBundle(); - super.processWatermark(mark); - } else if (elementCount == 0) { - // forward the watermark immediately if the bundle is already finished. - super.processWatermark(mark); - } else { - // It is not safe to advance the output watermark yet, so add a hold on the current - // output watermark. - bundleFinishedCallback = - () -> { - try { - // at this point the bundle is finished, allow the watermark to pass - super.processWatermark(mark); - } catch (Exception e) { - throw new RuntimeException( - "Failed to process watermark after finished bundle.", e); - } - }; - } - } - - /** - * Reset the {@link PythonConfig} if needed. - * */ - public void setPythonConfig(PythonConfig pythonConfig) { - this.config = pythonConfig; - } - - /** - * Returns the {@link PythonConfig}. - * */ - public PythonConfig getConfig() { - return config; - } - - /** - * Creates the {@link PythonFunctionRunner} which is responsible for Python user-defined function execution. - */ - public abstract PythonFunctionRunner createPythonFunctionRunner() throws Exception; - - /** - * Returns the {@link PythonEnv} used to create PythonEnvironmentManager.. - */ - public abstract PythonEnv getPythonEnv(); - - /** - * Sends the execution result to the downstream operator. - */ - public abstract void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception; - - /** - * Reserves the memory used by the Python worker from the MemoryManager. This makes sure that - * the memory used by the Python worker is managed by Flink. - */ - private void reserveMemoryForPythonWorker() throws MemoryReservationException { - long requiredPythonWorkerMemory = MemorySize.parse(config.getPythonFrameworkMemorySize()) - .add(MemorySize.parse(config.getPythonDataBufferMemorySize())) - .getBytes(); - MemoryManager memoryManager = getContainingTask().getEnvironment().getMemoryManager(); - long availableManagedMemory = memoryManager.computeMemorySize( - getOperatorConfig().getManagedMemoryFraction()); - if (requiredPythonWorkerMemory <= availableManagedMemory) { - memoryManager.reserveMemory(this, requiredPythonWorkerMemory); - LOG.info("Reserved memory {} for Python worker.", requiredPythonWorkerMemory); - this.reservedMemory = requiredPythonWorkerMemory; - // TODO enforce the memory limit of the Python worker - } else { - LOG.warn("Required Python worker memory {} exceeds the available managed off-heap " + - "memory {}. Skipping reserving off-heap memory from the MemoryManager. This does " + - "not affect the functionality. However, it may affect the stability of a job as " + - "the memory used by the Python worker is not managed by Flink.", - requiredPythonWorkerMemory, availableManagedMemory); - this.reservedMemory = -1; - } - } - - protected void emitResults() throws Exception { - Tuple2<byte[], Integer> resultTuple; - while ((resultTuple = pythonFunctionRunner.pollResult()) != null) { - emitResult(resultTuple); - } - } - - /** - * Checks whether to invoke finishBundle by elements count. Called in processElement. - */ - protected void checkInvokeFinishBundleByCount() throws Exception { - elementCount++; - if (elementCount >= maxBundleSize) { - invokeFinishBundle(); - } - } - - /** - * Checks whether to invoke finishBundle by timeout. - */ - private void checkInvokeFinishBundleByTime() throws Exception { - long now = getProcessingTimeService().getCurrentProcessingTime(); - if (now - lastFinishBundleTime >= maxBundleTimeMills) { - invokeFinishBundle(); - } - } - - protected void invokeFinishBundle() throws Exception { - if (elementCount > 0) { - pythonFunctionRunner.flush(); - elementCount = 0; - emitResults(); - lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime(); - // callback only after current bundle was fully finalized - if (bundleFinishedCallback != null) { - bundleFinishedCallback.run(); - bundleFinishedCallback = null; - } - } - } - - protected PythonEnvironmentManager createPythonEnvironmentManager() throws IOException { - PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create( - config, getRuntimeContext().getDistributedCache()); - PythonEnv pythonEnv = getPythonEnv(); - if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) { - return new ProcessPythonEnvironmentManager( - dependencyInfo, - getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(), - System.getenv()); - } else { - throw new UnsupportedOperationException(String.format( - "Execution type '%s' is not supported.", pythonEnv.getExecType())); - } - } - - protected FlinkMetricContainer getFlinkMetricContainer() { - return this.config.isMetricEnabled() ? - new FlinkMetricContainer(getRuntimeContext().getMetricGroup()) : null; - } } diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java similarity index 95% copy from flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java copy to flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java index 7361320..280bf0a 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java @@ -32,9 +32,7 @@ import org.apache.flink.python.metric.FlinkMetricContainer; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.memory.MemoryReservationException; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.ChainingStrategy; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.table.functions.python.PythonEnv; import org.apache.flink.util.Preconditions; @@ -46,9 +44,8 @@ import java.util.concurrent.ScheduledFuture; * Base class for all stream operators to execute Python functions. */ @Internal -public abstract class AbstractPythonFunctionOperator<IN, OUT> - extends AbstractStreamOperator<OUT> - implements OneInputStreamOperator<IN, OUT>, BoundedOneInput { +public abstract class AbstractPythonFunctionOperatorBase<OUT> + extends AbstractStreamOperator<OUT> { private static final long serialVersionUID = 1L; @@ -65,7 +62,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT> /** * Number of processed elements in the current bundle. */ - private transient int elementCount; + protected transient int elementCount; /** * Max duration of a bundle. @@ -85,7 +82,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT> /** * Callback to be executed after the current bundle was finished. */ - private transient Runnable bundleFinishedCallback; + protected transient Runnable bundleFinishedCallback; /** * The size of the reserved memory from the MemoryManager. @@ -97,7 +94,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT> */ private PythonConfig config; - public AbstractPythonFunctionOperator(Configuration config) { + public AbstractPythonFunctionOperatorBase(Configuration config) { this.config = new PythonConfig(Preconditions.checkNotNull(config)); this.chainingStrategy = ChainingStrategy.ALWAYS; } @@ -180,11 +177,6 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT> } @Override - public void endInput() throws Exception { - invokeFinishBundle(); - } - - @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { try { invokeFinishBundle(); diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java new file mode 100644 index 0000000..ed221c6 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; + +/** + * Base class for all two input stream operators to execute Python functions. + */ +@Internal +public abstract class AbstractTwoInputPythonFunctionOperator<IN1, IN2, OUT> + extends AbstractPythonFunctionOperatorBase<OUT> + implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput { + + private static final long serialVersionUID = 1L; + + public AbstractTwoInputPythonFunctionOperator(Configuration config) { + super(config); + } + + @Override + public void endInput(int inputId) throws Exception { + invokeFinishBundle(); + } +}