Repository: flink Updated Branches: refs/heads/master 9d3471574 -> 11218a35d
[FLINK-8013] [table] Support aggregate functions with generic arrays This closes #5011. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/11218a35 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/11218a35 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/11218a35 Branch: refs/heads/master Commit: 11218a35dc0fdd7439142a313e6628c51cffe689 Parents: 9d34715 Author: twalthr <twal...@apache.org> Authored: Tue Nov 14 11:06:54 2017 +0100 Committer: twalthr <twal...@apache.org> Committed: Wed Nov 15 15:11:39 2017 +0100 ---------------------------------------------------------------------- .../api/java/typeutils/TypeExtractionUtils.java | 16 +++ .../codegen/AggregationCodeGenerator.scala | 20 ++- .../flink/table/expressions/aggregations.scala | 6 +- .../utils/UserDefinedFunctionUtils.scala | 9 +- .../runtime/batch/table/AggregateITCase.scala | 25 +++- .../flink/table/utils/TableTestBase.scala | 2 +- .../table/utils/UserDefinedAggFunctions.scala | 126 +++++++++++++++++++ .../api/scala/util/CollectionDataSets.scala | 1 - 8 files changed, 185 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java index c5c2565..56fcf82 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java @@ -18,7 +18,9 @@ package org.apache.flink.api.java.typeutils; +import java.lang.reflect.Array; import java.lang.reflect.Constructor; +import java.lang.reflect.GenericArrayType; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; @@ -322,4 +324,18 @@ public class TypeExtractionUtils { } return false; } + + /** + * Returns the raw class of both parameterized types and generic arrays. + * Returns java.lang.Object for all other types. + */ + public static Class<?> getRawClass(Type t) { + if (isClassType(t)) { + return typeToClass(t); + } else if (t instanceof GenericArrayType) { + Type component = ((GenericArrayType)t).getGenericComponentType(); + return Array.newInstance(getRawClass(component), 0).getClass(); + } + return Object.class; + } } http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala index c85b111..32cbde2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala @@ -17,17 +17,18 @@ */ package org.apache.flink.table.codegen -import java.lang.reflect.{Modifier, ParameterizedType} +import java.lang.reflect.Modifier import java.lang.{Iterable => JIterable} import org.apache.calcite.rex.RexLiteral import org.apache.commons.codec.binary.Base64 import org.apache.flink.api.common.state.{State, StateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass} import org.apache.flink.table.api.TableConfig import org.apache.flink.table.api.dataview._ -import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess} +import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString} @@ -175,7 +176,7 @@ class AggregationCodeGenerator( } if (needMerge) { - val methods = + val method = getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]])) .getOrElse( throw new CodeGenException( @@ -183,17 +184,14 @@ class AggregationCodeGenerator( s"${a.getClass.getCanonicalName}'.") ) - var iterableTypeClass = methods.getGenericParameterTypes.apply(1) - .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0) - // further extract iterableTypeClass if the accumulator has generic type - iterableTypeClass match { - case impl: ParameterizedType => iterableTypeClass = impl.getRawType - case _ => - } + // use the TypeExtractionUtils here to support nested GenericArrayTypes and + // other complex types + val iterableGenericType = extractTypeArgument(method.getGenericParameterTypes()(1), 0) + val iterableTypeClass = getRawClass(iterableGenericType) if (iterableTypeClass != accTypeClasses(i)) { throw new CodeGenException( - s"merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " + + s"Merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " + s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " + s"Expected: ${accTypeClasses(i).toString}") } http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala index 1ffcb12..3adaaa9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala @@ -245,8 +245,10 @@ case class AggFunctionCall( ValidationFailure(s"Given parameters do not match any signature. \n" + s"Actual: ${signatureToString(signature)} \n" + s"Expected: ${ - getMethodSignatures(aggregateFunction, "accumulate").drop(1) - .map(signatureToString).mkString(", ")}") + getMethodSignatures(aggregateFunction, "accumulate") + .map(_.drop(1)) + .map(signatureToString) + .mkString(", ")}") } else { ValidationSuccess } http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index 3cd694a..4a34732 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -607,13 +607,14 @@ object UserDefinedFunctionUtils { candidate == expected || expected == classOf[Object] || expected.isPrimitive && Primitives.wrap(expected) == candidate || + // time types candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) || candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) || candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) || - (candidate.isArray && - expected.isArray && - candidate.getComponentType.isInstanceOf[Object] && - expected.getComponentType == classOf[Object]) + // arrays + (candidate.isArray && expected.isArray && + (candidate.getComponentType == expected.getComponentType || + expected.getComponentType == classOf[Object])) @throws[Exception] def serialize(function: UserDefinedFunction): String = { http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala index cf96d19..e1348f6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala @@ -22,12 +22,13 @@ import java.math.BigDecimal import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.{TableEnvironment, Types} import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinctWithMergeAndReset, WeightedAvgWithMergeAndReset} import org.apache.flink.table.api.scala._ import org.apache.flink.table.functions.aggfunctions.CountAggFunction import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.utils.Top10 import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row import org.junit._ @@ -392,6 +393,28 @@ class AggregationsITCase( val results = res.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } + + @Test + def testComplexAggregate(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val top10Fun = new Top10 + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, top10Fun('b.cast(Types.INT), 'a.cast(Types.FLOAT))) + + val expected = + "1,[(1,1.0), null, null, null, null, null, null, null, null, null]\n" + + "2,[(2,3.0), (2,2.0), null, null, null, null, null, null, null, null]\n" + + "3,[(3,6.0), (3,5.0), (3,4.0), null, null, null, null, null, null, null]\n" + + "4,[(4,10.0), (4,9.0), (4,8.0), (4,7.0), null, null, null, null, null, null]\n" + + "5,[(5,15.0), (5,14.0), (5,13.0), (5,12.0), (5,11.0), null, null, null, null, null]\n" + + "6,[(6,21.0), (6,20.0), (6,19.0), (6,18.0), (6,17.0), (6,16.0), null, null, null, null]" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } } case class WC(word: String, frequency: Long) http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala index 3829314..804fad8 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala @@ -91,7 +91,7 @@ abstract class TableTestUtil { val actual = RelOptUtil.toString(optimized) // we remove the charset for testing because it // depends on the native machine (Little/Big Endian) - val actualNoCharset = actual.replace("_UTF-16LE'", "'") + val actualNoCharset = actual.replace("_UTF-16LE'", "'").replace("_UTF-16BE'", "'") assertEquals( expected.split("\n").map(_.trim).mkString("\n"), actualNoCharset.split("\n").map(_.trim).mkString("\n")) http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala new file mode 100644 index 0000000..7d4393c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.utils + +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import java.lang.{Integer => JInt} +import java.lang.{Float => JFloat} +import java.util + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, TupleTypeInfo} +import org.apache.flink.table.api.Types + +/** + * User-defined aggregation function to compute the top 10 most visited Int IDs + * with the highest Float values. We use an Array[Tuple2[Int, Float]] as accumulator to + * store the top 10 entries. + * + * The result is emitted as Array as well. + */ +class Top10 extends AggregateFunction[Array[JTuple2[JInt, JFloat]], Array[JTuple2[JInt, JFloat]]] { + + @Override + def createAccumulator(): Array[JTuple2[JInt, JFloat]] = { + new Array[JTuple2[JInt, JFloat]](10) + } + + /** + * Adds a new entry and count to the top 10 entries if necessary. + * + * @param acc The current top 10 + * @param id The ID + * @param value The value for the ID + */ + def accumulate(acc: Array[JTuple2[JInt, JFloat]], id: Int, value: Float) { + + var i = 9 + var skipped = 0 + + // skip positions without records + while (i >= 0 && acc(i) == null) { + if (acc(i) == null) { + // continue until first entry in the top10 list + i -= 1 + } + } + // backward linear search for insert position + while (i >= 0 && value > acc(i).f1) { + // check next entry + skipped += 1 + i -= 1 + } + + // set if necessary + if (i < 9) { + // move entries with lower count by one position + if (i < 8 && skipped > 0) { + System.arraycopy(acc, i + 1, acc, i + 2, skipped) + } + + // add ID to top10 list + acc(i + 1) = JTuple2.of(id, value) + } + } + + override def getValue(acc: Array[JTuple2[JInt, JFloat]]): Array[JTuple2[JInt, JFloat]] = acc + + def resetAccumulator(acc: Array[JTuple2[JInt, JFloat]]): Unit = { + util.Arrays.fill(acc.asInstanceOf[Array[Object]], null) + } + + def merge( + acc: Array[JTuple2[JInt, JFloat]], + its: java.lang.Iterable[Array[JTuple2[JInt, JFloat]]]): Unit = { + + val it = its.iterator() + while(it.hasNext) { + val acc2 = it.next() + + var i = 0 + var i2 = 0 + while (i < 10 && i2 < 10 && acc2(i2) != null) { + if (acc(i) == null) { + // copy to empty place + acc(i) = acc2(i2) + i += 1 + i2 += 1 + } else if (acc(i).f1.asInstanceOf[Float] >= acc2(i2).f1.asInstanceOf[Float]) { + // forward to next + i += 1 + } else { + // shift and copy + System.arraycopy(acc, i, acc, i + 1, 9 - i) + acc(i) = acc2(i2) + i += 1 + i2 += 1 + } + } + } + } + + override def getAccumulatorType: TypeInformation[Array[JTuple2[JInt, JFloat]]] = { + ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT)) + } + + override def getResultType: TypeInformation[Array[JTuple2[JInt, JFloat]]] = { + ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT)) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala index ec1a810..1cb5b52 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala @@ -55,7 +55,6 @@ object CollectionDataSets { data.+=((19, 6L, "Comment#13")) data.+=((20, 6L, "Comment#14")) data.+=((21, 6L, "Comment#15")) - Random.shuffle(data) env.fromCollection(Random.shuffle(data)) }