fhueske commented on code in PR #27928: URL: https://github.com/apache/flink/pull/27928#discussion_r3232410899
########## flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java: ########## @@ -0,0 +1,1214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.annotation.ArgumentTrait; +import org.apache.flink.table.annotation.StateHint; +import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.data.conversion.DataStructureConverter; +import org.apache.flink.table.data.conversion.DataStructureConverters; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.functions.ProcessTableFunction; +import org.apache.flink.table.types.AbstractDataType; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.FieldsDataType; +import org.apache.flink.table.types.inference.StaticArgument; +import org.apache.flink.table.types.inference.StaticArgumentTrait; +import org.apache.flink.table.types.inference.SystemTypeInference; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.StructuredType; +import org.apache.flink.table.types.utils.TypeConversions; +import org.apache.flink.types.Row; +import org.apache.flink.types.RowKind; +import org.apache.flink.util.Collector; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Test harness for {@link ProcessTableFunction}. + * + * <p>Provides a fluent builder API for configuring and testing ProcessTableFunctions (PTFs) with + * table and scalar arguments, lifecycle management, and output collection. + * + * <p>Example usage: + * + * <pre>{@code + * ProcessTableFunctionTestHarness<Row> harness = + * ProcessTableFunctionTestHarness.ofClass(MyPTF.class) + * .withTableArgument("input", DataTypes.of("ROW<id INT, name STRING>")) + * .withScalarArgument("threshold", 100) + * .build(); + * + * harness.processElement(Row.of(1, "Alice")); + * harness.processElement(Row.of(2, "Bob")); + * + * List<Row> output = harness.getOutput(); + * }</pre> + */ +@PublicEvolving +public class ProcessTableFunctionTestHarness<OUT> implements AutoCloseable { + + private final ProcessTableFunction<OUT> function; + private final FunctionContext functionContext; + private final List<OUT> output; + private boolean isOpen; + private final HarnessCollector collector; + + private final String defaultTableArgument; + private final Method evalMethod; + private final List<ArgumentInfo> arguments; + + private final Map<String, ArgumentInfo> argumentsByName; + private final boolean isSingleTableFunction; + private final Map<String, Object> scalarArgumentValues; + + private final boolean hasTableArguments; + private final Map<String, DataStructureConverter<Object, Object>> inputConverters; + private final Map<String, DataStructureConverter<Object, Object>> outputConverters; + + private ProcessTableFunctionTestHarness( + ProcessTableFunction<OUT> function, + FunctionContext functionContext, + String defaultTableArgument, + Method evalMethod, + List<ArgumentInfo> arguments, + Map<String, ArgumentInfo> argumentsByName, + boolean isSingleTableFunction, + Map<String, Object> scalarArgumentValues, + Map<String, DataStructureConverter<Object, Object>> inputConverters, + Map<String, DataStructureConverter<Object, Object>> outputConverters) + throws Exception { + this.function = function; + this.functionContext = functionContext; + this.defaultTableArgument = defaultTableArgument; + this.evalMethod = evalMethod; + this.arguments = arguments; + this.argumentsByName = argumentsByName; + this.isSingleTableFunction = isSingleTableFunction; + this.scalarArgumentValues = scalarArgumentValues; + this.hasTableArguments = arguments.stream().anyMatch(arg -> arg.isTableArgument); + this.inputConverters = inputConverters; + this.outputConverters = outputConverters; + this.output = new ArrayList<>(); + this.collector = new HarnessCollector(); + this.isOpen = false; + + openFunction(); + } + + /** Creates a new harness builder for the given ProcessTableFunction class. */ + public static <OUT> Builder<OUT> ofClass( + Class<? extends ProcessTableFunction<OUT>> functionClass) { + return new Builder<>(functionClass); + } + + private void openFunction() throws Exception { + function.open(functionContext); + function.setCollector(collector); + isOpen = true; + } + + @Override + public void close() throws Exception { + if (isOpen) { + function.close(); + isOpen = false; + } + } + + /** + * Process a single element for the default table argument. + * + * <p>For PTFs with a single table argument, this processes one row. For multiple table + * arguments, use {@link #processElementForTable(String, Row)}. + */ + public void processElement(Row row) throws Exception { + if (!isSingleTableFunction) { + throw new IllegalStateException( + "PTF has multiple table arguments. Use processElementForTable(argumentName, row) " + + "to specify which table argument should receive the row."); + } + + processElementForTable(defaultTableArgument, row); + } + + /** Process a single element constructed from values. */ + public void processElement(Object... values) throws Exception { + processElement(Row.of(values)); + } + + /** Process a single element with a specific RowKind. */ + public void processElement(RowKind rowKind, Object... values) throws Exception { + processElement(Row.ofKind(rowKind, values)); + } + + /** Process a single element for a specific table argument. */ + public void processElementForTable(String tableArgument, Row row) throws Exception { + checkState(isOpen, "Harness not open"); + checkNotNull(tableArgument, "tableArgument must not be null"); + + ArgumentInfo tableArg = argumentsByName.get(tableArgument); + if (tableArg == null) { + throw new IllegalArgumentException("Unknown table argument: " + tableArgument); + } + invokeEval(tableArg, row); + } + + /** Process a single element for a specific table argument. */ + public void processElementForTable(String tableArgument, Object... values) throws Exception { + processElementForTable(tableArgument, Row.of(values)); + } + + /** Process a single element for a specific table argument with RowKind. */ + public void processElementForTable(String tableArgument, RowKind rowKind, Object... values) + throws Exception { + processElementForTable(tableArgument, Row.ofKind(rowKind, values)); + } + + /** + * Processes the PTF's eval() method with scalar arguments only. + * + * <p>This method is specifically for scalar-only PTFs (PTFs with only scalar arguments and no + * table arguments). For PTFs that accept table arguments, use {@link #processElement(Row)} or + * {@link #processElementForTable(String, Row)} instead. + * + * @throws IllegalStateException if the PTF has any table arguments + * @throws Exception if the eval() invocation fails + */ + public void process() throws Exception { + checkState(isOpen, "Harness not open"); + + if (hasTableArguments) { + throw new IllegalStateException( + "process() is only for scalar-only PTFs. This PTF has table arguments. " + + "Use processElement() or processElementForTable() instead."); + } + + // Clear collector context since there's no active table argument + collector.setContext(null, null); + + Object[] args = new Object[arguments.size()]; + for (int i = 0; i < arguments.size(); i++) { + ArgumentInfo arg = arguments.get(i); + if (arg.isScalar) { + args[i] = scalarArgumentValues.get(arg.name); + } else { + throw new IllegalStateException( + "Unexpected non-scalar argument at position " + i + ": " + arg.name); + } + } + try { + evalMethod.invoke(function, args); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception) { + Exception userException = (Exception) cause; + userException.addSuppressed( + new Exception( + String.format( + "Exception occurred during scalar-only PTF eval() invocation. " + + "Scalar arguments: %s", + scalarArgumentValues))); + throw userException; + } else { + throw new RuntimeException("Error invoking PTF eval() method", e); + } + } + } + + /** Returns all collected output rows. */ + public List<OUT> getOutput() { + return List.copyOf(output); + } + + /** Clears all collected output. */ + public void clearOutput() { + output.clear(); + } + + /** + * Given a target table argument and a row to process, construct the right set of arguments for + * the PTF's eval function and attempt to invoke it. + */ + private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws Exception { + // Set collector context so it can prepend columns if needed + collector.setContext(activeTableArg, activeRow); + + Object[] args = new Object[arguments.size()]; + + for (int i = 0; i < arguments.size(); i++) { + ArgumentInfo arg = arguments.get(i); + + if (arg.isTableArgument && arg.name.equals(activeTableArg.name)) { + // If the argument is the active table argument, first convert the input row + // to an internal RowData type, and then convert the RowData to type that the + // argument expects. For Rows, this will structure the Row based on the table + // argument structure. Otherwise, for POJOs, it will pass the expected POJO to eval. + + DataStructureConverter<Object, Object> inputConverter = + inputConverters.get(arg.name); + DataStructureConverter<Object, Object> outputConverter = + outputConverters.get(arg.name); + + args[i] = + outputConverter.toExternalOrNull( + inputConverter.toInternalOrNull(activeRow)); + + } else if (arg.isScalar) { + args[i] = scalarArgumentValues.get(arg.name); + + } else if (arg.isTableArgument) { + // Inactive table arguments receive null + args[i] = null; + } else { + throw new IllegalStateException( + "Unexpected argument type at position " + i + ": " + arg.name); + } + } + + try { + evalMethod.invoke(function, args); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception) { + Exception userException = (Exception) cause; + String partitionInfo = + activeTableArg.partitionColumnNames != null + && activeTableArg.partitionColumnNames.length > 0 + ? String.format( + ", partition columns: %s", + Arrays.toString(activeTableArg.partitionColumnNames)) + : ", no partitioning"; + userException.addSuppressed( + new Exception( + String.format( + "Exception occurred during PTF eval() while processing table argument '%s'%s. " + + "Input row: %s", + activeTableArg.name, partitionInfo, activeRow))); + throw userException; + } else { + throw new RuntimeException("Error invoking PTF eval() method", e); + } + } + } + + /** + * Collector implementation that stores output in the harness. + * + * <p>For SET_SEMANTIC_TABLE arguments, automatically prepends partition key columns to the PTF + * output. If the argument has PASS_COLUMNS_THROUGH trait, prepends all input columns. + */ + private class HarnessCollector implements Collector<OUT> { + // Context set before each eval() invocation + private ArgumentInfo activeTableArg; + private Row activeRow; + + void setContext(ArgumentInfo tableArg, Row row) { + this.activeTableArg = tableArg; + this.activeRow = row; + } + + @Override + public void collect(OUT record) { + if (activeTableArg == null || !activeTableArg.isTableArgument) { + output.add(record); + return; + } + + if (activeTableArg.hasPassColumnsThrough) { + output.add(prependAllColumns(record)); + } else if (activeTableArg.isSetSemantic + && activeTableArg.partitionColumnNames != null) { + output.add(prependPartitionKeys(record)); + } else { Review Comment: That makes sense, thank you! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
