This is an automated email from the ASF dual-hosted git repository. jincheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new a0012aa [FLINK-13087][table] Add group window Aggregate operator to Table API a0012aa is described below commit a0012aae89ec7a56133642b58b04e5f7b155c0f4 Author: hequn8128 <chenghe...@gmail.com> AuthorDate: Thu Jul 4 11:11:01 2019 +0800 [FLINK-13087][table] Add group window Aggregate operator to Table API This closes #8979 --- docs/dev/table/tableApi.md | 41 +++++- .../apache/flink/table/api/WindowGroupedTable.java | 39 +++++- .../apache/flink/table/api/internal/TableImpl.java | 109 +++++++++++++--- .../operations/utils/OperationTreeBuilder.java | 143 ++++++++++++++++++--- .../table/api/stream/table/AggregateTest.scala | 45 ++++++- .../stringexpr/AggregateStringExpressionTest.scala | 25 ++++ .../table/validation/AggregateValidationTest.scala | 21 ++- .../GroupWindowTableAggregateValidationTest.scala | 15 +++ .../validation/GroupWindowValidationTest.scala | 35 ++++- .../runtime/stream/table/GroupWindowITCase.scala | 32 +++++ 10 files changed, 461 insertions(+), 44 deletions(-) diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md index bd6bd38..744a82c 100644 --- a/docs/dev/table/tableApi.md +++ b/docs/dev/table/tableApi.md @@ -2643,6 +2643,26 @@ Table table = input <tr> <td> + <strong>Group Window Aggregate</strong><br> + <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> + </td> + <td> + <p>Groups and aggregates a table on a <a href="#group-windows">group window</a> and possibly one or more grouping keys. You have to close the "aggregate" with a select statement. And the select statement does not support "*" or aggregate functions.</p> +{% highlight java %} +AggregateFunction myAggFunc = new MyMinMax(); +tableEnv.registerFunction("myAggFunc", myAggFunc); + +Table table = input + .window(Tumble.over("5.minutes").on("rowtime").as("w")) // define window + .groupBy("key, w") // group by key and window + .aggregate("myAggFunc(a) as (x, y)") + .select("key, x, y, w.start, w.end"); // access window properties and aggregate results +{% endhighlight %} + </td> + </tr> + + <tr> + <td> <strong>FlatAggregate</strong><br> <span class="label label-primary">Streaming</span><br> <span class="label label-info">Result Updating</span> @@ -2837,7 +2857,7 @@ class MyMinMax extends AggregateFunction[Row, MyMinMaxAcc] { } } -val myAggFunc: AggregateFunction = new MyMinMax +val myAggFunc = new MyMinMax val table = input .groupBy('key) .aggregate(myAggFunc('a) as ('x, 'y)) @@ -2848,6 +2868,25 @@ val table = input <tr> <td> + <strong>Group Window Aggregate</strong><br> + <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> + </td> + <td> + <p>Groups and aggregates a table on a <a href="#group-windows">group window</a> and possibly one or more grouping keys. You have to close the "aggregate" with a select statement. And the select statement does not support "*" or aggregate functions.</p> +{% highlight scala %} +val myAggFunc = new MyMinMax +val table = input + .window(Tumble over 5.minutes on 'rowtime as 'w) // define window + .groupBy('key, 'w) // group by key and window + .aggregate(myAggFunc('a) as ('x, 'y)) + .select('key, 'x, 'y, 'w.start, 'w.end) // access window properties and aggregate results + +{% endhighlight %} + </td> + </tr> + + <tr> + <td> <strong>FlatAggregate</strong><br> <span class="label label-primary">Streaming</span><br> <span class="label label-info">Result Updating</span> diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java index 0e1cf84..7e5a3ac 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java @@ -56,6 +56,43 @@ public interface WindowGroupedTable { Table select(Expression... fields); /** + * Performs an aggregate operation on a window grouped table. You have to close the + * {@link #aggregate(String)} with a select statement. The output will be flattened if the + * output type is a composite type. + * + * <p>Example: + * + * <pre> + * {@code + * AggregateFunction aggFunc = new MyAggregateFunction(); + * tableEnv.registerFunction("aggFunc", aggFunc); + * windowGroupedTable + * .aggregate("aggFunc(a, b) as (x, y, z)") + * .select("key, window.start, x, y, z") + * } + * </pre> + */ + AggregatedTable aggregate(String aggregateFunction); + + /** + * Performs an aggregate operation on a window grouped table. You have to close the + * {@link #aggregate(Expression)} with a select statement. The output will be flattened if the + * output type is a composite type. + * + * <p>Scala Example: + * + * <pre> + * {@code + * val aggFunc = new MyAggregateFunction + * windowGroupedTable + * .aggregate(aggFunc('a, 'b) as ('x, 'y, 'z)) + * .select('key, 'window.start, 'x, 'y, 'z) + * } + * </pre> + */ + AggregatedTable aggregate(Expression aggregateFunction); + + /** * Performs a flatAggregate operation on a window grouped table. FlatAggregate takes a * TableAggregateFunction which returns multiple rows. Use a selection after flatAggregate. * @@ -63,7 +100,7 @@ public interface WindowGroupedTable { * * <pre> * {@code - * TableAggregateFunction tableAggFunc = new MyTableAggregateFunction + * TableAggregateFunction tableAggFunc = new MyTableAggregateFunction(); * tableEnv.registerFunction("tableAggFunc", tableAggFunc); * windowGroupedTable * .flatAggregate("tableAggFunc(a, b) as (x, y, z)") diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java index 8d33e10..ae08d24 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java @@ -36,6 +36,7 @@ import org.apache.flink.table.api.WindowGroupedTable; import org.apache.flink.table.catalog.FunctionLookup; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.ExpressionParser; +import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.expressions.resolver.LookupCallResolver; import org.apache.flink.table.functions.TemporalTableFunction; import org.apache.flink.table.functions.TemporalTableFunctionImpl; @@ -762,6 +763,16 @@ public class TableImpl implements Table { } @Override + public AggregatedTable aggregate(String aggregateFunction) { + return aggregate(ExpressionParser.parseExpression(aggregateFunction)); + } + + @Override + public AggregatedTable aggregate(Expression aggregateFunction) { + return new WindowAggregatedTableImpl(table, groupKeys, aggregateFunction, window); + } + + @Override public FlatAggregateTable flatAggregate(String tableAggregateFunction) { return flatAggregate(ExpressionParser.parseExpression(tableAggregateFunction)); } @@ -772,6 +783,62 @@ public class TableImpl implements Table { } } + private static final class WindowAggregatedTableImpl implements AggregatedTable { + private final TableImpl table; + private final List<Expression> groupKeys; + private final Expression aggregateFunction; + private final GroupWindow window; + + private WindowAggregatedTableImpl( + TableImpl table, + List<Expression> groupKeys, + Expression aggregateFunction, + GroupWindow window) { + this.table = table; + this.groupKeys = groupKeys; + this.aggregateFunction = aggregateFunction; + this.window = window; + } + + @Override + public Table select(String fields) { + return select(ExpressionParser.parseExpressionList(fields).toArray(new Expression[0])); + } + + @Override + public Table select(Expression... fields) { + List<Expression> expressionsWithResolvedCalls = Arrays.stream(fields) + .map(f -> f.accept(table.lookupResolver)) + .collect(Collectors.toList()); + CategorizedExpressions extracted = OperationExpressionsUtils.extractAggregationsAndProperties( + expressionsWithResolvedCalls + ); + + if (!extracted.getAggregations().isEmpty()) { + throw new ValidationException("Aggregate functions cannot be used in the select right " + + "after the aggregate."); + } + + if (extracted.getProjections().stream() + .anyMatch(p -> (p instanceof UnresolvedReferenceExpression) + && "*".equals(((UnresolvedReferenceExpression) p).getName()))) { + throw new ValidationException("Can not use * for window aggregate!"); + } + + return table.createTable( + table.operationTreeBuilder.project( + extracted.getProjections(), + table.operationTreeBuilder.windowAggregate( + groupKeys, + window, + extracted.getWindowProperties(), + aggregateFunction, + table.operationTree + ) + )); + } + } + private static final class WindowFlatAggregateTableImpl implements FlatAggregateTable { private final TableImpl table; @@ -804,24 +871,30 @@ public class TableImpl implements Table { expressionsWithResolvedCalls ); - if (!extracted.getAggregations().isEmpty()) { - throw new ValidationException("Aggregate functions cannot be used in the select right " + - "after the flatAggregate."); - } - - return table.createTable( - table.operationTreeBuilder.project( - extracted.getProjections(), - table.operationTreeBuilder.windowTableAggregate( - groupKeys, - window, - extracted.getWindowProperties(), - tableAggFunction, - table.operationTree - ), - // required for proper resolution of the time attribute in multi-windows - true - )); + if (!extracted.getAggregations().isEmpty()) { + throw new ValidationException("Aggregate functions cannot be used in the select right " + + "after the flatAggregate."); + } + + if (extracted.getProjections().stream() + .anyMatch(p -> (p instanceof UnresolvedReferenceExpression) + && "*".equals(((UnresolvedReferenceExpression) p).getName()))) { + throw new ValidationException("Can not use * for window aggregate!"); + } + + return table.createTable( + table.operationTreeBuilder.project( + extracted.getProjections(), + table.operationTreeBuilder.windowTableAggregate( + groupKeys, + window, + extracted.getWindowProperties(), + tableAggFunction, + table.operationTree + ), + // required for proper resolution of the time attribute in multi-windows + true + )); } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java index 03406de..7510cb9 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java @@ -58,6 +58,7 @@ import org.apache.flink.table.operations.utils.factories.SortOperationFactory; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; +import org.apache.flink.table.types.utils.TypeConversions; import org.apache.flink.table.typeutils.FieldInfoUtils; import org.apache.flink.util.Preconditions; @@ -253,6 +254,73 @@ public final class OperationTreeBuilder { child); } + public QueryOperation windowAggregate( + List<Expression> groupingExpressions, + GroupWindow window, + List<Expression> windowProperties, + Expression aggregateFunction, + QueryOperation child) { + + ExpressionResolver resolver = getResolver(child); + Expression resolvedAggregate = aggregateFunction.accept(lookupResolver); + AggregateWithAlias aggregateWithAlias = resolvedAggregate.accept(new ExtractAliasAndAggregate(true, resolver)); + + List<Expression> groupsAndAggregate = new ArrayList<>(groupingExpressions); + groupsAndAggregate.add(aggregateWithAlias.aggregate); + List<Expression> namedGroupsAndAggregate = addAliasToTheCallInAggregate( + Arrays.asList(child.getTableSchema().getFieldNames()), + groupsAndAggregate); + + // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to + // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the + // table aggregate function in Step6. + List<Expression> newGroupingExpressions = namedGroupsAndAggregate.subList(0, groupingExpressions.size()); + + // Step2: turn agg to a named agg, because it will be verified later. + Expression aggregateRenamed = namedGroupsAndAggregate.get(groupingExpressions.size()); + + // Step3: resolve expressions, including grouping, aggregates and window properties. + ResolvedGroupWindow resolvedWindow = aggregateOperationFactory.createResolvedWindow(window, resolver); + ExpressionResolver resolverWithWindowReferences = ExpressionResolver.resolverFor( + tableReferenceLookup, + functionCatalog, + child) + .withLocalReferences( + new LocalReferenceExpression( + resolvedWindow.getAlias(), + resolvedWindow.getTimeAttribute().getOutputDataType())) + .build(); + + List<ResolvedExpression> convertedGroupings = resolverWithWindowReferences.resolve(newGroupingExpressions); + List<ResolvedExpression> convertedAggregates = resolverWithWindowReferences.resolve(Collections.singletonList( + aggregateRenamed)); + List<ResolvedExpression> convertedProperties = resolverWithWindowReferences.resolve(windowProperties); + + // Step4: create window agg operation + QueryOperation aggregateOperation = aggregateOperationFactory.createWindowAggregate( + convertedGroupings, + Collections.singletonList(convertedAggregates.get(0)), + convertedProperties, + resolvedWindow, + child); + + // Step5: flatten the aggregate function + String[] aggNames = aggregateOperation.getTableSchema().getFieldNames(); + List<Expression> flattenedExpressions = Arrays.stream(aggNames) + .map(ApiExpressionUtils::unresolvedRef) + .collect(Collectors.toCollection(ArrayList::new)); + flattenedExpressions.set( + groupingExpressions.size(), + unresolvedCall( + BuiltInFunctionDefinitions.FLATTEN, + unresolvedRef(aggNames[groupingExpressions.size()]))); + QueryOperation flattenedProjection = this.project(flattenedExpressions, aggregateOperation); + + // Step6: add a top project to alias the output fields of the aggregate. Also, project the + // window attribute. + return aliasBackwardFields(flattenedProjection, aggregateWithAlias.aliases, groupingExpressions.size()); + } + public QueryOperation join( QueryOperation left, QueryOperation right, @@ -405,21 +473,30 @@ public final class OperationTreeBuilder { public QueryOperation aggregate(List<Expression> groupingExpressions, Expression aggregate, QueryOperation child) { Expression resolvedAggregate = aggregate.accept(lookupResolver); - AggregateWithAlias aggregateWithAlias = resolvedAggregate.accept(new ExtractAliasAndAggregate()); + AggregateWithAlias aggregateWithAlias = + resolvedAggregate.accept(new ExtractAliasAndAggregate(true, getResolver(child))); - // turn agg to a named agg, because it will be verified later. - String[] childNames = child.getTableSchema().getFieldNames(); - Expression aggregateRenamed = addAliasToTheCallInGroupings( - Arrays.asList(childNames), - Collections.singletonList(aggregateWithAlias.aggregate)).get(0); + List<Expression> groupsAndAggregate = new ArrayList<>(groupingExpressions); + groupsAndAggregate.add(aggregateWithAlias.aggregate); + List<Expression> namedGroupsAndAggregate = addAliasToTheCallInAggregate( + Arrays.asList(child.getTableSchema().getFieldNames()), + groupsAndAggregate); - // get agg table + // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to + // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the + // aggregate function in Step5. + List<Expression> newGroupingExpressions = namedGroupsAndAggregate.subList(0, groupingExpressions.size()); + + // Step2: turn agg to a named agg, because it will be verified later. + Expression aggregateRenamed = namedGroupsAndAggregate.get(groupingExpressions.size()); + + // Step3: get agg table QueryOperation aggregateOperation = this.aggregate( - groupingExpressions, + newGroupingExpressions, Collections.singletonList(aggregateRenamed), child); - // flatten the aggregate function + // Step4: flatten the aggregate function String[] aggNames = aggregateOperation.getTableSchema().getFieldNames(); List<Expression> flattenedExpressions = Arrays.asList(aggNames) .subList(0, groupingExpressions.size()) @@ -433,7 +510,7 @@ public final class OperationTreeBuilder { QueryOperation flattenedProjection = this.project(flattenedExpressions, aggregateOperation); - // add alias + // Step5: add alias return aliasBackwardFields(flattenedProjection, aggregateWithAlias.aliases, groupingExpressions.size()); } @@ -448,6 +525,16 @@ public final class OperationTreeBuilder { } private static class ExtractAliasAndAggregate extends ApiExpressionDefaultVisitor<AggregateWithAlias> { + + // need this flag to validate alias, i.e., the length of alias and function result type should be same. + private boolean isRowbasedAggregate = false; + private ExpressionResolver resolver = null; + + public ExtractAliasAndAggregate(boolean isRowbasedAggregate, ExpressionResolver resolver) { + this.isRowbasedAggregate = isRowbasedAggregate; + this.resolver = resolver; + } + @Override public AggregateWithAlias visit(UnresolvedCallExpression unresolvedCall) { if (ApiExpressionUtils.isFunction(unresolvedCall, BuiltInFunctionDefinitions.AS)) { @@ -489,6 +576,9 @@ public final class OperationTreeBuilder { fieldNames = Collections.emptyList(); } } else { + ResolvedExpression resolvedExpression = + resolver.resolve(Collections.singletonList(unresolvedCall)).get(0); + validateAlias(aliases, resolvedExpression, isRowbasedAggregate); fieldNames = aliases; } return Optional.of(new AggregateWithAlias(unresolvedCall, fieldNames)); @@ -501,6 +591,27 @@ public final class OperationTreeBuilder { protected AggregateWithAlias defaultMethod(Expression expression) { throw new ValidationException("Aggregate function expected. Got: " + expression); } + + private void validateAlias( + List<String> aliases, + ResolvedExpression resolvedExpression, + Boolean isRowbasedAggregate) { + + int length = TypeConversions + .fromDataTypeToLegacyInfo(resolvedExpression.getOutputDataType()).getArity(); + int callArity = isRowbasedAggregate ? length : 1; + int aliasesSize = aliases.size(); + + if ((0 < aliasesSize) && (aliasesSize != callArity)) { + throw new ValidationException(String.format( + "List of column aliases must have same degree as table; " + + "the returned table of function '%s' has " + + "%d columns, whereas alias list has %d columns", + resolvedExpression, + callArity, + aliasesSize)); + } + } } public QueryOperation tableAggregate( @@ -511,7 +622,7 @@ public final class OperationTreeBuilder { // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the // table aggregate function in Step4. - List<Expression> newGroupingExpressions = addAliasToTheCallInGroupings( + List<Expression> newGroupingExpressions = addAliasToTheCallInAggregate( Arrays.asList(child.getTableSchema().getFieldNames()), groupingExpressions); @@ -540,7 +651,7 @@ public final class OperationTreeBuilder { // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the // table aggregate function in Step4. - List<Expression> newGroupingExpressions = addAliasToTheCallInGroupings( + List<Expression> newGroupingExpressions = addAliasToTheCallInAggregate( Arrays.asList(child.getTableSchema().getFieldNames()), groupingExpressions); @@ -605,17 +716,17 @@ public final class OperationTreeBuilder { /** * Add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to - * groupBy(a % 5 as TMP_0). + * groupBy(a % 5 as TMP_0) or make aggregate a named aggregate. */ - private List<Expression> addAliasToTheCallInGroupings( + private List<Expression> addAliasToTheCallInAggregate( List<String> inputFieldNames, - List<Expression> groupingExpressions) { + List<Expression> expressions) { int attrNameCntr = 0; Set<String> usedFieldNames = new HashSet<>(inputFieldNames); List<Expression> result = new ArrayList<>(); - for (Expression groupingExpression : groupingExpressions) { + for (Expression groupingExpression : expressions) { if (groupingExpression instanceof UnresolvedCallExpression && !ApiExpressionUtils.isFunction(groupingExpression, BuiltInFunctionDefinitions.AS)) { String tempName = getUniqueName("TMP_" + attrNameCntr, usedFieldNames); diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala index 41ba5fc..2501e63 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala @@ -345,14 +345,14 @@ class AggregateTest extends TableTestBase { } @Test - def testSelectStar(): Unit = { + def testSelectStarAndGroupByCall(): Unit = { val util = streamTestUtil() val table = util.addTable[(Int, Long, String)]( "MyTable", 'a, 'b, 'c) val testAgg = new CountMinMax val resultTable = table - .groupBy('b) + .groupBy('b % 5) .aggregate(testAgg('a)) .select('*) @@ -364,12 +364,12 @@ class AggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", streamTableNode(table), - term("select", "a", "b") + term("select", "a", "MOD(b, 5) AS TMP_0") ), - term("groupBy", "b"), - term("select", "b", "CountMinMax(a) AS TMP_0") + term("groupBy", "TMP_0"), + term("select", "TMP_0", "CountMinMax(a) AS TMP_1") ), - term("select", "b", "TMP_0.f0 AS f0", "TMP_0.f1 AS f1", "TMP_0.f2 AS f2") + term("select", "TMP_0", "TMP_1.f0 AS f0", "TMP_1.f1 AS f1", "TMP_1.f2 AS f2") ) util.verifyTable(resultTable, expected) } @@ -428,4 +428,37 @@ class AggregateTest extends TableTestBase { ) util.verifyTable(resultTable, expected) } + + @Test + def testAggregateOnWindowedTable(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Int, Long, String)]( + "MyTable", 'a, 'b, 'c, 'rowtime.rowtime) + val testAgg = new CountMinMax + + val result = table + .window(Tumble over 15.minute on 'rowtime as 'w) + .groupBy('w, 'b % 3) + .aggregate(testAgg('a) as ('x, 'y, 'z)) + .select('w.start, 'x, 'y) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupWindowAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(table), + term("select", "a", "rowtime", "MOD(b, 3) AS TMP_0") + ), + term("groupBy", "TMP_0"), + term("window", "TumblingGroupWindow('w, 'rowtime, 900000.millis)"), + term("select", "TMP_0", "CountMinMax(a) AS TMP_1", "start('w) AS EXPR$0") + ), + term("select", "EXPR$0", "TMP_1.f0 AS x", "TMP_1.f1 AS y") + ) + + util.verifyTable(result, expected) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala index 5ee4471..7de35f3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala @@ -235,6 +235,31 @@ class AggregateStringExpressionTest extends TableTestBase { verifyTableEquals(resScala, resJava) } + + @Test + def testAggregateWithWindow(): Unit = { + val util = streamTestUtil() + val t = util.addTable[TestPojo]('int, 'long.rowtime as 'rowtime, 'string) + + val testAgg = new CountMinMax + util.tableEnv.registerFunction("testAgg", testAgg) + + // Expression / Scala API + val resScala = t + .window(Tumble over 50.milli on 'rowtime as 'w1) + .groupBy('w1, 'string) + .aggregate(testAgg('int) as ('x, 'y, 'z)) + .select('string, 'x, 'y, 'w1.start, 'w1.end) + + // String / Java API + val resJava = t + .window(Tumble.over("50.milli").on("rowtime").as("w1")) + .groupBy("w1, string") + .aggregate("testAgg(int) as (x, y, z)") + .select("string, x, y, w1.start, w1.end") + + verifyTableEquals(resJava, resScala) + } } class TestPojo() { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala index a8de009..4b25ea2 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala @@ -21,7 +21,7 @@ package org.apache.flink.table.api.stream.table.validation import org.apache.flink.api.scala._ import org.apache.flink.table.api.{ExpressionParserException, ValidationException} import org.apache.flink.table.api.scala._ -import org.apache.flink.table.utils.{TableFunc0, TableTestBase} +import org.apache.flink.table.utils.{CountMinMax, TableFunc0, TableTestBase} import org.junit.Test class AggregateValidationTest extends TableTestBase { @@ -123,4 +123,23 @@ class AggregateValidationTest extends TableTestBase { // must fail. Only one AggregateFunction can be used in aggregate .aggregate("sum(c), count(b)") } + + @Test + def testInvalidAlias(): Unit = { + expectedException.expect(classOf[ValidationException]) + expectedException.expectMessage("List of column aliases must have same degree as " + + "table; the returned table of function 'minMax(b)' has 3 columns, " + + "whereas alias list has 2 columns") + + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + val minMax = new CountMinMax + + util.tableEnv.registerFunction("minMax", minMax) + table + .groupBy('a) + // must fail. Invalid alias length + .aggregate("minMax(b) as (x, y)") + .select("x, y") + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala index 21369ca..714e1dd 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala @@ -254,4 +254,19 @@ class GroupWindowTableAggregateValidationTest extends TableTestBase { .flatAggregate(top3('int)) .select('string, 'f0.count) } + + @Test + def testInvalidStarInSelection(): Unit = { + expectedException.expect(classOf[ValidationException]) + expectedException.expectMessage("Can not use * for window aggregate!") + + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime) + + table + .window(Tumble over 2.rows on 'proctime as 'w) + .groupBy('string, 'w) + .flatAggregate(top3('int)) + .select('*) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala index 7b7ff87..2b03bc6 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.scala._ import org.apache.flink.table.api.{Session, Slide, Tumble, ValidationException} import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvgWithMerge import org.apache.flink.table.api.scala._ -import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.table.utils.{CountMinMax, TableTestBase} import org.junit.Test class GroupWindowValidationTest extends TableTestBase { @@ -290,4 +290,37 @@ class GroupWindowValidationTest extends TableTestBase { .groupBy('w, 'string) .select('string, 'w.start, 'w.end) // invalid start/end on rows-count window } + + @Test + def testInvalidAggregateInSelection(): Unit = { + expectedException.expect(classOf[ValidationException]) + expectedException.expectMessage("Aggregate functions cannot be used in the select " + + "right after the aggregate.") + + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime) + val testAgg = new CountMinMax + + table + .window(Tumble over 2.rows on 'proctime as 'w) + .groupBy('string, 'w) + .aggregate(testAgg('int)) + .select('string, 'f0.count) + } + + @Test + def testInvalidStarInSelection(): Unit = { + expectedException.expect(classOf[ValidationException]) + expectedException.expectMessage("Can not use * for window aggregate!") + + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime) + val testAgg = new CountMinMax + + table + .window(Tumble over 2.rows on 'proctime as 'w) + .groupBy('string, 'w) + .aggregate(testAgg('int)) + .select('*) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala index 9054bc8..130019d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala @@ -32,6 +32,7 @@ import org.apache.flink.table.functions.aggfunctions.CountAggFunction import org.apache.flink.table.runtime.stream.table.GroupWindowITCase._ import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, CountDistinctWithMerge, WeightedAvg, WeightedAvgWithMerge} import org.apache.flink.table.runtime.utils.StreamITCase +import org.apache.flink.table.utils.CountMinMax import org.apache.flink.test.util.AbstractTestBase import org.apache.flink.types.Row import org.junit.Assert._ @@ -442,6 +443,37 @@ class GroupWindowITCase extends AbstractTestBase { "null,1,1970-01-01 00:00:00.03,1970-01-01 00:00:00.033") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + + @Test + def testRowbasedAggregateWithEventTimeTumblingWindow(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = StreamTableEnvironment.create(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset[(Long, Int, String)](0L)) + val table = stream.toTable(tEnv, 'long, 'int, 'string, 'rowtime.rowtime) + val minMax = new CountMinMax + + val windowedTable = table + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('w, 'string) + .aggregate(minMax('int) as ('x, 'y, 'z)) + .select('string, 'x, 'y, 'z, 'w.start, 'w.end) + + val results = windowedTable.toAppendStream[Row] + results.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = Seq( + "Hello world,1,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", + "Hello world,1,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", + "Hello,2,2,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", + "Hi,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } } object GroupWindowITCase {