This is an automated email from the ASF dual-hosted git repository. dwysakowicz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new bc6c2cec37c [FLINK-33663] Serialize CallExpressions into SQL (#23811) bc6c2cec37c is described below commit bc6c2cec37c45f021ae22a2a7b5ab9537b8506cd Author: Dawid Wysakowicz <dwysakow...@apache.org> AuthorDate: Tue Dec 5 09:46:45 2023 +0100 [FLINK-33663] Serialize CallExpressions into SQL (#23811) --- .../expressions/ExpressionSerializationTest.java | 364 +++++++++++++++++++++ .../flink/table/expressions/CallExpression.java | 27 ++ .../expressions/FieldReferenceExpression.java | 6 + .../table/expressions/TypeLiteralExpression.java | 10 + .../table/functions/BuiltInFunctionDefinition.java | 55 ++++ .../functions/BuiltInFunctionDefinitions.java | 135 +++++++- .../flink/table/functions/CallSyntaxUtils.java | 49 +++ .../table/functions/JsonFunctionsCallSyntax.java | 185 +++++++++++ .../flink/table/functions/SqlCallSyntax.java | 275 ++++++++++++++++ .../BuiltInAggregateFunctionTestBase.java | 185 ++++++++++- .../planner/functions/BuiltInFunctionTestBase.java | 35 +- .../functions/IfThenElseFunctionITCase.java | 55 ++++ .../functions/JsonAggregationFunctionsITCase.java | 67 ++-- 13 files changed, 1396 insertions(+), 52 deletions(-) diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ExpressionSerializationTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ExpressionSerializationTest.java new file mode 100644 index 00000000000..2693cb38517 --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ExpressionSerializationTest.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.JsonExistsOnError; +import org.apache.flink.table.api.JsonOnNull; +import org.apache.flink.table.api.JsonQueryOnEmptyOrError; +import org.apache.flink.table.api.JsonQueryWrapper; +import org.apache.flink.table.api.JsonType; +import org.apache.flink.table.api.JsonValueOnEmptyOrError; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.catalog.Column; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.expressions.resolver.ExpressionResolver; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.operations.ValuesQueryOperation; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.utils.DataTypeFactoryMock; +import org.apache.flink.table.utils.FunctionLookupMock; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +import static org.apache.flink.table.api.Expressions.$; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for serializing {@link BuiltInFunctionDefinitions} into a SQL string. */ +public class ExpressionSerializationTest { + + public static Stream<TestSpec> testData() { + return Stream.of( + TestSpec.forExpr(Expressions.uuid()).expectStr("UUID()"), + TestSpec.forExpr($("f0").abs()) + .withField("f0", DataTypes.BIGINT()) + .expectStr("ABS(`f0`)"), + TestSpec.forExpr($("f0").isLess(123)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` < 123"), + TestSpec.forExpr($("f0").isLessOrEqual(123)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` <= 123"), + TestSpec.forExpr($("f0").isEqual(123)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` = 123"), + TestSpec.forExpr($("f0").isNotEqual(123)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` <> 123"), + TestSpec.forExpr($("f0").isGreaterOrEqual(123)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` >= 123"), + TestSpec.forExpr($("f0").isGreater(123)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` > 123"), + TestSpec.forExpr($("f0").isNull()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("`f0` IS NULL"), + TestSpec.forExpr($("f0").isNotNull()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("`f0` IS NOT NULL"), + TestSpec.forExpr($("f0").isTrue()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("`f0` IS TRUE"), + TestSpec.forExpr($("f0").isNotTrue()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("`f0` IS NOT TRUE"), + TestSpec.forExpr($("f0").isFalse()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("`f0` IS FALSE"), + TestSpec.forExpr($("f0").isNotFalse()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("`f0` IS NOT FALSE"), + TestSpec.forExpr($("f0").not()) + .withField("f0", DataTypes.BOOLEAN()) + .expectStr("NOT `f0`"), + TestSpec.forExpr( + Expressions.and( + $("f0").isNotNull(), + $("f0").isLess(420), + $("f0").isGreater(123))) + .withField("f0", DataTypes.BIGINT()) + .expectStr("(`f0` IS NOT NULL) AND (`f0` < 420) AND (`f0` > 123)"), + TestSpec.forExpr( + Expressions.or( + $("f0").isNotNull(), + $("f0").isLess(420), + $("f0").isGreater(123))) + .withField("f0", DataTypes.BIGINT()) + .expectStr("(`f0` IS NOT NULL) OR (`f0` < 420) OR (`f0` > 123)"), + TestSpec.forExpr( + Expressions.ifThenElse( + $("f0").isNotNull(), $("f0").plus(420), $("f0").minus(123))) + .withField("f0", DataTypes.BIGINT()) + .expectStr( + "CASE WHEN `f0` IS NOT NULL THEN `f0` + 420 ELSE `f0` - 123 END"), + TestSpec.forExpr($("f0").times(3).dividedBy($("f1"))) + .withField("f0", DataTypes.BIGINT()) + .withField("f1", DataTypes.BIGINT()) + .expectStr("(`f0` * 3) / `f1`"), + TestSpec.forExpr($("f0").mod(5)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` % 5"), + TestSpec.forExpr(Expressions.negative($("f0"))) + .withField("f0", DataTypes.BIGINT()) + .expectStr("- `f0`"), + TestSpec.forExpr($("f0").in(1, 2, 3, 4, 5)) + .withField("f0", DataTypes.INT()) + .expectStr("`f0` IN (1, 2, 3, 4, 5)"), + TestSpec.forExpr($("f0").cast(DataTypes.SMALLINT())) + .withField("f0", DataTypes.BIGINT()) + .expectStr("CAST(`f0` AS SMALLINT)"), + TestSpec.forExpr($("f0").tryCast(DataTypes.SMALLINT())) + .withField("f0", DataTypes.BIGINT()) + .expectStr("TRY_CAST(`f0` AS SMALLINT)"), + TestSpec.forExpr(Expressions.array($("f0"), $("f1"), "ABC")) + .withField("f0", DataTypes.STRING()) + .withField("f1", DataTypes.STRING()) + .expectStr("ARRAY[`f0`, `f1`, 'ABC']"), + TestSpec.forExpr(Expressions.map($("f0"), $("f1"), "ABC", "DEF")) + .withField("f0", DataTypes.STRING()) + .withField("f1", DataTypes.STRING()) + .expectStr("MAP[`f0`, `f1`, 'ABC', 'DEF']"), + TestSpec.forExpr($("f0").at(2)) + .withField("f0", DataTypes.ARRAY(DataTypes.STRING())) + .expectStr("`f0`[2]"), + TestSpec.forExpr($("f0").at("abc")) + .withField("f0", DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())) + .expectStr("`f0`['abc']"), + TestSpec.forExpr($("f0").between(1, 10)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` BETWEEN 1 AND 10"), + TestSpec.forExpr($("f0").notBetween(1, 10)) + .withField("f0", DataTypes.BIGINT()) + .expectStr("`f0` NOT BETWEEN 1 AND 10"), + TestSpec.forExpr($("f0").like("ABC")) + .withField("f0", DataTypes.STRING()) + .expectStr("`f0` LIKE 'ABC'"), + TestSpec.forExpr($("f0").similar("ABC")) + .withField("f0", DataTypes.STRING()) + .expectStr("`f0` SIMILAR TO 'ABC'"), + TestSpec.forExpr($("f0").position("ABC")) + .withField("f0", DataTypes.STRING()) + .expectStr("POSITION(`f0` IN 'ABC')"), + TestSpec.forExpr($("f0").trim("ABC")) + .withField("f0", DataTypes.STRING()) + .expectStr("TRIM BOTH 'ABC' FROM `f0`"), + TestSpec.forExpr($("f0").trimLeading("ABC")) + .withField("f0", DataTypes.STRING()) + .expectStr("TRIM LEADING 'ABC' FROM `f0`"), + TestSpec.forExpr($("f0").trimTrailing("ABC")) + .withField("f0", DataTypes.STRING()) + .expectStr("TRIM TRAILING 'ABC' FROM `f0`"), + TestSpec.forExpr($("f0").overlay("ABC", 2)) + .withField("f0", DataTypes.STRING()) + .expectStr("OVERLAY(`f0` PLACING 'ABC' FROM 2)"), + TestSpec.forExpr($("f0").overlay("ABC", 2, 5)) + .withField("f0", DataTypes.STRING()) + .expectStr("OVERLAY(`f0` PLACING 'ABC' FROM 2 FOR 5)"), + TestSpec.forExpr($("f0").substr(2)) + .withField("f0", DataTypes.STRING()) + .expectStr("SUBSTR(`f0` FROM 2)"), + TestSpec.forExpr($("f0").substr(2, 5)) + .withField("f0", DataTypes.STRING()) + .expectStr("SUBSTR(`f0` FROM 2 FOR 5)"), + TestSpec.forExpr($("f0").substring(2)) + .withField("f0", DataTypes.STRING()) + .expectStr("SUBSTRING(`f0` FROM 2)"), + TestSpec.forExpr($("f0").substring(2, 5)) + .withField("f0", DataTypes.STRING()) + .expectStr("SUBSTRING(`f0` FROM 2 FOR 5)"), + TestSpec.forExpr($("f0").extract(TimeIntervalUnit.HOUR)) + .withField("f0", DataTypes.TIMESTAMP()) + .expectStr("EXTRACT(HOUR FROM `f0`)"), + TestSpec.forExpr($("f0").floor(TimeIntervalUnit.HOUR)) + .withField("f0", DataTypes.TIMESTAMP()) + .expectStr("FLOOR(`f0` TO HOUR)"), + TestSpec.forExpr($("f0").ceil(TimeIntervalUnit.HOUR)) + .withField("f0", DataTypes.TIMESTAMP()) + .expectStr("CEIL(`f0` TO HOUR)"), + TestSpec.forExpr( + Expressions.temporalOverlaps( + $("f0"), $("f1"), + $("f2"), $("f3"))) + .withField("f0", DataTypes.TIMESTAMP()) + .withField("f1", DataTypes.TIMESTAMP()) + .withField("f2", DataTypes.TIMESTAMP()) + .withField("f3", DataTypes.TIMESTAMP()) + .expectStr("(`f0`, `f1`) OVERLAPS (`f2`, `f3`)"), + TestSpec.forExpr($("f0").get("g0").plus($("f0").get("g1").get("h1"))) + .withField( + "f0", + DataTypes.ROW( + DataTypes.FIELD("g0", DataTypes.BIGINT()), + DataTypes.FIELD( + "g1", + DataTypes.ROW( + DataTypes.FIELD( + "h1", DataTypes.BIGINT()))))) + .expectStr("(`f0`.`g0`) + (`f0`.`g1`.`h1`)"), + TestSpec.forExpr($("f0").abs().as("absolute`F0")) + .withField("f0", DataTypes.BIGINT()) + .expectStr("(ABS(`f0`)) AS `absolute``F0`"), + + // JSON functions + TestSpec.forExpr($("f0").isJson()) + .withField("f0", DataTypes.STRING()) + .expectStr("`f0` IS JSON"), + TestSpec.forExpr($("f0").isJson(JsonType.SCALAR)) + .withField("f0", DataTypes.STRING()) + .expectStr("`f0` IS JSON SCALAR"), + TestSpec.forExpr($("f0").jsonExists("$.a")) + .withField("f0", DataTypes.STRING()) + .expectStr("JSON_EXISTS(`f0`, '$.a')"), + TestSpec.forExpr($("f0").jsonExists("$.a", JsonExistsOnError.UNKNOWN)) + .withField("f0", DataTypes.STRING()) + .expectStr("JSON_EXISTS(`f0`, '$.a' UNKNOWN ON ERROR)"), + TestSpec.forExpr($("f0").jsonValue("$.a")) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_VALUE(`f0`, '$.a' RETURNING VARCHAR(2147483647) NULL ON EMPTY NULL ON ERROR)"), + TestSpec.forExpr($("f0").jsonValue("$.a", DataTypes.BOOLEAN(), false)) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_VALUE(`f0`, '$.a' RETURNING BOOLEAN DEFAULT FALSE ON EMPTY DEFAULT FALSE ON ERROR)"), + TestSpec.forExpr( + $("f0").jsonValue( + "$.a", + DataTypes.BIGINT(), + JsonValueOnEmptyOrError.DEFAULT, + 1, + JsonValueOnEmptyOrError.ERROR, + null)) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_VALUE(`f0`, '$.a' RETURNING BIGINT DEFAULT 1 ON EMPTY ERROR ON ERROR)"), + TestSpec.forExpr($("f0").jsonQuery("$.a")) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_QUERY(`f0`, '$.a' WITHOUT ARRAY WRAPPER NULL ON EMPTY NULL ON ERROR)"), + TestSpec.forExpr($("f0").jsonQuery("$.a", JsonQueryWrapper.UNCONDITIONAL_ARRAY)) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_QUERY(`f0`, '$.a' WITH UNCONDITIONAL ARRAY WRAPPER NULL ON EMPTY NULL ON ERROR)"), + TestSpec.forExpr( + $("f0").jsonQuery( + "$.a", + JsonQueryWrapper.CONDITIONAL_ARRAY, + JsonQueryOnEmptyOrError.EMPTY_OBJECT, + JsonQueryOnEmptyOrError.EMPTY_ARRAY)) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_QUERY(`f0`, '$.a' WITH CONDITIONAL ARRAY WRAPPER EMPTY OBJECT ON EMPTY EMPTY ARRAY ON ERROR)"), + TestSpec.forExpr( + Expressions.jsonObject(JsonOnNull.ABSENT, "k1", $("f0"), "k2", 123)) + .withField("f0", DataTypes.STRING()) + .expectStr( + "JSON_OBJECT(KEY 'k1' VALUE `f0`, KEY 'k2' VALUE 123 ABSENT ON NULL)"), + TestSpec.forExpr(Expressions.jsonArray(JsonOnNull.ABSENT, "k1", $("f0"), "k2")) + .withField("f0", DataTypes.STRING()) + .expectStr("JSON_ARRAY('k1', `f0`, 'k2' ABSENT ON NULL)"), + TestSpec.forExpr(Expressions.jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))) + .withField("f0", DataTypes.STRING()) + .expectStr("JSON_ARRAYAGG(`f0` ABSENT ON NULL)"), + TestSpec.forExpr(Expressions.jsonArrayAgg(JsonOnNull.NULL, $("f0"))) + .withField("f0", DataTypes.STRING()) + .expectStr("JSON_ARRAYAGG(`f0` NULL ON NULL)"), + TestSpec.forExpr(Expressions.jsonObjectAgg(JsonOnNull.ABSENT, $("f0"), $("f1"))) + .withField("f0", DataTypes.STRING()) + .withField("f1", DataTypes.STRING()) + .expectStr("JSON_OBJECTAGG(KEY `f0` VALUE `f1` ABSENT ON NULL)"), + TestSpec.forExpr(Expressions.jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))) + .withField("f0", DataTypes.STRING()) + .withField("f1", DataTypes.STRING()) + .expectStr("JSON_OBJECTAGG(KEY `f0` VALUE `f1` NULL ON NULL)"), + + // Aggregate functions + TestSpec.forExpr( + $("f0").count() + .distinct() + .plus($("f0").avg().distinct()) + .plus($("f0").max())) + .withField("f0", DataTypes.BIGINT()) + .expectStr( + "((COUNT(DISTINCT `f0`)) + (AVG(DISTINCT `f0`))) + (MAX(`f0`))")); + } + + @ParameterizedTest + @MethodSource("testData") + void testSerialization(TestSpec spec) { + final List<ResolvedExpression> resolved = + ExpressionResolver.resolverFor( + TableConfig.getDefault(), + Thread.currentThread().getContextClassLoader(), + name -> Optional.empty(), + new FunctionLookupMock(Collections.emptyMap()), + new DataTypeFactoryMock(), + (sqlExpression, inputRowType, outputType) -> null, + new ValuesQueryOperation( + Collections.emptyList(), + ResolvedSchema.of(new ArrayList<>(spec.columns.values())))) + .build() + .resolve(Collections.singletonList(spec.expr)); + + assertThat(resolved) + .hasSize(1) + .extracting(ResolvedExpression::asSerializableString) + .containsOnly(spec.expectedStr); + } + + private static class TestSpec { + private final Expression expr; + private String expectedStr; + + private final Map<String, Column> columns = new HashMap<>(); + + public TestSpec(Expression expr) { + this.expr = expr; + } + + public static TestSpec forExpr(Expression expr) { + return new TestSpec(expr); + } + + public TestSpec withField(String name, DataType dataType) { + this.columns.put(name, Column.physical(name, dataType)); + return this; + } + + public TestSpec expectStr(String expected) { + this.expectedStr = expected; + return this; + } + + @Override + public String toString() { + return expr.asSummaryString(); + } + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/CallExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/CallExpression.java index 24847ae143a..4465856977c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/CallExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/CallExpression.java @@ -20,13 +20,16 @@ package org.apache.flink.table.expressions; import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.TableException; import org.apache.flink.table.catalog.Catalog; import org.apache.flink.table.catalog.ObjectIdentifier; import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionIdentifier; +import org.apache.flink.table.functions.SqlCallSyntax; import org.apache.flink.table.module.Module; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.utils.EncodingUtils; import org.apache.flink.util.Preconditions; import javax.annotation.Nullable; @@ -216,6 +219,30 @@ public final class CallExpression implements ResolvedExpression { return getFunctionName() + argList; } + @Override + public String asSerializableString() { + if (functionDefinition instanceof BuiltInFunctionDefinition) { + final BuiltInFunctionDefinition definition = + (BuiltInFunctionDefinition) functionDefinition; + return definition.getCallSyntax().unparse(definition.getSqlName(), args); + } else { + return SqlCallSyntax.FUNCTION.unparse(getSerializableFunctionName(), args); + } + } + + private String getSerializableFunctionName() { + if (functionIdentifier == null) { + throw new TableException( + "Only functions that have been registered before are serializable."); + } + + return functionIdentifier + .getIdentifier() + .map(ObjectIdentifier::asSerializableString) + .orElseGet( + () -> EncodingUtils.escapeIdentifier(functionIdentifier.getFunctionName())); + } + @Override public List<Expression> getChildren() { return Collections.unmodifiableList(this.args); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/FieldReferenceExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/FieldReferenceExpression.java index 5afb57ac1a8..42f19ca3d66 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/FieldReferenceExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/FieldReferenceExpression.java @@ -20,6 +20,7 @@ package org.apache.flink.table.expressions; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.utils.EncodingUtils; import org.apache.flink.util.Preconditions; import java.util.Collections; @@ -88,6 +89,11 @@ public final class FieldReferenceExpression implements ResolvedExpression { return name; } + @Override + public String asSerializableString() { + return EncodingUtils.escapeIdentifier(name); + } + @Override public List<Expression> getChildren() { return Collections.emptyList(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java index c303736b42f..61c05ddb138 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java @@ -20,6 +20,7 @@ package org.apache.flink.table.expressions; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.Preconditions; import java.util.Collections; @@ -57,6 +58,15 @@ public final class TypeLiteralExpression implements ResolvedExpression { return dataType.toString(); } + @Override + public String asSerializableString() { + // in SQL nullability is not part of the type, but it is an additional constraint + // on table columns, we remove the nullability here to be able to use the string + // representation in SQL such as e.g. CAST(f0 AS BIGINT) + final LogicalType logicalType = dataType.getLogicalType(); + return logicalType.copy(true).asSerializableString(); + } + @Override public List<Expression> getChildren() { return Collections.emptyList(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java index 81241e2ecb7..fce8c4664fa 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java @@ -69,16 +69,23 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { private final boolean isInternal; + private final SqlCallSyntax sqlCallSyntax; + + private final String sqlName; + private BuiltInFunctionDefinition( String name, + String sqlName, int version, FunctionKind kind, TypeInference typeInference, + SqlCallSyntax sqlCallSyntax, boolean isDeterministic, boolean isRuntimeProvided, String runtimeClass, boolean isInternal) { this.name = checkNotNull(name, "Name must not be null."); + this.sqlName = sqlName; this.version = isInternal ? null : version; this.kind = checkNotNull(kind, "Kind must not be null."); this.typeInference = checkNotNull(typeInference, "Type inference must not be null."); @@ -86,6 +93,7 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { this.isRuntimeProvided = isRuntimeProvided; this.runtimeClass = runtimeClass; this.isInternal = isInternal; + this.sqlCallSyntax = sqlCallSyntax; validateFunction(this.name, this.version, this.isInternal); } @@ -98,6 +106,14 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { return name; } + public String getSqlName() { + if (sqlName != null) { + return sqlName; + } + + return getName().toUpperCase(Locale.ROOT); + } + public Optional<Integer> getVersion() { return Optional.ofNullable(version); } @@ -163,6 +179,10 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { return typeInference; } + public SqlCallSyntax getCallSyntax() { + return sqlCallSyntax; + } + @Override public boolean isDeterministic() { return isDeterministic; @@ -214,6 +234,8 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { private String name; + private String sqlName; + private int version = DEFAULT_VERSION; private FunctionKind kind; @@ -228,6 +250,8 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { private boolean isInternal = false; + private SqlCallSyntax sqlCallSyntax = SqlCallSyntax.FUNCTION; + public Builder() { // default constructor to allow a fluent definition } @@ -327,12 +351,43 @@ public final class BuiltInFunctionDefinition implements SpecializedFunction { return this; } + /** + * Overwrites the syntax used for unparsing a function into a SQL string. If not specified, + * {@link SqlCallSyntax#FUNCTION} is used. + */ + public Builder callSyntax(SqlCallSyntax syntax) { + this.sqlCallSyntax = syntax; + return this; + } + + /** + * Overwrites the syntax used for unparsing a function into a SQL string. If not specified, + * {@link SqlCallSyntax#FUNCTION} is used. This method overwrites the name as well. If the + * name is not provided {@link #name(String)} is passed to the {@link SqlCallSyntax}. + */ + public Builder callSyntax(String name, SqlCallSyntax syntax) { + this.sqlName = name; + this.sqlCallSyntax = syntax; + return this; + } + + /** + * Overwrites the name that is used for unparsing a function into a SQL string. If not + * specified, {@link #name(String)} is used. + */ + public Builder sqlName(String name) { + this.sqlName = name; + return this; + } + public BuiltInFunctionDefinition build() { return new BuiltInFunctionDefinition( name, + sqlName, version, kind, typeInferenceBuilder.build(), + sqlCallSyntax, isDeterministic, isRuntimeProvided, runtimeClass, diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 8f12cdd9b05..08a1b13dfeb 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -28,7 +28,9 @@ import org.apache.flink.table.api.JsonQueryWrapper; import org.apache.flink.table.api.JsonType; import org.apache.flink.table.api.JsonValueOnEmptyOrError; import org.apache.flink.table.api.TableException; +import org.apache.flink.table.expressions.TimeIntervalUnit; import org.apache.flink.table.expressions.TimePointUnit; +import org.apache.flink.table.expressions.ValueLiteralExpression; import org.apache.flink.table.types.inference.ArgumentTypeStrategy; import org.apache.flink.table.types.inference.ConstantArgumentCount; import org.apache.flink.table.types.inference.InputTypeStrategies; @@ -44,6 +46,7 @@ import org.apache.flink.table.types.logical.StructuredType.StructuredComparison; import org.apache.flink.table.types.logical.TimestampKind; import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; import org.apache.flink.table.types.utils.TypeConversions; +import org.apache.flink.table.utils.EncodingUtils; import org.apache.flink.util.Preconditions; import java.lang.reflect.Field; @@ -52,6 +55,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import static org.apache.flink.table.api.DataTypes.BIGINT; @@ -374,6 +378,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition AND = BuiltInFunctionDefinition.newBuilder() .name("and") + .callSyntax("AND", SqlCallSyntax.MULTIPLE_BINARY_OP) .kind(SCALAR) .inputTypeStrategy( varyingSequence( @@ -386,6 +391,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition OR = BuiltInFunctionDefinition.newBuilder() .name("or") + .callSyntax("OR", SqlCallSyntax.MULTIPLE_BINARY_OP) .kind(SCALAR) .inputTypeStrategy( varyingSequence( @@ -398,6 +404,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition NOT = BuiltInFunctionDefinition.newBuilder() .name("not") + .callSyntax("NOT", SqlCallSyntax.UNARY_PREFIX_OP) .kind(SCALAR) .inputTypeStrategy(sequence(logical(LogicalTypeRoot.BOOLEAN))) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -406,6 +413,13 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IF = BuiltInFunctionDefinition.newBuilder() .name("ifThenElse") + .callSyntax( + (sqlName, operands) -> + String.format( + "CASE WHEN %s THEN %s ELSE %s END", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + operands.get(2).asSerializableString())) .kind(SCALAR) .inputTypeStrategy( compositeSequence() @@ -422,6 +436,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition EQUALS = BuiltInFunctionDefinition.newBuilder() .name("equals") + .callSyntax("=", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy(TWO_EQUALS_COMPARABLE) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -430,6 +445,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition GREATER_THAN = BuiltInFunctionDefinition.newBuilder() .name("greaterThan") + .callSyntax(">", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy(TWO_FULLY_COMPARABLE) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -438,6 +454,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition GREATER_THAN_OR_EQUAL = BuiltInFunctionDefinition.newBuilder() .name("greaterThanOrEqual") + .callSyntax(">=", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy(TWO_FULLY_COMPARABLE) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -446,6 +463,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition LESS_THAN = BuiltInFunctionDefinition.newBuilder() .name("lessThan") + .callSyntax("<", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy(TWO_FULLY_COMPARABLE) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -454,6 +472,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition LESS_THAN_OR_EQUAL = BuiltInFunctionDefinition.newBuilder() .name("lessThanOrEqual") + .callSyntax("<=", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy(TWO_FULLY_COMPARABLE) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -462,6 +481,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition NOT_EQUALS = BuiltInFunctionDefinition.newBuilder() .name("notEquals") + .callSyntax("<>", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy(TWO_EQUALS_COMPARABLE) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -470,6 +490,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_NULL = BuiltInFunctionDefinition.newBuilder() .name("isNull") + .callSyntax("IS NULL", SqlCallSyntax.UNARY_SUFFIX_OP) .kind(SCALAR) .inputTypeStrategy(wildcardWithCount(ConstantArgumentCount.of(1))) .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) @@ -478,6 +499,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_NOT_NULL = BuiltInFunctionDefinition.newBuilder() .name("isNotNull") + .callSyntax("IS NOT NULL", SqlCallSyntax.UNARY_SUFFIX_OP) .kind(SCALAR) .inputTypeStrategy(wildcardWithCount(ConstantArgumentCount.of(1))) .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) @@ -486,6 +508,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_TRUE = BuiltInFunctionDefinition.newBuilder() .name("isTrue") + .callSyntax("IS TRUE", SqlCallSyntax.UNARY_SUFFIX_OP) .kind(SCALAR) .inputTypeStrategy(sequence(logical(LogicalTypeRoot.BOOLEAN))) .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) @@ -494,6 +517,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_FALSE = BuiltInFunctionDefinition.newBuilder() .name("isFalse") + .callSyntax("IS FALSE", SqlCallSyntax.UNARY_SUFFIX_OP) .kind(SCALAR) .inputTypeStrategy(sequence(logical(LogicalTypeRoot.BOOLEAN))) .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) @@ -502,6 +526,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_NOT_TRUE = BuiltInFunctionDefinition.newBuilder() .name("isNotTrue") + .callSyntax("IS NOT TRUE", SqlCallSyntax.UNARY_SUFFIX_OP) .kind(SCALAR) .inputTypeStrategy(sequence(logical(LogicalTypeRoot.BOOLEAN))) .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) @@ -510,6 +535,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_NOT_FALSE = BuiltInFunctionDefinition.newBuilder() .name("isNotFalse") + .callSyntax("IS NOT FALSE", SqlCallSyntax.UNARY_SUFFIX_OP) .kind(SCALAR) .inputTypeStrategy(sequence(logical(LogicalTypeRoot.BOOLEAN))) .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) @@ -519,6 +545,13 @@ public final class BuiltInFunctionDefinitions { BuiltInFunctionDefinition.newBuilder() .name("between") .kind(SCALAR) + .callSyntax( + (sqlName, operands) -> + String.format( + "%s BETWEEN %s AND %s", + CallSyntaxUtils.asSerializableOperand(operands.get(0)), + CallSyntaxUtils.asSerializableOperand(operands.get(1)), + CallSyntaxUtils.asSerializableOperand(operands.get(2)))) .inputTypeStrategy( comparable(ConstantArgumentCount.of(3), StructuredComparison.FULL)) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) @@ -527,6 +560,13 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition NOT_BETWEEN = BuiltInFunctionDefinition.newBuilder() .name("notBetween") + .callSyntax( + (sqlName, operands) -> + String.format( + "%s NOT BETWEEN %s AND %s", + CallSyntaxUtils.asSerializableOperand(operands.get(0)), + CallSyntaxUtils.asSerializableOperand(operands.get(1)), + CallSyntaxUtils.asSerializableOperand(operands.get(2)))) .kind(SCALAR) .inputTypeStrategy( comparable(ConstantArgumentCount.of(3), StructuredComparison.FULL)) @@ -679,6 +719,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition DISTINCT = BuiltInFunctionDefinition.newBuilder() .name("distinct") + .callSyntax(SqlCallSyntax.DISTINCT) .kind(AGGREGATE) .inputTypeStrategy(sequence(ANY)) .outputTypeStrategy(argument(0)) @@ -707,6 +748,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition LIKE = BuiltInFunctionDefinition.newBuilder() .name("like") + .callSyntax("LIKE", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -736,6 +778,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition SIMILAR = BuiltInFunctionDefinition.newBuilder() .name("similar") + .callSyntax("SIMILAR TO", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -747,6 +790,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition SUBSTRING = BuiltInFunctionDefinition.newBuilder() .name("substring") + .callSyntax("SUBSTRING", SqlCallSyntax.SUBSTRING) .kind(SCALAR) .inputTypeStrategy( or( @@ -763,6 +807,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition SUBSTR = BuiltInFunctionDefinition.newBuilder() .name("substr") + .callSyntax("SUBSTR", SqlCallSyntax.SUBSTRING) .kind(SCALAR) .inputTypeStrategy( or( @@ -791,6 +836,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition TRIM = BuiltInFunctionDefinition.newBuilder() .name("trim") + .callSyntax(SqlCallSyntax.TRIM) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -822,6 +868,12 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition POSITION = BuiltInFunctionDefinition.newBuilder() .name("position") + .callSyntax( + (sqlName, operands) -> + String.format( + "POSITION(%s IN %s)", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString())) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -833,6 +885,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition OVERLAY = BuiltInFunctionDefinition.newBuilder() .name("overlay") + .callSyntax(SqlCallSyntax.OVERLAY) .kind(SCALAR) .inputTypeStrategy( or( @@ -901,6 +954,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition REGEXP_EXTRACT = BuiltInFunctionDefinition.newBuilder() .name("regexpExtract") + .sqlName("REGEXP_EXTRACT") .kind(SCALAR) .inputTypeStrategy( or( @@ -1150,6 +1204,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition PLUS = BuiltInFunctionDefinition.newBuilder() .name("plus") + .callSyntax("+", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( or( @@ -1220,6 +1275,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition MINUS = BuiltInFunctionDefinition.newBuilder() .name("minus") + .callSyntax("-", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( or( @@ -1248,6 +1304,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition AGG_DECIMAL_MINUS = BuiltInFunctionDefinition.newBuilder() .name("AGG_DECIMAL_MINUS") + .callSyntax("-", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -1260,6 +1317,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition DIVIDE = BuiltInFunctionDefinition.newBuilder() .name("divide") + .callSyntax("/", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( or( @@ -1280,6 +1338,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition TIMES = BuiltInFunctionDefinition.newBuilder() .name("times") + .callSyntax("*", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( or( @@ -1322,6 +1381,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition FLOOR = BuiltInFunctionDefinition.newBuilder() .name("floor") + .callSyntax("FLOOR", SqlCallSyntax.FLOOR_OR_CEIL) .kind(SCALAR) .inputTypeStrategy( or( @@ -1338,6 +1398,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition CEIL = BuiltInFunctionDefinition.newBuilder() .name("ceil") + .callSyntax("CEIL", SqlCallSyntax.FLOOR_OR_CEIL) .kind(SCALAR) .inputTypeStrategy( or( @@ -1402,6 +1463,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition MOD = BuiltInFunctionDefinition.newBuilder() .name("mod") + .callSyntax("%", SqlCallSyntax.BINARY_OP) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -1422,6 +1484,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition MINUS_PREFIX = BuiltInFunctionDefinition.newBuilder() .name("minusPrefix") + .callSyntax("-", SqlCallSyntax.UNARY_PREFIX_OP) .kind(SCALAR) .inputTypeStrategy( or( @@ -1649,6 +1712,17 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition EXTRACT = BuiltInFunctionDefinition.newBuilder() .name("extract") + .callSyntax( + "EXTRACT", + (sqlName, operands) -> + String.format( + "%s(%s %s %s)", + sqlName, + ((ValueLiteralExpression) operands.get(0)) + .getValueAs(TimeIntervalUnit.class) + .get(), + "FROM", + operands.get(1).asSerializableString())) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.EXTRACT) .outputTypeStrategy(nullableIfArgs(explicit(BIGINT()))) @@ -1707,6 +1781,14 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition TEMPORAL_OVERLAPS = BuiltInFunctionDefinition.newBuilder() .name("temporalOverlaps") + .callSyntax( + (sqlName, operands) -> + String.format( + "(%s, %s) OVERLAPS (%s, %s)", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + operands.get(2).asSerializableString(), + operands.get(3).asSerializableString())) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.TEMPORAL_OVERLAPS) .outputTypeStrategy(nullableIfArgs(explicit(BOOLEAN()))) @@ -1830,6 +1912,12 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition AT = BuiltInFunctionDefinition.newBuilder() .name("at") + .callSyntax( + (sqlName, operands) -> + String.format( + "%s[%s]", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString())) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -1855,6 +1943,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition ARRAY = BuiltInFunctionDefinition.newBuilder() .name("array") + .callSyntax("ARRAY", SqlCallSyntax.COLLECTION_CTOR) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.ARRAY) .outputTypeStrategy(SpecificTypeStrategies.ARRAY) @@ -1871,6 +1960,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition MAP = BuiltInFunctionDefinition.newBuilder() .name("map") + .callSyntax("MAP", SqlCallSyntax.COLLECTION_CTOR) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.MAP) .outputTypeStrategy(SpecificTypeStrategies.MAP) @@ -1904,6 +1994,25 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition GET = BuiltInFunctionDefinition.newBuilder() .name("get") + .callSyntax( + (sqlName, operands) -> { + final Optional<String> fieldName = + ((ValueLiteralExpression) operands.get(1)) + .getValueAs(String.class); + + return fieldName + .map( + n -> + String.format( + "%s.%s", + operands.get(0) + .asSerializableString(), + EncodingUtils.escapeIdentifier(n))) + .orElseGet( + () -> + SqlCallSyntax.FUNCTION.unparse( + sqlName, operands)); + }) .kind(OTHER) .inputTypeStrategy( sequence( @@ -2092,6 +2201,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition IS_JSON = BuiltInFunctionDefinition.newBuilder() .name("IS_JSON") + .callSyntax(JsonFunctionsCallSyntax.IS_JSON) .kind(SCALAR) .inputTypeStrategy( or( @@ -2099,13 +2209,14 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeFamily.CHARACTER_STRING), symbol(JsonType.class)))) - .outputTypeStrategy(explicit(DataTypes.BOOLEAN().notNull())) + .outputTypeStrategy(explicit(BOOLEAN().notNull())) .runtimeDeferred() .build(); public static final BuiltInFunctionDefinition JSON_EXISTS = BuiltInFunctionDefinition.newBuilder() .name("JSON_EXISTS") + .callSyntax("JSON_EXISTS", JsonFunctionsCallSyntax.JSON_EXISTS) .kind(SCALAR) .inputTypeStrategy( or( @@ -2120,13 +2231,14 @@ public final class BuiltInFunctionDefinitions { logical(LogicalTypeFamily.CHARACTER_STRING), LITERAL), symbol(JsonExistsOnError.class)))) - .outputTypeStrategy(explicit(DataTypes.BOOLEAN().nullable())) + .outputTypeStrategy(explicit(BOOLEAN().nullable())) .runtimeDeferred() .build(); public static final BuiltInFunctionDefinition JSON_VALUE = BuiltInFunctionDefinition.newBuilder() .name("JSON_VALUE") + .callSyntax("JSON_VALUE", JsonFunctionsCallSyntax.JSON_VALUE) .kind(SCALAR) .inputTypeStrategy( sequence( @@ -2145,6 +2257,7 @@ public final class BuiltInFunctionDefinitions { BuiltInFunctionDefinition.newBuilder() .name("JSON_QUERY") .kind(SCALAR) + .callSyntax(JsonFunctionsCallSyntax.JSON_QUERY) .inputTypeStrategy( sequence( logical(LogicalTypeFamily.CHARACTER_STRING), @@ -2168,6 +2281,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition JSON_OBJECT = BuiltInFunctionDefinition.newBuilder() .name("JSON_OBJECT") + .callSyntax(JsonFunctionsCallSyntax.JSON_OBJECT) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.JSON_OBJECT) .outputTypeStrategy(explicit(DataTypes.STRING().notNull())) @@ -2177,6 +2291,9 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition JSON_OBJECTAGG_NULL_ON_NULL = BuiltInFunctionDefinition.newBuilder() .name("JSON_OBJECTAGG_NULL_ON_NULL") + .callSyntax( + "JSON_OBJECTAGG", + JsonFunctionsCallSyntax.jsonObjectAgg(JsonOnNull.NULL)) .kind(AGGREGATE) .inputTypeStrategy( sequence(logical(LogicalTypeFamily.CHARACTER_STRING), JSON_ARGUMENT)) @@ -2187,6 +2304,9 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition JSON_OBJECTAGG_ABSENT_ON_NULL = BuiltInFunctionDefinition.newBuilder() .name("JSON_OBJECTAGG_ABSENT_ON_NULL") + .callSyntax( + "JSON_OBJECTAGG", + JsonFunctionsCallSyntax.jsonObjectAgg(JsonOnNull.ABSENT)) .kind(AGGREGATE) .inputTypeStrategy( sequence(logical(LogicalTypeFamily.CHARACTER_STRING), JSON_ARGUMENT)) @@ -2197,6 +2317,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition JSON_ARRAY = BuiltInFunctionDefinition.newBuilder() .name("JSON_ARRAY") + .callSyntax(JsonFunctionsCallSyntax.JSON_ARRAY) .kind(SCALAR) .inputTypeStrategy( InputTypeStrategies.varyingSequence( @@ -2209,6 +2330,8 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition JSON_ARRAYAGG_NULL_ON_NULL = BuiltInFunctionDefinition.newBuilder() .name("JSON_ARRAYAGG_NULL_ON_NULL") + .callSyntax( + "JSON_ARRAYAGG", JsonFunctionsCallSyntax.jsonArrayAgg(JsonOnNull.NULL)) .kind(AGGREGATE) .inputTypeStrategy(sequence(JSON_ARGUMENT)) .outputTypeStrategy(explicit(DataTypes.STRING().notNull())) @@ -2218,6 +2341,9 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition JSON_ARRAYAGG_ABSENT_ON_NULL = BuiltInFunctionDefinition.newBuilder() .name("JSON_ARRAYAGG_ABSENT_ON_NULL") + .callSyntax( + "JSON_ARRAYAGG", + JsonFunctionsCallSyntax.jsonArrayAgg(JsonOnNull.ABSENT)) .kind(AGGREGATE) .inputTypeStrategy(sequence(JSON_ARGUMENT)) .outputTypeStrategy(explicit(DataTypes.STRING().notNull())) @@ -2232,6 +2358,7 @@ public final class BuiltInFunctionDefinitions { BuiltInFunctionDefinition.newBuilder() .name("in") .kind(SCALAR) + .callSyntax("IN", SqlCallSyntax.IN) .inputTypeStrategy(SpecificInputTypeStrategies.IN) .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN()))) .build(); @@ -2239,6 +2366,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition CAST = BuiltInFunctionDefinition.newBuilder() .name("cast") + .callSyntax("CAST", SqlCallSyntax.CAST) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.CAST) .outputTypeStrategy( @@ -2248,6 +2376,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition TRY_CAST = BuiltInFunctionDefinition.newBuilder() .name("TRY_CAST") + .callSyntax("TRY_CAST", SqlCallSyntax.CAST) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.CAST) .outputTypeStrategy(forceNullable(TypeStrategies.argument(1))) @@ -2256,6 +2385,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition REINTERPRET_CAST = BuiltInFunctionDefinition.newBuilder() .name("reinterpretCast") + .callSyntax("REINTERPRET_CAST", SqlCallSyntax.CAST) .kind(SCALAR) .inputTypeStrategy(SpecificInputTypeStrategies.REINTERPRET_CAST) .outputTypeStrategy(TypeStrategies.argument(1)) @@ -2264,6 +2394,7 @@ public final class BuiltInFunctionDefinitions { public static final BuiltInFunctionDefinition AS = BuiltInFunctionDefinition.newBuilder() .name("as") + .callSyntax("AS", SqlCallSyntax.AS) .kind(OTHER) .inputTypeStrategy( varyingSequence( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/CallSyntaxUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/CallSyntaxUtils.java new file mode 100644 index 00000000000..d285db68be0 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/CallSyntaxUtils.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.expressions.TableSymbol; +import org.apache.flink.table.expressions.ValueLiteralExpression; + +/** Utility functions that can be used for writing {@link SqlCallSyntax}. */ +@Internal +class CallSyntaxUtils { + + /** + * Converts the given {@link ResolvedExpression} into a SQL string. Wraps the string with + * parenthesis if the expression is not a leaf expression such as e.g. {@link + * ValueLiteralExpression} or {@link FieldReferenceExpression}. + */ + static String asSerializableOperand(ResolvedExpression expression) { + if (expression.getResolvedChildren().isEmpty()) { + return expression.asSerializableString(); + } + + return String.format("(%s)", expression.asSerializableString()); + } + + static <T extends TableSymbol> T getSymbolLiteral(ResolvedExpression operands, Class<T> clazz) { + return ((ValueLiteralExpression) operands).getValueAs(clazz).get(); + } + + private CallSyntaxUtils() {} +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/JsonFunctionsCallSyntax.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/JsonFunctionsCallSyntax.java new file mode 100644 index 00000000000..f60b9cddcdd --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/JsonFunctionsCallSyntax.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.JsonExistsOnError; +import org.apache.flink.table.api.JsonOnNull; +import org.apache.flink.table.api.JsonQueryOnEmptyOrError; +import org.apache.flink.table.api.JsonQueryWrapper; +import org.apache.flink.table.api.JsonType; +import org.apache.flink.table.api.JsonValueOnEmptyOrError; +import org.apache.flink.table.expressions.ResolvedExpression; + +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.table.functions.CallSyntaxUtils.getSymbolLiteral; + +/** Implementations of {@link SqlCallSyntax} specific for JSON functions. */ +@Internal +class JsonFunctionsCallSyntax { + + static final SqlCallSyntax IS_JSON = + (sqlName, operands) -> { + final String s = + String.format( + "%s IS JSON", + CallSyntaxUtils.asSerializableOperand(operands.get(0))); + if (operands.size() > 1) { + return s + " " + getSymbolLiteral(operands.get(1), JsonType.class); + } + + return s; + }; + + static final SqlCallSyntax JSON_VALUE = + (sqlName, operands) -> { + StringBuilder s = + new StringBuilder( + String.format( + "JSON_VALUE(%s, %s RETURNING %s ", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + operands.get(2).asSerializableString())); + + final JsonValueOnEmptyOrError onEmpty = + getSymbolLiteral(operands.get(3), JsonValueOnEmptyOrError.class); + + if (onEmpty == JsonValueOnEmptyOrError.DEFAULT) { + s.append(String.format("DEFAULT %s", operands.get(4).asSerializableString())); + } else { + s.append(onEmpty); + } + s.append(" ON EMPTY "); + + final JsonValueOnEmptyOrError onError = + getSymbolLiteral(operands.get(5), JsonValueOnEmptyOrError.class); + + if (onError == JsonValueOnEmptyOrError.DEFAULT) { + s.append(String.format("DEFAULT %s", operands.get(6).asSerializableString())); + } else { + s.append(onError); + } + s.append(" ON ERROR)"); + + return s.toString(); + }; + + static final SqlCallSyntax JSON_EXISTS = + (sqlName, operands) -> { + if (operands.size() == 3) { + return String.format( + "%s(%s, %s %s ON ERROR)", + sqlName, + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + getSymbolLiteral(operands.get(2), JsonExistsOnError.class)); + } else { + return SqlCallSyntax.FUNCTION.unparse(sqlName, operands); + } + }; + + static final SqlCallSyntax JSON_QUERY = + (sqlName, operands) -> { + final JsonQueryWrapper wrapper = + getSymbolLiteral(operands.get(2), JsonQueryWrapper.class); + final JsonQueryOnEmptyOrError onEmpty = + getSymbolLiteral(operands.get(3), JsonQueryOnEmptyOrError.class); + final JsonQueryOnEmptyOrError onError = + getSymbolLiteral(operands.get(4), JsonQueryOnEmptyOrError.class); + + return String.format( + "JSON_QUERY(%s, %s %s WRAPPER %s ON EMPTY %s ON ERROR)", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + toString(wrapper), + onEmpty.toString().replaceAll("_", " "), + onError.toString().replaceAll("_", " ")); + }; + + static final SqlCallSyntax JSON_OBJECT = + (sqlName, operands) -> { + final String entries = + IntStream.range(0, operands.size() / 2) + .mapToObj( + i -> + String.format( + "KEY %s VALUE %s", + operands.get(2 * i + 1) + .asSerializableString(), + operands.get(2 * i + 2) + .asSerializableString())) + .collect(Collectors.joining(", ")); + + final JsonOnNull onNull = getSymbolLiteral(operands.get(0), JsonOnNull.class); + return String.format("JSON_OBJECT(%s %s ON NULL)", entries, onNull); + }; + + static final SqlCallSyntax JSON_ARRAY = + (sqlName, operands) -> { + if (operands.size() == 1) { + return "JSON_ARRAY()"; + } + final String entries = + operands.subList(1, operands.size()).stream() + .map(ResolvedExpression::asSerializableString) + .collect(Collectors.joining(", ")); + + final JsonOnNull onNull = getSymbolLiteral(operands.get(0), JsonOnNull.class); + return String.format("JSON_ARRAY(%s %s ON NULL)", entries, onNull); + }; + + static SqlCallSyntax jsonArrayAgg(JsonOnNull onNull) { + return (sqlName, operands) -> + String.format( + "%s(%s %s ON NULL)", + sqlName, operands.get(0).asSerializableString(), onNull); + } + + static SqlCallSyntax jsonObjectAgg(JsonOnNull onNull) { + return (sqlName, operands) -> + String.format( + "%s(KEY %s VALUE %s %s ON NULL)", + sqlName, + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + onNull); + } + + private static String toString(JsonQueryWrapper wrapper) { + final String wrapperStr; + switch (wrapper) { + case WITHOUT_ARRAY: + wrapperStr = "WITHOUT ARRAY"; + break; + case CONDITIONAL_ARRAY: + wrapperStr = "WITH CONDITIONAL ARRAY"; + break; + case UNCONDITIONAL_ARRAY: + wrapperStr = "WITH UNCONDITIONAL ARRAY"; + break; + default: + throw new IllegalStateException("Unexpected value: " + wrapper); + } + return wrapperStr; + } + + private JsonFunctionsCallSyntax() {} +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/SqlCallSyntax.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/SqlCallSyntax.java new file mode 100644 index 00000000000..785d0657658 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/SqlCallSyntax.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.expressions.CallExpression; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.expressions.TimeIntervalUnit; +import org.apache.flink.table.expressions.ValueLiteralExpression; +import org.apache.flink.table.utils.EncodingUtils; + +import java.util.List; +import java.util.stream.Collectors; + +/** Provides a format for unparsing {@link BuiltInFunctionDefinitions} into a SQL string. */ +@Internal +public interface SqlCallSyntax { + + String unparse(String sqlName, List<ResolvedExpression> operands); + + /** + * Special case for aggregate functions, which can have a DISTINCT function applied. Called only + * from the DISTINCT function. + */ + default String unparseDistinct(String sqlName, List<ResolvedExpression> operands) { + throw new UnsupportedOperationException( + "Only the FUNCTION syntax supports the DISTINCT clause."); + } + + /** Function syntax, as in "Foo(x, y)". */ + SqlCallSyntax FUNCTION = + new SqlCallSyntax() { + @Override + public String unparse(String sqlName, List<ResolvedExpression> operands) { + return doUnParse(sqlName, operands, false); + } + + @Override + public String unparseDistinct(String sqlName, List<ResolvedExpression> operands) { + return doUnParse(sqlName, operands, true); + } + + private String doUnParse( + String sqlName, List<ResolvedExpression> operands, boolean isDistinct) { + return String.format( + "%s(%s%s)", + sqlName, + isDistinct ? "DISTINCT " : "", + operands.stream() + .map(ResolvedExpression::asSerializableString) + .collect(Collectors.joining(", "))); + } + }; + + /** + * Function syntax for handling DISTINCT aggregates. Special case. It does not have a syntax + * itself, but modifies the syntax of the nested call. + */ + SqlCallSyntax DISTINCT = + (sqlName, operands) -> { + final CallExpression callExpression = (CallExpression) operands.get(0); + if (callExpression.getFunctionDefinition() instanceof BuiltInFunctionDefinition) { + final BuiltInFunctionDefinition builtinDefinition = + (BuiltInFunctionDefinition) callExpression.getFunctionDefinition(); + return builtinDefinition + .getCallSyntax() + .unparseDistinct( + builtinDefinition.getSqlName(), + callExpression.getResolvedChildren()); + } else { + return SqlCallSyntax.FUNCTION.unparseDistinct( + callExpression.getFunctionName(), callExpression.getResolvedChildren()); + } + }; + + /** Function syntax for collection ctors, such as ARRAY[1, 2, 3] or MAP['a', 1, 'b', 2]. */ + SqlCallSyntax COLLECTION_CTOR = + (sqlName, operands) -> + String.format( + "%s[%s]", + sqlName, + operands.stream() + .map(ResolvedExpression::asSerializableString) + .collect(Collectors.joining(", "))); + + /** Binary operator syntax, as in "x + y". */ + SqlCallSyntax BINARY_OP = + (sqlName, operands) -> + String.format( + "%s %s %s", + CallSyntaxUtils.asSerializableOperand(operands.get(0)), + sqlName, + CallSyntaxUtils.asSerializableOperand(operands.get(1))); + + /** + * Binary operator syntax that in Table API can accept multiple operands, as in "x AND y AND t + * AND w". + */ + SqlCallSyntax MULTIPLE_BINARY_OP = + (sqlName, operands) -> + operands.stream() + .map(CallSyntaxUtils::asSerializableOperand) + .collect(Collectors.joining(String.format(" %s ", sqlName))); + + /** Postfix unary operator syntax, as in "x ++". */ + SqlCallSyntax UNARY_SUFFIX_OP = + (sqlName, operands) -> + String.format( + "%s %s", + CallSyntaxUtils.asSerializableOperand(operands.get(0)), sqlName); + + /** Prefix unary operator syntax, as in "- x". */ + SqlCallSyntax UNARY_PREFIX_OP = + (sqlName, operands) -> + String.format( + "%s %s", + sqlName, CallSyntaxUtils.asSerializableOperand(operands.get(0))); + + /** + * Special sql syntax for CAST operators (CAST, TRY_CAST, REINTERPRET_CAST). + * + * <p>Example: CAST(123 AS STRING) + */ + SqlCallSyntax CAST = + (sqlName, operands) -> + String.format( + "%s(%s AS %s)", + sqlName, + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString()); + + /** + * Special sql syntax for SUBSTRING operators (SUBSTRING, SUBSTR). + * + * <p>Example: SUBSTR('abc' FROM 'abcdef' FOR 3) + */ + SqlCallSyntax SUBSTRING = + (sqlName, operands) -> { + final String s = + String.format( + "%s(%s FROM %s", + sqlName, + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString()); + if (operands.size() == 3) { + return s + String.format(" FOR %s)", operands.get(2).asSerializableString()); + } + + return s + ")"; + }; + + /** + * Special sql syntax for FLOOR and CEIL. + * + * <p>Examples: + * + * <ul> + * <li>FLOOR(TIME ‘12:44:31’ TO MINUTE) + * <li>FLOOR(123) + * </ul> + */ + SqlCallSyntax FLOOR_OR_CEIL = + (sqlName, operands) -> { + if (operands.size() == 1) { + // case for numeric floor & ceil + return SqlCallSyntax.FUNCTION.unparse(sqlName, operands); + } else { + // case for flooring/ceiling to temporal units + return String.format( + "%s(%s TO %s)", + sqlName, + operands.get(0).asSerializableString(), + ((ValueLiteralExpression) operands.get(1)) + .getValueAs(TimeIntervalUnit.class) + .get()); + } + }; + + /** + * Special sql syntax for TRIM. + * + * <p>Example: TRIM BOTH ' ' FROM 0; + */ + SqlCallSyntax TRIM = + (sqlName, operands) -> { + final boolean trimLeading = + ((ValueLiteralExpression) operands.get(0)).getValueAs(Boolean.class).get(); + final boolean trimTrailing = + ((ValueLiteralExpression) operands.get(1)).getValueAs(Boolean.class).get(); + final String format; + + // leading & trailing is translated to BOTH + if (trimLeading && trimTrailing) { + format = "TRIM BOTH %s FROM %s"; + } else if (trimLeading) { + format = "TRIM LEADING %s FROM %s"; + } else if (trimTrailing) { + format = "TRIM TRAILING %s FROM %s"; + } else { + format = "TRIM %s FROM %s"; + } + + return String.format( + format, + operands.get(2).asSerializableString(), + operands.get(3).asSerializableString()); + }; + + /** + * Special sql syntax for OVERLAY. + * + * <p>Example: OVERLAY('abcd' PLACING 'def' FROM 3 FOR 2) + */ + SqlCallSyntax OVERLAY = + (sqlName, operands) -> { + final String s = + String.format( + "OVERLAY(%s PLACING %s FROM %s", + operands.get(0).asSerializableString(), + operands.get(1).asSerializableString(), + operands.get(2).asSerializableString()); + + // optional length + if (operands.size() == 4) { + return s + String.format(" FOR %s)", operands.get(3).asSerializableString()); + } + + return s + ")"; + }; + + /** Special sql syntax for AS. The string literal is formatted as an identifier. */ + SqlCallSyntax AS = + (sqlName, operands) -> { + if (operands.size() != 2) { + throw new TableException( + "The AS function with multiple aliases is not SQL" + + " serializable. It should've been flattened during expression" + + " resolution."); + } + final String identifier = + ((ValueLiteralExpression) operands.get(1)).getValueAs(String.class).get(); + return String.format( + "%s %s %s", + CallSyntaxUtils.asSerializableOperand(operands.get(0)), + sqlName, + EncodingUtils.escapeIdentifier(identifier)); + }; + + /** Call syntax for {@link BuiltInFunctionDefinitions#IN}. */ + SqlCallSyntax IN = + (sqlName, operands) -> + String.format( + "%s IN (%s)", + operands.get(0).asSerializableString(), + operands.subList(1, operands.size()).stream() + .map(ResolvedExpression::asSerializableString) + .collect(Collectors.joining(", "))); +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java index e4e5d595e74..dd05dad760a 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java @@ -31,7 +31,11 @@ import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.connector.source.DynamicTableSource; import org.apache.flink.table.connector.source.SourceFunctionProvider; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.functions.BuiltInFunctionDefinition; +import org.apache.flink.table.operations.AggregateQueryOperation; +import org.apache.flink.table.operations.ProjectQueryOperation; import org.apache.flink.table.planner.factories.TableFactoryHarness; import org.apache.flink.table.types.DataType; import org.apache.flink.test.junit5.MiniClusterExtension; @@ -40,6 +44,7 @@ import org.apache.flink.types.RowKind; import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.Preconditions; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.function.Executable; @@ -51,10 +56,12 @@ import org.junit.jupiter.params.provider.MethodSource; import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.apache.flink.runtime.state.StateBackendLoader.HASHMAP_STATE_BACKEND_NAME; @@ -69,7 +76,7 @@ import static org.assertj.core.api.Assertions.assertThat; abstract class BuiltInAggregateFunctionTestBase { @RegisterExtension - private static final MiniClusterExtension MINI_CLUSTER_EXTENSION = new MiniClusterExtension(); + public static final MiniClusterExtension MINI_CLUSTER_EXTENSION = new MiniClusterExtension(); abstract Stream<TestSpec> getTestCaseSpecs(); @@ -145,6 +152,27 @@ abstract class BuiltInAggregateFunctionTestBase { // --------------------------------------------------------------------------------------------- + protected static final class TableApiAggSpec { + private final List<Expression> selectExpr; + private final List<Expression> groupByExpr; + + public TableApiAggSpec(List<Expression> selectExpr, List<Expression> groupByExpr) { + this.selectExpr = selectExpr; + this.groupByExpr = groupByExpr; + } + + public static TableApiAggSpec groupBySelect( + List<Expression> groupByExpr, Expression... selectExpr) { + return new TableApiAggSpec( + Arrays.stream(selectExpr).collect(Collectors.toList()), groupByExpr); + } + + public static TableApiAggSpec select(Expression... selectExpr) { + return new TableApiAggSpec( + Arrays.stream(selectExpr).collect(Collectors.toList()), null); + } + } + /** Test specification. */ protected static class TestSpec { @@ -182,16 +210,29 @@ abstract class BuiltInAggregateFunctionTestBase { } TestSpec testApiResult( - Function<Table, Table> tableApiSpec, + List<Expression> selectExpr, + List<Expression> groupByExpr, DataType expectedRowType, List<Row> expectedRows) { - this.testItems.add(new TableApiTestItem(tableApiSpec, expectedRowType, expectedRows)); + this.testItems.add( + new TableApiTestItem(selectExpr, groupByExpr, expectedRowType, expectedRows)); + return this; + } + + TestSpec testApiSqlResult( + List<Expression> selectExpr, + List<Expression> groupByExpr, + DataType expectedRowType, + List<Row> expectedRows) { + this.testItems.add( + new TableApiSqlResultTestItem( + selectExpr, groupByExpr, expectedRowType, expectedRows)); return this; } TestSpec testResult( Function<Table, String> sqlSpec, - Function<Table, Table> tableApiSpec, + TableApiAggSpec tableApiSpec, DataType expectedRowType, List<Row> expectedRows) { return testResult( @@ -200,12 +241,21 @@ abstract class BuiltInAggregateFunctionTestBase { TestSpec testResult( Function<Table, String> sqlSpec, - Function<Table, Table> tableApiSpec, + TableApiAggSpec tableApiSpec, DataType expectedSqlRowType, DataType expectedTableApiRowType, List<Row> expectedRows) { testSqlResult(sqlSpec, expectedSqlRowType, expectedRows); - testApiResult(tableApiSpec, expectedTableApiRowType, expectedRows); + testApiResult( + tableApiSpec.selectExpr, + tableApiSpec.groupByExpr, + expectedTableApiRowType, + expectedRows); + testApiSqlResult( + tableApiSpec.selectExpr, + tableApiSpec.groupByExpr, + expectedSqlRowType, + expectedRows); return this; } @@ -312,19 +362,134 @@ abstract class BuiltInAggregateFunctionTestBase { } private static class TableApiTestItem extends SuccessItem { - private final Function<Table, Table> spec; + private final List<Expression> selectExpr; + private final List<Expression> groupByExpr; public TableApiTestItem( - Function<Table, Table> spec, + List<Expression> selectExpr, + @Nullable List<Expression> groupByExpr, @Nullable DataType expectedRowType, @Nullable List<Row> expectedRows) { super(expectedRowType, expectedRows); - this.spec = spec; + this.selectExpr = selectExpr; + this.groupByExpr = groupByExpr; + } + + @Override + protected TableResult getResult(TableEnvironment tEnv, Table sourceTable) { + if (groupByExpr != null) { + return sourceTable + .groupBy(groupByExpr.toArray(new Expression[0])) + .select(selectExpr.toArray(new Expression[0])) + .execute(); + } else { + return sourceTable.select(selectExpr.toArray(new Expression[0])).execute(); + } + } + } + + private static class TableApiSqlResultTestItem extends SuccessItem { + + private final List<Expression> selectExpr; + private final List<Expression> groupByExpr; + + public TableApiSqlResultTestItem( + List<Expression> selectExpr, + @Nullable List<Expression> groupByExpr, + @Nullable DataType expectedRowType, + @Nullable List<Row> expectedRows) { + super(expectedRowType, expectedRows); + this.selectExpr = selectExpr; + this.groupByExpr = groupByExpr; } @Override protected TableResult getResult(TableEnvironment tEnv, Table sourceTable) { - return spec.apply(sourceTable).execute(); + final Table select; + if (groupByExpr != null) { + select = + sourceTable + .groupBy(groupByExpr.toArray(new Expression[0])) + .select(selectExpr.toArray(new Expression[0])); + + } else { + select = sourceTable.select(selectExpr.toArray(new Expression[0])); + } + final ProjectQueryOperation projectQueryOperation = + (ProjectQueryOperation) select.getQueryOperation(); + final AggregateQueryOperation aggQueryOperation = + (AggregateQueryOperation) select.getQueryOperation().getChildren().get(0); + + final List<ResolvedExpression> selectExpr = + recreateSelectList(aggQueryOperation, projectQueryOperation); + + final String selectAsSerializableString = toSerializableExpr(selectExpr); + final String groupByAsSerializableString = + toSerializableExpr(aggQueryOperation.getGroupingExpressions()); + + final StringBuilder stringBuilder = new StringBuilder(); + stringBuilder + .append("SELECT ") + .append(selectAsSerializableString) + .append(" FROM ") + .append(sourceTable); + if (!groupByAsSerializableString.isEmpty()) { + stringBuilder.append(" GROUP BY ").append(groupByAsSerializableString); + } + + return tEnv.sqlQuery(stringBuilder.toString()).execute(); + } + + @NotNull + private static List<ResolvedExpression> recreateSelectList( + AggregateQueryOperation aggQueryOperation, + ProjectQueryOperation projectQueryOperation) { + final List<String> projectSchemaFields = + projectQueryOperation.getResolvedSchema().getColumnNames(); + final List<String> aggSchemaFields = + aggQueryOperation.getResolvedSchema().getColumnNames(); + return IntStream.range(0, projectSchemaFields.size()) + .mapToObj( + idx -> { + final int indexInAgg = + aggSchemaFields.indexOf(projectSchemaFields.get(idx)); + if (indexInAgg >= 0) { + final int groupingExprCount = + aggQueryOperation.getGroupingExpressions().size(); + if (indexInAgg < groupingExprCount) { + return aggQueryOperation + .getGroupingExpressions() + .get(indexInAgg); + } else { + return aggQueryOperation + .getAggregateExpressions() + .get(indexInAgg - groupingExprCount); + } + } else { + return projectQueryOperation.getProjectList().get(idx); + } + }) + .collect(Collectors.toList()); + } + + private static String toSerializableExpr(List<ResolvedExpression> expressions) { + return expressions.stream() + .map(ResolvedExpression::asSerializableString) + .collect(Collectors.joining(", ")); + } + + @Override + public String toString() { + return String.format( + "[API as SQL] select: [%s] groupBy: [%s]", + selectExpr.stream() + .map(Expression::asSummaryString) + .collect(Collectors.joining(", ")), + groupByExpr != null + ? groupByExpr.stream() + .map(Expression::asSummaryString) + .collect(Collectors.joining(", ")) + : ""); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java index b7d0ae13294..566a0106083 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java @@ -28,8 +28,10 @@ import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.api.internal.TableEnvironmentInternal; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.UserDefinedFunction; +import org.apache.flink.table.operations.ProjectQueryOperation; import org.apache.flink.table.types.AbstractDataType; import org.apache.flink.table.types.DataType; import org.apache.flink.test.junit5.MiniClusterExtension; @@ -76,7 +78,7 @@ import static org.assertj.core.api.Assertions.catchThrowable; abstract class BuiltInFunctionTestBase { @RegisterExtension - private static final MiniClusterExtension MINI_CLUSTER_EXTENSION = new MiniClusterExtension(); + public static final MiniClusterExtension MINI_CLUSTER_EXTENSION = new MiniClusterExtension(); Configuration getConfiguration() { return new Configuration(); @@ -283,6 +285,7 @@ abstract class BuiltInFunctionTestBase { List<AbstractDataType<?>> tableApiDataType, List<AbstractDataType<?>> sqlDataType) { testItems.add(new TableApiResultTestItem(expression, result, tableApiDataType)); + testItems.add(new TableApiSqlResultTestItem(expression, result, tableApiDataType)); testItems.add( new SqlResultTestItem(String.join(",", sqlExpression), result, sqlDataType)); return this; @@ -457,6 +460,36 @@ abstract class BuiltInFunctionTestBase { } } + private static class TableApiSqlResultTestItem extends ResultTestItem<List<Expression>> { + + TableApiSqlResultTestItem( + List<Expression> expressions, + List<Object> results, + List<AbstractDataType<?>> dataTypes) { + super(expressions, results, dataTypes); + } + + @Override + Table query(TableEnvironment env, Table inputTable) { + final Table select = inputTable.select(expression.toArray(new Expression[] {})); + final ProjectQueryOperation projectQueryOperation = + (ProjectQueryOperation) select.getQueryOperation(); + final String exprAsSerializableString = + projectQueryOperation.getProjectList().stream() + .map(ResolvedExpression::asSerializableString) + .collect(Collectors.joining(", ")); + return env.sqlQuery("SELECT " + exprAsSerializableString + " FROM " + inputTable); + } + + @Override + public String toString() { + return "[API as SQL] " + + expression.stream() + .map(Expression::asSummaryString) + .collect(Collectors.joining(", ")); + } + } + private static class TableApiErrorTestItem extends ErrorTestItem<Expression> { TableApiErrorTestItem( diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/IfThenElseFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/IfThenElseFunctionITCase.java new file mode 100644 index 00000000000..a01aeac65e9 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/IfThenElseFunctionITCase.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; + +import java.util.stream.Stream; + +import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.table.api.Expressions.lit; + +/** IT tests for {@link BuiltInFunctionDefinitions#IF}. */ +class IfThenElseFunctionITCase extends BuiltInFunctionTestBase { + + @Override + Stream<TestSetSpec> getTestSetSpecs() { + return Stream.of( + TestSetSpec.forFunction(BuiltInFunctionDefinitions.IF) + .onFieldsWithData(2) + .andDataTypes(DataTypes.INT()) + .testResult( + Expressions.ifThenElse( + $("f0").isGreater(lit(0)), lit("GREATER"), lit("SMALLER")), + "CASE WHEN f0 > 0 THEN 'GREATER' ELSE 'SMALLER' END", + "GREATER", + DataTypes.CHAR(7).notNull()) + .testResult( + Expressions.ifThenElse( + $("f0").isGreater(lit(0)), + lit("GREATER"), + Expressions.ifThenElse( + $("f0").isEqual(0), lit("EQUAL"), lit("SMALLER"))), + "CASE WHEN f0 > 0 THEN 'GREATER' ELSE CASE WHEN f0 = 0 THEN 'EQUAL' ELSE 'SMALLER' END END", + "GREATER", + DataTypes.VARCHAR(7).notNull())); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonAggregationFunctionsITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonAggregationFunctionsITCase.java index cec9f237e68..34b00813d75 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonAggregationFunctionsITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonAggregationFunctionsITCase.java @@ -53,9 +53,8 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { Row.ofKind(INSERT, "C", 3))) .testResult( source -> "SELECT JSON_OBJECTAGG(f0 VALUE f1) FROM " + source, - source -> - source.select( - jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))), + TableApiAggSpec.select( + jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))), ROW(VARCHAR(2000).notNull()), ROW(STRING().notNull()), Collections.singletonList(Row.of("{\"A\":1,\"B\":null,\"C\":3}"))), @@ -71,9 +70,8 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { source -> "SELECT JSON_OBJECTAGG(f0 VALUE f1 ABSENT ON NULL) FROM " + source, - source -> - source.select( - jsonObjectAgg(JsonOnNull.ABSENT, $("f0"), $("f1"))), + TableApiAggSpec.select( + jsonObjectAgg(JsonOnNull.ABSENT, $("f0"), $("f1"))), ROW(VARCHAR(2000).notNull()), ROW(STRING().notNull()), Collections.singletonList(Row.of("{\"A\":1,\"C\":3}"))), @@ -88,9 +86,8 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { Row.ofKind(DELETE, "B", 2))) .testResult( source -> "SELECT JSON_OBJECTAGG(f0 VALUE f1) FROM " + source, - source -> - source.select( - jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))), + TableApiAggSpec.select( + jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))), ROW(VARCHAR(2000).notNull()), ROW(STRING().notNull()), Collections.singletonList(Row.of("{\"A\":1,\"C\":3}"))), @@ -108,12 +105,10 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { "SELECT f0, JSON_OBJECTAGG(f1 VALUE f2) FROM " + source + " GROUP BY f0", - source -> - source.groupBy($("f0")) - .select( - $("f0"), - jsonObjectAgg( - JsonOnNull.NULL, $("f1"), $("f2"))), + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + jsonObjectAgg(JsonOnNull.NULL, $("f1"), $("f2"))), ROW(INT(), VARCHAR(2000).notNull()), ROW(INT(), STRING().notNull()), Arrays.asList( @@ -131,10 +126,9 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { source -> "SELECT max(f1), JSON_OBJECTAGG(f0 VALUE f1) FROM " + source, - source -> - source.select( - $("f1").max(), - jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))), + TableApiAggSpec.select( + $("f1").max(), + jsonObjectAgg(JsonOnNull.NULL, $("f0"), $("f1"))), ROW(INT(), VARCHAR(2000).notNull()), ROW(INT(), STRING().notNull()), Collections.singletonList( @@ -153,13 +147,11 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { "SELECT f0, JSON_OBJECTAGG(f1 VALUE f2), max(f2) FROM " + source + " GROUP BY f0", - source -> - source.groupBy($("f0")) - .select( - $("f0"), - jsonObjectAgg( - JsonOnNull.NULL, $("f1"), $("f2")), - $("f2").max()), + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + jsonObjectAgg(JsonOnNull.NULL, $("f1"), $("f2")), + $("f2").max()), ROW(INT(), VARCHAR(2000).notNull(), INT()), ROW(INT(), STRING().notNull(), INT()), Arrays.asList( @@ -177,7 +169,7 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { Row.ofKind(INSERT, "C"))) .testResult( source -> "SELECT JSON_ARRAYAGG(f0) FROM " + source, - source -> source.select(jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))), + TableApiAggSpec.select(jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))), ROW(VARCHAR(2000).notNull()), ROW(STRING().notNull()), Collections.singletonList(Row.of("[\"A\",\"C\"]"))), @@ -191,7 +183,7 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { Row.ofKind(INSERT, "C"))) .testResult( source -> "SELECT JSON_ARRAYAGG(f0 NULL ON NULL) FROM " + source, - source -> source.select(jsonArrayAgg(JsonOnNull.NULL, $("f0"))), + TableApiAggSpec.select(jsonArrayAgg(JsonOnNull.NULL, $("f0"))), ROW(VARCHAR(2000).notNull()), ROW(STRING().notNull()), Collections.singletonList(Row.of("[\"A\",null,\"C\"]"))), @@ -206,7 +198,7 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { Row.ofKind(DELETE, 2))) .testResult( source -> "SELECT JSON_ARRAYAGG(f0) FROM " + source, - source -> source.select(jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))), + TableApiAggSpec.select(jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))), ROW(VARCHAR(2000).notNull()), ROW(STRING().notNull()), Collections.singletonList(Row.of("[1,3]"))), @@ -220,10 +212,8 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { Row.ofKind(INSERT, "C"))) .testResult( source -> "SELECT max(f0), JSON_ARRAYAGG(f0) FROM " + source, - source -> - source.select( - $("f0").max(), - jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))), + TableApiAggSpec.select( + $("f0").max(), jsonArrayAgg(JsonOnNull.ABSENT, $("f0"))), ROW(STRING(), VARCHAR(2000).notNull()), ROW(STRING(), STRING().notNull()), Collections.singletonList(Row.of("C", "[\"A\",\"C\"]"))), @@ -241,12 +231,11 @@ class JsonAggregationFunctionsITCase extends BuiltInAggregateFunctionTestBase { "SELECT f0, max(f1), JSON_ARRAYAGG(f1)FROM " + source + " GROUP BY f0", - source -> - source.groupBy($("f0")) - .select( - $("f0"), - $("f1").max(), - jsonArrayAgg(JsonOnNull.ABSENT, $("f1"))), + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + $("f1").max(), + jsonArrayAgg(JsonOnNull.ABSENT, $("f1"))), ROW(INT(), STRING(), VARCHAR(2000).notNull()), ROW(INT(), STRING(), STRING().notNull()), Arrays.asList(