[FLINK-5906] [table] Add support to register UDAGGs for Table API and SQL. This closes #3809.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/981dea41 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/981dea41 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/981dea41 Branch: refs/heads/master Commit: 981dea41e593f3db763af3d0366bf7adbdd1d3bf Parents: d6435e8 Author: shaoxuan-wang <wshaox...@gmail.com> Authored: Tue May 2 23:00:51 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Thu May 4 23:40:31 2017 +0200 ---------------------------------------------------------------------- .../flink/table/api/TableEnvironment.scala | 25 ++- .../table/api/java/BatchTableEnvironment.scala | 22 ++- .../table/api/java/StreamTableEnvironment.scala | 22 ++- .../table/api/scala/BatchTableEnvironment.scala | 18 +- .../api/scala/StreamTableEnvironment.scala | 18 +- .../flink/table/api/scala/expressionDsl.scala | 3 + .../org/apache/flink/table/api/table.scala | 21 +- .../flink/table/codegen/CodeGenerator.scala | 97 +++++++-- .../codegen/calls/ScalarFunctionCallGen.scala | 4 +- .../codegen/calls/TableFunctionCallGen.scala | 2 +- .../table/expressions/UDAGGExpression.scala | 36 ++++ .../flink/table/expressions/aggregations.scala | 75 ++++++- .../apache/flink/table/expressions/call.scala | 13 +- .../table/functions/AggregateFunction.scala | 14 +- .../aggfunctions/CountAggFunction.scala | 3 + .../table/functions/utils/AggSqlFunction.scala | 179 +++++++++++++++++ .../functions/utils/ScalarSqlFunction.scala | 33 +--- .../utils/UserDefinedFunctionUtils.scala | 198 ++++++++++++++----- .../flink/table/plan/ProjectionTranslator.scala | 57 ++++++ .../flink/table/plan/logical/operators.scala | 4 +- .../table/plan/nodes/CommonAggregate.scala | 2 +- .../flink/table/plan/nodes/OverAggregate.scala | 2 +- .../nodes/datastream/DataStreamAggregate.scala | 12 +- .../table/runtime/aggregate/AggregateUtil.scala | 48 +++-- .../flink/table/validate/FunctionCatalog.scala | 13 +- .../api/java/utils/UserDefinedAggFunctions.java | 95 +++++++++ .../scala/batch/sql/AggregationsITCase.scala | 43 ++-- .../scala/batch/sql/WindowAggregateTest.scala | 43 +++- .../scala/batch/table/AggregationsITCase.scala | 9 +- .../api/scala/batch/table/GroupWindowTest.scala | 155 +++++++++++++++ .../AggregationsStringExpressionTest.scala | 58 ++++++ .../validation/AggregationsValidationTest.scala | 99 ++++++++++ .../scala/stream/sql/WindowAggregateTest.scala | 74 ++++--- .../scala/stream/table/AggregationsITCase.scala | 43 ++-- .../scala/stream/table/GroupWindowTest.scala | 120 ++++++++++- .../scala/stream/table/OverWindowITCase.scala | 45 +++-- .../api/scala/stream/table/OverWindowTest.scala | 88 ++++++--- .../GroupWindowStringExpressionTest.scala | 65 ++++++ 38 files changed, 1617 insertions(+), 241 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala index 45267d2..06c405e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala @@ -50,8 +50,8 @@ import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkRelBuilder, FlinkT import org.apache.flink.table.catalog.{ExternalCatalog, ExternalCatalogSchema} import org.apache.flink.table.codegen.{CodeGenerator, ExpressionReducer} import org.apache.flink.table.expressions.{Alias, Expression, UnresolvedFieldReference} -import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createScalarSqlFunction, createTableSqlFunctions} -import org.apache.flink.table.functions.{ScalarFunction, TableFunction} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ +import org.apache.flink.table.functions.{ScalarFunction, TableFunction, AggregateFunction} import org.apache.flink.table.plan.cost.DataSetCostFactory import org.apache.flink.table.plan.logical.{CatalogNode, LogicalRelNode} import org.apache.flink.table.plan.rules.FlinkRuleSets @@ -352,6 +352,27 @@ abstract class TableEnvironment(val config: TableConfig) { } /** + * Registers an [[AggregateFunction]] under a unique name. Replaces already existing + * user-defined functions under this name. + */ + private[flink] def registerAggregateFunctionInternal[T: TypeInformation, ACC]( + name: String, function: AggregateFunction[T, ACC]): Unit = { + // check if class not Scala object + checkNotSingleton(function.getClass) + // check if class could be instantiated + checkForInstantiation(function.getClass) + + val typeInfo: TypeInformation[_] = implicitly[TypeInformation[T]] + + // register in Table API + functionCatalog.registerFunction(name, function.getClass) + + // register in SQL API + val sqlFunctions = createAggregateSqlFunction(name, function, typeInfo, typeFactory) + functionCatalog.registerSqlFunction(sqlFunctions) + } + + /** * Registers a [[Table]] under a unique name in the TableEnvironment's catalog. * Registered tables can be referenced in SQL queries. * http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala index de5f789..03fb77e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} import org.apache.flink.table.expressions.ExpressionParser import org.apache.flink.table.api._ -import org.apache.flink.table.functions.TableFunction +import org.apache.flink.table.functions.{AggregateFunction, TableFunction} /** * The [[TableEnvironment]] for a Java batch [[DataSet]] @@ -178,4 +178,24 @@ class BatchTableEnvironment( registerTableFunctionInternal[T](name, tf) } + + /** + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction[T, ACC]( + name: String, + f: AggregateFunction[T, ACC]) + : Unit = { + implicit val typeInfo: TypeInformation[T] = TypeExtractor + .createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0) + .asInstanceOf[TypeInformation[T]] + + registerAggregateFunctionInternal[T, ACC](name, f) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala index 4d9f1e1..a649584 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.api.java import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.table.api._ -import org.apache.flink.table.functions.TableFunction +import org.apache.flink.table.functions.{AggregateFunction, TableFunction} import org.apache.flink.table.expressions.ExpressionParser import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment @@ -180,4 +180,24 @@ class StreamTableEnvironment( registerTableFunctionInternal[T](name, tf) } + + /** + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction[T, ACC]( + name: String, + f: AggregateFunction[T, ACC]) + : Unit = { + implicit val typeInfo: TypeInformation[T] = TypeExtractor + .createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0) + .asInstanceOf[TypeInformation[T]] + + registerAggregateFunctionInternal[T, ACC](name, f) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala index 3ae8c31..0dd7ca0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala @@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.expressions.Expression -import org.apache.flink.table.functions.TableFunction +import org.apache.flink.table.functions.{AggregateFunction, TableFunction} import _root_.scala.reflect.ClassTag @@ -151,4 +151,20 @@ class BatchTableEnvironment( def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = { registerTableFunctionInternal(name, tf) } + + /** + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction[T: TypeInformation, ACC]( + name: String, + f: AggregateFunction[T, ACC]) + : Unit = { + registerAggregateFunctionInternal[T, ACC](name, f) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala index 0113146..0552d7c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.api.scala import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.{TableEnvironment, Table, TableConfig} -import org.apache.flink.table.functions.TableFunction +import org.apache.flink.table.functions.{AggregateFunction, TableFunction} import org.apache.flink.table.expressions.Expression import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} import org.apache.flink.streaming.api.scala.asScalaStream @@ -152,4 +152,20 @@ class StreamTableEnvironment( def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = { registerTableFunctionInternal(name, tf) } + + /** + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction[T: TypeInformation, ACC]( + name: String, + f: AggregateFunction[T, ACC]) + : Unit = { + registerAggregateFunctionInternal[T, ACC](name, f) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index b3de4a4..cc58ff5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -26,6 +26,7 @@ import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, Unb import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval} import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.table.expressions._ +import org.apache.flink.table.functions.AggregateFunction import scala.language.implicitConversions @@ -773,6 +774,8 @@ trait ImplicitExpressionConversions { implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp) implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array) + implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC] + (udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg) } // ------------------------------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 9606979..87dde0a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -151,7 +151,9 @@ class Table( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) - select(fieldExprs: _*) + //get the correct expression for AggFunctionCall + val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, tableEnv)) + select(withResolvedAggFunctionCall: _*) } /** @@ -167,7 +169,7 @@ class Table( def as(fields: Expression*): Table = { logicalPlan match { - case functionCall: LogicalTableFunctionCall if functionCall.child == null => { + case functionCall: LogicalTableFunctionCall if functionCall.child == null => // If the logical plan is a TableFunctionCall, we replace its field names to avoid special // cases during the validation. if (fields.length != functionCall.output.length) { @@ -181,7 +183,7 @@ class Table( } new Table( tableEnv, - new LogicalTableFunctionCall( + LogicalTableFunctionCall( functionCall.functionName, functionCall.tableFunction, functionCall.parameters, @@ -189,7 +191,6 @@ class Table( fields.map(_.asInstanceOf[UnresolvedFieldReference].name).toArray, functionCall.child) ) - } case _ => // prepend an AliasNode new Table(tableEnv, AliasNode(fields, logicalPlan).validate(tableEnv)) @@ -908,7 +909,9 @@ class GroupedTable( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) - select(fieldExprs: _*) + //get the correct expression for AggFunctionCall + val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv)) + select(withResolvedAggFunctionCall: _*) } } @@ -983,7 +986,9 @@ class OverWindowedTable( def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) - select(fieldExprs: _*) + //get the correct expression for AggFunctionCall + val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv)) + select(withResolvedAggFunctionCall: _*) } } @@ -1043,7 +1048,9 @@ class WindowGroupedTable( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) - select(fieldExprs: _*) + //get the correct expression for AggFunctionCall + val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv)) + select(withResolvedAggFunctionCall: _*) } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 648efe6..5bb3b0e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -18,6 +18,8 @@ package org.apache.flink.table.codegen +import java.lang.reflect.ParameterizedType +import java.lang.{Iterable => JIterable} import java.math.{BigDecimal => JBigDecimal} import org.apache.calcite.avatica.util.DateTimeUtils @@ -45,6 +47,7 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.runtime.TableFunctionCollector import org.apache.flink.table.typeutils.TypeCheckUtils._ import org.apache.flink.types.Row +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -258,6 +261,9 @@ class CodeGenerator( * @param constantFlags An optional parameter to define where to set constant boolean flags in * the output row. * @param outputArity The number of fields in the output row. + * @param needRetract a flag to indicate if the aggregate needs the retract method + * @param needMerge a flag to indicate if the aggregate needs the merge method + * @param needReset a flag to indicate if the aggregate needs the resetAccumulator method * * @return A GeneratedAggregationsFunction */ @@ -274,7 +280,8 @@ class CodeGenerator( constantFlags: Option[Array[(Int, Boolean)]], outputArity: Int, needRetract: Boolean, - needMerge: Boolean) + needMerge: Boolean, + needReset: Boolean) : GeneratedAggregationsFunction = { // get unique function name @@ -282,19 +289,80 @@ class CodeGenerator( // register UDAGGs val aggs = aggregates.map(a => generator.addReusableFunction(a)) // get java types of accumulators - val accTypes = aggregates.map { a => - a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName + val accTypeClasses = aggregates.map { a => + a.getClass.getMethod("createAccumulator").getReturnType } + val accTypes = accTypeClasses.map(_.getCanonicalName) - // get java types of input fields - val javaTypes = inputType.getFieldList - .map(f => FlinkTypeFactory.toTypeInfo(f.getType)) - .map(t => t.getTypeClass.getCanonicalName) + // get java classes of input fields + val javaClasses = inputType.getFieldList + .map(f => FlinkTypeFactory.toTypeInfo(f.getType).getTypeClass) // get parameter lists for aggregation functions - val parameters = aggFields.map {inFields => - val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)" + val parameters = aggFields.map { inFields => + val fields = for (f <- inFields) yield + s"(${javaClasses(f).getCanonicalName}) input.getField($f)" fields.mkString(", ") } + val methodSignaturesList = aggFields.map { + inFields => for (f <- inFields) yield javaClasses(f) + } + + // check and validate the needed methods + aggregates.zipWithIndex.map { + case (a, i) => { + getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) + .getOrElse( + throw new CodeGenException( + s"No matching accumulate method found for AggregateFunction " + + s"'${a.getClass.getCanonicalName}'" + + s"with parameters '${signatureToString(methodSignaturesList(i))}'.") + ) + + if (needRetract) { + getUserDefinedMethod(a, "retract", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) + .getOrElse( + throw new CodeGenException( + s"No matching retract method found for AggregateFunction " + + s"'${a.getClass.getCanonicalName}'" + + s"with parameters '${signatureToString(methodSignaturesList(i))}'.") + ) + } + + if (needMerge) { + val methods = + getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]])) + .getOrElse( + throw new CodeGenException( + s"No matching merge method found for AggregateFunction " + + s"${a.getClass.getCanonicalName}'.") + ) + + var iterableTypeClass = methods.getGenericParameterTypes.apply(1) + .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0) + // further extract iterableTypeClass if the accumulator has generic type + iterableTypeClass match { + case impl: ParameterizedType => iterableTypeClass = impl.getRawType + case _ => + } + + if (iterableTypeClass != accTypeClasses(i)) { + throw new CodeGenException( + s"merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " + + s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " + + s"Expected: ${accTypeClasses(i).toString}") + } + } + + if (needReset) { + getUserDefinedMethod(a, "resetAccumulator", Array(accTypeClasses(i))) + .getOrElse( + throw new CodeGenException( + s"No matching resetAccumulator method found for " + + s"aggregate ${a.getClass.getCanonicalName}'.") + ) + } + } + } def genSetAggregationResults: String = { @@ -529,9 +597,14 @@ class CodeGenerator( | ((${accTypes(i)}) accs.getField($i)));""".stripMargin }.mkString("\n") - j"""$sig { - |$reset - | }""".stripMargin + if (needReset) { + j"""$sig { + |$reset + | }""".stripMargin + } else { + j"""$sig { + | }""".stripMargin + } } var funcCode = http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala index b0b4e09..07a8708 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala @@ -44,10 +44,10 @@ class ScalarFunctionCallGen( operands: Seq[GeneratedExpression]) : GeneratedExpression = { // determine function method and result class - val matchingMethod = getEvalMethod(scalarFunction, signature) + val matchingMethod = getUserDefinedMethod(scalarFunction, "eval", typeInfoToClass(signature)) .getOrElse(throw new CodeGenException("No matching signature found.")) val matchingSignature = matchingMethod.getParameterTypes - val resultClass = getResultTypeClass(scalarFunction, matchingSignature) + val resultClass = getResultTypeClassOfScalarFunction(scalarFunction, matchingSignature) // zip for variable signatures var paramToOperands = matchingSignature.zip(operands) http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala index ba90292..a3609c1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala @@ -45,7 +45,7 @@ class TableFunctionCallGen( operands: Seq[GeneratedExpression]) : GeneratedExpression = { // determine function method - val matchingMethod = getEvalMethod(tableFunction, signature) + val matchingMethod = getUserDefinedMethod(tableFunction, "eval", typeInfoToClass(signature)) .getOrElse(throw new CodeGenException("No matching signature found.")) val matchingSignature = matchingMethod.getParameterTypes http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala new file mode 100644 index 0000000..c0e213d --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.functions.AggregateFunction + +/** + * A class which creates a call to an aggregateFunction + */ +case class UDAGGExpression[T: TypeInformation, ACC](aggregateFunction: AggregateFunction[T, ACC]) { + + /** + * Creates a call to an [[AggregateFunction]]. + * + * @param params actual parameters of function + * @return a [[AggFunctionCall]] + */ + def apply(params: Expression*): AggFunctionCall = + AggFunctionCall(aggregateFunction, params) +} http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala index 72e7e4b..7b180ae 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala @@ -23,13 +23,18 @@ import org.apache.calcite.sql.SqlKind._ import org.apache.calcite.sql.fun._ import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder.AggCall +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction +import org.apache.flink.table.typeutils.TypeCheckUtils import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.typeutils.TypeCheckUtils +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ +import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} -abstract sealed class Aggregation extends UnaryExpression { +abstract sealed class Aggregation extends Expression { - override def toString = s"Aggregate($child)" + override def toString = s"Aggregate" override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = throw new UnsupportedOperationException("Aggregate cannot be transformed to RexNode") @@ -47,6 +52,7 @@ abstract sealed class Aggregation extends UnaryExpression { } case class Sum(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"sum($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -67,6 +73,7 @@ case class Sum(child: Expression) extends Aggregation { } case class Sum0(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"sum0($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -83,6 +90,7 @@ case class Sum0(child: Expression) extends Aggregation { } case class Min(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"min($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -100,6 +108,7 @@ case class Min(child: Expression) extends Aggregation { } case class Max(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"max($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -117,6 +126,7 @@ case class Max(child: Expression) extends Aggregation { } case class Count(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"count($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -131,6 +141,7 @@ case class Count(child: Expression) extends Aggregation { } case class Avg(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"avg($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -148,6 +159,7 @@ case class Avg(child: Expression) extends Aggregation { } case class StddevPop(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"stddev_pop($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -164,6 +176,7 @@ case class StddevPop(child: Expression) extends Aggregation { } case class StddevSamp(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"stddev_samp($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -180,6 +193,7 @@ case class StddevSamp(child: Expression) extends Aggregation { } case class VarPop(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"var_pop($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -196,6 +210,7 @@ case class VarPop(child: Expression) extends Aggregation { } case class VarSamp(child: Expression) extends Aggregation { + override private[flink] def children: Seq[Expression] = Seq(child) override def toString = s"var_samp($child)" override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { @@ -210,3 +225,57 @@ case class VarSamp(child: Expression) extends Aggregation { override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = new SqlAvgAggFunction(VAR_SAMP) } + +case class AggFunctionCall( + aggregateFunction: AggregateFunction[_, _], + args: Seq[Expression]) + extends Aggregation { + + override private[flink] def children: Seq[Expression] = args + + override def resultType: TypeInformation[_] = getResultTypeOfAggregateFunction(aggregateFunction) + + override def validateInput(): ValidationResult = { + val signature = children.map(_.resultType) + // look for a signature that matches the input types + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature) + if (foundSignature.isEmpty) { + ValidationFailure(s"Given parameters do not match any signature. \n" + + s"Actual: ${signatureToString(signature)} \n" + + s"Expected: ${ + getMethodSignatures(aggregateFunction, "accumulate").drop(1) + .map(signatureToString).mkString(", ")}") + } else { + ValidationSuccess + } + } + + override def toString(): String = s"${aggregateFunction.getClass.getSimpleName}($args)" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val sqlFunction = AggSqlFunction(aggregateFunction.getClass.getSimpleName, + aggregateFunction, + resultType, + typeFactory) + relBuilder.aggregateCall(sqlFunction, false, null, name, args.map(_.toRexNode): _*) + } + + override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + AggSqlFunction(aggregateFunction.getClass.getSimpleName, + aggregateFunction, + resultType, + typeFactory) + } + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + relBuilder.call( + AggSqlFunction(aggregateFunction.getClass.getSimpleName, + aggregateFunction, + resultType, + typeFactory), + args.map(_.toRexNode): _*) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala index 68ed688..5f7204a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala @@ -106,8 +106,8 @@ case class OverCall( .getTypeFactory.asInstanceOf[FlinkTypeFactory] .createTypeFromTypeInfo(agg.resultType) - val aggChildName = agg.asInstanceOf[Aggregation].child.asInstanceOf[ResolvedFieldReference].name - val aggExprs = List(relBuilder.field(aggChildName).asInstanceOf[RexNode]).asJava + // assemble exprs by agg children + val aggExprs = agg.asInstanceOf[Aggregation].children.map(_.toRexNode(relBuilder)).asJava // assemble order by key val orderKey = orderBy match { @@ -281,16 +281,19 @@ case class ScalarFunctionCall( override def toString = s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})" - override private[flink] def resultType = getResultType(scalarFunction, foundSignature.get) + override private[flink] def resultType = + getResultTypeOfScalarFunction( + scalarFunction, + foundSignature.get) override private[flink] def validateInput(): ValidationResult = { val signature = children.map(_.resultType) // look for a signature that matches the input types - foundSignature = getSignature(scalarFunction, signature) + foundSignature = getEvalMethodSignature(scalarFunction, signature) if (foundSignature.isEmpty) { ValidationFailure(s"Given parameters do not match any signature. \n" + s"Actual: ${signatureToString(signature)} \n" + - s"Expected: ${signaturesToString(scalarFunction)}") + s"Expected: ${signaturesToString(scalarFunction, "eval")}") } else { ValidationSuccess } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index 7a74112..9c79439 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -89,7 +89,7 @@ package org.apache.flink.table.functions * * {{{ * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the accumulator. This - * function is optional and can be implemented if the accumulator type cannot automatically + * function is optional and can be implemented if the accumulator type cannot be automatically * inferred from the instance returned by createAccumulator method. * * @return the type information for the accumulator. @@ -98,6 +98,18 @@ package org.apache.flink.table.functions * }}} * * + * {{{ + * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the return value. This + * function is optional and needed in case Flink's type extraction facilities are not sufficient + * to extract the TypeInformation. Flink's type extraction facilities can handle basic types or + * simple POJOs but might be wrong for more complex, custom, or composite types. + * + * @return the type information for the return value. + * + * def getResultType: TypeInformation[_] + * }}} + * + * * @tparam T the type of the aggregation result * @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the aggregated * values which are needed to compute an aggregation result. AggregateFunction http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala index 77341cd..2b8ec14 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala @@ -68,4 +68,7 @@ class CountAggFunction extends AggregateFunction[Long, CountAccumulator] { def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO) } + + def getResultType(): TypeInformation[_] = + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[_]] } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala new file mode 100644 index 0000000..c3f6c4c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction[_, _], + returnType: TypeInformation[_], + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction[_, _], + returnType: TypeInformation[_], + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private[flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction[_, _], + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [[AggregateFunction]] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array[RelDataType]): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse( + throw new ValidationException( + s"Operand types of ${signatureToString(operandTypeInfo)} could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } + } + } + } + } + + private[flink] def createReturnTypeInference( + resultType: TypeInformation[_], + typeFactory: FlinkTypeFactory) + : SqlReturnTypeInference = { + + new SqlReturnTypeInference { + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = { + typeFactory.createTypeFromTypeInfo(resultType) + } + } + } + + private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _]) + : SqlOperandTypeChecker = { + + val signatures = getMethodSignatures(aggregateFunction, "accumulate") + + /** + * Operand type checker based on [[AggregateFunction]] given information. + */ + new SqlOperandTypeChecker { + override def getAllowedSignatures(op: SqlOperator, opName: String): String = { + s"$opName[${signaturesToString(aggregateFunction, "accumulate")}]" + } + + override def getOperandCountRange: SqlOperandCountRange = { + var min = 255 + var max = -1 + signatures.foreach( + sig => { + //do not count accumulator as input + val inputSig = sig.drop(1) + var len = inputSig.length + if (len > 0 && inputSig.last.isArray) { + max = 253 // according to JVM spec 4.3.3 + len = sig.length - 1 + } + max = Math.max(len, max) + min = Math.min(len, min) + }) + SqlOperandCountRanges.between(min, max) + } + + override def checkOperandTypes( + callBinding: SqlCallBinding, + throwOnFailure: Boolean) + : Boolean = { + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + + if (foundSignature.isEmpty) { + if (throwOnFailure) { + throw new ValidationException( + s"Given parameters of function do not match any signature. \n" + + s"Actual: ${signatureToString(operandTypeInfo)} \n" + + s"Expected: ${signaturesToString(aggregateFunction, "accumulate")}") + } else { + false + } + } else { + true + } + } + + override def isOptional(i: Int): Boolean = false + + override def getConsistency: Consistency = Consistency.NONE + + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala index e2cd272..bbfa3aa 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala @@ -23,12 +23,11 @@ import org.apache.calcite.sql._ import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency import org.apache.calcite.sql.`type`._ import org.apache.calcite.sql.parser.SqlParserPos -import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.ValidationException import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.functions.ScalarFunction -import org.apache.flink.table.functions.utils.ScalarSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} -import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, getSignatures, signatureToString, signaturesToString} +import org.apache.flink.table.functions.utils.ScalarSqlFunction._ +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import scala.collection.JavaConverters._ @@ -77,14 +76,14 @@ object ScalarSqlFunction { FlinkTypeFactory.toTypeInfo(operandType) } } - val foundSignature = getSignature(scalarFunction, parameters) + val foundSignature = getEvalMethodSignature(scalarFunction, parameters) if (foundSignature.isEmpty) { throw new ValidationException( s"Given parameters of function '$name' do not match any signature. \n" + s"Actual: ${signatureToString(parameters)} \n" + - s"Expected: ${signaturesToString(scalarFunction)}") + s"Expected: ${signaturesToString(scalarFunction, "eval")}") } - val resultType = getResultType(scalarFunction, foundSignature.get) + val resultType = getResultTypeOfScalarFunction(scalarFunction, foundSignature.get) val t = typeFactory.createTypeFromTypeInfo(resultType) typeFactory.createTypeWithNullability(t, nullable = true) } @@ -106,7 +105,7 @@ object ScalarSqlFunction { val operandTypeInfo = getOperandTypeInfo(callBinding) - val foundSignature = getSignature(scalarFunction, operandTypeInfo) + val foundSignature = getEvalMethodSignature(scalarFunction, operandTypeInfo) .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) val inferredTypes = scalarFunction @@ -132,14 +131,14 @@ object ScalarSqlFunction { scalarFunction: ScalarFunction) : SqlOperandTypeChecker = { - val signatures = getSignatures(scalarFunction) + val signatures = getMethodSignatures(scalarFunction, "eval") /** * Operand type checker based on [[ScalarFunction]] given information. */ new SqlOperandTypeChecker { override def getAllowedSignatures(op: SqlOperator, opName: String): String = { - s"$opName[${signaturesToString(scalarFunction)}]" + s"$opName[${signaturesToString(scalarFunction, "eval")}]" } override def getOperandCountRange: SqlOperandCountRange = { @@ -163,14 +162,14 @@ object ScalarSqlFunction { : Boolean = { val operandTypeInfo = getOperandTypeInfo(callBinding) - val foundSignature = getSignature(scalarFunction, operandTypeInfo) + val foundSignature = getEvalMethodSignature(scalarFunction, operandTypeInfo) if (foundSignature.isEmpty) { if (throwOnFailure) { throw new ValidationException( s"Given parameters of function '$name' do not match any signature. \n" + s"Actual: ${signatureToString(operandTypeInfo)} \n" + - s"Expected: ${signaturesToString(scalarFunction)}") + s"Expected: ${signaturesToString(scalarFunction, "eval")}") } else { false } @@ -185,16 +184,4 @@ object ScalarSqlFunction { } } - - private[flink] def getOperandTypeInfo(callBinding: SqlCallBinding): Seq[TypeInformation[_]] = { - val operandTypes = for (i <- 0 until callBinding.getOperandCount) - yield callBinding.getOperandType(i) - operandTypes.map { operandType => - if (operandType.getSqlTypeName == SqlTypeName.NULL) { - null - } else { - FlinkTypeFactory.toTypeInfo(operandType) - } - } - } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index d108e31..689bf0e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -25,15 +25,16 @@ import java.sql.{Date, Time, Timestamp} import org.apache.commons.codec.binary.Base64 import com.google.common.primitives.Primitives -import org.apache.calcite.sql.SqlFunction +import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.sql.{SqlCallBinding, SqlFunction} import org.apache.flink.api.common.functions.InvalidTypesException import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException} import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction} import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction, UserDefinedFunction} import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl import org.apache.flink.util.InstantiationUtil @@ -69,52 +70,89 @@ object UserDefinedFunctionUtils { } // ---------------------------------------------------------------------------------------------- - // Utilities for eval methods + // Utilities for user-defined methods // ---------------------------------------------------------------------------------------------- /** - * Returns signatures matching the given signature of [[TypeInformation]]. + * Returns signatures of eval methods matching the given signature of [[TypeInformation]]. * Elements of the signature can be null (act as a wildcard). */ - def getSignature( - function: UserDefinedFunction, - signature: Seq[TypeInformation[_]]) + def getEvalMethodSignature( + function: UserDefinedFunction, + signature: Seq[TypeInformation[_]]) : Option[Array[Class[_]]] = { - getEvalMethod(function, signature).map(_.getParameterTypes) + getUserDefinedMethod(function, "eval", typeInfoToClass(signature)).map(_.getParameterTypes) } /** - * Returns eval method matching the given signature of [[TypeInformation]]. + * Returns signatures of accumulate methods matching the given signature of [[TypeInformation]]. + * Elements of the signature can be null (act as a wildcard). */ - def getEvalMethod( - function: UserDefinedFunction, + def getAccumulateMethodSignature( + function: AggregateFunction[_, _], signature: Seq[TypeInformation[_]]) + : Option[Array[Class[_]]] = { + val accType = TypeExtractor.createTypeInfo( + function, classOf[AggregateFunction[_, _]], function.getClass, 1) + val input = (Array(accType) ++ signature).toSeq + getUserDefinedMethod( + function, + "accumulate", + typeInfoToClass(input)).map(_.getParameterTypes) + } + + def getParameterTypes( + function: UserDefinedFunction, + signature: Array[Class[_]]): Array[TypeInformation[_]] = { + signature.map { c => + try { + TypeExtractor.getForClass(c) + } catch { + case ite: InvalidTypesException => + throw new ValidationException( + s"Parameter types of function '${function.getClass.getCanonicalName}' cannot be " + + s"automatically determined. Please provide type information manually.") + } + } + } + + /** + * Returns user defined method matching the given name and signature. + * + * @param function function instance + * @param methodName method name + * @param methodSignature an array of raw Java classes. We compare the raw Java classes not the + * TypeInformation. TypeInformation does not matter during runtime (e.g. + * within a MapFunction) + */ + def getUserDefinedMethod( + function: UserDefinedFunction, + methodName: String, + methodSignature: Array[Class[_]]) : Option[Method] = { - // We compare the raw Java classes not the TypeInformation. - // TypeInformation does not matter during runtime (e.g. within a MapFunction). - val actualSignature = typeInfoToClass(signature) - val evalMethods = checkAndExtractEvalMethods(function) - val filtered = evalMethods - // go over all eval methods and filter out matching methods + val methods = checkAndExtractMethods(function, methodName) + + val filtered = methods + // go over all the methods and filter out matching methods .filter { case cur if !cur.isVarArgs => val signatures = cur.getParameterTypes // match parameters of signature to actual parameters - actualSignature.length == signatures.length && + methodSignature.length == signatures.length && signatures.zipWithIndex.forall { case (clazz, i) => - parameterTypeEquals(actualSignature(i), clazz) + parameterTypeEquals(methodSignature(i), clazz) } case cur if cur.isVarArgs => val signatures = cur.getParameterTypes - actualSignature.zipWithIndex.forall { + methodSignature.zipWithIndex.forall { // non-varargs case (clazz, i) if i < signatures.length - 1 => parameterTypeEquals(clazz, signatures(i)) // varargs case (clazz, i) if i >= signatures.length - 1 => parameterTypeEquals(clazz, signatures.last.getComponentType) - } || (actualSignature.isEmpty && signatures.length == 1) // empty varargs + } || (methodSignature.isEmpty && signatures.length == 1) // empty varargs } // if there is a fixed method, compiler will call this method preferentially @@ -126,19 +164,21 @@ object UserDefinedFunctionUtils { // check if there is a Scala varargs annotation if (found.isEmpty && - evalMethods.exists { evalMethod => - val signatures = evalMethod.getParameterTypes + methods.exists { method => + val signatures = method.getParameterTypes signatures.zipWithIndex.forall { case (clazz, i) if i < signatures.length - 1 => - parameterTypeEquals(actualSignature(i), clazz) + parameterTypeEquals(methodSignature(i), clazz) case (clazz, i) if i == signatures.length - 1 => clazz.getName.equals("scala.collection.Seq") } }) { - throw new ValidationException("Scala-style variable arguments in 'eval' methods are not " + - "supported. Please add a @scala.annotation.varargs annotation.") + throw new ValidationException( + s"Scala-style variable arguments in '${methodName}' methods are not supported. Please " + + s"add a @scala.annotation.varargs annotation.") } else if (found.length > 1) { - throw new ValidationException("Found multiple 'eval' methods which match the signature.") + throw new ValidationException( + "Found multiple '${methodName}' methods which match the signature.") } found.headOption } @@ -157,16 +197,18 @@ object UserDefinedFunctionUtils { } /** - * Extracts "eval" methods and throws a [[ValidationException]] if no implementation + * Extracts methods and throws a [[ValidationException]] if no implementation * can be found, or implementation does not match the requirements. */ - def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = { + def checkAndExtractMethods( + function: UserDefinedFunction, + methodName: String): Array[Method] = { val methods = function .getClass - .getDeclaredMethods + .getMethods .filter { m => val modifiers = m.getModifiers - m.getName == "eval" && + m.getName == methodName && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) && !(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers)) @@ -175,15 +217,17 @@ object UserDefinedFunctionUtils { if (methods.isEmpty) { throw new ValidationException( s"Function class '${function.getClass.getCanonicalName}' does not implement at least " + - s"one method named 'eval' which is public, not abstract and " + + s"one method named '${methodName}' which is public, not abstract and " + s"(in case of table functions) not static.") } methods } - def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = { - checkAndExtractEvalMethods(function).map(_.getParameterTypes) + def getMethodSignatures( + function: UserDefinedFunction, + methodName: String): Array[Array[Class[_]]] = { + checkAndExtractMethods(function, methodName).map(_.getParameterTypes) } // ---------------------------------------------------------------------------------------------- @@ -222,7 +266,7 @@ object UserDefinedFunctionUtils { typeFactory: FlinkTypeFactory) : Seq[SqlFunction] = { val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType) - val evalMethods = checkAndExtractEvalMethods(tableFunction) + val evalMethods = checkAndExtractMethods(tableFunction, "eval") evalMethods.map { method => val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method) @@ -230,29 +274,75 @@ object UserDefinedFunctionUtils { } } + /** + * Create [[SqlFunction]] for an [[AggregateFunction]] + * + * @param name function name + * @param aggFunction aggregate function + * @param typeFactory type factory + * @return the TableSqlFunction + */ + def createAggregateSqlFunction( + name: String, + aggFunction: AggregateFunction[_, _], + typeInfo: TypeInformation[_], + typeFactory: FlinkTypeFactory) + : SqlFunction = { + //check if a qualified accumulate method exists before create Sql function + checkAndExtractMethods(aggFunction, "accumulate") + val resultType: TypeInformation[_] = getResultTypeOfAggregateFunction(aggFunction, typeInfo) + AggSqlFunction(name, aggFunction, resultType, typeFactory) + } + // ---------------------------------------------------------------------------------------------- - // Utilities for scalar functions + // Utilities for user-defined functions // ---------------------------------------------------------------------------------------------- /** + * Internal method of AggregateFunction#getResultType() that does some pre-checking and uses + * [[TypeExtractor]] as default return type inference. + */ + def getResultTypeOfAggregateFunction( + aggregateFunction: AggregateFunction[_, _], + extractedType: TypeInformation[_] = null) + : TypeInformation[_] = { + + val resultType = try { + val method: Method = aggregateFunction.getClass.getMethod("getResultType") + method.invoke(aggregateFunction).asInstanceOf[TypeInformation[_]] + } catch { + case _: NoSuchMethodException => null + case ite: Throwable => throw new TableException("Unexpected exception:", ite) + } + if (resultType != null) { + resultType + } else if(extractedType != null) { + extractedType + } else { + TypeExtractor + .createTypeInfo(aggregateFunction, + classOf[AggregateFunction[_, _]], + aggregateFunction.getClass, + 0) + .asInstanceOf[TypeInformation[_]] + } + } + + /** * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses * [[TypeExtractor]] as default return type inference. */ - def getResultType( + def getResultTypeOfScalarFunction( function: ScalarFunction, signature: Array[Class[_]]) : TypeInformation[_] = { - // find method for signature - val evalMethod = checkAndExtractEvalMethods(function) - .find(m => signature.sameElements(m.getParameterTypes)) - .getOrElse(throw new ValidationException("Given signature is invalid.")) val userDefinedTypeInfo = function.getResultType(signature) if (userDefinedTypeInfo != null) { userDefinedTypeInfo } else { try { - TypeExtractor.getForClass(evalMethod.getReturnType) + TypeExtractor.getForClass(getResultTypeClassOfScalarFunction(function, signature)) } catch { case ite: InvalidTypesException => throw new ValidationException( @@ -265,12 +355,12 @@ object UserDefinedFunctionUtils { /** * Returns the return type of the evaluation method matching the given signature. */ - def getResultTypeClass( + def getResultTypeClassOfScalarFunction( function: ScalarFunction, signature: Array[Class[_]]) : Class[_] = { // find method for signature - val evalMethod = checkAndExtractEvalMethods(function) + val evalMethod = checkAndExtractMethods(function, "eval") .find(m => signature.sameElements(m.getParameterTypes)) .getOrElse(throw new IllegalArgumentException("Given signature is invalid.")) evalMethod.getReturnType @@ -317,16 +407,16 @@ object UserDefinedFunctionUtils { } /** - * Prints all eval methods signatures of a class. + * Prints all signatures of methods with given name in a class. */ - def signaturesToString(function: UserDefinedFunction): String = { - getSignatures(function).map(signatureToString).mkString(", ") + def signaturesToString(function: UserDefinedFunction, name: String): String = { + getMethodSignatures(function, name).map(signatureToString).mkString(", ") } /** * Extracts type classes of [[TypeInformation]] in a null-aware way. */ - private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] = + def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] = typeInfos.map { typeInfo => if (typeInfo == null) { null @@ -393,4 +483,16 @@ object UserDefinedFunctionUtils { .as(alias).toLogicalTableFunctionCall(child = null) functionCall } + + def getOperandTypeInfo(callBinding: SqlCallBinding): Seq[TypeInformation[_]] = { + val operandTypes = for (i <- 0 until callBinding.getOperandCount) + yield callBinding.getOperandType(i) + operandTypes.map { operandType => + if (operandType.getSqlTypeName == SqlTypeName.NULL) { + null + } else { + FlinkTypeFactory.toTypeInfo(operandType) + } + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala index 0d45a37..d26cdcf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala @@ -303,6 +303,11 @@ object ProjectionTranslator { (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences) } + case aggfc @ AggFunctionCall(clazz, args) => + args.foldLeft(fieldReferences) { + (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences) + } + // array constructor case c @ ArrayConstructor(args) => args.foldLeft(fieldReferences) { @@ -327,4 +332,56 @@ object ProjectionTranslator { } } + /** + * Find and replace UDAGG function Call to AggFunctionCall + * + * @param field the expression to check + * @param tableEnv the TableEnvironment + * @return an expression with correct AggFunctionCall type for UDAGG functions + */ + def replaceAggFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = { + field match { + case l: LeafExpression => l + + case u: UnaryExpression => + val c = replaceAggFunctionCall(u.child, tableEnv) + u.makeCopy(Array(c)) + + case b: BinaryExpression => + val l = replaceAggFunctionCall(b.left, tableEnv) + val r = replaceAggFunctionCall(b.right, tableEnv) + b.makeCopy(Array(l, r)) + + // Functions calls + case c @ Call(name, args) => + val function = tableEnv.getFunctionCatalog.lookupFunction(name, args) + if (function.isInstanceOf[AggFunctionCall]) { + function + } else { + val newArgs = + args.map( + (exp: Expression) => + replaceAggFunctionCall(exp, tableEnv)) + c.makeCopy(Array(name, newArgs)) + } + + // Scala functions + case sfc @ ScalarFunctionCall(clazz, args) => + val newArgs: Seq[Expression] = + args.map( + (exp: Expression) => + replaceAggFunctionCall(exp, tableEnv)) + sfc.makeCopy(Array(clazz, newArgs)) + + // Array constructor + case c @ ArrayConstructor(args) => + val newArgs = + c.elements + .map((exp: Expression) => replaceAggFunctionCall(exp, tableEnv)) + c.makeCopy(Array(newArgs)) + + // Other expressions + case e: Expression => e + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index c67bfd1..5f2394c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -685,12 +685,12 @@ case class LogicalTableFunctionCall( checkForInstantiation(tableFunction.getClass) // look for a signature that matches the input types val signature = node.parameters.map(_.resultType) - val foundMethod = getEvalMethod(tableFunction, signature) + val foundMethod = getUserDefinedMethod(tableFunction, "eval", typeInfoToClass(signature)) if (foundMethod.isEmpty) { failValidation( s"Given parameters of function '$functionName' do not match any signature. \n" + s"Actual: ${signatureToString(signature)} \n" + - s"Expected: ${signaturesToString(tableFunction)}") + s"Expected: ${signaturesToString(tableFunction, "eval")}") } else { node.evalMethod = foundMethod.get } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala index 3883b14..e95747c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala @@ -50,7 +50,7 @@ trait CommonAggregate { val aggs = namedAggregates.map(_.getKey) val aggStrings = aggs.map( a => s"${a.getAggregation}(${ if (a.getArgList.size() > 0) { - inFields(a.getArgList.get(0)) + a.getArgList.asScala.map(inFields(_)).mkString(", ") } else { "*" } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala index 91c8cef..6878473 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala @@ -88,7 +88,7 @@ trait OverAggregate { val aggStrings = namedAggregates.map(_.getKey).map( a => s"${a.getAggregation}(${ if (a.getArgList.size() > 0) { - inFields(a.getArgList.get(0)) + a.getArgList.asScala.map(inFields(_)).mkString(", ") } else { "*" } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala index 187773d..c232a71 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala @@ -122,6 +122,12 @@ class DataStreamAggregate( false, inputDS.getType) + val needMerge = window match { + case ProcessingTimeSessionGroupWindow(_, _) => true + case EventTimeSessionGroupWindow(_, _, _) => true + case _ => false + } + // grouped / keyed aggregation if (grouping.length > 0) { val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( @@ -141,7 +147,8 @@ class DataStreamAggregate( generator, namedAggregates, inputType, - rowRelDataType) + rowRelDataType, + needMerge) windowedStream .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) @@ -163,7 +170,8 @@ class DataStreamAggregate( generator, namedAggregates, inputType, - rowRelDataType) + rowRelDataType, + needMerge) windowedStream .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 2950a78..e38207d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -39,6 +39,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.expressions._ import org.apache.flink.table.functions.aggfunctions._ +import org.apache.flink.table.functions.utils.AggSqlFunction import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction} import org.apache.flink.table.plan.logical._ @@ -101,7 +102,8 @@ object AggregateUtil { None, outputArity, needRetract, - needMerge = false + needMerge = false, + needReset = false ) if (isRowTimeType) { @@ -178,7 +180,8 @@ object AggregateUtil { None, outputArity, needRetract, - needMerge = false + needMerge = false, + needReset = true ) if (isRowTimeType) { @@ -303,7 +306,8 @@ object AggregateUtil { None, outputArity, needRetract, - needMerge = false + needMerge = false, + needReset = true ) new DataSetWindowAggMapFunction( @@ -374,12 +378,13 @@ object AggregateUtil { aggFieldIndexes, aggregates.indices.map(_ + groupings.length).toArray, partialResults = true, - groupings, + groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, keysAndAggregatesArity + 1, needRetract, - needMerge = true + needMerge = true, + needReset = true ) new DataSetSlideTimeWindowAggReduceGroupFunction( genFunction, @@ -481,7 +486,8 @@ object AggregateUtil { None, outputType.getFieldCount, needRetract, - needMerge = true + needMerge = true, + needReset = true ) val genFinalAggFunction = generator.generateAggregations( @@ -497,7 +503,8 @@ object AggregateUtil { None, outputType.getFieldCount, needRetract, - needMerge = true + needMerge = true, + needReset = true ) val keysAndAggregatesArity = groupings.length + namedAggregates.length @@ -636,7 +643,8 @@ object AggregateUtil { None, groupings.length + aggregates.length + 2, needRetract, - needMerge = true + needMerge = true, + needReset = true ) new DataSetSessionWindowAggregatePreProcessor( @@ -708,7 +716,8 @@ object AggregateUtil { None, groupings.length + aggregates.length + 2, needRetract, - needMerge = true + needMerge = true, + needReset = true ) new DataSetSessionWindowAggregatePreProcessor( @@ -787,7 +796,8 @@ object AggregateUtil { None, groupings.length + aggregates.length, needRetract, - needMerge = false + needMerge = false, + needReset = true ) // compute mapping of forwarded grouping keys @@ -813,7 +823,8 @@ object AggregateUtil { constantFlags, outputType.getFieldCount, needRetract, - needMerge = true + needMerge = true, + needReset = true ) ( @@ -836,7 +847,8 @@ object AggregateUtil { constantFlags, outputType.getFieldCount, needRetract, - needMerge = false + needMerge = false, + needReset = true ) ( @@ -902,7 +914,8 @@ object AggregateUtil { generator: CodeGenerator, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, - outputType: RelDataType) + outputType: RelDataType, + needMerge: Boolean) : (DataStreamAggFunction[Row, Row, Row], RowTypeInfo, RowTypeInfo) = { val needRetract = false @@ -928,7 +941,8 @@ object AggregateUtil { None, outputArity, needRetract, - needMerge = true + needMerge, + needReset = false ) val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType)) @@ -1083,9 +1097,6 @@ object AggregateUtil { throw new TableException("Aggregate fields should not be empty.") } } else { - if (argList.size() > 1) { - throw new TableException("Currently, do not support aggregate on multi fields.") - } aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray } val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType @@ -1298,6 +1309,9 @@ object AggregateUtil { case _: SqlCountAggFunction => aggregates(index) = new CountAggFunction + case udagg: AggSqlFunction => + aggregates(index) = udagg.getFunction + case unSupported: SqlAggFunction => throw new TableException("unsupported Function: " + unSupported.getName) } http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 1022e4d..63dc1ae 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable} import org.apache.flink.table.api.ValidationException import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.utils.{ScalarSqlFunction, TableSqlFunction} -import org.apache.flink.table.functions.{EventTimeExtractor, RowTime, ScalarFunction, TableFunction, _} +import org.apache.flink.table.functions.utils.{AggSqlFunction, ScalarSqlFunction, TableSqlFunction} +import org.apache.flink.table.functions.{AggregateFunction, EventTimeExtractor, RowTime, ScalarFunction, TableFunction, _} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -97,6 +97,15 @@ class FunctionCatalog { val function = tableSqlFunction.getTableFunction TableFunctionCall(name, function, children, typeInfo) + // user-defined aggregate function call + case af if classOf[AggregateFunction[_, _]].isAssignableFrom(af) => + val aggregateFunction = sqlFunctions + .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[AggSqlFunction]) + .getOrElse(throw ValidationException(s"Undefined table function: $name")) + .asInstanceOf[AggSqlFunction] + val function = aggregateFunction.getFunction + AggFunctionCall(function, children) + // general expression call case expression if classOf[Expression].isAssignableFrom(expression) => // try to find a constructor accepts `Seq[Expression]`