This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 5bae5b5251c5247cbb194d525ba41cd8aeedb443 Author: Ingo Bürk <ingo.bu...@tngtech.com> AuthorDate: Thu Jun 10 14:47:03 2021 +0200 [hotfix][table-common] Move specific type strategies out of TypeStrategies --- .../functions/BuiltInFunctionDefinitions.java | 49 +-- .../table/types/inference/TypeStrategies.java | 356 -------------------- .../inference/strategies/ArrayTypeStrategy.java | 46 +++ .../strategies/CurrentWatermarkTypeStrategy.java | 53 +++ .../strategies/DecimalDivideTypeStrategy.java | 73 ++++ .../strategies/DecimalModTypeStrategy.java | 76 +++++ .../strategies/DecimalPlusTypeStrategy.java | 72 ++++ .../strategies/DecimalScale0TypeStrategy.java | 62 ++++ .../strategies/DecimalTimesTypeStrategy.java | 73 ++++ .../inference/strategies/GetTypeStrategy.java | 64 ++++ .../inference/strategies/IfNullTypeStrategy.java | 44 +++ .../inference/strategies/MapTypeStrategy.java | 47 +++ .../inference/strategies/RoundTypeStrategy.java | 79 +++++ .../inference/strategies/RowTypeStrategy.java | 45 +++ .../strategies/SourceWatermarkTypeStrategy.java | 45 +++ .../strategies/SpecificTypeStrategies.java | 79 +++++ .../types/inference/strategies/StrategyUtils.java | 12 + .../strategies/StringConcatTypeStrategy.java | 74 ++++ .../types/inference/MappingTypeStrategiesTest.java | 75 +++++ .../table/types/inference/TypeStrategiesTest.java | 371 ++------------------- .../types/inference/TypeStrategiesTestBase.java | 165 +++++++++ .../strategies/ArrayTypeStrategyTest.java | 39 +++ .../CurrentWatermarkTypeStrategyTest.java | 69 ++++ .../strategies/DecimalTypeStrategyTest.java | 50 +++ .../inference/strategies/GetTypeStrategyTest.java | 137 ++++++++ .../inference/strategies/MapTypeStrategyTest.java | 41 +++ .../inference/strategies/RowTypeStrategyTest.java | 43 +++ .../strategies/StringConcatTypeStrategyTest.java | 39 +++ 28 files changed, 1650 insertions(+), 728 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index d4c76c6..ccafce4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -25,6 +25,7 @@ import org.apache.flink.table.types.inference.ArgumentTypeStrategy; import org.apache.flink.table.types.inference.ConstantArgumentCount; import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.TypeStrategies; +import org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.StructuredType.StructuredComparision; @@ -59,12 +60,6 @@ import static org.apache.flink.table.types.inference.InputTypeStrategies.sequenc import static org.apache.flink.table.types.inference.InputTypeStrategies.varyingSequence; import static org.apache.flink.table.types.inference.InputTypeStrategies.wildcardWithCount; import static org.apache.flink.table.types.inference.TypeStrategies.COMMON; -import static org.apache.flink.table.types.inference.TypeStrategies.DECIMAL_DIVIDE; -import static org.apache.flink.table.types.inference.TypeStrategies.DECIMAL_MOD; -import static org.apache.flink.table.types.inference.TypeStrategies.DECIMAL_PLUS; -import static org.apache.flink.table.types.inference.TypeStrategies.DECIMAL_SCALE0; -import static org.apache.flink.table.types.inference.TypeStrategies.DECIMAL_TIMES; -import static org.apache.flink.table.types.inference.TypeStrategies.STRING_CONCAT; import static org.apache.flink.table.types.inference.TypeStrategies.argument; import static org.apache.flink.table.types.inference.TypeStrategies.explicit; import static org.apache.flink.table.types.inference.TypeStrategies.first; @@ -108,7 +103,7 @@ public final class BuiltInFunctionDefinitions { new ArgumentTypeStrategy[] { COMMON_ARG_NULLABLE, COMMON_ARG_NULLABLE })) - .outputTypeStrategy(TypeStrategies.IF_NULL) + .outputTypeStrategy(SpecificTypeStrategies.IF_NULL) .runtimeClass("org.apache.flink.table.runtime.functions.scalar.IfNullFunction") .build(); @@ -117,7 +112,7 @@ public final class BuiltInFunctionDefinitions { .name("SOURCE_WATERMARK") .kind(SCALAR) .inputTypeStrategy(NO_ARGS) - .outputTypeStrategy(TypeStrategies.SOURCE_WATERMARK) + .outputTypeStrategy(SpecificTypeStrategies.SOURCE_WATERMARK) .runtimeClass( "org.apache.flink.table.runtime.functions.scalar.SourceWatermarkFunction") .build(); @@ -540,7 +535,7 @@ public final class BuiltInFunctionDefinitions { logical(LogicalTypeFamily.CHARACTER_STRING), logical(LogicalTypeRoot.INTEGER), logical(LogicalTypeRoot.INTEGER)))) - .outputTypeStrategy(nullable(STRING_CONCAT)) + .outputTypeStrategy(nullable(SpecificTypeStrategies.STRING_CONCAT)) .build(); public static final BuiltInFunctionDefinition CONCAT = @@ -555,7 +550,7 @@ public final class BuiltInFunctionDefinitions { varyingSequence( logical(LogicalTypeFamily.BINARY_STRING), logical(LogicalTypeFamily.BINARY_STRING)))) - .outputTypeStrategy(nullable(STRING_CONCAT)) + .outputTypeStrategy(nullable(SpecificTypeStrategies.STRING_CONCAT)) .build(); public static final BuiltInFunctionDefinition CONCAT_WS = @@ -716,7 +711,11 @@ public final class BuiltInFunctionDefinitions { logical(LogicalTypeFamily.CHARACTER_STRING), logical(LogicalTypeFamily.PREDEFINED)))) .outputTypeStrategy( - nullable(first(DECIMAL_PLUS, COMMON, explicit(DataTypes.STRING())))) + nullable( + first( + SpecificTypeStrategies.DECIMAL_PLUS, + COMMON, + explicit(DataTypes.STRING())))) .build(); /** Combines numeric subtraction and "datetime - interval" arithmetic. */ @@ -738,7 +737,8 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeFamily.DATETIME), logical(LogicalTypeFamily.INTERVAL)))) - .outputTypeStrategy(nullable(first(DECIMAL_PLUS, COMMON))) + .outputTypeStrategy( + nullable(first(SpecificTypeStrategies.DECIMAL_PLUS, COMMON))) .build(); public static final BuiltInFunctionDefinition DIVIDE = @@ -756,7 +756,7 @@ public final class BuiltInFunctionDefinitions { .outputTypeStrategy( nullable( first( - DECIMAL_DIVIDE, + SpecificTypeStrategies.DECIMAL_DIVIDE, matchFamily(0, LogicalTypeFamily.INTERVAL), COMMON))) .build(); @@ -779,7 +779,7 @@ public final class BuiltInFunctionDefinitions { .outputTypeStrategy( nullable( first( - DECIMAL_TIMES, + SpecificTypeStrategies.DECIMAL_TIMES, matchFamily(0, LogicalTypeFamily.INTERVAL), COMMON))) .build(); @@ -814,7 +814,8 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeFamily.DATETIME), logical(LogicalTypeRoot.SYMBOL)))) - .outputTypeStrategy(nullable(first(DECIMAL_SCALE0, argument(0)))) + .outputTypeStrategy( + nullable(first(SpecificTypeStrategies.DECIMAL_SCALE_0, argument(0)))) .build(); public static final BuiltInFunctionDefinition CEIL = @@ -828,7 +829,8 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeFamily.DATETIME), logical(LogicalTypeRoot.SYMBOL)))) - .outputTypeStrategy(nullable(first(DECIMAL_SCALE0, argument(0)))) + .outputTypeStrategy( + nullable(first(SpecificTypeStrategies.DECIMAL_SCALE_0, argument(0)))) .build(); public static final BuiltInFunctionDefinition LOG10 = @@ -887,7 +889,8 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeFamily.EXACT_NUMERIC), logical(LogicalTypeFamily.EXACT_NUMERIC))) - .outputTypeStrategy(nullable(first(DECIMAL_MOD, argument(1)))) + .outputTypeStrategy( + nullable(first(SpecificTypeStrategies.DECIMAL_MOD, argument(1)))) .build(); public static final BuiltInFunctionDefinition SQRT = @@ -1035,7 +1038,7 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeFamily.NUMERIC), logical(LogicalTypeRoot.INTEGER)))) - .outputTypeStrategy(nullable(TypeStrategies.ROUND)) + .outputTypeStrategy(nullable(SpecificTypeStrategies.ROUND)) .build(); public static final BuiltInFunctionDefinition PI = @@ -1212,7 +1215,7 @@ public final class BuiltInFunctionDefinitions { .name("array") .kind(SCALAR) .inputTypeStrategy(InputTypeStrategies.SPECIFIC_FOR_ARRAY) - .outputTypeStrategy(TypeStrategies.ARRAY) + .outputTypeStrategy(SpecificTypeStrategies.ARRAY) .build(); public static final BuiltInFunctionDefinition ARRAY_ELEMENT = @@ -1227,7 +1230,7 @@ public final class BuiltInFunctionDefinitions { .name("map") .kind(SCALAR) .inputTypeStrategy(InputTypeStrategies.SPECIFIC_FOR_MAP) - .outputTypeStrategy(TypeStrategies.MAP) + .outputTypeStrategy(SpecificTypeStrategies.MAP) .build(); public static final BuiltInFunctionDefinition ROW = @@ -1236,7 +1239,7 @@ public final class BuiltInFunctionDefinitions { .kind(SCALAR) .inputTypeStrategy( InputTypeStrategies.wildcardWithCount(ConstantArgumentCount.from(1))) - .outputTypeStrategy(TypeStrategies.ROW) + .outputTypeStrategy(SpecificTypeStrategies.ROW) .build(); // -------------------------------------------------------------------------------------------- @@ -1267,7 +1270,7 @@ public final class BuiltInFunctionDefinitions { or( logical(LogicalTypeRoot.INTEGER), logical(LogicalTypeFamily.CHARACTER_STRING))))) - .outputTypeStrategy(TypeStrategies.GET) + .outputTypeStrategy(SpecificTypeStrategies.GET) .build(); // -------------------------------------------------------------------------------------------- @@ -1392,7 +1395,7 @@ public final class BuiltInFunctionDefinitions { .name("CURRENT_WATERMARK") .kind(SCALAR) .inputTypeStrategy(InputTypeStrategies.SPECIFIC_FOR_CURRENT_WATERMARK) - .outputTypeStrategy(TypeStrategies.CURRENT_WATERMARK) + .outputTypeStrategy(SpecificTypeStrategies.CURRENT_WATERMARK) .notDeterministic() .runtimeProvided() .build(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeStrategies.java index 34ffc68..8d007ca 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeStrategies.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeStrategies.java @@ -19,8 +19,6 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.Internal; -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.strategies.CommonTypeStrategy; import org.apache.flink.table.types.inference.strategies.ExplicitTypeStrategy; @@ -31,32 +29,15 @@ import org.apache.flink.table.types.inference.strategies.MissingTypeStrategy; import org.apache.flink.table.types.inference.strategies.NullableTypeStrategy; import org.apache.flink.table.types.inference.strategies.UseArgumentTypeStrategy; import org.apache.flink.table.types.inference.strategies.VaryingStringTypeStrategy; -import org.apache.flink.table.types.logical.BinaryType; -import org.apache.flink.table.types.logical.CharType; -import org.apache.flink.table.types.logical.DecimalType; -import org.apache.flink.table.types.logical.LegacyTypeInformationType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; -import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; -import org.apache.flink.table.types.utils.DataTypeUtils; -import org.apache.flink.table.types.utils.TypeConversions; -import java.math.BigDecimal; import java.util.Arrays; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Function; -import java.util.stream.IntStream; -import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getLength; -import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; -import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale; -import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasFamily; -import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; -import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasScale; -import static org.apache.flink.table.types.logical.utils.LogicalTypeMerging.findCommonType; import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; /** @@ -126,333 +107,6 @@ public final class TypeStrategies { return new VaryingStringTypeStrategy(initialStrategy); } - // -------------------------------------------------------------------------------------------- - // Specific type strategies - // -------------------------------------------------------------------------------------------- - - /** - * Type strategy that returns a {@link DataTypes#ROW()} with fields types equal to input types. - */ - public static final TypeStrategy ROW = - callContext -> { - List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - DataTypes.Field[] fields = - IntStream.range(0, argumentDataTypes.size()) - .mapToObj( - idx -> - DataTypes.FIELD( - "f" + idx, argumentDataTypes.get(idx))) - .toArray(DataTypes.Field[]::new); - - return Optional.of(DataTypes.ROW(fields).notNull()); - }; - - /** - * Type strategy that returns a {@link DataTypes#MAP(DataType, DataType)} with a key type equal - * to type of the first argument and a value type equal to the type of second argument. - */ - public static final TypeStrategy MAP = - callContext -> { - List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - if (argumentDataTypes.size() < 2) { - return Optional.empty(); - } - return Optional.of( - DataTypes.MAP(argumentDataTypes.get(0), argumentDataTypes.get(1)) - .notNull()); - }; - - /** - * Type strategy that returns a {@link DataTypes#ARRAY(DataType)} with element type equal to the - * type of the first argument. - */ - public static final TypeStrategy ARRAY = - callContext -> { - List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - if (argumentDataTypes.size() < 1) { - return Optional.empty(); - } - return Optional.of(DataTypes.ARRAY(argumentDataTypes.get(0)).notNull()); - }; - - /** - * Type strategy that returns the sum of an exact numeric addition that includes at least one - * decimal. - */ - public static final TypeStrategy DECIMAL_PLUS = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final LogicalType addend1 = argumentDataTypes.get(0).getLogicalType(); - final LogicalType addend2 = argumentDataTypes.get(1).getLogicalType(); - // a hack to make legacy types possible until we drop them - if (addend1 instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(0)); - } - if (addend2 instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(1)); - } - if (!isDecimalComputation(addend1, addend2)) { - return Optional.empty(); - } - final DecimalType decimalType = - LogicalTypeMerging.findAdditionDecimalType( - getPrecision(addend1), - getScale(addend1), - getPrecision(addend2), - getScale(addend2)); - return Optional.of(fromLogicalToDataType(decimalType)); - }; - - /** - * Type strategy that returns the quotient of an exact numeric division that includes at least - * one decimal. - */ - public static final TypeStrategy DECIMAL_DIVIDE = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final LogicalType dividend = argumentDataTypes.get(0).getLogicalType(); - final LogicalType divisor = argumentDataTypes.get(1).getLogicalType(); - // a hack to make legacy types possible until we drop them - if (dividend instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(0)); - } - if (divisor instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(1)); - } - if (!isDecimalComputation(dividend, divisor)) { - return Optional.empty(); - } - final DecimalType decimalType = - LogicalTypeMerging.findDivisionDecimalType( - getPrecision(dividend), - getScale(dividend), - getPrecision(divisor), - getScale(divisor)); - return Optional.of(fromLogicalToDataType(decimalType)); - }; - - /** - * Type strategy that returns the product of an exact numeric multiplication that includes at - * least one decimal. - */ - public static final TypeStrategy DECIMAL_TIMES = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final LogicalType factor1 = argumentDataTypes.get(0).getLogicalType(); - final LogicalType factor2 = argumentDataTypes.get(1).getLogicalType(); - // a hack to make legacy types possible until we drop them - if (factor1 instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(0)); - } - if (factor2 instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(1)); - } - if (!isDecimalComputation(factor1, factor2)) { - return Optional.empty(); - } - final DecimalType decimalType = - LogicalTypeMerging.findMultiplicationDecimalType( - getPrecision(factor1), - getScale(factor1), - getPrecision(factor2), - getScale(factor2)); - return Optional.of(fromLogicalToDataType(decimalType)); - }; - - /** - * Type strategy that returns the modulo of an exact numeric division that includes at least one - * decimal. - */ - public static final TypeStrategy DECIMAL_MOD = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final LogicalType dividend = argumentDataTypes.get(0).getLogicalType(); - final LogicalType divisor = argumentDataTypes.get(1).getLogicalType(); - // a hack to make legacy types possible until we drop them - if (dividend instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(0)); - } - if (divisor instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataTypes.get(1)); - } - if (!isDecimalComputation(dividend, divisor)) { - return Optional.empty(); - } - final int dividendScale = getScale(dividend); - final int divisorScale = getScale(divisor); - if (dividendScale == 0 && divisorScale == 0) { - return Optional.of(argumentDataTypes.get(1)); - } - final DecimalType decimalType = - LogicalTypeMerging.findModuloDecimalType( - getPrecision(dividend), - dividendScale, - getPrecision(divisor), - divisorScale); - return Optional.of(fromLogicalToDataType(decimalType)); - }; - - /** Strategy that returns a decimal type but with a scale of 0. */ - public static final TypeStrategy DECIMAL_SCALE0 = - callContext -> { - final DataType argumentDataType = callContext.getArgumentDataTypes().get(0); - final LogicalType argumentType = argumentDataType.getLogicalType(); - // a hack to make legacy types possible until we drop them - if (argumentType instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataType); - } - if (hasRoot(argumentType, LogicalTypeRoot.DECIMAL)) { - if (hasScale(argumentType, 0)) { - return Optional.of(argumentDataType); - } - final LogicalType inferredType = - new DecimalType( - argumentType.isNullable(), getPrecision(argumentType), 0); - return Optional.of(fromLogicalToDataType(inferredType)); - } - return Optional.empty(); - }; - - /** Type strategy that returns the result of a rounding operation. */ - public static final TypeStrategy ROUND = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final DataType argumentDataType = callContext.getArgumentDataTypes().get(0); - final LogicalType argumentType = argumentDataType.getLogicalType(); - // a hack to make legacy types possible until we drop them - if (argumentType instanceof LegacyTypeInformationType) { - return Optional.of(argumentDataType); - } - if (!hasRoot(argumentType, LogicalTypeRoot.DECIMAL)) { - return Optional.of(argumentDataType); - } - final BigDecimal roundLength; - if (argumentDataTypes.size() == 2) { - if (!callContext.isArgumentLiteral(1) || callContext.isArgumentNull(1)) { - return Optional.of(argumentDataType); - } - roundLength = - callContext - .getArgumentValue(1, BigDecimal.class) - .orElseThrow(AssertionError::new); - } else { - roundLength = BigDecimal.ZERO; - } - final LogicalType inferredType = - LogicalTypeMerging.findRoundDecimalType( - getPrecision(argumentType), - getScale(argumentType), - roundLength.intValueExact()); - return Optional.of(fromLogicalToDataType(inferredType)); - }; - - /** - * Type strategy that returns the type of a string concatenation. It assumes that the first two - * arguments are of the same family of either {@link LogicalTypeFamily#BINARY_STRING} or {@link - * LogicalTypeFamily#CHARACTER_STRING}. - */ - public static final TypeStrategy STRING_CONCAT = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final LogicalType type1 = argumentDataTypes.get(0).getLogicalType(); - final LogicalType type2 = argumentDataTypes.get(1).getLogicalType(); - int length = getLength(type1) + getLength(type2); - // handle overflow - if (length < 0) { - length = CharType.MAX_LENGTH; - } - final LogicalType minimumType; - if (hasFamily(type1, LogicalTypeFamily.CHARACTER_STRING) - || hasFamily(type2, LogicalTypeFamily.CHARACTER_STRING)) { - minimumType = new CharType(false, length); - } else if (hasFamily(type1, LogicalTypeFamily.BINARY_STRING) - || hasFamily(type2, LogicalTypeFamily.BINARY_STRING)) { - minimumType = new BinaryType(false, length); - } else { - return Optional.empty(); - } - // deal with nullability handling and varying semantics - return findCommonType(Arrays.asList(type1, type2, minimumType)) - .map(TypeConversions::fromLogicalToDataType); - }; - - /** - * Type strategy that returns a type of a field nested inside a composite type that is described - * by the second argument. The second argument must be a literal that describes either the - * nested field name or index. - */ - public static final TypeStrategy GET = - callContext -> { - List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - DataType rowDataType = argumentDataTypes.get(0); - - Optional<DataType> result = Optional.empty(); - - Optional<String> fieldName = callContext.getArgumentValue(1, String.class); - if (fieldName.isPresent()) { - result = DataTypeUtils.getField(rowDataType, fieldName.get()); - } - - Optional<Integer> fieldIndex = callContext.getArgumentValue(1, Integer.class); - if (fieldIndex.isPresent()) { - result = DataTypeUtils.getField(rowDataType, fieldIndex.get()); - } - - return result.map( - type -> { - if (rowDataType.getLogicalType().isNullable()) { - return type.nullable(); - } else { - return type; - } - }); - }; - - /** Type strategy specific for avoiding nulls. */ - public static final TypeStrategy IF_NULL = - callContext -> { - final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); - final DataType inputDataType = argumentDataTypes.get(0); - final DataType nullReplacementDataType = argumentDataTypes.get(1); - if (!inputDataType.getLogicalType().isNullable()) { - return Optional.of(inputDataType); - } - return Optional.of(nullReplacementDataType); - }; - - /** Type strategy specific for source watermarks that depend on the output type. */ - public static final TypeStrategy SOURCE_WATERMARK = - callContext -> { - final DataType timestampDataType = - callContext - .getOutputDataType() - .filter( - dt -> - hasFamily( - dt.getLogicalType(), - LogicalTypeFamily.TIMESTAMP)) - .orElse(DataTypes.TIMESTAMP_LTZ(3)); - return Optional.of(timestampDataType); - }; - - /** - * Type strategy for {@link BuiltInFunctionDefinitions#CURRENT_WATERMARK} which mirrors the type - * of the passed rowtime column, but removes the rowtime kind and enforces the correct precision - * for watermarks. - */ - public static final TypeStrategy CURRENT_WATERMARK = - callContext -> { - final LogicalType inputType = - callContext.getArgumentDataTypes().get(0).getLogicalType(); - if (hasRoot(inputType, LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE)) { - return Optional.of(DataTypes.TIMESTAMP(3)); - } else if (hasRoot(inputType, LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) { - return Optional.of(DataTypes.TIMESTAMP_LTZ(3)); - } - - return Optional.empty(); - }; - /** * Type strategy specific for aggregations that partially produce different nullability * depending whether the result is grouped or not. @@ -477,16 +131,6 @@ public final class TypeStrategies { // -------------------------------------------------------------------------------------------- @SuppressWarnings("BooleanMethodIsAlwaysInverted") - private static boolean isDecimalComputation(LogicalType type1, LogicalType type2) { - // both must be exact numeric - if (!hasFamily(type1, LogicalTypeFamily.EXACT_NUMERIC) - || !hasFamily(type2, LogicalTypeFamily.EXACT_NUMERIC)) { - return false; - } - // one decimal must be present - return hasRoot(type1, LogicalTypeRoot.DECIMAL) || hasRoot(type2, LogicalTypeRoot.DECIMAL); - } - private TypeStrategies() { // no instantiation } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ArrayTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ArrayTypeStrategy.java new file mode 100644 index 0000000..95f9035 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ArrayTypeStrategy.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; + +import java.util.List; +import java.util.Optional; + +/** + * Type strategy that returns a {@link DataTypes#ARRAY(DataType)} with element type equal to the + * type of the first argument. + */ +@Internal +class ArrayTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + if (argumentDataTypes.size() < 1) { + return Optional.empty(); + } + + return Optional.of(DataTypes.ARRAY(argumentDataTypes.get(0)).notNull()); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CurrentWatermarkTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CurrentWatermarkTypeStrategy.java new file mode 100644 index 0000000..11ab435 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CurrentWatermarkTypeStrategy.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; + +import java.util.Optional; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; + +/** + * Type strategy for {@link BuiltInFunctionDefinitions#CURRENT_WATERMARK} which mirrors the type of + * the passed rowtime column, but removes the rowtime kind and enforces the correct precision for + * watermarks. + */ +@Internal +class CurrentWatermarkTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final LogicalType inputType = callContext.getArgumentDataTypes().get(0).getLogicalType(); + if (hasRoot(inputType, LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE)) { + return Optional.of(DataTypes.TIMESTAMP(3)); + } else if (hasRoot(inputType, LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) { + return Optional.of(DataTypes.TIMESTAMP_LTZ(3)); + } + + return Optional.empty(); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalDivideTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalDivideTypeStrategy.java new file mode 100644 index 0000000..cedd357 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalDivideTypeStrategy.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; + +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.types.inference.strategies.StrategyUtils.isDecimalComputation; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale; +import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; + +/** + * Type strategy that returns the quotient of an exact numeric division that includes at least one + * decimal. + */ +@Internal +class DecimalDivideTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final LogicalType dividend = argumentDataTypes.get(0).getLogicalType(); + final LogicalType divisor = argumentDataTypes.get(1).getLogicalType(); + + // a hack to make legacy types possible until we drop them + if (dividend instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(0)); + } + + if (divisor instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(1)); + } + + if (!isDecimalComputation(dividend, divisor)) { + return Optional.empty(); + } + + final DecimalType decimalType = + LogicalTypeMerging.findDivisionDecimalType( + getPrecision(dividend), + getScale(dividend), + getPrecision(divisor), + getScale(divisor)); + + return Optional.of(fromLogicalToDataType(decimalType)); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalModTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalModTypeStrategy.java new file mode 100644 index 0000000..7d1d096 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalModTypeStrategy.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; + +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.types.inference.strategies.StrategyUtils.isDecimalComputation; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale; +import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; + +/** + * Type strategy that returns the modulo of an exact numeric division that includes at least one + * decimal. + */ +@Internal +class DecimalModTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final LogicalType dividend = argumentDataTypes.get(0).getLogicalType(); + final LogicalType divisor = argumentDataTypes.get(1).getLogicalType(); + + // a hack to make legacy types possible until we drop them + if (dividend instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(0)); + } + + if (divisor instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(1)); + } + + if (!isDecimalComputation(dividend, divisor)) { + return Optional.empty(); + } + + final int dividendScale = getScale(dividend); + final int divisorScale = getScale(divisor); + if (dividendScale == 0 && divisorScale == 0) { + return Optional.of(argumentDataTypes.get(1)); + } + + final DecimalType decimalType = + LogicalTypeMerging.findModuloDecimalType( + getPrecision(dividend), dividendScale, getPrecision(divisor), divisorScale); + + return Optional.of(fromLogicalToDataType(decimalType)); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalPlusTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalPlusTypeStrategy.java new file mode 100644 index 0000000..bab6e27 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalPlusTypeStrategy.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; + +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.types.inference.strategies.StrategyUtils.isDecimalComputation; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale; +import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; + +/** + * Type strategy that returns the sum of an exact numeric addition that includes at least one + * decimal. + */ +@Internal +class DecimalPlusTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final LogicalType addend1 = argumentDataTypes.get(0).getLogicalType(); + final LogicalType addend2 = argumentDataTypes.get(1).getLogicalType(); + + // a hack to make legacy types possible until we drop them + if (addend1 instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(0)); + } + + if (addend2 instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(1)); + } + + if (!isDecimalComputation(addend1, addend2)) { + return Optional.empty(); + } + final DecimalType decimalType = + LogicalTypeMerging.findAdditionDecimalType( + getPrecision(addend1), + getScale(addend1), + getPrecision(addend2), + getScale(addend2)); + + return Optional.of(fromLogicalToDataType(decimalType)); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalScale0TypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalScale0TypeStrategy.java new file mode 100644 index 0000000..00d4281 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalScale0TypeStrategy.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; + +import java.util.Optional; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasScale; +import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; + +/** Strategy that returns a decimal type but with a scale of 0. */ +@Internal +class DecimalScale0TypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final DataType argumentDataType = callContext.getArgumentDataTypes().get(0); + final LogicalType argumentType = argumentDataType.getLogicalType(); + + // a hack to make legacy types possible until we drop them + if (argumentType instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataType); + } + + if (hasRoot(argumentType, LogicalTypeRoot.DECIMAL)) { + if (hasScale(argumentType, 0)) { + return Optional.of(argumentDataType); + } + final LogicalType inferredType = + new DecimalType(argumentType.isNullable(), getPrecision(argumentType), 0); + return Optional.of(fromLogicalToDataType(inferredType)); + } + + return Optional.empty(); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalTimesTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalTimesTypeStrategy.java new file mode 100644 index 0000000..d74413e --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/DecimalTimesTypeStrategy.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; + +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.types.inference.strategies.StrategyUtils.isDecimalComputation; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale; +import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; + +/** + * Type strategy that returns the product of an exact numeric multiplication that includes at least + * one decimal. + */ +@Internal +class DecimalTimesTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final LogicalType factor1 = argumentDataTypes.get(0).getLogicalType(); + final LogicalType factor2 = argumentDataTypes.get(1).getLogicalType(); + + // a hack to make legacy types possible until we drop them + if (factor1 instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(0)); + } + + if (factor2 instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataTypes.get(1)); + } + + if (!isDecimalComputation(factor1, factor2)) { + return Optional.empty(); + } + + final DecimalType decimalType = + LogicalTypeMerging.findMultiplicationDecimalType( + getPrecision(factor1), + getScale(factor1), + getPrecision(factor2), + getScale(factor2)); + + return Optional.of(fromLogicalToDataType(decimalType)); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/GetTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/GetTypeStrategy.java new file mode 100644 index 0000000..1b007f5 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/GetTypeStrategy.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.utils.DataTypeUtils; + +import java.util.List; +import java.util.Optional; + +/** + * Type strategy that returns a type of a field nested inside a composite type that is described by + * the second argument. The second argument must be a literal that describes either the nested field + * name or index. + */ +@Internal +class GetTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + DataType rowDataType = argumentDataTypes.get(0); + + Optional<DataType> result = Optional.empty(); + + Optional<String> fieldName = callContext.getArgumentValue(1, String.class); + if (fieldName.isPresent()) { + result = DataTypeUtils.getField(rowDataType, fieldName.get()); + } + + Optional<Integer> fieldIndex = callContext.getArgumentValue(1, Integer.class); + if (fieldIndex.isPresent()) { + result = DataTypeUtils.getField(rowDataType, fieldIndex.get()); + } + + return result.map( + type -> { + if (rowDataType.getLogicalType().isNullable()) { + return type.nullable(); + } else { + return type; + } + }); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/IfNullTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/IfNullTypeStrategy.java new file mode 100644 index 0000000..f3f0956 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/IfNullTypeStrategy.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; + +import java.util.List; +import java.util.Optional; + +/** Type strategy specific for avoiding nulls. */ +@Internal +class IfNullTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final DataType inputDataType = argumentDataTypes.get(0); + final DataType nullReplacementDataType = argumentDataTypes.get(1); + if (!inputDataType.getLogicalType().isNullable()) { + return Optional.of(inputDataType); + } + + return Optional.of(nullReplacementDataType); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MapTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MapTypeStrategy.java new file mode 100644 index 0000000..917ea1f --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MapTypeStrategy.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; + +import java.util.List; +import java.util.Optional; + +/** + * Type strategy that returns a {@link DataTypes#MAP(DataType, DataType)} with a key type equal to + * type of the first argument and a value type equal to the type of second argument. + */ +@Internal +class MapTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + if (argumentDataTypes.size() < 2) { + return Optional.empty(); + } + + return Optional.of( + DataTypes.MAP(argumentDataTypes.get(0), argumentDataTypes.get(1)).notNull()); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/RoundTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/RoundTypeStrategy.java new file mode 100644 index 0000000..aa2de1f --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/RoundTypeStrategy.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; +import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; + +/** Type strategy that returns the result of a rounding operation. */ +@Internal +class RoundTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final DataType argumentDataType = callContext.getArgumentDataTypes().get(0); + final LogicalType argumentType = argumentDataType.getLogicalType(); + + // a hack to make legacy types possible until we drop them + if (argumentType instanceof LegacyTypeInformationType) { + return Optional.of(argumentDataType); + } + + if (!hasRoot(argumentType, LogicalTypeRoot.DECIMAL)) { + return Optional.of(argumentDataType); + } + + final BigDecimal roundLength; + if (argumentDataTypes.size() == 2) { + if (!callContext.isArgumentLiteral(1) || callContext.isArgumentNull(1)) { + return Optional.of(argumentDataType); + } + roundLength = + callContext + .getArgumentValue(1, BigDecimal.class) + .orElseThrow(AssertionError::new); + } else { + roundLength = BigDecimal.ZERO; + } + + final LogicalType inferredType = + LogicalTypeMerging.findRoundDecimalType( + getPrecision(argumentType), + getScale(argumentType), + roundLength.intValueExact()); + + return Optional.of(fromLogicalToDataType(inferredType)); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/RowTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/RowTypeStrategy.java new file mode 100644 index 0000000..c9bedb8 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/RowTypeStrategy.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; + +import java.util.List; +import java.util.Optional; +import java.util.stream.IntStream; + +/** Type strategy that returns a {@link DataTypes#ROW()} with fields types equal to input types. */ +@Internal +class RowTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + DataTypes.Field[] fields = + IntStream.range(0, argumentDataTypes.size()) + .mapToObj(idx -> DataTypes.FIELD("f" + idx, argumentDataTypes.get(idx))) + .toArray(DataTypes.Field[]::new); + + return Optional.of(DataTypes.ROW(fields).notNull()); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SourceWatermarkTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SourceWatermarkTypeStrategy.java new file mode 100644 index 0000000..93bbf7c --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SourceWatermarkTypeStrategy.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.LogicalTypeFamily; + +import java.util.Optional; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasFamily; + +/** Type strategy specific for source watermarks that depend on the output type. */ +@Internal +class SourceWatermarkTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final DataType timestampDataType = + callContext + .getOutputDataType() + .filter(dt -> hasFamily(dt.getLogicalType(), LogicalTypeFamily.TIMESTAMP)) + .orElse(DataTypes.TIMESTAMP_LTZ(3)); + return Optional.of(timestampDataType); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java new file mode 100644 index 0000000..5d7431c --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.inference.TypeStrategies; +import org.apache.flink.table.types.inference.TypeStrategy; + +/** + * Entry point for specific type strategies not covered in {@link TypeStrategies}. + * + * <p>This primarily serves the purpose of reducing visibility of individual type strategy + * implementations to avoid polluting the API classpath. + */ +@Internal +public class SpecificTypeStrategies { + + /** See {@link RowTypeStrategy}. */ + public static final TypeStrategy ROW = new RowTypeStrategy(); + + /** See {@link RoundTypeStrategy}. */ + public static final TypeStrategy ROUND = new RoundTypeStrategy(); + + /** See {@link MapTypeStrategy}. */ + public static final TypeStrategy MAP = new MapTypeStrategy(); + + /** See {@link IfNullTypeStrategy}. */ + public static final TypeStrategy IF_NULL = new IfNullTypeStrategy(); + + /** See {@link StringConcatTypeStrategy}. */ + public static final TypeStrategy STRING_CONCAT = new StringConcatTypeStrategy(); + + /** See {@link ArrayTypeStrategy}. */ + public static final TypeStrategy ARRAY = new ArrayTypeStrategy(); + + /** See {@link GetTypeStrategy}. */ + public static final TypeStrategy GET = new GetTypeStrategy(); + + /** See {@link DecimalModTypeStrategy}. */ + public static final TypeStrategy DECIMAL_MOD = new DecimalModTypeStrategy(); + + /** See {@link DecimalDivideTypeStrategy}. */ + public static final TypeStrategy DECIMAL_DIVIDE = new DecimalDivideTypeStrategy(); + + /** See {@link DecimalPlusTypeStrategy}. */ + public static final TypeStrategy DECIMAL_PLUS = new DecimalPlusTypeStrategy(); + + /** See {@link DecimalScale0TypeStrategy}. */ + public static final TypeStrategy DECIMAL_SCALE_0 = new DecimalScale0TypeStrategy(); + + /** See {@link DecimalTimesTypeStrategy}. */ + public static final TypeStrategy DECIMAL_TIMES = new DecimalTimesTypeStrategy(); + + /** See {@link SourceWatermarkTypeStrategy}. */ + public static final TypeStrategy SOURCE_WATERMARK = new SourceWatermarkTypeStrategy(); + + /** See {@link CurrentWatermarkTypeStrategy}. */ + public static final TypeStrategy CURRENT_WATERMARK = new CurrentWatermarkTypeStrategy(); + + private SpecificTypeStrategies() { + // no instantiation + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StrategyUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StrategyUtils.java index 4a8b5de..f352e4f 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StrategyUtils.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StrategyUtils.java @@ -25,6 +25,7 @@ import org.apache.flink.table.types.logical.BinaryType; import org.apache.flink.table.types.logical.CharType; import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.VarBinaryType; import org.apache.flink.table.types.logical.VarCharType; @@ -99,6 +100,17 @@ final class StrategyUtils { }); } + static boolean isDecimalComputation(LogicalType type1, LogicalType type2) { + // both must be exact numeric + if (!hasFamily(type1, LogicalTypeFamily.EXACT_NUMERIC) + || !hasFamily(type2, LogicalTypeFamily.EXACT_NUMERIC)) { + return false; + } + + // one decimal must be present + return hasRoot(type1, LogicalTypeRoot.DECIMAL) || hasRoot(type2, LogicalTypeRoot.DECIMAL); + } + /** * Returns a data type for the given data type and expected root. * diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StringConcatTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StringConcatTypeStrategy.java new file mode 100644 index 0000000..633be96 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/StringConcatTypeStrategy.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.BinaryType; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeFamily; +import org.apache.flink.table.types.utils.TypeConversions; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getLength; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasFamily; +import static org.apache.flink.table.types.logical.utils.LogicalTypeMerging.findCommonType; + +/** + * Type strategy that returns the type of a string concatenation. It assumes that the first two + * arguments are of the same family of either {@link LogicalTypeFamily#BINARY_STRING} or {@link + * LogicalTypeFamily#CHARACTER_STRING}. + */ +@Internal +class StringConcatTypeStrategy implements TypeStrategy { + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final LogicalType type1 = argumentDataTypes.get(0).getLogicalType(); + final LogicalType type2 = argumentDataTypes.get(1).getLogicalType(); + int length = getLength(type1) + getLength(type2); + + // handle overflow + if (length < 0) { + length = CharType.MAX_LENGTH; + } + + final LogicalType minimumType; + if (hasFamily(type1, LogicalTypeFamily.CHARACTER_STRING) + || hasFamily(type2, LogicalTypeFamily.CHARACTER_STRING)) { + minimumType = new CharType(false, length); + } else if (hasFamily(type1, LogicalTypeFamily.BINARY_STRING) + || hasFamily(type2, LogicalTypeFamily.BINARY_STRING)) { + minimumType = new BinaryType(false, length); + } else { + return Optional.empty(); + } + + // deal with nullability handling and varying semantics + return findCommonType(Arrays.asList(type1, type2, minimumType)) + .map(TypeConversions::fromLogicalToDataType); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/MappingTypeStrategiesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/MappingTypeStrategiesTest.java new file mode 100644 index 0000000..be59f1d --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/MappingTypeStrategiesTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.table.api.DataTypes; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.types.inference.TypeStrategies.explicit; + +/** Tests for {@link TypeStrategies#mapping(Map)}. */ +public class MappingTypeStrategiesTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + // (INT, BOOLEAN) -> STRING + TestSpec.forStrategy(createMappingTypeStrategy()) + .inputTypes(DataTypes.INT(), DataTypes.BOOLEAN()) + .expectDataType(DataTypes.STRING()), + + // (INT, STRING) -> BOOLEAN + TestSpec.forStrategy(createMappingTypeStrategy()) + .inputTypes(DataTypes.INT(), DataTypes.STRING()) + .expectDataType(DataTypes.BOOLEAN().bridgedTo(boolean.class)), + + // (INT, CHAR(10)) -> BOOLEAN + // but avoiding casts (mapping actually expects STRING) + TestSpec.forStrategy(createMappingTypeStrategy()) + .inputTypes(DataTypes.INT(), DataTypes.CHAR(10)) + .expectDataType(DataTypes.BOOLEAN().bridgedTo(boolean.class)), + + // invalid mapping strategy + TestSpec.forStrategy(createMappingTypeStrategy()) + .inputTypes(DataTypes.INT(), DataTypes.INT()) + .expectErrorMessage( + "Could not infer an output type for the given arguments.")); + } + + private static TypeStrategy createMappingTypeStrategy() { + final Map<InputTypeStrategy, TypeStrategy> mappings = new HashMap<>(); + mappings.put( + InputTypeStrategies.sequence( + InputTypeStrategies.explicit(DataTypes.INT()), + InputTypeStrategies.explicit(DataTypes.STRING())), + explicit(DataTypes.BOOLEAN().bridgedTo(boolean.class))); + mappings.put( + InputTypeStrategies.sequence( + InputTypeStrategies.explicit(DataTypes.INT()), + InputTypeStrategies.explicit(DataTypes.BOOLEAN())), + explicit(DataTypes.STRING())); + return TypeStrategies.mapping(mappings); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java index 55985a3..88c0c60 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java @@ -19,165 +19,97 @@ package org.apache.flink.table.types.inference; import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.ValidationException; -import org.apache.flink.table.catalog.ObjectIdentifier; -import org.apache.flink.table.functions.FunctionKind; -import org.apache.flink.table.types.DataType; -import org.apache.flink.table.types.FieldsDataType; -import org.apache.flink.table.types.inference.utils.CallContextMock; -import org.apache.flink.table.types.inference.utils.FunctionDefinitionMock; -import org.apache.flink.table.types.logical.BigIntType; -import org.apache.flink.table.types.logical.LocalZonedTimestampType; import org.apache.flink.table.types.logical.LogicalTypeFamily; -import org.apache.flink.table.types.logical.StructuredType; -import org.apache.flink.table.types.logical.TimestampKind; -import org.apache.flink.table.types.logical.TimestampType; import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; -import org.apache.flink.table.types.utils.TypeConversions; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameter; -import org.junit.runners.Parameterized.Parameters; - -import javax.annotation.Nullable; import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import static org.apache.flink.core.testutils.FlinkMatchers.containsCause; import static org.apache.flink.table.types.inference.TypeStrategies.MISSING; -import static org.apache.flink.table.types.inference.TypeStrategies.STRING_CONCAT; import static org.apache.flink.table.types.inference.TypeStrategies.argument; import static org.apache.flink.table.types.inference.TypeStrategies.explicit; import static org.apache.flink.table.types.inference.TypeStrategies.nullable; import static org.apache.flink.table.types.inference.TypeStrategies.varyingString; -import static org.hamcrest.CoreMatchers.equalTo; /** Tests for built-in {@link TypeStrategies}. */ -@RunWith(Parameterized.class) -public class TypeStrategiesTest { +public class TypeStrategiesTest extends TypeStrategiesTestBase { - @Parameters(name = "{index}: {0}") - public static List<TestSpec> testData() { + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TypeStrategiesTestBase.TestSpec> testData() { return Arrays.asList( // missing strategy with arbitrary argument - TestSpec.forStrategy(MISSING) + TypeStrategiesTestBase.TestSpec.forStrategy(MISSING) .inputTypes(DataTypes.INT()) .expectErrorMessage( "Could not infer an output type for the given arguments."), // valid explicit - TestSpec.forStrategy(explicit(DataTypes.BIGINT())) + TypeStrategiesTestBase.TestSpec.forStrategy(explicit(DataTypes.BIGINT())) .inputTypes() .expectDataType(DataTypes.BIGINT()), // infer from input - TestSpec.forStrategy(argument(0)) + TypeStrategiesTestBase.TestSpec.forStrategy(argument(0)) .inputTypes(DataTypes.INT(), DataTypes.STRING()) .expectDataType(DataTypes.INT()), // infer from not existing input - TestSpec.forStrategy(argument(0)) + TypeStrategiesTestBase.TestSpec.forStrategy(argument(0)) .inputTypes() .expectErrorMessage( "Could not infer an output type for the given arguments."), - // (INT, BOOLEAN) -> STRING - TestSpec.forStrategy(createMappingTypeStrategy()) - .inputTypes(DataTypes.INT(), DataTypes.BOOLEAN()) - .expectDataType(DataTypes.STRING()), - - // (INT, STRING) -> BOOLEAN - TestSpec.forStrategy(createMappingTypeStrategy()) - .inputTypes(DataTypes.INT(), DataTypes.STRING()) - .expectDataType(DataTypes.BOOLEAN().bridgedTo(boolean.class)), - - // (INT, CHAR(10)) -> BOOLEAN - // but avoiding casts (mapping actually expects STRING) - TestSpec.forStrategy(createMappingTypeStrategy()) - .inputTypes(DataTypes.INT(), DataTypes.CHAR(10)) - .expectDataType(DataTypes.BOOLEAN().bridgedTo(boolean.class)), - - // invalid mapping strategy - TestSpec.forStrategy(createMappingTypeStrategy()) - .inputTypes(DataTypes.INT(), DataTypes.INT()) - .expectErrorMessage( - "Could not infer an output type for the given arguments."), - // invalid return type - TestSpec.forStrategy(explicit(DataTypes.NULL())) + TypeStrategiesTestBase.TestSpec.forStrategy(explicit(DataTypes.NULL())) .inputTypes() .expectErrorMessage( "Could not infer an output type for the given arguments. Untyped NULL received."), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "First type strategy", TypeStrategies.first( (callContext) -> Optional.empty(), explicit(DataTypes.INT()))) .inputTypes() .expectDataType(DataTypes.INT()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Match root type strategy", TypeStrategies.matchFamily(0, LogicalTypeFamily.NUMERIC)) .inputTypes(DataTypes.INT()) .expectDataType(DataTypes.INT()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Invalid match root type strategy", TypeStrategies.matchFamily(0, LogicalTypeFamily.NUMERIC)) .inputTypes(DataTypes.BOOLEAN()) .expectErrorMessage( "Could not infer an output type for the given arguments."), - TestSpec.forStrategy("Infer a row type", TypeStrategies.ROW) - .inputTypes(DataTypes.BIGINT(), DataTypes.STRING()) - .expectDataType( - DataTypes.ROW( - DataTypes.FIELD("f0", DataTypes.BIGINT()), - DataTypes.FIELD("f1", DataTypes.STRING())) - .notNull()), - TestSpec.forStrategy("Infer an array type", TypeStrategies.ARRAY) - .inputTypes(DataTypes.BIGINT(), DataTypes.BIGINT()) - .expectDataType(DataTypes.ARRAY(DataTypes.BIGINT()).notNull()), - TestSpec.forStrategy("Infer a map type", TypeStrategies.MAP) - .inputTypes(DataTypes.BIGINT(), DataTypes.STRING().notNull()) - .expectDataType( - DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING().notNull()) - .notNull()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Cascading to nullable type", nullable(explicit(DataTypes.BOOLEAN().notNull()))) .inputTypes(DataTypes.BIGINT().notNull(), DataTypes.VARCHAR(2).nullable()) .expectDataType(DataTypes.BOOLEAN().nullable()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Cascading to not null type", nullable(explicit(DataTypes.BOOLEAN().nullable()))) .inputTypes(DataTypes.BIGINT().notNull(), DataTypes.VARCHAR(2).notNull()) .expectDataType(DataTypes.BOOLEAN().notNull()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Cascading to not null type but only consider first argument", nullable( ConstantArgumentCount.to(0), explicit(DataTypes.BOOLEAN().nullable()))) .inputTypes(DataTypes.BIGINT().notNull(), DataTypes.VARCHAR(2).nullable()) .expectDataType(DataTypes.BOOLEAN().notNull()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Cascading to null type but only consider first two argument", nullable( ConstantArgumentCount.to(1), explicit(DataTypes.BOOLEAN().nullable()))) .inputTypes(DataTypes.BIGINT().notNull(), DataTypes.VARCHAR(2).nullable()) .expectDataType(DataTypes.BOOLEAN().nullable()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Cascading to not null type but only consider the second and third argument", nullable( ConstantArgumentCount.between(1, 2), @@ -187,287 +119,28 @@ public class TypeStrategiesTest { DataTypes.BIGINT().notNull(), DataTypes.VARCHAR(2).notNull()) .expectDataType(DataTypes.BOOLEAN().notNull()), - TestSpec.forStrategy("Find a common type", TypeStrategies.COMMON) + TypeStrategiesTestBase.TestSpec.forStrategy( + "Find a common type", TypeStrategies.COMMON) .inputTypes( DataTypes.INT(), DataTypes.TINYINT().notNull(), DataTypes.DECIMAL(20, 10)) .expectDataType(DataTypes.DECIMAL(20, 10)), - TestSpec.forStrategy("Find a decimal sum", TypeStrategies.DECIMAL_PLUS) - .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) - .expectDataType(DataTypes.DECIMAL(6, 4).notNull()), - TestSpec.forStrategy("Find a decimal quotient", TypeStrategies.DECIMAL_DIVIDE) - .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) - .expectDataType(DataTypes.DECIMAL(11, 8).notNull()), - TestSpec.forStrategy("Find a decimal product", TypeStrategies.DECIMAL_TIMES) - .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) - .expectDataType(DataTypes.DECIMAL(9, 6).notNull()), - TestSpec.forStrategy("Find a decimal modulo", TypeStrategies.DECIMAL_MOD) - .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) - .expectDataType(DataTypes.DECIMAL(5, 4).notNull()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Convert to varying string", varyingString(explicit(DataTypes.CHAR(12).notNull()))) .inputTypes(DataTypes.CHAR(12).notNull()) .expectDataType(DataTypes.VARCHAR(12).notNull()), - TestSpec.forStrategy("Concat two strings", STRING_CONCAT) - .inputTypes(DataTypes.CHAR(12).notNull(), DataTypes.VARCHAR(12)) - .expectDataType(DataTypes.VARCHAR(24)), - TestSpec.forStrategy( - "Access field of a row nullable type by name", TypeStrategies.GET) - .inputTypes( - DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())), - DataTypes.STRING().notNull()) - .calledWithLiteralAt(1, "f0") - .expectDataType(DataTypes.BIGINT().nullable()), - TestSpec.forStrategy( - "Access field of a row not null type by name", TypeStrategies.GET) - .inputTypes( - DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) - .notNull(), - DataTypes.STRING().notNull()) - .calledWithLiteralAt(1, "f0") - .expectDataType(DataTypes.BIGINT().notNull()), - TestSpec.forStrategy( - "Access field of a structured nullable type by name", - TypeStrategies.GET) - .inputTypes( - new FieldsDataType( - StructuredType.newBuilder( - ObjectIdentifier.of( - "cat", "db", "type")) - .attributes( - Collections.singletonList( - new StructuredType - .StructuredAttribute( - "f0", - new BigIntType( - false)))) - .build(), - Collections.singletonList( - DataTypes.BIGINT().notNull())) - .nullable(), - DataTypes.STRING().notNull()) - .calledWithLiteralAt(1, "f0") - .expectDataType(DataTypes.BIGINT().nullable()), - TestSpec.forStrategy( - "Access field of a structured not null type by name", - TypeStrategies.GET) - .inputTypes( - new FieldsDataType( - StructuredType.newBuilder( - ObjectIdentifier.of( - "cat", "db", "type")) - .attributes( - Collections.singletonList( - new StructuredType - .StructuredAttribute( - "f0", - new BigIntType( - false)))) - .build(), - Collections.singletonList( - DataTypes.BIGINT().notNull())) - .notNull(), - DataTypes.STRING().notNull()) - .calledWithLiteralAt(1, "f0") - .expectDataType(DataTypes.BIGINT().notNull()), - TestSpec.forStrategy( - "Access field of a row nullable type by index", TypeStrategies.GET) - .inputTypes( - DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())), - DataTypes.INT().notNull()) - .calledWithLiteralAt(1, 0) - .expectDataType(DataTypes.BIGINT().nullable()), - TestSpec.forStrategy( - "Access field of a row not null type by index", TypeStrategies.GET) - .inputTypes( - DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) - .notNull(), - DataTypes.INT().notNull()) - .calledWithLiteralAt(1, 0) - .expectDataType(DataTypes.BIGINT().notNull()), - TestSpec.forStrategy( - "Fields can be accessed only with a literal (name)", - TypeStrategies.GET) - .inputTypes( - DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) - .notNull(), - DataTypes.STRING().notNull()) - .expectErrorMessage( - "Could not infer an output type for the given arguments."), - TestSpec.forStrategy( - "Fields can be accessed only with a literal (index)", - TypeStrategies.GET) - .inputTypes( - DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) - .notNull(), - DataTypes.INT().notNull()) - .expectErrorMessage( - "Could not infer an output type for the given arguments."), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Average with grouped aggregation", TypeStrategies.aggArg0(LogicalTypeMerging::findAvgAggType, true)) .inputTypes(DataTypes.INT().notNull()) .calledWithGroupedAggregation() .expectDataType(DataTypes.INT().notNull()), - TestSpec.forStrategy( + TypeStrategiesTestBase.TestSpec.forStrategy( "Average without grouped aggregation", TypeStrategies.aggArg0(LogicalTypeMerging::findAvgAggType, true)) .inputTypes(DataTypes.INT().notNull()) - .expectDataType(DataTypes.INT()), - - // CURRENT_WATERMARK - TestSpec.forStrategy("TIMESTAMP(3) *ROWTIME*", TypeStrategies.CURRENT_WATERMARK) - .inputTypes(createRowtimeType(TimestampKind.ROWTIME, 3).notNull()) - .expectDataType(DataTypes.TIMESTAMP(3)), - TestSpec.forStrategy("TIMESTAMP_LTZ(3) *ROWTIME*", TypeStrategies.CURRENT_WATERMARK) - .inputTypes(createRowtimeLtzType(TimestampKind.ROWTIME, 3).notNull()) - .expectDataType(DataTypes.TIMESTAMP_LTZ(3)), - TestSpec.forStrategy("TIMESTAMP(9) *ROWTIME*", TypeStrategies.CURRENT_WATERMARK) - .inputTypes(createRowtimeType(TimestampKind.ROWTIME, 9).notNull()) - .expectDataType(DataTypes.TIMESTAMP(3)), - TestSpec.forStrategy("TIMESTAMP_LTZ(9) *ROWTIME*", TypeStrategies.CURRENT_WATERMARK) - .inputTypes(createRowtimeLtzType(TimestampKind.ROWTIME, 9).notNull()) - .expectDataType(DataTypes.TIMESTAMP_LTZ(3))); - } - - @Parameter public TestSpec testSpec; - - @Rule public ExpectedException thrown = ExpectedException.none(); - - @Test - public void testTypeStrategy() { - if (testSpec.expectedErrorMessage != null) { - thrown.expect(ValidationException.class); - thrown.expectCause( - containsCause(new ValidationException(testSpec.expectedErrorMessage))); - } - TypeInferenceUtil.Result result = runTypeInference(); - if (testSpec.expectedDataType != null) { - Assert.assertThat(result.getOutputDataType(), equalTo(testSpec.expectedDataType)); - } - } - - // -------------------------------------------------------------------------------------------- - - private TypeInferenceUtil.Result runTypeInference() { - final FunctionDefinitionMock functionDefinitionMock = new FunctionDefinitionMock(); - functionDefinitionMock.functionKind = FunctionKind.SCALAR; - final CallContextMock callContextMock = new CallContextMock(); - callContextMock.functionDefinition = functionDefinitionMock; - callContextMock.argumentDataTypes = testSpec.inputTypes; - callContextMock.name = "f"; - callContextMock.outputDataType = Optional.empty(); - callContextMock.isGroupedAggregation = testSpec.isGroupedAggregation; - - callContextMock.argumentLiterals = - IntStream.range(0, testSpec.inputTypes.size()) - .mapToObj(i -> testSpec.literalPos != null && i == testSpec.literalPos) - .collect(Collectors.toList()); - callContextMock.argumentValues = - IntStream.range(0, testSpec.inputTypes.size()) - .mapToObj( - i -> - (testSpec.literalPos != null && i == testSpec.literalPos) - ? Optional.ofNullable(testSpec.literalValue) - : Optional.empty()) - .collect(Collectors.toList()); - - final TypeInference typeInference = - TypeInference.newBuilder() - .inputTypeStrategy(InputTypeStrategies.WILDCARD) - .outputTypeStrategy(testSpec.strategy) - .build(); - return TypeInferenceUtil.runTypeInference(typeInference, callContextMock, null); - } - - // -------------------------------------------------------------------------------------------- - - private static class TestSpec { - - private @Nullable final String description; - - private final TypeStrategy strategy; - - private List<DataType> inputTypes; - - private @Nullable DataType expectedDataType; - - private @Nullable String expectedErrorMessage; - - private @Nullable Integer literalPos; - - private @Nullable Object literalValue; - - private boolean isGroupedAggregation; - - private TestSpec(@Nullable String description, TypeStrategy strategy) { - this.description = description; - this.strategy = strategy; - } - - static TestSpec forStrategy(TypeStrategy strategy) { - return new TestSpec(null, strategy); - } - - static TestSpec forStrategy(String description, TypeStrategy strategy) { - return new TestSpec(description, strategy); - } - - TestSpec inputTypes(DataType... dataTypes) { - this.inputTypes = Arrays.asList(dataTypes); - return this; - } - - TestSpec calledWithLiteralAt(int pos, Object value) { - this.literalPos = pos; - this.literalValue = value; - return this; - } - - TestSpec calledWithGroupedAggregation() { - this.isGroupedAggregation = true; - return this; - } - - TestSpec expectDataType(DataType expectedDataType) { - this.expectedDataType = expectedDataType; - return this; - } - - TestSpec expectErrorMessage(String expectedErrorMessage) { - this.expectedErrorMessage = expectedErrorMessage; - return this; - } - - @Override - public String toString() { - return description != null ? description : ""; - } - } - - private static TypeStrategy createMappingTypeStrategy() { - final Map<InputTypeStrategy, TypeStrategy> mappings = new HashMap<>(); - mappings.put( - InputTypeStrategies.sequence( - InputTypeStrategies.explicit(DataTypes.INT()), - InputTypeStrategies.explicit(DataTypes.STRING())), - explicit(DataTypes.BOOLEAN().bridgedTo(boolean.class))); - mappings.put( - InputTypeStrategies.sequence( - InputTypeStrategies.explicit(DataTypes.INT()), - InputTypeStrategies.explicit(DataTypes.BOOLEAN())), - explicit(DataTypes.STRING())); - return TypeStrategies.mapping(mappings); - } - - private static DataType createRowtimeType(TimestampKind kind, int precision) { - return TypeConversions.fromLogicalToDataType(new TimestampType(false, kind, precision)); - } - - private static DataType createRowtimeLtzType(TimestampKind kind, int precision) { - return TypeConversions.fromLogicalToDataType( - new LocalZonedTimestampType(false, kind, precision)); + .expectDataType(DataTypes.INT())); } } diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTestBase.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTestBase.java new file mode 100644 index 0000000..b3abcd7 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTestBase.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.utils.CallContextMock; +import org.apache.flink.table.types.inference.utils.FunctionDefinitionMock; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.core.testutils.FlinkMatchers.containsCause; +import static org.hamcrest.CoreMatchers.equalTo; + +/** Base class for tests of {@link TypeStrategies}. */ +@RunWith(Parameterized.class) +public abstract class TypeStrategiesTestBase { + + @Parameter public TestSpec testSpec; + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testTypeStrategy() { + if (testSpec.expectedErrorMessage != null) { + thrown.expect(ValidationException.class); + thrown.expectCause( + containsCause(new ValidationException(testSpec.expectedErrorMessage))); + } + TypeInferenceUtil.Result result = runTypeInference(); + if (testSpec.expectedDataType != null) { + Assert.assertThat(result.getOutputDataType(), equalTo(testSpec.expectedDataType)); + } + } + + // -------------------------------------------------------------------------------------------- + + private TypeInferenceUtil.Result runTypeInference() { + final FunctionDefinitionMock functionDefinitionMock = new FunctionDefinitionMock(); + functionDefinitionMock.functionKind = FunctionKind.SCALAR; + final CallContextMock callContextMock = new CallContextMock(); + callContextMock.functionDefinition = functionDefinitionMock; + callContextMock.argumentDataTypes = testSpec.inputTypes; + callContextMock.name = "f"; + callContextMock.outputDataType = Optional.empty(); + callContextMock.isGroupedAggregation = testSpec.isGroupedAggregation; + + callContextMock.argumentLiterals = + IntStream.range(0, testSpec.inputTypes.size()) + .mapToObj(i -> testSpec.literalPos != null && i == testSpec.literalPos) + .collect(Collectors.toList()); + callContextMock.argumentValues = + IntStream.range(0, testSpec.inputTypes.size()) + .mapToObj( + i -> + (testSpec.literalPos != null && i == testSpec.literalPos) + ? Optional.ofNullable(testSpec.literalValue) + : Optional.empty()) + .collect(Collectors.toList()); + + final TypeInference typeInference = + TypeInference.newBuilder() + .inputTypeStrategy(InputTypeStrategies.WILDCARD) + .outputTypeStrategy(testSpec.strategy) + .build(); + return TypeInferenceUtil.runTypeInference(typeInference, callContextMock, null); + } + + // -------------------------------------------------------------------------------------------- + + /** Specification of a test scenario. */ + public static class TestSpec { + + private @Nullable final String description; + + private final TypeStrategy strategy; + + private List<DataType> inputTypes; + + private @Nullable DataType expectedDataType; + + private @Nullable String expectedErrorMessage; + + private @Nullable Integer literalPos; + + private @Nullable Object literalValue; + + private boolean isGroupedAggregation; + + private TestSpec(@Nullable String description, TypeStrategy strategy) { + this.description = description; + this.strategy = strategy; + } + + public static TestSpec forStrategy(TypeStrategy strategy) { + return new TestSpec(null, strategy); + } + + public static TestSpec forStrategy(String description, TypeStrategy strategy) { + return new TestSpec(description, strategy); + } + + public TestSpec inputTypes(DataType... dataTypes) { + this.inputTypes = Arrays.asList(dataTypes); + return this; + } + + public TestSpec calledWithLiteralAt(int pos, Object value) { + this.literalPos = pos; + this.literalValue = value; + return this; + } + + public TestSpec calledWithGroupedAggregation() { + this.isGroupedAggregation = true; + return this; + } + + public TestSpec expectDataType(DataType expectedDataType) { + this.expectedDataType = expectedDataType; + return this; + } + + public TestSpec expectErrorMessage(String expectedErrorMessage) { + this.expectedErrorMessage = expectedErrorMessage; + return this; + } + + @Override + public String toString() { + return description != null ? description : ""; + } + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/ArrayTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/ArrayTypeStrategyTest.java new file mode 100644 index 0000000..31ac3e7 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/ArrayTypeStrategyTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +/** Tests for {@link ArrayTypeStrategy}. */ +public class ArrayTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + TestSpec.forStrategy("Infer an array type", SpecificTypeStrategies.ARRAY) + .inputTypes(DataTypes.BIGINT(), DataTypes.BIGINT()) + .expectDataType(DataTypes.ARRAY(DataTypes.BIGINT()).notNull())); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CurrentWatermarkTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CurrentWatermarkTypeStrategyTest.java new file mode 100644 index 0000000..d56e588 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CurrentWatermarkTypeStrategyTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; +import org.apache.flink.table.types.logical.LocalZonedTimestampType; +import org.apache.flink.table.types.logical.TimestampKind; +import org.apache.flink.table.types.logical.TimestampType; +import org.apache.flink.table.types.utils.TypeConversions; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +/** Tests for {@link CurrentWatermarkTypeStrategy}. */ +public class CurrentWatermarkTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + // CURRENT_WATERMARK + TestSpec.forStrategy( + "TIMESTAMP(3) *ROWTIME*", SpecificTypeStrategies.CURRENT_WATERMARK) + .inputTypes(createRowtimeType(TimestampKind.ROWTIME, 3).notNull()) + .expectDataType(DataTypes.TIMESTAMP(3)), + TestSpec.forStrategy( + "TIMESTAMP_LTZ(3) *ROWTIME*", + SpecificTypeStrategies.CURRENT_WATERMARK) + .inputTypes(createRowtimeLtzType(TimestampKind.ROWTIME, 3).notNull()) + .expectDataType(DataTypes.TIMESTAMP_LTZ(3)), + TestSpec.forStrategy( + "TIMESTAMP(9) *ROWTIME*", SpecificTypeStrategies.CURRENT_WATERMARK) + .inputTypes(createRowtimeType(TimestampKind.ROWTIME, 9).notNull()) + .expectDataType(DataTypes.TIMESTAMP(3)), + TestSpec.forStrategy( + "TIMESTAMP_LTZ(9) *ROWTIME*", + SpecificTypeStrategies.CURRENT_WATERMARK) + .inputTypes(createRowtimeLtzType(TimestampKind.ROWTIME, 9).notNull()) + .expectDataType(DataTypes.TIMESTAMP_LTZ(3))); + } + + static DataType createRowtimeType(TimestampKind kind, int precision) { + return TypeConversions.fromLogicalToDataType(new TimestampType(false, kind, precision)); + } + + static DataType createRowtimeLtzType(TimestampKind kind, int precision) { + return TypeConversions.fromLogicalToDataType( + new LocalZonedTimestampType(false, kind, precision)); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/DecimalTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/DecimalTypeStrategyTest.java new file mode 100644 index 0000000..0186154 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/DecimalTypeStrategyTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; +import org.apache.flink.table.types.inference.TypeStrategy; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +/** Tests for decimal {@link TypeStrategy TypeStrategies}. */ +public class DecimalTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + TestSpec.forStrategy("Find a decimal sum", SpecificTypeStrategies.DECIMAL_PLUS) + .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) + .expectDataType(DataTypes.DECIMAL(6, 4).notNull()), + TestSpec.forStrategy( + "Find a decimal quotient", SpecificTypeStrategies.DECIMAL_DIVIDE) + .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) + .expectDataType(DataTypes.DECIMAL(11, 8).notNull()), + TestSpec.forStrategy("Find a decimal product", SpecificTypeStrategies.DECIMAL_TIMES) + .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) + .expectDataType(DataTypes.DECIMAL(9, 6).notNull()), + TestSpec.forStrategy("Find a decimal modulo", SpecificTypeStrategies.DECIMAL_MOD) + .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) + .expectDataType(DataTypes.DECIMAL(5, 4).notNull())); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/GetTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/GetTypeStrategyTest.java new file mode 100644 index 0000000..d52b4f5 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/GetTypeStrategyTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.catalog.ObjectIdentifier; +import org.apache.flink.table.types.FieldsDataType; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.StructuredType; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** Tests for {@link GetTypeStrategy}. */ +public class GetTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + TestSpec.forStrategy( + "Access field of a row nullable type by name", + SpecificTypeStrategies.GET) + .inputTypes( + DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())), + DataTypes.STRING().notNull()) + .calledWithLiteralAt(1, "f0") + .expectDataType(DataTypes.BIGINT().nullable()), + TestSpec.forStrategy( + "Access field of a row not null type by name", + SpecificTypeStrategies.GET) + .inputTypes( + DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) + .notNull(), + DataTypes.STRING().notNull()) + .calledWithLiteralAt(1, "f0") + .expectDataType(DataTypes.BIGINT().notNull()), + TestSpec.forStrategy( + "Access field of a structured nullable type by name", + SpecificTypeStrategies.GET) + .inputTypes( + new FieldsDataType( + StructuredType.newBuilder( + ObjectIdentifier.of( + "cat", "db", "type")) + .attributes( + Collections.singletonList( + new StructuredType + .StructuredAttribute( + "f0", + new BigIntType( + false)))) + .build(), + Collections.singletonList( + DataTypes.BIGINT().notNull())) + .nullable(), + DataTypes.STRING().notNull()) + .calledWithLiteralAt(1, "f0") + .expectDataType(DataTypes.BIGINT().nullable()), + TestSpec.forStrategy( + "Access field of a structured not null type by name", + SpecificTypeStrategies.GET) + .inputTypes( + new FieldsDataType( + StructuredType.newBuilder( + ObjectIdentifier.of( + "cat", "db", "type")) + .attributes( + Collections.singletonList( + new StructuredType + .StructuredAttribute( + "f0", + new BigIntType( + false)))) + .build(), + Collections.singletonList( + DataTypes.BIGINT().notNull())) + .notNull(), + DataTypes.STRING().notNull()) + .calledWithLiteralAt(1, "f0") + .expectDataType(DataTypes.BIGINT().notNull()), + TestSpec.forStrategy( + "Access field of a row nullable type by index", + SpecificTypeStrategies.GET) + .inputTypes( + DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())), + DataTypes.INT().notNull()) + .calledWithLiteralAt(1, 0) + .expectDataType(DataTypes.BIGINT().nullable()), + TestSpec.forStrategy( + "Access field of a row not null type by index", + SpecificTypeStrategies.GET) + .inputTypes( + DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) + .notNull(), + DataTypes.INT().notNull()) + .calledWithLiteralAt(1, 0) + .expectDataType(DataTypes.BIGINT().notNull()), + TestSpec.forStrategy( + "Fields can be accessed only with a literal (name)", + SpecificTypeStrategies.GET) + .inputTypes( + DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) + .notNull(), + DataTypes.STRING().notNull()) + .expectErrorMessage( + "Could not infer an output type for the given arguments."), + TypeStrategiesTestBase.TestSpec.forStrategy( + "Fields can be accessed only with a literal (index)", + SpecificTypeStrategies.GET) + .inputTypes( + DataTypes.ROW(DataTypes.FIELD("f0", DataTypes.BIGINT().notNull())) + .notNull(), + DataTypes.INT().notNull()) + .expectErrorMessage( + "Could not infer an output type for the given arguments.")); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/MapTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/MapTypeStrategyTest.java new file mode 100644 index 0000000..722033a --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/MapTypeStrategyTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +/** Tests for {@link MapTypeStrategy}. */ +public class MapTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + TestSpec.forStrategy("Infer a map type", SpecificTypeStrategies.MAP) + .inputTypes(DataTypes.BIGINT(), DataTypes.STRING().notNull()) + .expectDataType( + DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING().notNull()) + .notNull())); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/RowTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/RowTypeStrategyTest.java new file mode 100644 index 0000000..13bcb4a --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/RowTypeStrategyTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +/** Tests for {@link RowTypeStrategy}. */ +public class RowTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + TestSpec.forStrategy("Infer a row type", SpecificTypeStrategies.ROW) + .inputTypes(DataTypes.BIGINT(), DataTypes.STRING()) + .expectDataType( + DataTypes.ROW( + DataTypes.FIELD("f0", DataTypes.BIGINT()), + DataTypes.FIELD("f1", DataTypes.STRING())) + .notNull())); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/StringConcatTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/StringConcatTypeStrategyTest.java new file mode 100644 index 0000000..1428179 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/StringConcatTypeStrategyTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.inference.TypeStrategiesTestBase; + +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +/** Tests for {@link StringConcatTypeStrategy}. */ +public class StringConcatTypeStrategyTest extends TypeStrategiesTestBase { + + @Parameterized.Parameters(name = "{index}: {0}") + public static List<TestSpec> testData() { + return Arrays.asList( + TestSpec.forStrategy("Concat two strings", SpecificTypeStrategies.STRING_CONCAT) + .inputTypes(DataTypes.CHAR(12).notNull(), DataTypes.VARCHAR(12)) + .expectDataType(DataTypes.VARCHAR(24))); + } +}