This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new d8417565d6f [FLINK-28526][python] Fix Python UDF to support time 
indicator inputs
d8417565d6f is described below

commit d8417565d6fb7f907f54eeabb8a53ebf790ffad8
Author: Dian Fu <dia...@apache.org>
AuthorDate: Mon Jan 16 14:05:59 2023 +0800

    [FLINK-28526][python] Fix Python UDF to support time indicator inputs
    
    This closes #21686.
---
 flink-python/pyflink/table/tests/test_udf.py       | 43 ++++++++++++++++++++++
 .../plan/nodes/exec/utils/CommonPythonUtil.java    | 33 ++++++++++++-----
 2 files changed, 66 insertions(+), 10 deletions(-)

diff --git a/flink-python/pyflink/table/tests/test_udf.py 
b/flink-python/pyflink/table/tests/test_udf.py
index 851c43fc0b2..d974c2402bf 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -24,6 +24,7 @@ import uuid
 import pytest
 import pytz
 
+from pyflink.common import Row
 from pyflink.table import DataTypes, expressions as expr
 from pyflink.table.expressions import call
 from pyflink.table.udf import ScalarFunction, udf, FunctionContext
@@ -860,6 +861,48 @@ class 
PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
         lines.sort()
         self.assertEqual(lines, ['1,2', '2,3', '3,4'])
 
+    def test_udf_with_rowtime_arguments(self):
+        from pyflink.common import WatermarkStrategy
+        from pyflink.common.typeinfo import Types
+        from pyflink.common.watermark_strategy import TimestampAssigner
+        from pyflink.table import Schema
+
+        class MyTimestampAssigner(TimestampAssigner):
+
+            def extract_timestamp(self, value, record_timestamp) -> int:
+                return int(value[0])
+
+        ds = self.env.from_collection(
+            [(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")],
+            Types.ROW_NAMED(["a", "b", "c"], [Types.LONG(), Types.INT(), 
Types.STRING()]))
+
+        ds = ds.assign_timestamps_and_watermarks(
+            WatermarkStrategy.for_monotonous_timestamps()
+            .with_timestamp_assigner(MyTimestampAssigner()))
+
+        table = self.t_env.from_data_stream(
+            ds,
+            Schema.new_builder()
+                  .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)")
+                  .watermark("rowtime", "SOURCE_WATERMARK()")
+                  .build())
+
+        @udf(result_type=DataTypes.ROW([DataTypes.FIELD('f1', 
DataTypes.INT())]))
+        def inc(input_row):
+            return Row(input_row.b)
+
+        sink_table = generate_random_table_name()
+        sink_table_ddl = f"""
+                    CREATE TABLE {sink_table}(
+                        a INT
+                    ) WITH ('connector'='test-sink')
+                """
+        self.t_env.execute_sql(sink_table_ddl)
+        table.map(inc).execute_insert(sink_table).wait()
+
+        actual = source_sink_utils.results()
+        self.assert_equals(actual, ['+I[42]', '+I[5]', '+I[1000]', '+I[1000]'])
+
 
 class PyFlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
                                            PyFlinkBatchTableTestCase):
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index 201407b718a..ff4ed47dc3b 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -47,6 +47,7 @@ import 
org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
 import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
 import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
 import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
+import org.apache.flink.table.planner.plan.schema.TimeIndicatorRelDataType;
 import org.apache.flink.table.planner.plan.utils.AggregateInfo;
 import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
 import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment;
@@ -70,10 +71,12 @@ import org.apache.flink.table.types.logical.StructuredType;
 
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlCastFunction;
 import org.apache.calcite.sql.type.SqlTypeName;
 
 import java.lang.reflect.Field;
@@ -438,22 +441,32 @@ public class CommonPythonUtil {
         for (RexNode operand : pythonRexCall.getOperands()) {
             if (operand instanceof RexCall) {
                 RexCall childPythonRexCall = (RexCall) operand;
-                PythonFunctionInfo argPythonInfo =
-                        createPythonFunctionInfo(childPythonRexCall, 
inputNodes, classLoader);
-                inputs.add(argPythonInfo);
+                if (childPythonRexCall.getOperator() instanceof SqlCastFunction
+                        && childPythonRexCall.getOperands().get(0) instanceof 
RexInputRef
+                        && childPythonRexCall.getOperands().get(0).getType()
+                                instanceof TimeIndicatorRelDataType) {
+                    operand = childPythonRexCall.getOperands().get(0);
+                } else {
+                    PythonFunctionInfo argPythonInfo =
+                            createPythonFunctionInfo(childPythonRexCall, 
inputNodes, classLoader);
+                    inputs.add(argPythonInfo);
+                    continue;
+                }
             } else if (operand instanceof RexLiteral) {
                 RexLiteral literal = (RexLiteral) operand;
                 inputs.add(
                         convertLiteralToPython(
                                 literal, literal.getType().getSqlTypeName(), 
classLoader));
+                continue;
+            }
+
+            assert operand instanceof RexInputRef;
+            if (inputNodes.containsKey(operand)) {
+                inputs.add(inputNodes.get(operand));
             } else {
-                if (inputNodes.containsKey(operand)) {
-                    inputs.add(inputNodes.get(operand));
-                } else {
-                    Integer inputOffset = inputNodes.size();
-                    inputs.add(inputOffset);
-                    inputNodes.put(operand, inputOffset);
-                }
+                Integer inputOffset = inputNodes.size();
+                inputs.add(inputOffset);
+                inputNodes.put(operand, inputOffset);
             }
         }
         return new PythonFunctionInfo((PythonFunction) functionDefinition, 
inputs.toArray());

Reply via email to