This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 606f297198a [FLINK-29719][hive] Supports native count function for hive dialect 606f297198a is described below commit 606f297198acd74a5c1a39700bd84ad9e26e7b82 Author: fengli <ldliu...@163.com> AuthorDate: Wed Jan 4 15:31:03 2023 +0800 [FLINK-29719][hive] Supports native count function for hive dialect This closes #21596 --- .../table/functions/hive/HiveCountAggFunction.java | 116 +++++++++++++++++++++ .../apache/flink/table/module/hive/HiveModule.java | 6 +- .../connectors/hive/HiveDialectAggITCase.java | 72 +++++++++++-- .../connectors/hive/HiveDialectQueryPlanTest.java | 24 ++++- .../explain/testCountAggFunctionFallbackPlan.out | 35 +++++++ .../resources/explain/testCountAggFunctionPlan.out | 27 +++++ 6 files changed, 268 insertions(+), 12 deletions(-) diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java new file mode 100644 index 00000000000..e15a0cbaf3d --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.hive; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.UnresolvedReferenceExpression; +import org.apache.flink.table.planner.expressions.ExpressionBuilder; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus; + +/** built-in hive count aggregate function. */ +public class HiveCountAggFunction extends HiveDeclarativeAggregateFunction { + + private final UnresolvedReferenceExpression count = unresolvedRef("count"); + private Integer arguments; + private boolean countLiteral; + + @Override + public int operandCount() { + return arguments; + } + + @Override + public UnresolvedReferenceExpression[] aggBufferAttributes() { + return new UnresolvedReferenceExpression[] {count}; + } + + @Override + public DataType[] getAggBufferTypes() { + return new DataType[] {DataTypes.BIGINT()}; + } + + @Override + public DataType getResultType() { + return DataTypes.BIGINT(); + } + + @Override + public Expression[] initialValuesExpressions() { + return new Expression[] {/* count = */ literal(0L, getResultType().notNull())}; + } + + @Override + public Expression[] accumulateExpressions() { + // count(*) and count(literal) mean that count all elements + if (arguments == 0 || countLiteral) { + return new Expression[] {/* count = */ plus(count, literal(1L))}; + } + + // other case need to determine the value of the element + List<Expression> operandExpressions = new ArrayList<>(); + for (int i = 0; i < arguments; i++) { + operandExpressions.add(operand(i)); + } + Expression operandExpression = + operandExpressions.stream() + .map(ExpressionBuilder::isNull) + .reduce(ExpressionBuilder::or) + .get(); + return new Expression[] { + /* count = */ ifThenElse(operandExpression, count, plus(count, literal(1L))) + }; + } + + @Override + public Expression[] retractExpressions() { + throw new TableException("Count aggregate function does not support retraction."); + } + + @Override + public Expression[] mergeExpressions() { + return new Expression[] {/* count = */ plus(count, mergeOperand(count))}; + } + + @Override + public Expression getValueExpression() { + return count; + } + + @Override + public void setArguments(CallContext callContext) { + if (arguments == null) { + arguments = callContext.getArgumentDataTypes().size(); + if (arguments == 1) { + // If the argument is literal indicates use count(literal) + countLiteral = callContext.isArgumentLiteral(0); + } + } + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java index 6ca6ca84dd9..bb598891fef 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java @@ -26,6 +26,7 @@ import org.apache.flink.table.catalog.hive.client.HiveShimLoader; import org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory; import org.apache.flink.table.factories.FunctionDefinitionFactory; import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.hive.HiveCountAggFunction; import org.apache.flink.table.functions.hive.HiveMinAggFunction; import org.apache.flink.table.functions.hive.HiveSumAggFunction; import org.apache.flink.table.module.Module; @@ -86,7 +87,7 @@ public class HiveModule implements Module { "tumble_start"))); static final Set<String> BUILTIN_NATIVE_AGG_FUNC = - Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "min"))); + Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "count", "min"))); private final HiveFunctionDefinitionFactory factory; private final String hiveVersion; @@ -206,6 +207,9 @@ public class HiveModule implements Module { case "sum": // We override Hive's sum function by native implementation to supports hash-agg return Optional.of(new HiveSumAggFunction()); + case "count": + // We override Hive's sum function by native implementation to supports hash-agg + return Optional.of(new HiveCountAggFunction()); case "min": // We override Hive's min function by native implementation to supports hash-agg return Optional.of(new HiveMinAggFunction()); diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java index bd124c7d554..3af77ad72b6 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java @@ -30,7 +30,6 @@ import org.apache.flink.util.CollectionUtil; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; -import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; @@ -58,6 +57,8 @@ public class HiveDialectAggITCase { hiveCatalog.getHiveConf().setVar(HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT, "none"); hiveCatalog.open(); tableEnv = getTableEnvWithHiveCatalog(); + // enable native hive agg function + tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, true); // create tables tableEnv.executeSql("create table foo (x int, y int)"); @@ -71,12 +72,6 @@ public class HiveDialectAggITCase { .commit(); } - @Before - public void before() { - // enable native hive agg function - tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, true); - } - @Test public void testSimpleSumAggFunction() throws Exception { tableEnv.executeSql( @@ -167,6 +162,69 @@ public class HiveDialectAggITCase { tableEnv.executeSql("drop table test_sum_group"); } + @Test + public void testSimpleCount() throws Exception { + tableEnv.executeSql("create table test_count(a int, x string, y string, z int, d bigint)"); + tableEnv.executeSql( + "insert into test_count values (1, NULL, '2', 1, 2), " + + "(1, NULL, 'b', 2, NULL), " + + "(2, NULL, '4', 1, 2), " + + "(2, NULL, NULL, 4, 3)") + .await(); + + // test count(*) + List<Row> result = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select count(*) from test_count").collect()); + assertThat(result.toString()).isEqualTo("[+I[4]]"); + + // test count(1) + List<Row> result2 = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select count(1) from test_count").collect()); + assertThat(result2.toString()).isEqualTo("[+I[4]]"); + + // test count(col1) + List<Row> result3 = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select count(y) from test_count").collect()); + assertThat(result3.toString()).isEqualTo("[+I[3]]"); + + // test count(distinct col1) + List<Row> result4 = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select count(distinct z) from test_count").collect()); + assertThat(result4.toString()).isEqualTo("[+I[3]]"); + + // test count(distinct col1, col2) + List<Row> result5 = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select count(distinct z, d) from test_count") + .collect()); + assertThat(result5.toString()).isEqualTo("[+I[2]]"); + + tableEnv.executeSql("drop table test_count"); + } + + @Test + public void testCountAggWithGroupKey() throws Exception { + tableEnv.executeSql( + "create table test_count_group(a int, x string, y string, z int, d bigint)"); + tableEnv.executeSql( + "insert into test_count_group values (1, NULL, '2', 1, 2), " + + "(1, NULL, '2', 2, NULL), " + + "(2, NULL, '4', 1, 2), " + + "(2, NULL, 3, 4, 3)") + .await(); + + List<Row> result = + CollectionUtil.iteratorToList( + tableEnv.executeSql( + "select count(*), count(x), count(distinct y), count(distinct z, d) from test_count_group group by a") + .collect()); + assertThat(result.toString()).isEqualTo("[+I[2, 0, 1, 1], +I[2, 0, 2, 2]]"); + } + @Test public void testMinAggFunction() throws Exception { tableEnv.executeSql( diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java index 69b8ca9179f..48cc8b913d1 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java @@ -70,25 +70,41 @@ public class HiveDialectQueryPlanTest { @Test public void testSumAggFunctionPlan() { // test explain - String actualPlan = explainSql("select x, sum(y) from foo group by x"); + String sql = "select x, sum(y) from foo group by x"; + String actualPlan = explainSql(sql); assertThat(actualPlan).isEqualTo(readFromResource("/explain/testSumAggFunctionPlan.out")); // test fallback to hive sum udaf tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false); - String actualSortAggPlan = explainSql("select x, sum(y) from foo group by x"); + String actualSortAggPlan = explainSql(sql); assertThat(actualSortAggPlan) .isEqualTo(readFromResource("/explain/testSumAggFunctionFallbackPlan.out")); } + @Test + public void testCountAggFunctionPlan() { + // test explain + String sql = "select x, count(*), count(y), count(distinct y) from foo group by x"; + String actualPlan = explainSql(sql); + assertThat(actualPlan).isEqualTo(readFromResource("/explain/testCountAggFunctionPlan.out")); + + // test fallback to hive count udaf + tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false); + String actualSortAggPlan = explainSql(sql); + assertThat(actualSortAggPlan) + .isEqualTo(readFromResource("/explain/testCountAggFunctionFallbackPlan.out")); + } + @Test public void testMinAggFunctionPlan() { // test explain - String actualPlan = explainSql("select x, min(y) from foo group by x"); + String sql = "select x, min(y) from foo group by x"; + String actualPlan = explainSql(sql); assertThat(actualPlan).isEqualTo(readFromResource("/explain/testMinAggFunctionPlan.out")); // test fallback to hive min udaf tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false); - String actualSortAggPlan = explainSql("select x, min(y) from foo group by x"); + String actualSortAggPlan = explainSql(sql); assertThat(actualSortAggPlan) .isEqualTo(readFromResource("/explain/testMinAggFunctionFallbackPlan.out")); } diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out new file mode 100644 index 00000000000..e356f402212 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out @@ -0,0 +1,35 @@ +== Abstract Syntax Tree == +LogicalProject(x=[$0], _o__c1=[$1], _o__c2=[$2], _o__c3=[$3]) ++- LogicalAggregate(group=[{0}], agg#0=[count()], agg#1=[count($1)], agg#2=[count(DISTINCT $1)]) + +- LogicalProject($f0=[$0], $f1=[$1]) + +- LogicalTableScan(table=[[test-catalog, default, foo]]) + +== Optimized Physical Plan == +SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count($f3) AS $f3]) ++- Sort(orderBy=[x ASC]) + +- Exchange(distribution=[hash[x]]) + +- LocalSortAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS $f3]) + +- Calc(select=[x, y, $f1, $f2, =(CASE(=($e, 0), 0, 1), 0) AS $g_0, =(CASE(=($e, 0), 0, 1), 1) AS $g_1]) + +- Sort(orderBy=[x ASC]) + +- SortAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count($f1) AS $f1, Final_count($f2) AS $f2]) + +- Sort(orderBy=[x ASC, y ASC, $e ASC]) + +- Exchange(distribution=[hash[x, y, $e]]) + +- LocalSortAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS $f1, Partial_count(y_0) AS $f2]) + +- Sort(orderBy=[x ASC, y ASC, $e ASC]) + +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}]) + +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y]) + +== Optimized Execution Plan == +SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count($f3) AS $f3]) ++- Sort(orderBy=[x ASC]) + +- Exchange(distribution=[hash[x]]) + +- LocalSortAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS $f3]) + +- Calc(select=[x, y, $f1, $f2, (CASE(($e = 0), 0, 1) = 0) AS $g_0, (CASE(($e = 0), 0, 1) = 1) AS $g_1]) + +- Sort(orderBy=[x ASC]) + +- SortAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count($f1) AS $f1, Final_count($f2) AS $f2]) + +- Sort(orderBy=[x ASC, y ASC, $e ASC]) + +- Exchange(distribution=[hash[x, y, $e]]) + +- LocalSortAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS $f1, Partial_count(y_0) AS $f2]) + +- Sort(orderBy=[x ASC, y ASC, $e ASC]) + +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}]) + +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y]) diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out new file mode 100644 index 00000000000..fc6e8b6d8cb --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out @@ -0,0 +1,27 @@ +== Abstract Syntax Tree == +LogicalProject(x=[$0], _o__c1=[$1], _o__c2=[$2], _o__c3=[$3]) ++- LogicalAggregate(group=[{0}], agg#0=[count()], agg#1=[count($1)], agg#2=[count(DISTINCT $1)]) + +- LogicalProject($f0=[$0], $f1=[$1]) + +- LogicalTableScan(table=[[test-catalog, default, foo]]) + +== Optimized Physical Plan == +HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count(count$2) AS $f3]) ++- Exchange(distribution=[hash[x]]) + +- LocalHashAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS count$2]) + +- Calc(select=[x, y, $f1, $f2, =(CASE(=($e, 0), 0, 1), 0) AS $g_0, =(CASE(=($e, 0), 0, 1), 1) AS $g_1]) + +- HashAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count(count$0) AS $f1, Final_count(count$1) AS $f2]) + +- Exchange(distribution=[hash[x, y, $e]]) + +- LocalHashAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS count$0, Partial_count(y_0) AS count$1]) + +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}]) + +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y]) + +== Optimized Execution Plan == +HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count(count$2) AS $f3]) ++- Exchange(distribution=[hash[x]]) + +- LocalHashAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS count$2]) + +- Calc(select=[x, y, $f1, $f2, (CASE(($e = 0), 0, 1) = 0) AS $g_0, (CASE(($e = 0), 0, 1) = 1) AS $g_1]) + +- HashAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count(count$0) AS $f1, Final_count(count$1) AS $f2]) + +- Exchange(distribution=[hash[x, y, $e]]) + +- LocalHashAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS count$0, Partial_count(y_0) AS count$1]) + +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}]) + +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])