This is an automated email from the ASF dual-hosted git repository. dianfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new d8417565d6f [FLINK-28526][python] Fix Python UDF to support time indicator inputs d8417565d6f is described below commit d8417565d6fb7f907f54eeabb8a53ebf790ffad8 Author: Dian Fu <dia...@apache.org> AuthorDate: Mon Jan 16 14:05:59 2023 +0800 [FLINK-28526][python] Fix Python UDF to support time indicator inputs This closes #21686. --- flink-python/pyflink/table/tests/test_udf.py | 43 ++++++++++++++++++++++ .../plan/nodes/exec/utils/CommonPythonUtil.java | 33 ++++++++++++----- 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py index 851c43fc0b2..d974c2402bf 100644 --- a/flink-python/pyflink/table/tests/test_udf.py +++ b/flink-python/pyflink/table/tests/test_udf.py @@ -24,6 +24,7 @@ import uuid import pytest import pytz +from pyflink.common import Row from pyflink.table import DataTypes, expressions as expr from pyflink.table.expressions import call from pyflink.table.udf import ScalarFunction, udf, FunctionContext @@ -860,6 +861,48 @@ class PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests, lines.sort() self.assertEqual(lines, ['1,2', '2,3', '3,4']) + def test_udf_with_rowtime_arguments(self): + from pyflink.common import WatermarkStrategy + from pyflink.common.typeinfo import Types + from pyflink.common.watermark_strategy import TimestampAssigner + from pyflink.table import Schema + + class MyTimestampAssigner(TimestampAssigner): + + def extract_timestamp(self, value, record_timestamp) -> int: + return int(value[0]) + + ds = self.env.from_collection( + [(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")], + Types.ROW_NAMED(["a", "b", "c"], [Types.LONG(), Types.INT(), Types.STRING()])) + + ds = ds.assign_timestamps_and_watermarks( + WatermarkStrategy.for_monotonous_timestamps() + .with_timestamp_assigner(MyTimestampAssigner())) + + table = self.t_env.from_data_stream( + ds, + Schema.new_builder() + .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build()) + + @udf(result_type=DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.INT())])) + def inc(input_row): + return Row(input_row.b) + + sink_table = generate_random_table_name() + sink_table_ddl = f""" + CREATE TABLE {sink_table}( + a INT + ) WITH ('connector'='test-sink') + """ + self.t_env.execute_sql(sink_table_ddl) + table.map(inc).execute_insert(sink_table).wait() + + actual = source_sink_utils.results() + self.assert_equals(actual, ['+I[42]', '+I[5]', '+I[1000]', '+I[1000]']) + class PyFlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests, PyFlinkBatchTableTestCase): diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java index 201407b718a..ff4ed47dc3b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java @@ -47,6 +47,7 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; import org.apache.flink.table.planner.functions.utils.AggSqlFunction; import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction; import org.apache.flink.table.planner.functions.utils.TableSqlFunction; +import org.apache.flink.table.planner.plan.schema.TimeIndicatorRelDataType; import org.apache.flink.table.planner.plan.utils.AggregateInfo; import org.apache.flink.table.planner.plan.utils.AggregateInfoList; import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment; @@ -70,10 +71,12 @@ import org.apache.flink.table.types.logical.StructuredType; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlCastFunction; import org.apache.calcite.sql.type.SqlTypeName; import java.lang.reflect.Field; @@ -438,22 +441,32 @@ public class CommonPythonUtil { for (RexNode operand : pythonRexCall.getOperands()) { if (operand instanceof RexCall) { RexCall childPythonRexCall = (RexCall) operand; - PythonFunctionInfo argPythonInfo = - createPythonFunctionInfo(childPythonRexCall, inputNodes, classLoader); - inputs.add(argPythonInfo); + if (childPythonRexCall.getOperator() instanceof SqlCastFunction + && childPythonRexCall.getOperands().get(0) instanceof RexInputRef + && childPythonRexCall.getOperands().get(0).getType() + instanceof TimeIndicatorRelDataType) { + operand = childPythonRexCall.getOperands().get(0); + } else { + PythonFunctionInfo argPythonInfo = + createPythonFunctionInfo(childPythonRexCall, inputNodes, classLoader); + inputs.add(argPythonInfo); + continue; + } } else if (operand instanceof RexLiteral) { RexLiteral literal = (RexLiteral) operand; inputs.add( convertLiteralToPython( literal, literal.getType().getSqlTypeName(), classLoader)); + continue; + } + + assert operand instanceof RexInputRef; + if (inputNodes.containsKey(operand)) { + inputs.add(inputNodes.get(operand)); } else { - if (inputNodes.containsKey(operand)) { - inputs.add(inputNodes.get(operand)); - } else { - Integer inputOffset = inputNodes.size(); - inputs.add(inputOffset); - inputNodes.put(operand, inputOffset); - } + Integer inputOffset = inputNodes.size(); + inputs.add(inputOffset); + inputNodes.put(operand, inputOffset); } } return new PythonFunctionInfo((PythonFunction) functionDefinition, inputs.toArray());