This is an automated email from the ASF dual-hosted git repository.

hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new f1e34e6  [FLINK-18866][python] Support filter() operation for Python 
DataStream API. (#13098)
f1e34e6 is described below

commit f1e34e659aed9f7bbf7cd7c38e10e97ca88e4b73
Author: Shuiqiang Chen <acqua....@alibaba-inc.com>
AuthorDate: Tue Aug 11 19:44:51 2020 +0800

    [FLINK-18866][python] Support filter() operation for Python DataStream API. 
(#13098)
---
 flink-python/pyflink/datastream/data_stream.py     | 54 ++++++++++++++++++++--
 flink-python/pyflink/datastream/functions.py       | 36 +++++++++++++++
 .../pyflink/datastream/tests/test_data_stream.py   | 43 +++++++++++++++++
 3 files changed, 128 insertions(+), 5 deletions(-)

diff --git a/flink-python/pyflink/datastream/data_stream.py 
b/flink-python/pyflink/datastream/data_stream.py
index d6f44b0..4d1b0dc 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -22,8 +22,9 @@ from pyflink.common import typeinfo, ExecutionConfig
 from pyflink.common.typeinfo import RowTypeInfo, PickledBytesTypeInfo, Types
 from pyflink.common.typeinfo import TypeInformation
 from pyflink.datastream.functions import _get_python_env, 
FlatMapFunctionWrapper, FlatMapFunction, \
-    MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, \
-    KeySelectorFunctionWrapper, KeySelector, ReduceFunction, 
ReduceFunctionWrapper
+    MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, 
FilterFunction, \
+    FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, 
ReduceFunction, \
+    ReduceFunctionWrapper
 from pyflink.java_gateway import get_gateway
 
 
@@ -262,10 +263,41 @@ class DataStream(object):
                                                                          
output_type_info]))
                                            ._j_data_stream
                                            
.keyBy(PickledKeySelector(is_key_pickled_byte_array),
-                                                  
key_type_info.get_java_type_info()))
+                                                  
key_type_info.get_java_type_info()), self)
         generated_key_stream._original_data_type_info = output_type_info
         return generated_key_stream
 
+    def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
+        """
+        Applies a Filter transformation on a DataStream. The transformation 
calls a FilterFunction
+        for each element of the DataStream and retains only those element for 
which the function
+        returns true. Elements for which the function returns false are 
filtered. The user can also
+        extend RichFilterFunction to gain access to other features provided by 
the RichFunction
+        interface.
+
+        :param func: The FilterFunction that is called for each element of the 
DataStream.
+        :return: The filtered DataStream.
+        """
+        class FilterFlatMap(FlatMapFunction):
+            def __init__(self, filter_func):
+                self._func = filter_func
+
+            def flat_map(self, value):
+                if self._func.filter(value):
+                    yield value
+
+        if isinstance(func, Callable):
+            func = FilterFunctionWrapper(func)
+        elif not isinstance(func, FilterFunction):
+            raise TypeError("func must be a Callable or instance of 
FilterFunction.")
+
+        j_input_type = self._j_data_stream.getTransformation().getOutputType()
+        type_info = typeinfo._from_java_type(j_input_type)
+        j_data_stream = self.flat_map(FilterFlatMap(func), 
type_info=type_info)._j_data_stream
+        filtered_stream = DataStream(j_data_stream)
+        filtered_stream.name("Filter")
+        return filtered_stream
+
     def _get_java_python_function_operator(self, func: Union[Function, 
FunctionWrapper],
                                            type_info: TypeInformation, 
func_name: str,
                                            func_type: int):
@@ -432,14 +464,16 @@ class KeyedStream(DataStream):
     Reduce-style operations, such as reduce and sum work on elements that have 
the same key.
     """
 
-    def __init__(self, j_keyed_stream):
+    def __init__(self, j_keyed_stream, origin_stream: DataStream):
         """
         Constructor of KeyedStream.
 
         :param j_keyed_stream: A java KeyedStream object.
+        :param origin_stream: The DataStream before key by.
         """
         super(KeyedStream, self).__init__(j_data_stream=j_keyed_stream)
         self._original_data_type_info = None
+        self._origin_stream = origin_stream
 
     def map(self, func: Union[Callable, MapFunction], type_info: 
TypeInformation = None) \
             -> 'DataStream':
@@ -483,7 +517,17 @@ class KeyedStream(DataStream):
             j_python_data_stream_scalar_function_operator
         ))
 
-    def _values(self):
+    def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
+        return self._values().filter(func)
+
+    def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink':
+        return self._values().add_sink(sink_func)
+
+    def key_by(self, key_selector: Union[Callable, KeySelector],
+               key_type_info: TypeInformation = None) -> 'KeyedStream':
+        return self._origin_stream.key_by(key_selector, key_type_info)
+
+    def _values(self) -> 'DataStream':
         """
         Since python KeyedStream is in the format of Row(key_value, 
original_data), it is used for
         getting the original_data.
diff --git a/flink-python/pyflink/datastream/functions.py 
b/flink-python/pyflink/datastream/functions.py
index 30f8c6f..c5049e8 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -131,6 +131,29 @@ class KeySelector(Function):
         pass
 
 
+class FilterFunction(Function):
+    """
+    A filter function is a predicate applied individually to each record. The 
predicate decides
+    whether to keep the element, or to discard it.
+    The basic syntax for using a FilterFunction is as follows:
+    :
+         >>> ds = ...
+         >>> result = ds.filter(MyFilterFunction())
+    Note that the system assumes that the function does not modify the 
elements on which the
+    predicate is applied. Violating this assumption can lead to incorrect 
results.
+    """
+
+    @abc.abstractmethod
+    def filter(self, value):
+        """
+        The filter function that evaluates the predicate.
+
+        :param value: The value to be filtered.
+        :return: True for values that should be retained, false for values to 
be filtered out.
+        """
+        pass
+
+
 class FunctionWrapper(object):
     """
     A basic wrapper class for user defined function.
@@ -188,6 +211,19 @@ class FlatMapFunctionWrapper(FunctionWrapper):
         return self._func(value)
 
 
+class FilterFunctionWrapper(FunctionWrapper):
+    """
+        A wrapper class for FilterFunction. It's used for wrapping up user 
defined function in a
+        FilterFunction when user does not implement a FilterFunction but 
directly pass a function
+        object or a lambda function to filter() function.
+        """
+    def __init__(self, func):
+        super(FilterFunctionWrapper, self).__init__(func)
+
+    def filter(self, value):
+        return self._func(value)
+
+
 class ReduceFunctionWrapper(FunctionWrapper):
     """
     A wrapper class for ReduceFunction. It's used for wrapping up user defined 
function in a
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 6d4fa49..1abd80d 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -19,6 +19,7 @@ import decimal
 
 from pyflink.common.typeinfo import Types
 from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.datastream.functions import FilterFunction
 from pyflink.datastream.functions import KeySelector
 from pyflink.datastream.functions import MapFunction, FlatMapFunction
 from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
@@ -148,6 +149,29 @@ class DataStreamTests(PyFlinkTestCase):
         expected.sort()
         self.assertEqual(expected, results)
 
+    def test_filter_without_data_types(self):
+        ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
+        ds.filter(MyFilterFunction()).add_sink(self.test_sink)
+        self.env.execute("test filter")
+        results = self.test_sink.get_results(True)
+        expected = ["(2, 'Hello', 'Hi')"]
+        results.sort()
+        expected.sort()
+        self.assertEqual(expected, results)
+
+    def test_filter_with_data_types(self):
+        ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
+                                      type_info=Types.ROW(
+                                          [Types.INT(), Types.STRING(), 
Types.STRING()])
+                                      )
+        ds.filter(lambda x: x[0] % 2 == 0).add_sink(self.test_sink)
+        self.env.execute("test filter")
+        results = self.test_sink.get_results(False)
+        expected = ['2,Hello,Hi']
+        results.sort()
+        expected.sort()
+        self.assertEqual(expected, results)
+
     def test_add_sink(self):
         ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), 
('deeefg', 4)],
                                       type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
@@ -188,6 +212,19 @@ class DataStreamTests(PyFlinkTestCase):
         expected.sort()
         self.assertEqual(expected, results)
 
+    def test_multi_key_by(self):
+        ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), 
('e', 2)],
+                                      type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
+        ds.key_by(MyKeySelector(), key_type_info=Types.INT()).key_by(lambda x: 
x[0])\
+            .add_sink(self.test_sink)
+
+        self.env.execute("test multi key by")
+        results = self.test_sink.get_results(False)
+        expected = ['d,1', 'c,1', 'a,0', 'b,0', 'e,2']
+        results.sort()
+        expected.sort()
+        self.assertEqual(expected, results)
+
     def tearDown(self) -> None:
         self.test_sink.get_results()
 
@@ -209,3 +246,9 @@ class MyFlatMapFunction(FlatMapFunction):
 class MyKeySelector(KeySelector):
     def get_key(self, value):
         return value[1]
+
+
+class MyFilterFunction(FilterFunction):
+
+    def filter(self, value):
+        return value[0] % 2 == 0

Reply via email to