autophagy commented on code in PR #27928: URL: https://github.com/apache/flink/pull/27928#discussion_r3181239588
########## flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java: ########## @@ -0,0 +1,1199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.annotation.ArgumentTrait; +import org.apache.flink.table.annotation.StateHint; +import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.data.conversion.DataStructureConverter; +import org.apache.flink.table.data.conversion.DataStructureConverters; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.functions.ProcessTableFunction; +import org.apache.flink.table.types.AbstractDataType; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.FieldsDataType; +import org.apache.flink.table.types.inference.StaticArgument; +import org.apache.flink.table.types.inference.StaticArgumentTrait; +import org.apache.flink.table.types.inference.SystemTypeInference; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.StructuredType; +import org.apache.flink.table.types.utils.TypeConversions; +import org.apache.flink.types.Row; +import org.apache.flink.types.RowKind; +import org.apache.flink.util.Collector; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Test harness for {@link ProcessTableFunction}. + * + * <p>Provides a fluent builder API for configuring and testing ProcessTableFunctions (PTFs) with + * table and scalar arguments, lifecycle management, and output collection. + * + * <p>Example usage: + * + * <pre>{@code + * ProcessTableFunctionTestHarness<Row> harness = + * ProcessTableFunctionTestHarness.ofClass(MyPTF.class) + * .withTableArgument("input", DataTypes.of("ROW<id INT, name STRING>")) + * .withScalarArgument("threshold", 100) + * .build(); + * + * harness.processElement(Row.of(1, "Alice")); + * harness.processElement(Row.of(2, "Bob")); + * + * List<Row> output = harness.getOutput(); + * }</pre> + */ +@PublicEvolving +public class ProcessTableFunctionTestHarness<OUT> implements AutoCloseable { + + private final ProcessTableFunction<OUT> function; + private final FunctionContext functionContext; + private final List<OUT> output; + private boolean isOpen; + private final HarnessCollector collector; + + private final String defaultTableArgument; + private final Method evalMethod; + private final List<ArgumentInfo> arguments; + + private final Map<String, ArgumentInfo> argumentsByName; + private final boolean isSingleTableFunction; + private final Map<String, Object> scalarArgumentValues; + + private final Map<String, DataStructureConverter<Object, Object>> inputConverters; + private final Map<String, DataStructureConverter<Object, Object>> outputConverters; + + private ProcessTableFunctionTestHarness( + ProcessTableFunction<OUT> function, + FunctionContext functionContext, + String defaultTableArgument, + Method evalMethod, + List<ArgumentInfo> arguments, + Map<String, ArgumentInfo> argumentsByName, + boolean isSingleTableFunction, + Map<String, Object> scalarArgumentValues, + Map<String, DataStructureConverter<Object, Object>> inputConverters, + Map<String, DataStructureConverter<Object, Object>> outputConverters) + throws Exception { + this.function = function; + this.functionContext = functionContext; + this.defaultTableArgument = defaultTableArgument; + this.evalMethod = evalMethod; + this.arguments = arguments; + this.argumentsByName = argumentsByName; + this.isSingleTableFunction = isSingleTableFunction; + this.scalarArgumentValues = scalarArgumentValues; + this.inputConverters = inputConverters; + this.outputConverters = outputConverters; + this.output = new ArrayList<>(); + this.collector = new HarnessCollector(); + this.isOpen = false; + + openFunction(); + } + + /** Creates a new harness builder for the given ProcessTableFunction class. */ + public static <OUT> Builder<OUT> ofClass( + Class<? extends ProcessTableFunction<OUT>> functionClass) { + return new Builder<>(functionClass); + } + + private void openFunction() throws Exception { + function.open(functionContext); + function.setCollector(collector); + isOpen = true; + } + + @Override + public void close() throws Exception { + if (isOpen) { + function.close(); + isOpen = false; + } + } + + /** + * Process a single element for the default table argument. + * + * <p>For PTFs with a single table argument, this processes one row. For multiple table + * arguments, use {@link #processElementForTable(String, Row)}. + */ + public void processElement(Row row) throws Exception { + if (!isSingleTableFunction) { + throw new IllegalStateException( + "PTF has multiple table arguments. Use processElementForTable(argumentName, row) " + + "to specify which table argument should receive the row."); + } + + processElementForTable(defaultTableArgument, row); + } + + /** Process a single element constructed from values. */ + public void processElement(Object... values) throws Exception { + processElement(Row.of(values)); + } + + /** Process a single element with a specific RowKind. */ + public void processElement(RowKind rowKind, Object... values) throws Exception { + processElement(Row.ofKind(rowKind, values)); + } + + /** Process a single element for a specific table argument. */ + public void processElementForTable(String tableArgument, Row row) throws Exception { + checkState(isOpen, "Harness not open"); + checkNotNull(tableArgument, "tableArgument must not be null"); + + // Try named arguments first + ArgumentInfo tableArg = argumentsByName.get(tableArgument); + if (tableArg == null) { + throw new IllegalArgumentException("Unknown table argument: " + tableArgument); + } else { + invokeEval(tableArg, row); + } + } + + /** Process a single element for a specific table argument. */ + public void processElementForTable(String tableArgument, Object... values) throws Exception { + processElementForTable(tableArgument, Row.of(values)); + } + + /** Process a single element for a specific table argument with RowKind. */ + public void processElementForTable(String tableArgument, RowKind rowKind, Object... values) + throws Exception { + processElementForTable(tableArgument, Row.ofKind(rowKind, values)); + } + + /** + * Invokes the PTF's eval() method with scalar arguments only. + * + * <p>This method is specifically for scalar-only PTFs (PTFs with only scalar arguments and no + * table arguments). For PTFs that accept table arguments, use {@link #processElement(Row)} or + * {@link #processElementForTable(String, Row)} instead. + * + * @throws IllegalStateException if the PTF has any table arguments + * @throws Exception if the eval() invocation fails + */ + public void invoke() throws Exception { + checkState(isOpen, "Harness not open"); + + // Validate this is a scalar-only PTF + boolean hasTableArguments = arguments.stream().anyMatch(arg -> arg.isTableArgument); + if (hasTableArguments) { + throw new IllegalStateException( + "invoke() is only for scalar-only PTFs. This PTF has table arguments. " + + "Use processElement() or processElementForTable() instead."); + } + + // Clear collector context since there's no active table argument + collector.setContext(null, null); + + // Build arguments array with only scalar values + Object[] args = new Object[arguments.size()]; + for (int i = 0; i < arguments.size(); i++) { + ArgumentInfo arg = arguments.get(i); + if (arg.isScalar) { + args[i] = scalarArgumentValues.get(arg.name); + } else { + throw new IllegalStateException( + "Unexpected non-scalar argument at position " + i + ": " + arg.name); + } + } + + // Invoke eval() method + try { + evalMethod.invoke(function, args); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception) { + Exception userException = (Exception) cause; + userException.addSuppressed( + new Exception( + String.format( + "Exception occurred during scalar-only PTF eval() invocation. " + + "Scalar arguments: %s", + scalarArgumentValues))); + throw userException; + } else { + throw new RuntimeException("Error invoking PTF eval() method", e); + } + } + } + + /** Returns all collected output rows. */ + public List<OUT> getOutput() { + return List.copyOf(output); + } + + /** Clears all collected output. */ + public void clearOutput() { + output.clear(); + } + + /** + * Given a target table argument and a row to process, construct the right set of arguments for + * the PTF's eval function and attempt to invoke it. + */ + private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws Exception { + // Set collector context so it can prepend columns if needed + collector.setContext(activeTableArg, activeRow); + + Object[] args = new Object[arguments.size()]; + + for (int i = 0; i < arguments.size(); i++) { + ArgumentInfo arg = arguments.get(i); + + if (arg.isTableArgument && arg.name.equals(activeTableArg.name)) { + // If the argument is the active table argument, first convert the input row + // to an internal RowData type, and then convert the RowData to type that the + // argument expects. For Rows, this will structure the Row based on the table + // argument structure. Otherwise, for POJOs, it will pass the expected POJO to eval. + + DataStructureConverter<Object, Object> inputConverter = + inputConverters.get(arg.name); + DataStructureConverter<Object, Object> outputConverter = + outputConverters.get(arg.name); + + args[i] = + outputConverter.toExternalOrNull( + inputConverter.toInternalOrNull(activeRow)); + + } else if (arg.isScalar) { + // If the argument is a scalar argument, pull it from the predefined scalar + // argument fixtures. + args[i] = scalarArgumentValues.get(arg.name); + + } else if (arg.isTableArgument) { + // If the argument is a table argument but is not the current, active table argument + // then just pass in null. + args[i] = null; + } else { + throw new IllegalStateException( + "Unexpected argument type at position " + i + ": " + arg.name); + } + } + + try { + evalMethod.invoke(function, args); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception) { + Exception userException = (Exception) cause; + String partitionInfo = + activeTableArg.partitionColumnNames != null + && activeTableArg.partitionColumnNames.length > 0 + ? String.format( + ", partition columns: %s", + Arrays.toString(activeTableArg.partitionColumnNames)) + : ", no partitioning"; + userException.addSuppressed( + new Exception( + String.format( + "Exception occurred during PTF eval() while processing table argument '%s'%s. " + + "Input row: %s", + activeTableArg.name, partitionInfo, activeRow))); + throw userException; + } else { + throw new RuntimeException("Error invoking PTF eval() method", e); + } + } + } + + /** + * Collector implementation that stores output in the harness. + * + * <p>For SET_SEMANTIC_TABLE arguments, automatically prepends partition key columns to the PTF + * output. If the argument has PASS_COLUMNS_THROUGH trait, prepends all input columns. + */ + private class HarnessCollector implements Collector<OUT> { + // Context set before each eval() invocation + private ArgumentInfo activeTableArg; + private Row activeRow; + + void setContext(ArgumentInfo tableArg, Row row) { + this.activeTableArg = tableArg; + this.activeRow = row; + } + + @Override + public void collect(OUT record) { + if (activeTableArg == null || !activeTableArg.isTableArgument) { + // No active table argument or it's scalar - just collect as-is + output.add(record); + return; + } + + // Determine which columns to prepend + if (activeTableArg.hasPassColumnsThrough) { + // PASS_COLUMNS_THROUGH: Prepend ALL input columns + output.add(prependAllColumns(record)); + } else if (activeTableArg.isSetSemantic + && activeTableArg.partitionColumnNames != null) { + // SET_SEMANTIC_TABLE: Prepend partition key columns only + output.add(prependPartitionKeys(record)); + } else { + // ROW_SEMANTIC_TABLE or no partitioning: no prepending + output.add(record); + } + } + + @SuppressWarnings("unchecked") + private OUT prependPartitionKeys(OUT ptfOutput) { + if (!(ptfOutput instanceof Row)) { + throw new IllegalStateException( + "Cannot prepend partition keys to non-Row output type: " + + ptfOutput.getClass()); + } + + Row ptfRow = (Row) ptfOutput; + + // For multi-table PTFs, prepend partition keys from ALL SET_SEMANTIC_TABLE arguments + // Active table contributes actual partition key values, inactive tables contribute + // nulls + int totalPartitionKeyCount = 0; + for (ArgumentInfo arg : arguments) { + if (arg.isSetSemantic && arg.partitionColumnNames != null) { + totalPartitionKeyCount += arg.partitionColumnNames.length; + } + } + + int ptfOutputArity = ptfRow.getArity(); + int totalArity = totalPartitionKeyCount + ptfOutputArity; + + Row result = new Row(ptfRow.getKind(), totalArity); + + // Prepend partition key values from all SET_SEMANTIC_TABLE arguments + int resultIndex = 0; + for (ArgumentInfo arg : arguments) { + if (arg.isSetSemantic && arg.partitionColumnNames != null) { + // Check if this is the active table + boolean isActive = arg.name.equals(activeTableArg.name); + + for (String columnName : arg.partitionColumnNames) { + if (isActive) { + // Active table: extract partition key value from input row + // Convert column name to position index + int columnIndex = getFieldIndex(arg.dataType, columnName); + result.setField(resultIndex++, activeRow.getField(columnIndex)); + } else { + // Inactive table: use null + result.setField(resultIndex++, null); + } + } + } + } + + // Append PTF output + for (int i = 0; i < ptfOutputArity; i++) { + result.setField(resultIndex++, ptfRow.getField(i)); + } + + return (OUT) result; + } + + /** Helper to get field index by name from a DataType. */ + private int getFieldIndex(DataType dataType, String fieldName) { + RowType rowType = (RowType) dataType.getLogicalType(); + int index = 0; + for (RowType.RowField field : rowType.getFields()) { + if (field.getName().equals(fieldName)) { + return index; + } + index++; + } + throw new IllegalStateException( + String.format("Field '%s' not found in type %s", fieldName, dataType)); + } + + @SuppressWarnings("unchecked") + private OUT prependAllColumns(OUT ptfOutput) { + if (!(ptfOutput instanceof Row)) { + throw new IllegalStateException( + "Cannot prepend columns to non-Row output type: " + ptfOutput.getClass()); + } + + Row ptfRow = (Row) ptfOutput; + int inputArity = activeRow.getArity(); + int ptfOutputArity = ptfRow.getArity(); + int totalArity = inputArity + ptfOutputArity; + + Row result = new Row(ptfRow.getKind(), totalArity); + + // Prepend ALL input columns + for (int i = 0; i < inputArity; i++) { + result.setField(i, activeRow.getField(i)); + } + + // Append PTF output + for (int i = 0; i < ptfOutputArity; i++) { + result.setField(inputArity + i, ptfRow.getField(i)); + } + + return (OUT) result; + } + + @Override + public void close() {} + } + + /** + * Builder for {@link ProcessTableFunctionTestHarness}. + * + * @param <OUT> The output type of the ProcessTableFunction + */ + @PublicEvolving + public static class Builder<OUT> { + private final Class<? extends ProcessTableFunction<OUT>> functionClass; + + // Position counter for positional-only arguments (those without @ArgumentHint names) + // Tracks eval() signature position for consistent dual-lookup + private int nextPosition = 0; + private final LinkedHashMap<String, ScalarArgumentConfiguration> scalarArgs = + new LinkedHashMap<>(); + private final LinkedHashMap<String, TableArgumentConfiguration> tableArgs = + new LinkedHashMap<>(); + private final Map<String, PartitionConfiguration> partitionConfigs = new HashMap<>(); + + private Builder(Class<? extends ProcessTableFunction<OUT>> functionClass) { + this.functionClass = checkNotNull(functionClass, "functionClass must not be null"); + } + + // --------------------------------------------------------------------- + // Table & Scalar Arguments + // --------------------------------------------------------------------- + + /** + * Configures a table argument with its schema (named argument). + * + * <p>Use this for dynamic tables that receive elements during the test. Elements are + * provided via {@link #processElement(Row)} or {@link #processElementForTable(String, + * Row)}. + * + * @param argumentName The table argument name + * @param dataType The schema/structure of the table + */ + public Builder<OUT> withTableArgument(String argumentName, AbstractDataType<?> dataType) { + checkNotNull(argumentName, "argumentName must not be null"); + checkNotNull(dataType, "dataType must not be null"); + + if (scalarArgs.containsKey(argumentName)) { + throw new IllegalArgumentException( + "Argument already configured as scalar: " + argumentName); + } + + if (tableArgs.containsKey(argumentName)) { + throw new IllegalArgumentException( + "Table argument already configured: " + argumentName); + } + + TableArgumentConfiguration config = new TableArgumentConfiguration(argumentName); + config.explicitType = dataType; + tableArgs.put(argumentName, config); + return this; + } + + /** + * Configures a scalar (non-table) argument for the PTF's eval() method. + * + * <p>Scalar arguments are constant values passed to every eval() invocation, such as + * thresholds, multipliers, or configuration parameters. + * + * @param argumentName Must match the parameter name in eval() or the @ArgumentHint name + * @param value The value to pass for this argument in all eval() calls + */ + public Builder<OUT> withScalarArgument(String argumentName, Object value) { + checkNotNull(argumentName, "argumentName must not be null"); + checkNotNull(value, "value must not be null"); Review Comment: Good point! I've been a little overzealous with null checks here, I'll clean them up where it makes sense to -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
