hequn8128 commented on a change in pull request #8230: [FLINK-10977][table] Add streaming non-window FlatAggregate to Table API URL: https://github.com/apache/flink/pull/8230#discussion_r279434307
########## File path: flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/AggregateOperationFactory.java ########## @@ -465,4 +535,98 @@ protected Void defaultMethod(Expression expression) { return null; } } + + /** + * Extract a table aggregate Expression and it's aliases. + */ + public Tuple2<Expression, List<String>> extractTableAggFunctionAndAliases(Expression callExpr) { + TableAggFunctionCallVisitor visitor = new TableAggFunctionCallVisitor(); + return Tuple2.of(callExpr.accept(visitor), visitor.getAlias()); + } + + private class TableAggFunctionCallVisitor extends ApiExpressionDefaultVisitor<Expression> { + + private List<String> alias = new LinkedList<>(); + + public List<String> getAlias() { + return alias; + } + + @Override + public Expression visitCall(CallExpression call) { + FunctionDefinition definition = call.getFunctionDefinition(); + if (definition.equals(AS)) { + return unwrapFromAlias(call); + } else if (definition instanceof AggregateFunctionDefinition) { + if (!isTableAggFunctionCall(call)) { + throw fail(); + } + return call; + } else { + return defaultMethod(call); + } + } + + private Expression unwrapFromAlias(CallExpression call) { + List<Expression> children = call.getChildren(); + List<String> aliases = children.subList(1, children.size()) + .stream() + .map(alias -> ExpressionUtils.extractValue(alias, Types.STRING) + .orElseThrow(() -> new ValidationException("Unexpected alias: " + alias))) + .collect(toList()); + + if (!isTableAggFunctionCall(children.get(0))) { + throw fail(); + } + + validateAlias(aliases, (AggregateFunctionDefinition) ((CallExpression) children.get(0)).getFunctionDefinition()); + alias = aliases; + return children.get(0); + } + + private void validateAlias( + List<String> aliases, + AggregateFunctionDefinition aggFunctionDefinition) { + + TypeInformation resultType = aggFunctionDefinition.getResultTypeInfo(); + + int callArity = resultType.getTotalFields(); + int aliasesSize = aliases.size(); + + if (aliasesSize > 0 && aliasesSize != callArity) { + throw new ValidationException(String.format( + "List of column aliases must have same degree as table; " + + "the returned table of function '%s' has " + + "%d columns, whereas alias list has %d columns", + aggFunctionDefinition.getName(), + callArity, + aliasesSize)); + } + } + + @Override + protected AggFunctionCall defaultMethod(Expression expression) { + throw fail(); + } + + private ValidationException fail() { + return new ValidationException( + "A flatAggregate only accepts an expression which defines a table aggregate " + + "function that might be followed by some alias."); + } + } + + /** + * Return true if the input {@link Expression} is a {@link CallExpression} of table aggregate function. + */ + public static boolean isTableAggFunctionCall(Expression expression) { + return Collections.singletonList(expression).stream() + .filter(p -> p instanceof CallExpression) + .map(p -> (CallExpression) p) + .filter(p -> p.getFunctionDefinition().getType() == AGGREGATE_FUNCTION) + .filter(p -> p.getFunctionDefinition() instanceof AggregateFunctionDefinition) + .map(p -> (AggregateFunctionDefinition) p.getFunctionDefinition()) + .filter(p -> p.getAggregateFunction() instanceof TableAggregateFunction) + .collect(Collectors.toList()).size() == 1; Review comment: Thank you for your suggestions. I will keep the original code. I also make it more simply with `Stream.of` and `anyMatch`. :-) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services