This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 03839f1993b4f525d34789e2f53702200aa3d7bd Author: Timo Walther <twal...@apache.org> AuthorDate: Mon Jun 24 16:14:44 2019 +0200 [FLINK-12924][table] Add basic type inference logic This closes #8865. --- .../table/functions/BuiltInFunctionDefinition.java | 49 ++++- .../functions/BuiltInFunctionDefinitions.java | 136 ++++++++++++ .../table/types/inference/TypeInferenceUtil.java | 233 ++++++++++++++++++++- .../functions/InternalFunctionDefinitions.java | 3 + .../rules/ResolveCallByArgumentsRule.java | 141 ++++++++++++- 5 files changed, 555 insertions(+), 7 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java index af8a11d..bfa3fb0 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java @@ -19,8 +19,14 @@ package org.apache.flink.table.functions; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.InputTypeValidator; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.inference.TypeStrategy; import org.apache.flink.util.Preconditions; +import java.util.List; + /** * Definition of a built-in function. It enables unique identification across different * modules by reference equality. @@ -37,17 +43,29 @@ public final class BuiltInFunctionDefinition implements FunctionDefinition { private final FunctionKind kind; + private final TypeInference typeInference; + private BuiltInFunctionDefinition( String name, - FunctionKind kind) { + FunctionKind kind, + TypeInference typeInference) { this.name = Preconditions.checkNotNull(name, "Name must not be null."); this.kind = Preconditions.checkNotNull(kind, "Kind must not be null."); + this.typeInference = Preconditions.checkNotNull(typeInference, "Type inference must not be null."); } public String getName() { return name; } + /** + * Currently, the type inference is just exposed here. In the future, function definition will + * require it. + */ + public TypeInference getTypeInference() { + return typeInference; + } + @Override public FunctionKind getKind() { return kind; @@ -69,6 +87,8 @@ public final class BuiltInFunctionDefinition implements FunctionDefinition { private FunctionKind kind; + private TypeInference.Builder typeInferenceBuilder = new TypeInference.Builder(); + public Builder() { // default constructor to allow a fluent definition } @@ -83,8 +103,33 @@ public final class BuiltInFunctionDefinition implements FunctionDefinition { return this; } + public Builder inputTypeValidator(InputTypeValidator inputTypeValidator) { + this.typeInferenceBuilder.inputTypeValidator(inputTypeValidator); + return this; + } + + public Builder accumulatorTypeStrategy(TypeStrategy accumulatorTypeStrategy) { + this.typeInferenceBuilder.accumulatorTypeStrategy(accumulatorTypeStrategy); + return this; + } + + public Builder outputTypeStrategy(TypeStrategy outputTypeStrategy) { + this.typeInferenceBuilder.outputTypeStrategy(outputTypeStrategy); + return this; + } + + public Builder namedArguments(List<String> argumentNames) { + this.typeInferenceBuilder.namedArguments(argumentNames); + return this; + } + + public Builder typedArguments(List<DataType> argumentTypes) { + this.typeInferenceBuilder.typedArguments(argumentTypes); + return this; + } + public BuiltInFunctionDefinition build() { - return new BuiltInFunctionDefinition(name, kind); + return new BuiltInFunctionDefinition(name, kind, typeInferenceBuilder.build()); } } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index fd02e46..5c6a70c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -20,6 +20,7 @@ package org.apache.flink.table.functions; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.api.TableException; +import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.util.Preconditions; import java.lang.reflect.Field; @@ -44,21 +45,25 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("and") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition OR = new BuiltInFunctionDefinition.Builder() .name("or") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition NOT = new BuiltInFunctionDefinition.Builder() .name("not") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IF = new BuiltInFunctionDefinition.Builder() .name("ifThenElse") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // comparison functions @@ -66,71 +71,85 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("equals") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition GREATER_THAN = new BuiltInFunctionDefinition.Builder() .name("greaterThan") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition GREATER_THAN_OR_EQUAL = new BuiltInFunctionDefinition.Builder() .name("greaterThanOrEqual") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LESS_THAN = new BuiltInFunctionDefinition.Builder() .name("lessThan") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LESS_THAN_OR_EQUAL = new BuiltInFunctionDefinition.Builder() .name("lessThanOrEqual") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition NOT_EQUALS = new BuiltInFunctionDefinition.Builder() .name("notEquals") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IS_NULL = new BuiltInFunctionDefinition.Builder() .name("isNull") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IS_NOT_NULL = new BuiltInFunctionDefinition.Builder() .name("isNotNull") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IS_TRUE = new BuiltInFunctionDefinition.Builder() .name("isTrue") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IS_FALSE = new BuiltInFunctionDefinition.Builder() .name("isFalse") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IS_NOT_TRUE = new BuiltInFunctionDefinition.Builder() .name("isNotTrue") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition IS_NOT_FALSE = new BuiltInFunctionDefinition.Builder() .name("isNotFalse") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition BETWEEN = new BuiltInFunctionDefinition.Builder() .name("between") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition NOT_BETWEEN = new BuiltInFunctionDefinition.Builder() .name("notBetween") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // aggregate functions @@ -138,61 +157,73 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("avg") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition COUNT = new BuiltInFunctionDefinition.Builder() .name("count") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition MAX = new BuiltInFunctionDefinition.Builder() .name("max") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition MIN = new BuiltInFunctionDefinition.Builder() .name("min") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SUM = new BuiltInFunctionDefinition.Builder() .name("sum") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SUM0 = new BuiltInFunctionDefinition.Builder() .name("sum0") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition STDDEV_POP = new BuiltInFunctionDefinition.Builder() .name("stddevPop") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition STDDEV_SAMP = new BuiltInFunctionDefinition.Builder() .name("stddevSamp") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition VAR_POP = new BuiltInFunctionDefinition.Builder() .name("varPop") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition VAR_SAMP = new BuiltInFunctionDefinition.Builder() .name("varSamp") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition COLLECT = new BuiltInFunctionDefinition.Builder() .name("collect") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition DISTINCT = new BuiltInFunctionDefinition.Builder() .name("distinct") .kind(AGGREGATE) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // string functions @@ -200,116 +231,139 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("charLength") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition INIT_CAP = new BuiltInFunctionDefinition.Builder() .name("initCap") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LIKE = new BuiltInFunctionDefinition.Builder() .name("like") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LOWER = new BuiltInFunctionDefinition.Builder() .name("lowerCase") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SIMILAR = new BuiltInFunctionDefinition.Builder() .name("similar") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SUBSTRING = new BuiltInFunctionDefinition.Builder() .name("substring") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition REPLACE = new BuiltInFunctionDefinition.Builder() .name("replace") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TRIM = new BuiltInFunctionDefinition.Builder() .name("trim") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition UPPER = new BuiltInFunctionDefinition.Builder() .name("upperCase") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition POSITION = new BuiltInFunctionDefinition.Builder() .name("position") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition OVERLAY = new BuiltInFunctionDefinition.Builder() .name("overlay") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CONCAT = new BuiltInFunctionDefinition.Builder() .name("concat") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CONCAT_WS = new BuiltInFunctionDefinition.Builder() .name("concat_ws") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LPAD = new BuiltInFunctionDefinition.Builder() .name("lpad") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition RPAD = new BuiltInFunctionDefinition.Builder() .name("rpad") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition REGEXP_EXTRACT = new BuiltInFunctionDefinition.Builder() .name("regexpExtract") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition FROM_BASE64 = new BuiltInFunctionDefinition.Builder() .name("fromBase64") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TO_BASE64 = new BuiltInFunctionDefinition.Builder() .name("toBase64") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition UUID = new BuiltInFunctionDefinition.Builder() .name("uuid") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LTRIM = new BuiltInFunctionDefinition.Builder() .name("ltrim") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition RTRIM = new BuiltInFunctionDefinition.Builder() .name("rtrim") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition REPEAT = new BuiltInFunctionDefinition.Builder() .name("repeat") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition REGEXP_REPLACE = new BuiltInFunctionDefinition.Builder() .name("regexpReplace") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // math functions @@ -317,191 +371,229 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("plus") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition MINUS = new BuiltInFunctionDefinition.Builder() .name("minus") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition DIVIDE = new BuiltInFunctionDefinition.Builder() .name("divide") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TIMES = new BuiltInFunctionDefinition.Builder() .name("times") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ABS = new BuiltInFunctionDefinition.Builder() .name("abs") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CEIL = new BuiltInFunctionDefinition.Builder() .name("ceil") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition EXP = new BuiltInFunctionDefinition.Builder() .name("exp") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition FLOOR = new BuiltInFunctionDefinition.Builder() .name("floor") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LOG10 = new BuiltInFunctionDefinition.Builder() .name("log10") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LOG2 = new BuiltInFunctionDefinition.Builder() .name("log2") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LN = new BuiltInFunctionDefinition.Builder() .name("ln") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LOG = new BuiltInFunctionDefinition.Builder() .name("log") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition POWER = new BuiltInFunctionDefinition.Builder() .name("power") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition MOD = new BuiltInFunctionDefinition.Builder() .name("mod") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SQRT = new BuiltInFunctionDefinition.Builder() .name("sqrt") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition MINUS_PREFIX = new BuiltInFunctionDefinition.Builder() .name("minusPrefix") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SIN = new BuiltInFunctionDefinition.Builder() .name("sin") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition COS = new BuiltInFunctionDefinition.Builder() .name("cos") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SINH = new BuiltInFunctionDefinition.Builder() .name("sinh") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TAN = new BuiltInFunctionDefinition.Builder() .name("tan") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TANH = new BuiltInFunctionDefinition.Builder() .name("tanh") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition COT = new BuiltInFunctionDefinition.Builder() .name("cot") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ASIN = new BuiltInFunctionDefinition.Builder() .name("asin") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ACOS = new BuiltInFunctionDefinition.Builder() .name("acos") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ATAN = new BuiltInFunctionDefinition.Builder() .name("atan") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ATAN2 = new BuiltInFunctionDefinition.Builder() .name("atan2") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition COSH = new BuiltInFunctionDefinition.Builder() .name("cosh") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition DEGREES = new BuiltInFunctionDefinition.Builder() .name("degrees") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition RADIANS = new BuiltInFunctionDefinition.Builder() .name("radians") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SIGN = new BuiltInFunctionDefinition.Builder() .name("sign") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ROUND = new BuiltInFunctionDefinition.Builder() .name("round") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition PI = new BuiltInFunctionDefinition.Builder() .name("pi") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition E = new BuiltInFunctionDefinition.Builder() .name("e") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition RAND = new BuiltInFunctionDefinition.Builder() .name("rand") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition RAND_INTEGER = new BuiltInFunctionDefinition.Builder() .name("randInteger") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition BIN = new BuiltInFunctionDefinition.Builder() .name("bin") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition HEX = new BuiltInFunctionDefinition.Builder() .name("hex") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TRUNCATE = new BuiltInFunctionDefinition.Builder() .name("truncate") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // time functions @@ -509,51 +601,61 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("extract") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CURRENT_DATE = new BuiltInFunctionDefinition.Builder() .name("currentDate") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CURRENT_TIME = new BuiltInFunctionDefinition.Builder() .name("currentTime") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CURRENT_TIMESTAMP = new BuiltInFunctionDefinition.Builder() .name("currentTimestamp") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LOCAL_TIME = new BuiltInFunctionDefinition.Builder() .name("localTime") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition LOCAL_TIMESTAMP = new BuiltInFunctionDefinition.Builder() .name("localTimestamp") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TEMPORAL_OVERLAPS = new BuiltInFunctionDefinition.Builder() .name("temporalOverlaps") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition DATE_TIME_PLUS = new BuiltInFunctionDefinition.Builder() .name("dateTimePlus") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition DATE_FORMAT = new BuiltInFunctionDefinition.Builder() .name("dateFormat") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition TIMESTAMP_DIFF = new BuiltInFunctionDefinition.Builder() .name("timestampDiff") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // collection @@ -561,31 +663,37 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("at") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CARDINALITY = new BuiltInFunctionDefinition.Builder() .name("cardinality") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ARRAY = new BuiltInFunctionDefinition.Builder() .name("array") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ARRAY_ELEMENT = new BuiltInFunctionDefinition.Builder() .name("element") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition MAP = new BuiltInFunctionDefinition.Builder() .name("map") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ROW = new BuiltInFunctionDefinition.Builder() .name("row") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // composite @@ -593,11 +701,13 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("flatten") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition GET = new BuiltInFunctionDefinition.Builder() .name("get") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // window properties @@ -605,11 +715,13 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("start") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition WINDOW_END = new BuiltInFunctionDefinition.Builder() .name("end") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // ordering @@ -617,11 +729,13 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("asc") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ORDER_DESC = new BuiltInFunctionDefinition.Builder() .name("desc") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // crypto hash @@ -629,36 +743,43 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("md5") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SHA1 = new BuiltInFunctionDefinition.Builder() .name("sha1") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SHA224 = new BuiltInFunctionDefinition.Builder() .name("sha224") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SHA256 = new BuiltInFunctionDefinition.Builder() .name("sha256") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SHA384 = new BuiltInFunctionDefinition.Builder() .name("sha384") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SHA512 = new BuiltInFunctionDefinition.Builder() .name("sha512") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition SHA2 = new BuiltInFunctionDefinition.Builder() .name("sha2") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // time attributes @@ -666,11 +787,13 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("proctime") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition ROWTIME = new BuiltInFunctionDefinition.Builder() .name("rowtime") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // over window @@ -678,26 +801,31 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("over") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition UNBOUNDED_RANGE = new BuiltInFunctionDefinition.Builder() .name("unboundedRange") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition UNBOUNDED_ROW = new BuiltInFunctionDefinition.Builder() .name("unboundedRow") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CURRENT_RANGE = new BuiltInFunctionDefinition.Builder() .name("currentRange") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CURRENT_ROW = new BuiltInFunctionDefinition.Builder() .name("currentRow") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // columns @@ -705,11 +833,13 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("withColumns") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition WITHOUT_COLUMNS = new BuiltInFunctionDefinition.Builder() .name("withoutColumns") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); // etc @@ -717,31 +847,37 @@ public final class BuiltInFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("in") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition CAST = new BuiltInFunctionDefinition.Builder() .name("cast") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition REINTERPRET_CAST = new BuiltInFunctionDefinition.Builder() .name("reinterpretCast") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition AS = new BuiltInFunctionDefinition.Builder() .name("as") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition STREAM_RECORD_TIMESTAMP = new BuiltInFunctionDefinition.Builder() .name("streamRecordTimestamp") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final BuiltInFunctionDefinition RANGE_TO = new BuiltInFunctionDefinition.Builder() .name("rangeTo") .kind(OTHER) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); public static final Set<FunctionDefinition> WINDOW_PROPERTIES = new HashSet<>(Arrays.asList( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java index 6d87625..210d004 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java @@ -19,12 +19,17 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.types.DataType; import javax.annotation.Nullable; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; /** * Utility for performing type inference. @@ -33,13 +38,32 @@ import java.util.Optional; public final class TypeInferenceUtil { public static Result runTypeInference(TypeInference typeInference, CallContext callContext) { - throw new UnsupportedOperationException(); + try { + return runTypeInferenceInternal(typeInference, callContext); + } catch (ValidationException e) { + throw new ValidationException( + String.format( + "Invalid call to function '%s'. Given arguments: %s", + callContext.getName(), + callContext.getArgumentDataTypes().stream() + .map(DataType::toString) + .collect(Collectors.joining(", "))), + e); + } catch (Throwable t) { + throw new TableException( + String.format( + "Unexpected error in type inference logic of function '%s'. This is a bug.", + callContext.getName()), + t); + } } /** * The result of a type inference run. It contains information about how arguments need to be - * modified in order to comply with the function's signature. This includes casts that need to be - * inserted, reordering of arguments, or insertion of default values. + * adapted in order to comply with the function's signature. + * + * <p>This includes casts that need to be inserted, reordering of arguments (*), or insertion of default + * values (*) where (*) is future work. */ public static final class Result { @@ -71,6 +95,209 @@ public final class TypeInferenceUtil { } } + // -------------------------------------------------------------------------------------------- + + private static Result runTypeInferenceInternal(TypeInference typeInference, CallContext callContext) { + final List<DataType> argumentTypes = callContext.getArgumentDataTypes(); + + try { + validateArgumentCount( + typeInference.getInputTypeValidator().getArgumentCount(), + argumentTypes.size()); + } catch (ValidationException e) { + throw getInvalidInputException(typeInference.getInputTypeValidator(), callContext); + } + + final List<DataType> expectedTypes = typeInference.getArgumentTypes() + .orElse(argumentTypes); + + final AdaptedCallContext adaptedCallContext = adaptArguments( + callContext, + expectedTypes); + + try { + validateInputTypes( + typeInference.getInputTypeValidator(), + adaptedCallContext); + } catch (ValidationException e) { + throw getInvalidInputException(typeInference.getInputTypeValidator(), adaptedCallContext); + } + + return inferTypes( + adaptedCallContext, + typeInference.getAccumulatorTypeStrategy().orElse(null), + typeInference.getOutputTypeStrategy()); + } + + private static ValidationException getInvalidInputException( + InputTypeValidator validator, + CallContext callContext) { + return new ValidationException( + String.format( + "Invalid input arguments. Expected signatures are:\n%s", + String.join( + "\n", + validator.getExpectedSignatures( + callContext.getName(), + callContext.getFunctionDefinition())))); + } + + private static void validateArgumentCount(ArgumentCount argumentCount, int actualCount) { + argumentCount.getMinCount().ifPresent((min) -> { + if (actualCount < min) { + throw new ValidationException( + String.format( + "Invalid number of arguments. At least %d arguments expected but %d passed.", + min, + actualCount)); + } + }); + + argumentCount.getMaxCount().ifPresent((max) -> { + if (actualCount > max) { + throw new ValidationException( + String.format( + "Invalid number of arguments. At most %d arguments expected but %d passed.", + max, + actualCount)); + } + }); + + if (argumentCount.isValidCount(actualCount)) { + throw new ValidationException( + String.format( + "Invalid number of arguments. %d arguments passed.", + actualCount)); + } + } + + private static void validateInputTypes(InputTypeValidator inputTypeValidator, CallContext callContext) { + if (!inputTypeValidator.validate(callContext, true)) { + throw new ValidationException("Invalid input arguments."); + } + } + + /** + * Adapts the call's argument if necessary. + * + * <p>This includes casts that need to be inserted, reordering of arguments (*), or insertion of default + * values (*) where (*) is future work. + */ + private static AdaptedCallContext adaptArguments( + CallContext callContext, + List<DataType> expectedTypes) { + + final List<DataType> actualTypes = callContext.getArgumentDataTypes(); + for (int pos = 0; pos < actualTypes.size(); pos++) { + final DataType expectedType = expectedTypes.get(pos); + final DataType actualType = actualTypes.get(pos); + + if (!actualType.equals(expectedType) && !canCast(actualType, expectedType)) { + throw new ValidationException( + String.format( + "Invalid argument type at position %d. Data type %s expected but %s passed.", + pos, + expectedType, + actualType)); + } + } + + return new AdaptedCallContext(callContext, expectedTypes); + } + + private static boolean canCast(DataType sourceDataType, DataType targetDataType) { + return false; // TODO unsupported for now + } + + private static Result inferTypes( + AdaptedCallContext adaptedCallContext, + @Nullable TypeStrategy accumulatorTypeStrategy, + TypeStrategy outputTypeStrategy) { + + // infer output type first for better error message + // (logically an accumulator type should be inferred first) + final Optional<DataType> potentialOutputType = outputTypeStrategy.inferType(adaptedCallContext); + if (!potentialOutputType.isPresent()) { + throw new ValidationException("Could not infer an output type for the given arguments."); + } + final DataType outputType = potentialOutputType.get(); + + if (adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.TABLE_AGGREGATE || + adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.AGGREGATE) { + // an accumulator might be an internal feature of the planner, therefore it is not + // mandatory here; we assume the output type to be the accumulator type in this case + if (accumulatorTypeStrategy == null) { + return new Result(adaptedCallContext.expectedArguments, outputType, outputType); + } + final Optional<DataType> potentialAccumulatorType = accumulatorTypeStrategy.inferType(adaptedCallContext); + if (!potentialAccumulatorType.isPresent()) { + throw new ValidationException("Could not infer an accumulator type for the given arguments."); + } + return new Result(adaptedCallContext.expectedArguments, potentialAccumulatorType.get(), outputType); + + } else { + return new Result(adaptedCallContext.expectedArguments, null, outputType); + } + } + + /** + * Helper context that deals with adapted arguments. + * + * <p>For example, if an argument needs to be casted to a target type, an expression that was a + * literal before is not a literal anymore in this call context. + */ + private static class AdaptedCallContext implements CallContext { + + private final CallContext originalContext; + private final List<DataType> expectedArguments; + + public AdaptedCallContext(CallContext originalContext, List<DataType> castedArguments) { + this.originalContext = originalContext; + this.expectedArguments = castedArguments; + } + + @Override + public List<DataType> getArgumentDataTypes() { + return expectedArguments; + } + + @Override + public FunctionDefinition getFunctionDefinition() { + return originalContext.getFunctionDefinition(); + } + + @Override + public boolean isArgumentLiteral(int pos) { + if (isCasted(pos)) { + return false; + } + return originalContext.isArgumentLiteral(pos); + } + + @Override + public boolean isArgumentNull(int pos) { + // null remains null regardless of casting + return originalContext.isArgumentNull(pos); + } + + @Override + public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) { + if (isCasted(pos)) { + return Optional.empty(); + } + return originalContext.getArgumentValue(pos, clazz); + } + + @Override + public String getName() { + return originalContext.getName(); + } + + private boolean isCasted(int pos) { + return !originalContext.getArgumentDataTypes().get(pos).equals(expectedArguments.get(pos)); + } + } + private TypeInferenceUtil() { // no instantiation } diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/InternalFunctionDefinitions.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/InternalFunctionDefinitions.java index ef2ee16..c9d89e2 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/InternalFunctionDefinitions.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/InternalFunctionDefinitions.java @@ -18,6 +18,8 @@ package org.apache.flink.table.functions; +import org.apache.flink.table.types.inference.TypeStrategies; + import static org.apache.flink.table.functions.FunctionKind.SCALAR; /** @@ -29,6 +31,7 @@ public class InternalFunctionDefinitions { new BuiltInFunctionDefinition.Builder() .name("throwException") .kind(SCALAR) + .outputTypeStrategy(TypeStrategies.MISSING) .build(); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java index 81a45ef..d308891 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java @@ -22,15 +22,25 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.table.api.TableException; import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.catalog.ObjectIdentifier; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.InputTypeSpec; import org.apache.flink.table.expressions.PlannerExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.UnresolvedCallExpression; +import org.apache.flink.table.expressions.ValueLiteralExpression; +import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.inference.TypeInferenceUtil; +import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.table.typeutils.TypeCoercion; import org.apache.flink.table.validate.ValidationFailure; import org.apache.flink.table.validate.ValidationResult; +import org.apache.flink.util.Preconditions; import java.util.List; import java.util.Optional; @@ -75,6 +85,71 @@ final class ResolveCallByArgumentsRule implements ResolverRule { }) .collect(Collectors.toList()); + if (unresolvedCall.getFunctionDefinition() instanceof BuiltInFunctionDefinition) { + final BuiltInFunctionDefinition definition = + (BuiltInFunctionDefinition) unresolvedCall.getFunctionDefinition(); + if (definition.getTypeInference().getOutputTypeStrategy() != TypeStrategies.MISSING) { + return runTypeInference( + unresolvedCall, + definition.getTypeInference(), + resolvedArgs); + } + } + return runLegacyTypeInference(unresolvedCall, resolvedArgs); + } + + private ResolvedExpression runTypeInference( + UnresolvedCallExpression unresolvedCall, + TypeInference inference, + List<ResolvedExpression> resolvedArgs) { + + final String name = unresolvedCall.getObjectIdentifier() + .map(ObjectIdentifier::toString) + .orElseGet(() -> unresolvedCall.getFunctionDefinition().toString()); + + final TypeInferenceUtil.Result inferenceResult = TypeInferenceUtil.runTypeInference( + inference, + new TableApiCallContext(name, unresolvedCall.getFunctionDefinition(), resolvedArgs)); + + final List<ResolvedExpression> adaptedArguments = adaptArguments(inferenceResult, resolvedArgs); + + return unresolvedCall.resolve(adaptedArguments, inferenceResult.getOutputDataType()); + } + + /** + * Adapts the arguments according to the properties of the {@link TypeInferenceUtil.Result}. + */ + private List<ResolvedExpression> adaptArguments( + TypeInferenceUtil.Result inferenceResult, + List<ResolvedExpression> resolvedArgs) { + + return IntStream.range(0, resolvedArgs.size()) + .mapToObj(pos -> { + final ResolvedExpression argument = resolvedArgs.get(pos); + final DataType argumentType = argument.getOutputDataType(); + final DataType expectedType = inferenceResult.getExpectedArgumentTypes().get(pos); + if (!argumentType.equals(expectedType)) { + return resolutionContext + .postResolutionFactory() + .cast(argument, expectedType); + } + return argument; + }) + .collect(Collectors.toList()); + } + + @Override + protected Expression defaultMethod(Expression expression) { + return expression; + } + + // ---------------------------------------------------------------------------------------- + // legacy code + // ---------------------------------------------------------------------------------------- + + private ResolvedExpression runLegacyTypeInference( + UnresolvedCallExpression unresolvedCall, + List<ResolvedExpression> resolvedArgs) { final PlannerExpression plannerCall = resolutionContext.bridge(unresolvedCall); if (plannerCall instanceof InputTypeSpec) { @@ -160,10 +235,72 @@ final class ResolveCallByArgumentsRule implements ResolverRule { expectedType)); } } + } + + // -------------------------------------------------------------------------------------------- + + private class TableApiCallContext implements CallContext { + + private final String name; + + private final FunctionDefinition definition; + + private final List<ResolvedExpression> resolvedArgs; + + public TableApiCallContext( + String name, + FunctionDefinition definition, + List<ResolvedExpression> resolvedArgs) { + this.name = name; + this.definition = definition; + this.resolvedArgs = resolvedArgs; + } @Override - protected Expression defaultMethod(Expression expression) { - return expression; + public List<DataType> getArgumentDataTypes() { + return resolvedArgs.stream() + .map(ResolvedExpression::getOutputDataType) + .collect(Collectors.toList()); + } + + @Override + public FunctionDefinition getFunctionDefinition() { + return definition; + } + + @Override + public boolean isArgumentLiteral(int pos) { + return getArgument(pos) instanceof ValueLiteralExpression; + } + + @Override + public boolean isArgumentNull(int pos) { + Preconditions.checkArgument(isArgumentLiteral(pos), "Argument at position %s is not a literal.", pos); + final ValueLiteralExpression literal = (ValueLiteralExpression) getArgument(pos); + return literal.isNull(); + } + + @Override + public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) { + Preconditions.checkArgument(isArgumentLiteral(pos), "Argument at position %s is not a literal.", pos); + final ValueLiteralExpression literal = (ValueLiteralExpression) getArgument(pos); + return literal.getValueAs(clazz); + } + + @Override + public String getName() { + return name; + } + + private ResolvedExpression getArgument(int pos) { + if (pos >= resolvedArgs.size()) { + throw new IndexOutOfBoundsException( + String.format( + "Not enough arguments to access literal at position %d for function '%s'.", + pos, + name)); + } + return resolvedArgs.get(pos); } } }