This is an automated email from the ASF dual-hosted git repository. shengkai pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 9171194ef2647af1b55e58b98daeebabb6c84ad7 Author: Feng Jin <jinfeng1...@gmail.com> AuthorDate: Mon Jan 22 11:41:12 2024 +0800 [FLINK-34057][table] Support named parameters for functions Co-authored-by: xuyang <xyzhong...@163.com> --- .../types/extraction/BaseMappingExtractor.java | 14 +- .../table/types/extraction/DataTypeExtractor.java | 9 +- .../table/types/extraction/DataTypeTemplate.java | 9 + .../table/types/extraction/ExtractionUtils.java | 13 +- .../extraction/FunctionSignatureTemplate.java | 6 + .../table/types/extraction/FunctionTemplate.java | 53 ++++- .../types/extraction/TypeInferenceExtractor.java | 4 +- .../extraction/TypeInferenceExtractorTest.java | 233 ++++++++++++++++++++- .../calcite/sql/validate/SqlValidatorImpl.java | 15 +- .../apache/calcite/sql2rel/SqlToRelConverter.java | 22 +- .../planner/calcite/RexSetSemanticsTableCall.java | 6 + .../inference/TypeInferenceOperandChecker.java | 44 +++- .../logical/FlinkLogicalTableFunctionScan.scala | 7 +- .../planner/runtime/stream/sql/FunctionITCase.java | 181 +++++++++++++++- 14 files changed, 585 insertions(+), 31 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java index fda3e8d6dd1..fdd12e1d795 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java @@ -18,6 +18,7 @@ package org.apache.flink.table.types.extraction; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.ValidationException; @@ -362,7 +363,18 @@ abstract class BaseMappingExtractor { Method method, int paramPos) { final Parameter parameter = method.getParameters()[paramPos]; final DataTypeHint hint = parameter.getAnnotation(DataTypeHint.class); - if (hint != null) { + final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class); + if (hint != null && argumentHint != null) { + throw extractionError( + "Argument and dataType hints cannot be declared in the same parameter at position %d.", + paramPos); + } + if (argumentHint != null) { + final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(argumentHint, null); + if (template.inputGroup != null) { + return Optional.of(FunctionArgumentTemplate.of(template.inputGroup)); + } + } else if (hint != null) { final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(hint, null); if (template.inputGroup != null) { return Optional.of(FunctionArgumentTemplate.of(template.inputGroup)); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java index e1f0122d10e..ed05ba7a99b 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java @@ -19,6 +19,7 @@ package org.apache.flink.table.types.extraction; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.dataview.DataView; @@ -140,8 +141,11 @@ public final class DataTypeExtractor { DataTypeFactory typeFactory, Class<?> baseClass, Method method, int paramPos) { final Parameter parameter = method.getParameters()[paramPos]; final DataTypeHint hint = parameter.getAnnotation(DataTypeHint.class); + final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class); final DataTypeTemplate template; - if (hint != null) { + if (argumentHint != null) { + template = DataTypeTemplate.fromAnnotation(typeFactory, argumentHint.type()); + } else if (hint != null) { template = DataTypeTemplate.fromAnnotation(typeFactory, hint); } else { template = DataTypeTemplate.fromDefaults(); @@ -261,8 +265,11 @@ public final class DataTypeExtractor { final Class<?> clazz = toClass(resolvedType); if (clazz != null) { final DataTypeHint hint = clazz.getAnnotation(DataTypeHint.class); + final ArgumentHint argumentHint = clazz.getAnnotation(ArgumentHint.class); if (hint != null) { template = outerTemplate.mergeWithInnerAnnotation(typeFactory, hint); + } else if (argumentHint != null) { + template = outerTemplate.mergeWithInnerAnnotation(typeFactory, argumentHint.type()); } } // main work diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeTemplate.java index d93badfd815..08426874ec2 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeTemplate.java @@ -20,6 +20,7 @@ package org.apache.flink.table.types.extraction; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.ExtractionVersion; import org.apache.flink.table.annotation.HintFlag; @@ -114,6 +115,14 @@ final class DataTypeTemplate { return fromAnnotation(hint, null); } + /** + * Creates an instance from the given {@link ArgumentHint} with a resolved data type if + * available. + */ + static DataTypeTemplate fromAnnotation(ArgumentHint argumentHint, @Nullable DataType dataType) { + return fromAnnotation(argumentHint.type(), dataType); + } + /** * Creates an instance from the given {@link DataTypeHint} with a resolved data type if * available. diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java index b0b9a74f636..9854e62eb0c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java @@ -20,6 +20,7 @@ package org.apache.flink.table.types.extraction; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.catalog.DataTypeFactory; @@ -44,7 +45,6 @@ import java.lang.reflect.Field; import java.lang.reflect.GenericArrayType; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.lang.reflect.Parameter; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; @@ -759,7 +759,16 @@ public final class ExtractionUtils { // so we need to extract them manually if possible List<String> parameterNames = Stream.of(executable.getParameters()) - .map(Parameter::getName) + .map( + parameter -> { + ArgumentHint argumentHint = + parameter.getAnnotation(ArgumentHint.class); + if (argumentHint != null && argumentHint.name() != "") { + return argumentHint.name(); + } else { + return parameter.getName(); + } + }) .collect(Collectors.toList()); if (parameterNames.stream().allMatch(n -> n.startsWith("arg"))) { final ParameterExtractor extractor; diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java index 80252572c3d..2efbbd0e8f6 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java @@ -26,6 +26,7 @@ import org.apache.flink.table.types.inference.InputTypeStrategy; import javax.annotation.Nullable; import java.lang.reflect.Array; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -61,6 +62,11 @@ final class FunctionSignatureTemplate { "Mismatch between number of argument names '%s' and argument types '%s'.", argumentNames.length, argumentTemplates.size()); } + if (argumentNames != null + && argumentNames.length != Arrays.stream(argumentNames).distinct().count()) { + throw extractionError( + "Argument name conflict, there are at least two argument names that are the same."); + } return new FunctionSignatureTemplate(argumentTemplates, isVarArgs, argumentNames); } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java index 68f23586ce2..80e84b9137a 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java @@ -19,6 +19,7 @@ package org.apache.flink.table.types.extraction; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.FunctionHint; import org.apache.flink.table.annotation.ProcedureHint; @@ -68,6 +69,7 @@ final class FunctionTemplate { typeFactory, defaultAsNull(hint, FunctionHint::input), defaultAsNull(hint, FunctionHint::argumentNames), + defaultAsNull(hint, FunctionHint::argument), hint.isVarArgs()), createResultTemplate(typeFactory, defaultAsNull(hint, FunctionHint::accumulator)), createResultTemplate(typeFactory, defaultAsNull(hint, FunctionHint::output))); @@ -83,6 +85,7 @@ final class FunctionTemplate { typeFactory, defaultAsNull(hint, ProcedureHint::input), defaultAsNull(hint, ProcedureHint::argumentNames), + defaultAsNull(hint, ProcedureHint::argument), hint.isVarArgs()), null, createResultTemplate(typeFactory, defaultAsNull(hint, ProcedureHint::output))); @@ -145,6 +148,7 @@ final class FunctionTemplate { @ProcedureHint @FunctionHint + @ArgumentHint private static class DefaultAnnotationHelper { // no implementation } @@ -161,6 +165,10 @@ final class FunctionTemplate { return defaultAsNull(hint, getDefaultAnnotation(ProcedureHint.class), accessor); } + private static <T> T defaultAsNull(ArgumentHint hint, Function<ArgumentHint, T> accessor) { + return defaultAsNull(hint, getDefaultAnnotation(ArgumentHint.class), accessor); + } + private static <T, H extends Annotation> T defaultAsNull( H hint, H defaultHint, Function<H, T> accessor) { final T defaultValue = accessor.apply(defaultHint); @@ -173,18 +181,55 @@ final class FunctionTemplate { private static @Nullable FunctionSignatureTemplate createSignatureTemplate( DataTypeFactory typeFactory, - @Nullable DataTypeHint[] input, + @Nullable DataTypeHint[] inputs, @Nullable String[] argumentNames, + @Nullable ArgumentHint[] argumentHints, boolean isVarArg) { - if (input == null) { + + String[] argumentHintNames; + DataTypeHint[] argumentHintTypes; + + if (argumentHints != null && inputs != null) { + throw extractionError( + "Argument and input hints cannot be declared in the same function hint."); + } + + if (argumentHints != null) { + argumentHintNames = new String[argumentHints.length]; + argumentHintTypes = new DataTypeHint[argumentHints.length]; + boolean allArgumentNameNotSet = true; + for (int i = 0; i < argumentHints.length; i++) { + ArgumentHint argumentHint = argumentHints[i]; + argumentHintNames[i] = defaultAsNull(argumentHint, ArgumentHint::name); + argumentHintTypes[i] = defaultAsNull(argumentHint, ArgumentHint::type); + if (argumentHintTypes[i] == null) { + throw extractionError("The type of the argument at position %d is not set.", i); + } + if (argumentHintNames[i] != null) { + allArgumentNameNotSet = false; + } else if (!allArgumentNameNotSet) { + throw extractionError( + "The argument name in function hint must be either fully set or not set at all."); + } + } + if (allArgumentNameNotSet) { + argumentHintNames = null; + } + } else { + argumentHintTypes = inputs; + argumentHintNames = argumentNames; + } + + if (argumentHintTypes == null) { return null; } + return FunctionSignatureTemplate.of( - Arrays.stream(input) + Arrays.stream(argumentHintTypes) .map(dataTypeHint -> createArgumentTemplate(typeFactory, dataTypeHint)) .collect(Collectors.toList()), isVarArg, - argumentNames); + argumentHintNames); } private static FunctionArgumentTemplate createArgumentTemplate( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java index 2ba0cc25b5b..b817efa2b84 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java @@ -254,14 +254,14 @@ public final class TypeInferenceExtractor { if (signatures.stream().anyMatch(s -> s.isVarArgs || s.argumentNames == null)) { return; } - final Set<List<String>> argumentNames = + final List<List<String>> argumentNames = signatures.stream() .map( s -> { assert s.argumentNames != null; return Arrays.asList(s.argumentNames); }) - .collect(Collectors.toSet()); + .collect(Collectors.toList()); if (argumentNames.size() != 1) { return; } diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java index 97e08fc2a2a..483e156636e 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java @@ -19,6 +19,7 @@ package org.apache.flink.table.types.extraction; import org.apache.flink.core.testutils.FlinkAssertions; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.FunctionHint; import org.apache.flink.table.annotation.InputGroup; @@ -427,8 +428,8 @@ class TypeInferenceExtractorTest { "Could not find a publicly accessible method named 'eval'."), // named arguments with overloaded function - TestSpec.forScalarFunction(NamedArgumentsScalarFunction.class) - .expectNamedArguments("n"), + // expected no named argument for overloaded function + TestSpec.forScalarFunction(NamedArgumentsScalarFunction.class), // scalar function that takes any input TestSpec.forScalarFunction(InputGroupScalarFunction.class) @@ -535,7 +536,80 @@ class TypeInferenceExtractorTest { new String[] {}, new ArgumentTypeStrategy[] {}), TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())) - .bridgedTo(RowData.class)))); + .bridgedTo(RowData.class))), + TestSpec.forScalarFunction( + "Scalar function with arguments hints", + ArgumentHintScalarFunction.class) + .expectNamedArguments("f1", "f2") + .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) + .expectOutputMapping( + InputTypeStrategies.sequence( + new String[] {"f1", "f2"}, + new ArgumentTypeStrategy[] { + InputTypeStrategies.explicit(DataTypes.STRING()), + InputTypeStrategies.explicit(DataTypes.INT()) + }), + TypeStrategies.explicit(DataTypes.STRING())), + TestSpec.forScalarFunction( + "Scalar function with arguments hints missing type", + ArgumentHintMissingTypeScalarFunction.class) + .expectErrorMessage("The type of the argument at position 0 is not set."), + TestSpec.forScalarFunction( + "Scalar function with arguments hints all missing name", + ArgumentHintMissingNameScalarFunction.class) + .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()), + TestSpec.forScalarFunction( + "Scalar function with arguments hints all missing partial name", + ArgumentHintMissingPartialNameScalarFunction.class) + .expectErrorMessage( + "The argument name in function hint must be either fully set or not set at all."), + TestSpec.forScalarFunction( + "Scalar function with arguments hints name conflict", + ArgumentHintNameConflictScalarFunction.class) + .expectErrorMessage( + "Argument name conflict, there are at least two argument names that are the same."), + TestSpec.forScalarFunction( + "Scalar function with arguments hints on method parameter", + ArgumentHintOnParameterScalarFunction.class) + .expectNamedArguments("in1", "in2") + .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) + .expectOutputMapping( + InputTypeStrategies.sequence( + new String[] {"in1", "in2"}, + new ArgumentTypeStrategy[] { + InputTypeStrategies.explicit(DataTypes.STRING()), + InputTypeStrategies.explicit(DataTypes.INT()) + }), + TypeStrategies.explicit(DataTypes.STRING())), + TestSpec.forScalarFunction( + "Scalar function with arguments hints and inputs hints both defined", + ArgumentsAndInputsScalarFunction.class) + .expectErrorMessage( + "Argument and input hints cannot be declared in the same function hint."), + TestSpec.forScalarFunction( + "Scalar function with argument hint and dataType hint declared in the same parameter", + ArgumentsHintAndDataTypeHintScalarFunction.class) + .expectErrorMessage( + "Argument and dataType hints cannot be declared in the same parameter at position 0."), + TestSpec.forScalarFunction( + "An invalid scalar function that declare FunctionHint for both class and method in the same class.", + InvalidFunctionHintOnClassAndMethod.class) + .expectErrorMessage( + "Argument and input hints cannot be declared in the same function hint."), + TestSpec.forScalarFunction( + "A valid scalar class that declare FunctionHint for both class and method in the same class.", + ValidFunctionHintOnClassAndMethod.class) + .expectNamedArguments("f1", "f2") + .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()), + TestSpec.forScalarFunction( + "The FunctionHint of the function conflicts with the method.", + ScalarFunctionWithFunctionHintConflictMethod.class) + .expectErrorMessage( + "Considering all hints, the method should comply with the signature"), + // For function with overloaded function, argument name will be empty + TestSpec.forScalarFunction( + "Scalar function with overloaded functions and arguments hint declared.", + ArgumentsHintScalarFunctionWithOverloadedFunction.class)); } private static Stream<TestSpec> procedureSpecs() { @@ -700,7 +774,8 @@ class TypeInferenceExtractorTest { TypeStrategies.explicit( DataTypes.DOUBLE().notNull().bridgedTo(double.class))), // named arguments with overloaded function - TestSpec.forProcedure(NamedArgumentsProcedure.class).expectNamedArguments("n"), + // expected no named argument for overloaded function + TestSpec.forProcedure(NamedArgumentsProcedure.class), // scalar function that takes any input TestSpec.forProcedure(InputGroupProcedure.class) @@ -1585,4 +1660,154 @@ class TypeInferenceExtractorTest { private static class DataTypeHintOnScalarFunctionAsync extends AsyncScalarFunction { public void eval(@DataTypeHint("ROW<i INT>") CompletableFuture<RowData> f) {} } + + private static class ArgumentHintScalarFunction extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentHintMissingTypeScalarFunction extends ScalarFunction { + @FunctionHint(argument = {@ArgumentHint(name = "f1"), @ArgumentHint(name = "f2")}) + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentHintMissingNameScalarFunction extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING")), + @ArgumentHint(type = @DataTypeHint("INTEGER")) + }) + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentHintMissingPartialNameScalarFunction extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "in1"), + @ArgumentHint(type = @DataTypeHint("INTEGER")) + }) + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentHintNameConflictScalarFunction extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(name = "in1", type = @DataTypeHint("STRING")), + @ArgumentHint(name = "in1", type = @DataTypeHint("INTEGER")) + }) + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentHintOnParameterScalarFunction extends ScalarFunction { + public String eval( + @ArgumentHint(type = @DataTypeHint("STRING"), name = "in1") String f1, + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "in2") Integer f2) { + return ""; + } + } + + private static class ArgumentsAndInputsScalarFunction extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }, + input = {@DataTypeHint("STRING"), @DataTypeHint("INTEGER")}) + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentsHintAndDataTypeHintScalarFunction extends ScalarFunction { + + public String eval( + @DataTypeHint("STRING") @ArgumentHint(name = "f1", type = @DataTypeHint("STRING")) + String f1, + @ArgumentHint(name = "f2", type = @DataTypeHint("INTEGER")) Integer f2) { + return ""; + } + } + + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + private static class InvalidFunctionHintOnClassAndMethod extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }, + input = {@DataTypeHint("STRING"), @DataTypeHint("INTEGER")}) + public String eval(String f1, Integer f2) { + return ""; + } + } + + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + private static class ValidFunctionHintOnClassAndMethod extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + public String eval(String f1, Integer f2) { + return ""; + } + } + + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + private static class ScalarFunctionWithFunctionHintConflictMethod extends ScalarFunction { + public String eval(String f1, Integer f2) { + return ""; + } + } + + private static class ArgumentsHintScalarFunctionWithOverloadedFunction extends ScalarFunction { + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("INTEGER"), name = "f2") + }) + public String eval(String f1, Integer f2) { + return ""; + } + + @FunctionHint( + argument = { + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f1"), + @ArgumentHint(type = @DataTypeHint("STRING"), name = "f2") + }) + public String eval(String f1, String f2) { + return ""; + } + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index 28c3504f97e..3322f4f458a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -158,13 +158,16 @@ import static org.apache.calcite.util.Static.RESOURCE; * Default implementation of {@link SqlValidator}, the class was copied over because of * CALCITE-4554. * - * <p>Lines 1954 ~ 1977, Flink improves error message for functions without appropriate arguments in + * <p>Lines 1958 ~ 1978, Flink improves error message for functions without appropriate arguments in * handleUnresolvedFunction at {@link SqlValidatorImpl#handleUnresolvedFunction}. * - * <p>Lines 5101 ~ 5114, Flink enables TIMESTAMP and TIMESTAMP_LTZ for system time period + * <p>Lines 3736 ~ 3740, Flink improves Optimize the retrieval of sub-operands in SqlCall when using + * NamedParameters at {@link SqlValidatorImpl#checkRollUp}. + * + * <p>Lines 5108 ~ 5121, Flink enables TIMESTAMP and TIMESTAMP_LTZ for system time period * specification type at {@link org.apache.calcite.sql.validate.SqlValidatorImpl#validateSnapshot}. * - * <p>Lines 5458 ~ 5464, Flink enables TIMESTAMP and TIMESTAMP_LTZ for first orderBy column in + * <p>Lines 5465 ~ 5471, Flink enables TIMESTAMP and TIMESTAMP_LTZ for first orderBy column in * matchRecognize at {@link SqlValidatorImpl#validateMatchRecognize}. */ public class SqlValidatorImpl implements SqlValidatorWithHints { @@ -3730,7 +3733,11 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { // can be another SqlCall, or an SqlIdentifier. checkRollUp(grandParent, parent, stripDot, scope, contextClause); } else { - List<? extends @Nullable SqlNode> children = ((SqlCall) stripDot).getOperandList(); + // ----- FLINK MODIFICATION BEGIN ----- + SqlCall call = (SqlCall) stripDot; + List<? extends @Nullable SqlNode> children = + new SqlCallBinding(this, scope, call).operands(); + // ----- FLINK MODIFICATION END ----- for (SqlNode child : children) { checkRollUp(parent, current, child, scope, contextClause); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index e36b5e91f49..bccced1342b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -234,13 +234,14 @@ import static org.apache.flink.util.Preconditions.checkNotNull; * <p>FLINK modifications are at lines * * <ol> - * <li>Added in FLINK-29081, FLINK-28682, FLINK-33395: Lines 653 ~ 670 - * <li>Added in Flink-24024: Lines 1434 ~ 1444, Lines 1458 ~ 1500 - * <li>Added in FLINK-28682: Lines 2322 ~ 2339 - * <li>Added in FLINK-28682: Lines 2376 ~ 2404 - * <li>Added in FLINK-32474: Lines 2874 ~ 2886 - * <li>Added in FLINK-32474: Lines 2986 ~ 3020 - * <li>Added in FLINK-20873: Lines 5518 ~ 5527 + * <li>Added in FLINK-29081, FLINK-28682, FLINK-33395: Lines 654 ~ 671 + * <li>Added in Flink-24024: Lines 1435 ~ 1445, Lines 1459 ~ 1501 + * <li>Added in FLINK-28682: Lines 2323 ~ 2340 + * <li>Added in FLINK-28682: Lines 2377 ~ 2405 + * <li>Added in FLINK-32474: Lines 2875 ~ 2887 + * <li>Added in FLINK-32474: Lines 2987 ~ 3021 + * <li>Added in FLINK-20873: Lines 5519 ~ 5528 + * <li>Added in FLINK-34057: Lines 6089 ~ 6092 * </ol> */ @SuppressWarnings("UnstableApiUsage") @@ -6085,8 +6086,11 @@ public class SqlToRelConverter { try { // switch out of agg mode bb.agg = null; - for (SqlNode operand : call.getOperandList()) { - + // ----- FLINK MODIFICATION BEGIN ----- + for (SqlNode operand : + new SqlCallBinding(validator(), aggregatingSelectScope, call).operands()) + // ----- FLINK MODIFICATION END ----- + { // special case for COUNT(*): delete the * if (operand instanceof SqlIdentifier) { SqlIdentifier id = (SqlIdentifier) operand; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexSetSemanticsTableCall.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexSetSemanticsTableCall.java index f2d2ce07b1e..2f038a6662c 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexSetSemanticsTableCall.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexSetSemanticsTableCall.java @@ -92,4 +92,10 @@ public class RexSetSemanticsTableCall extends RexCall { List<? extends RexNode> newOperands, int[] newPartitionKeys, int[] newOrderKeys) { return new RexSetSemanticsTableCall(type, op, newOperands, newPartitionKeys, newOrderKeys); } + + @Override + public RexSetSemanticsTableCall clone(RelDataType type, List<RexNode> operands) { + return new RexSetSemanticsTableCall( + type, getOperator(), operands, partitionKeys, orderKeys); + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java index a98efee8656..5c122ca5038 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java @@ -30,20 +30,24 @@ import org.apache.flink.table.types.inference.TypeInferenceUtil; import org.apache.flink.table.types.logical.LogicalType; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlOperandMetadata; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorNamespace; import java.util.List; +import java.util.stream.Collectors; import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; +import static org.apache.flink.table.planner.typeutils.LogicalRelDataTypeConverter.toRelDataType; import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory; import static org.apache.flink.table.types.inference.TypeInferenceUtil.adaptArguments; import static org.apache.flink.table.types.inference.TypeInferenceUtil.createInvalidCallException; @@ -57,7 +61,8 @@ import static org.apache.flink.table.types.logical.utils.LogicalTypeCasts.suppor * <p>Note: This class must be kept in sync with {@link TypeInferenceUtil}. */ @Internal -public final class TypeInferenceOperandChecker implements SqlOperandTypeChecker { +public final class TypeInferenceOperandChecker + implements SqlOperandTypeChecker, SqlOperandMetadata { private final DataTypeFactory dataTypeFactory; @@ -114,6 +119,40 @@ public final class TypeInferenceOperandChecker implements SqlOperandTypeChecker return false; } + @Override + public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) { + return typeInference + .getTypedArguments() + .map( + types -> + types.stream() + .map( + type -> + toRelDataType( + type.getLogicalType(), typeFactory)) + .collect(Collectors.toList())) + .orElseThrow( + () -> + new ValidationException( + "Could not find the argument types. " + + "Currently named arguments are not supported " + + "for varArgs and multi different argument names " + + "with overload function")); + } + + @Override + public List<String> paramNames() { + return typeInference + .getNamedArguments() + .orElseThrow( + () -> + new ValidationException( + "Could not find the argument names. " + + "Currently named arguments are not supported " + + "for varArgs and multi different argument names " + + "with overload function")); + } + // -------------------------------------------------------------------------------------------- private boolean checkOperandTypesOrError(SqlCallBinding callBinding, CallContext callContext) { @@ -134,7 +173,8 @@ public final class TypeInferenceOperandChecker implements SqlOperandTypeChecker final List<SqlNode> operands = callBinding.operands(); for (int i = 0; i < operands.size(); i++) { final LogicalType expectedType = expectedDataTypes.get(i).getLogicalType(); - final LogicalType argumentType = toLogicalType(callBinding.getOperandType(i)); + final LogicalType argumentType = + toLogicalType(SqlTypeUtil.deriveType(callBinding, operands.get(i))); if (!supportsAvoidingCast(argumentType, expectedType)) { final RelDataType expectedRelDataType = diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala index c402ab6c770..ca664d7b63e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala @@ -106,12 +106,17 @@ class FlinkLogicalTableFunctionScanConverter(config: Config) extends ConverterRu val scan = rel.asInstanceOf[LogicalTableFunctionScan] val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL).simplify() val newInputs = scan.getInputs.map(input => RelOptRule.convert(input, FlinkConventions.LOGICAL)) + val rexCall = scan.getCall.asInstanceOf[RexCall]; + val builder = rel.getCluster.getRexBuilder + // When rexCall uses NamedArguments, RexCall is not inferred with the correct type. + // We just use the type of scan as the type of RexCall. + val newCall = rexCall.clone(rel.getRowType, rexCall.getOperands) new FlinkLogicalTableFunctionScan( scan.getCluster, traitSet, newInputs, - scan.getCall, + newCall, scan.getElementType, scan.getRowType, scan.getColumnMappings diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java index 8f562b4f189..a137e3578ed 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java @@ -21,6 +21,7 @@ package org.apache.flink.table.planner.runtime.stream.sql; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.core.fs.Path; +import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.FunctionHint; import org.apache.flink.table.annotation.HintFlag; @@ -1056,7 +1057,108 @@ public class FunctionITCase extends StreamingTestBase { } @Test - void testInvalidUseOfScalarFunction() { + void testNamedArgumentsTableFunction() throws Exception { + final Row[] sinkData = new Row[] {Row.of("str1, str2")}; + + TestCollectionTableFactory.reset(); + + tEnv().executeSql("CREATE TABLE SinkTable(s STRING) WITH ('connector' = 'COLLECTION')"); + + tEnv().createFunction("NamedArgumentsTableFunction", NamedArgumentsTableFunction.class); + tEnv().executeSql( + "INSERT INTO SinkTable " + + "SELECT T1.s FROM TABLE(NamedArgumentsTableFunction(in2 => 'str2', in1 => 'str1')) AS T1(s) ") + .await(); + + assertThat(TestCollectionTableFactory.getResult()).containsExactlyInAnyOrder(sinkData); + } + + @Test + void testNamedArgumentsScalarFunction() throws Exception { + final List<Row> sourceData = + Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); + + final List<Row> sinkData = + Arrays.asList(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"), Row.of(5, 6, "5: 6")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + + tEnv().executeSql( + "CREATE TABLE TestTable(i1 INT NOT NULL, i2 INT NOT NULL, s1 STRING) WITH ('connector' = 'COLLECTION')"); + + tEnv().createTemporarySystemFunction( + "NamedArgumentsScalarFunction", NamedArgumentsScalarFunction.class); + tEnv().executeSql( + "INSERT INTO TestTable SELECT" + + " i1, i2," + + " NamedArgumentsScalarFunction(in2 => i2, in1 => i1) as s1 FROM TestTable") + .await(); + + assertThat(TestCollectionTableFactory.getResult()).isEqualTo(sinkData); + } + + @Test + void testNamedParametersScalarFunctionWithOverloadedMethod() throws Exception { + final List<Row> sourceData = + Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + + tEnv().executeSql( + "CREATE TABLE TestTable(i1 INT NOT NULL, i2 INT NOT NULL, s1 STRING) WITH ('connector' = 'COLLECTION')"); + tEnv().createTemporarySystemFunction( + "NamedArgumentsScalarFunction", + NamedArgumentsWithOverloadedScalarFunction.class); + + assertThatThrownBy( + () -> + tEnv().executeSql( + "INSERT INTO TestTable SELECT" + + " i1, i2," + + " NamedArgumentsScalarFunction(in2 => i2, in1 => i1) as s1 FROM TestTable") + .await()) + .hasMessageContaining( + "SQL validation failed. Could not find the argument names. Currently named arguments are not supported for varArgs and multi different argument names with overload function"); + } + + @Test + void testNamedArgumentAggregateFunction() throws Exception { + final List<Row> sourceData = + Arrays.asList( + Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "a", "b", 1, 2), + Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "c", "d", 33, 44), + Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "e", "f", 5, 6), + Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "gg", "hh", 7, 88)); + + final List<Row> sinkData = Arrays.asList(Row.of("a:b", "b:a"), Row.of("gg:hh", "hh:gg")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + + tEnv().executeSql( + "CREATE TABLE SourceTable(ts TIMESTAMP(3), s1 STRING, s2 STRING, i1 INT, i2 INT, WATERMARK FOR ts AS ts - INTERVAL '1' SECOND) " + + "WITH ('connector' = 'COLLECTION')"); + tEnv().executeSql( + "CREATE TABLE SinkTable(s1 STRING, s2 STRING) WITH ('connector' = 'COLLECTION')"); + + tEnv().createTemporarySystemFunction( + "NamedArgumentAggregateFunction", NamedArgumentAggregateFunction.class); + + tEnv().executeSql( + "INSERT INTO SinkTable " + + "SELECT NamedArgumentAggregateFunction(in2 => s2, in1 => s1)," + + "NamedArgumentAggregateFunction(in1 => s2, in2 => s1)" + + "FROM SourceTable " + + "GROUP BY TUMBLE(ts, INTERVAL '1' SECOND)") + .await(); + + assertThat(TestCollectionTableFactory.getResult()).isEqualTo(sinkData); + } + + @Test + public void testInvalidUseOfScalarFunction() { tEnv().executeSql( "CREATE TABLE SinkTable(s BIGINT NOT NULL) WITH ('connector' = 'COLLECTION')"); @@ -1393,6 +1495,42 @@ public class FunctionITCase extends StreamingTestBase { } } + /** Scalar function with argument hint. */ + public static class NamedArgumentsScalarFunction extends ScalarFunction { + @FunctionHint( + output = @DataTypeHint("STRING"), + argument = { + @ArgumentHint(name = "in1", type = @DataTypeHint("int")), + @ArgumentHint(name = "in2", type = @DataTypeHint("int")) + }) + public String eval(Integer arg1, Integer arg2) { + return (arg1 + ": " + arg2); + } + } + + /** Scalar function with overloaded functions and arguments declared. */ + public static class NamedArgumentsWithOverloadedScalarFunction extends ScalarFunction { + @FunctionHint( + output = @DataTypeHint("STRING"), + argument = { + @ArgumentHint(name = "in1", type = @DataTypeHint("int")), + @ArgumentHint(name = "in2", type = @DataTypeHint("int")) + }) + public String eval(Integer arg1, Integer arg2) { + return (arg1 + ": " + arg2); + } + + @FunctionHint( + output = @DataTypeHint("STRING"), + argument = { + @ArgumentHint(name = "in1", type = @DataTypeHint("string")), + @ArgumentHint(name = "in2", type = @DataTypeHint("string")) + }) + public String eval(String arg1, String arg2) { + return (arg1 + ": " + arg2); + } + } + /** Function that is overloaded and takes use of annotations. */ public static class ComplexScalarFunction extends ScalarFunction { public String eval( @@ -1509,6 +1647,17 @@ public class FunctionITCase extends StreamingTestBase { } } + /** Function that returns a string or integer. */ + public static class NamedArgumentsTableFunction extends TableFunction<Object> { + @FunctionHint( + input = {@DataTypeHint("STRING"), @DataTypeHint("STRING")}, + output = @DataTypeHint("STRING"), + argumentNames = {"in1", "in2"}) + public void eval(String arg1, String arg2) { + collect(arg1 + ", " + arg2); + } + } + /** * Function that returns which method has been called. * @@ -1623,6 +1772,36 @@ public class FunctionITCase extends StreamingTestBase { } } + /** Function that aggregates strings and finds the longest string. */ + public static class NamedArgumentAggregateFunction extends AggregateFunction<String, Row> { + + @Override + public Row createAccumulator() { + return Row.of((String) null); + } + + @FunctionHint( + input = {@DataTypeHint("STRING"), @DataTypeHint("STRING")}, + output = @DataTypeHint("STRING"), + argumentNames = {"in1", "in2"}, + accumulator = @DataTypeHint("ROW<longestString STRING>")) + public void accumulate(Row acc, String arg1, String arg2) { + if (arg1 == null || arg2 == null) { + return; + } + String value = arg1 + ":" + arg2; + final String longestString = (String) acc.getField(0); + if (longestString == null || longestString.length() < value.length()) { + acc.setField(0, value); + } + } + + @Override + public String getValue(Row acc) { + return (String) acc.getField(0); + } + } + /** Aggregate function that tests raw types in map views. */ public static class RawMapViewAggregateFunction extends AggregateFunction<String, RawMapViewAggregateFunction.AccWithRawView> {