This is an automated email from the ASF dual-hosted git repository. hxb pushed a commit to branch release-1.16 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.16 by this push: new 041d8635523 [FLINK-30366][python] Fix Python Group Agg failed in cleaning the idle state 041d8635523 is described below commit 041d863552396105d08af097a456ee291263d434 Author: huangxingbo <h...@apache.org> AuthorDate: Mon Dec 12 11:49:15 2022 +0800 [FLINK-30366][python] Fix Python Group Agg failed in cleaning the idle state This closes #21488. --- flink-python/pyflink/fn_execution/coder_impl_fast.pyx | 2 +- flink-python/pyflink/fn_execution/coder_impl_slow.py | 2 ++ flink-python/pyflink/fn_execution/table/aggregate_fast.pyx | 2 +- flink-python/pyflink/fn_execution/table/aggregate_slow.py | 4 ++-- flink-python/pyflink/fn_execution/tests/test_coders.py | 6 ++++++ 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx index 41ed4ab38c4..92dff893fe9 100644 --- a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx +++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx @@ -327,7 +327,7 @@ cdef class FlattenRowCoderImpl(FieldCoderImpl): cdef size_t i cdef FieldCoderImpl field_coder - list_value = <list> value + list_value = <list?> value # encode mask value self._mask_utils.write_mask(list_value, 0, out_stream) diff --git a/flink-python/pyflink/fn_execution/coder_impl_slow.py b/flink-python/pyflink/fn_execution/coder_impl_slow.py index 23af3483e0a..769720dc277 100644 --- a/flink-python/pyflink/fn_execution/coder_impl_slow.py +++ b/flink-python/pyflink/fn_execution/coder_impl_slow.py @@ -209,6 +209,8 @@ class FlattenRowCoderImpl(FieldCoderImpl): self._mask_utils = MaskUtils(self._field_count) def encode_to_stream(self, value, out_stream: OutputStream): + if not isinstance(value, List): + raise TypeError('Expected list, got {0}'.format(type(value))) # encode mask value self._mask_utils.write_mask(value, 0, out_stream) diff --git a/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx b/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx index 2db06c735df..ca68f8546f1 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx +++ b/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx @@ -456,7 +456,7 @@ cdef class GroupAggFunctionBase: cpdef void on_timer(self, InternalRow key): if self.state_cleaning_enabled: - self.state_backend.set_current_key(key) + self.state_backend.set_current_key(list(key.values)) accumulator_state = self.state_backend.get_value_state( "accumulators", self.state_value_coder) accumulator_state.clear() diff --git a/flink-python/pyflink/fn_execution/table/aggregate_slow.py b/flink-python/pyflink/fn_execution/table/aggregate_slow.py index e10f6b06946..52c61e0f855 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_slow.py +++ b/flink-python/pyflink/fn_execution/table/aggregate_slow.py @@ -436,9 +436,9 @@ class GroupAggFunctionBase(object): except KeyError: self.buffer[tuple(key)] = [input_data] - def on_timer(self, key): + def on_timer(self, key: Row): if self.state_cleaning_enabled: - self.state_backend.set_current_key(key) + self.state_backend.set_current_key(list(key._values)) accumulator_state = self.state_backend.get_value_state( "accumulators", self.state_value_coder) accumulator_state.clear() diff --git a/flink-python/pyflink/fn_execution/tests/test_coders.py b/flink-python/pyflink/fn_execution/tests/test_coders.py index 4d11ca18e5f..8d4a3b24cc2 100644 --- a/flink-python/pyflink/fn_execution/tests/test_coders.py +++ b/flink-python/pyflink/fn_execution/tests/test_coders.py @@ -189,6 +189,12 @@ class CodersTest(PyFlinkTestCase): coder = CountWindowCoder() self.check_coder(coder, CountWindow(100)) + def test_coder_with_unmatched_type(self): + from pyflink.common import Row + coder = FlattenRowCoder([BigIntCoder()]) + with self.assertRaises(TypeError, msg='Expected list, got Row'): + self.check_coder(coder, Row(1)) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)