[FLINK-6216] [table] Add non-windowed GroupBy aggregation for streams. This closes #3646.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8f78824b Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8f78824b Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8f78824b Branch: refs/heads/master Commit: 8f78824be2e6b5e2029f142b70f7cac15d98abd3 Parents: 24bf61c Author: shaoxuan-wang <wshaox...@gmail.com> Authored: Thu Mar 30 03:57:58 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Sat May 6 01:51:54 2017 +0200 ---------------------------------------------------------------------- .../flink/table/plan/logical/operators.scala | 3 - .../nodes/datastream/DataStreamAggregate.scala | 300 ------------------- .../datastream/DataStreamGroupAggregate.scala | 152 ++++++++++ .../DataStreamGroupWindowAggregate.scala | 300 +++++++++++++++++++ .../flink/table/plan/rules/FlinkRuleSets.scala | 3 +- .../datastream/DataStreamAggregateRule.scala | 77 ----- .../DataStreamGroupAggregateRule.scala | 81 +++++ .../DataStreamGroupWindowAggregateRule.scala | 77 +++++ .../table/runtime/aggregate/AggregateUtil.scala | 66 +++- .../aggregate/GroupAggProcessFunction.scala | 90 ++++++ .../aggregate/ProcTimeBoundedRangeOver.scala | 2 +- .../scala/batch/table/FieldProjectionTest.scala | 4 +- .../api/scala/stream/TableSourceTest.scala | 2 +- .../table/api/scala/stream/sql/SqlITCase.scala | 21 ++ .../scala/stream/sql/WindowAggregateTest.scala | 51 ++-- .../scala/stream/table/AggregationsITCase.scala | 197 ------------ .../stream/table/GroupAggregationsITCase.scala | 132 ++++++++ .../stream/table/GroupAggregationsTest.scala | 214 +++++++++++++ .../table/GroupWindowAggregationsITCase.scala | 197 ++++++++++++ .../scala/stream/table/GroupWindowTest.scala | 195 +++++++----- .../scala/stream/table/UnsupportedOpsTest.scala | 7 - 21 files changed, 1480 insertions(+), 691 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 3839145..36067eb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -215,9 +215,6 @@ case class Aggregate( } override def validate(tableEnv: TableEnvironment): LogicalNode = { - if (tableEnv.isInstanceOf[StreamTableEnvironment]) { - failValidation(s"Aggregate on stream tables is currently not supported.") - } val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] val groupingExprs = resolvedAggregate.groupingExpressions http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala deleted file mode 100644 index 5697449..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.nodes.datastream - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} -import org.apache.flink.streaming.api.windowing.assigners._ -import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} -import org.apache.flink.table.api.StreamTableEnvironment -import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty -import org.apache.flink.table.codegen.CodeGenerator -import org.apache.flink.table.expressions.ExpressionUtils._ -import org.apache.flink.table.plan.logical._ -import org.apache.flink.table.plan.nodes.CommonAggregate -import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate._ -import org.apache.flink.table.plan.schema.RowSchema -import org.apache.flink.table.runtime.aggregate.AggregateUtil._ -import org.apache.flink.table.runtime.aggregate._ -import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval -import org.apache.flink.types.Row - -class DataStreamAggregate( - window: LogicalWindow, - namedProperties: Seq[NamedWindowProperty], - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputNode: RelNode, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - schema: RowSchema, - inputSchema: RowSchema, - grouping: Array[Int]) - extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { - - override def deriveRowType(): RelDataType = schema.logicalType - - override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { - new DataStreamAggregate( - window, - namedProperties, - cluster, - traitSet, - inputs.get(0), - namedAggregates, - schema, - inputSchema, - grouping) - } - - override def toString: String = { - s"Aggregate(${ - if (!grouping.isEmpty) { - s"groupBy: (${groupingToString(inputSchema.logicalType, grouping)}), " - } else { - "" - } - }window: ($window), " + - s"select: (${ - aggregationToString( - inputSchema.logicalType, - grouping, - getRowType, - namedAggregates, - namedProperties) - }))" - } - - override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) - .itemIf("groupBy", groupingToString(inputSchema.logicalType, grouping), !grouping.isEmpty) - .item("window", window) - .item( - "select", aggregationToString( - inputSchema.logicalType, - grouping, - schema.logicalType, - namedAggregates, - namedProperties)) - } - - override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { - - val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) - val physicalNamedAggregates = namedAggregates.map { namedAggregate => - new CalcitePair[AggregateCall, String]( - inputSchema.mapAggregateCall(namedAggregate.left), - namedAggregate.right) - } - - val aggString = aggregationToString( - inputSchema.logicalType, - grouping, - schema.logicalType, - namedAggregates, - namedProperties) - - val keyedAggOpName = s"groupBy: (${groupingToString(schema.logicalType, grouping)}), " + - s"window: ($window), " + - s"select: ($aggString)" - val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" - - val generator = new CodeGenerator( - tableEnv.getConfig, - false, - inputDS.getType) - - val needMerge = window match { - case SessionGroupWindow(_, _, _) => true - case _ => false - } - val physicalGrouping = grouping.map(inputSchema.mapIndex) - - // grouped / keyed aggregation - if (physicalGrouping.length > 0) { - val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( - window, - physicalGrouping.length, - physicalNamedAggregates.size, - schema.physicalArity, - namedProperties) - - val keyedStream = inputDS.keyBy(physicalGrouping: _*) - val windowedStream = - createKeyedWindowedStream(window, keyedStream) - .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] - - val (aggFunction, accumulatorRowType, aggResultRowType) = - AggregateUtil.createDataStreamAggregateFunction( - generator, - physicalNamedAggregates, - inputSchema.physicalType, - inputSchema.physicalFieldTypeInfo, - schema.physicalType, - needMerge) - - windowedStream - .aggregate( - aggFunction, - windowFunction, - accumulatorRowType, - aggResultRowType, - schema.physicalTypeInfo) - .name(keyedAggOpName) - } - // global / non-keyed aggregation - else { - val windowFunction = AggregateUtil.createAggregationAllWindowFunction( - window, - schema.physicalArity, - namedProperties) - - val windowedStream = - createNonKeyedWindowedStream(window, inputDS) - .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] - - val (aggFunction, accumulatorRowType, aggResultRowType) = - AggregateUtil.createDataStreamAggregateFunction( - generator, - physicalNamedAggregates, - inputSchema.physicalType, - inputSchema.physicalFieldTypeInfo, - schema.physicalType, - needMerge) - - windowedStream - .aggregate( - aggFunction, - windowFunction, - accumulatorRowType, - aggResultRowType, - schema.physicalTypeInfo) - .name(nonKeyedAggOpName) - } - } -} - -object DataStreamAggregate { - - - private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) - : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { - - case TumblingGroupWindow(_, timeField, size) - if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size)=> - stream.window(TumblingProcessingTimeWindows.of(toTime(size))) - - case TumblingGroupWindow(_, timeField, size) - if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> - stream.countWindow(toLong(size)) - - case TumblingGroupWindow(_, timeField, size) - if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) => - stream.window(TumblingEventTimeWindows.of(toTime(size))) - - case TumblingGroupWindow(_, _, size) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case SlidingGroupWindow(_, timeField, size, slide) - if isProctimeAttribute(timeField) && isTimeIntervalLiteral(slide) => - stream.window(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide))) - - case SlidingGroupWindow(_, timeField, size, slide) - if isProctimeAttribute(timeField) && isRowCountLiteral(size) => - stream.countWindow(toLong(size), toLong(slide)) - - case SlidingGroupWindow(_, timeField, size, slide) - if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=> - stream.window(SlidingEventTimeWindows.of(toTime(size), toTime(slide))) - - case SlidingGroupWindow(_, _, size, slide) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case SessionGroupWindow(_, timeField, gap) - if isProctimeAttribute(timeField) => - stream.window(ProcessingTimeSessionWindows.withGap(toTime(gap))) - - case SessionGroupWindow(_, timeField, gap) - if isRowtimeAttribute(timeField) => - stream.window(EventTimeSessionWindows.withGap(toTime(gap))) - } - - private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) - : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { - - case TumblingGroupWindow(_, timeField, size) - if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) => - stream.windowAll(TumblingProcessingTimeWindows.of(toTime(size))) - - case TumblingGroupWindow(_, timeField, size) - if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> - stream.countWindowAll(toLong(size)) - - case TumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingEventTimeWindows.of(toTime(size))) - - case TumblingGroupWindow(_, _, size) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case SlidingGroupWindow(_, timeField, size, slide) - if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) => - stream.windowAll(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide))) - - case SlidingGroupWindow(_, timeField, size, slide) - if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> - stream.countWindowAll(toLong(size), toLong(slide)) - - case SlidingGroupWindow(_, timeField, size, slide) - if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=> - stream.windowAll(SlidingEventTimeWindows.of(toTime(size), toTime(slide))) - - case SlidingGroupWindow(_, _, size, slide) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case SessionGroupWindow(_, timeField, gap) - if isProctimeAttribute(timeField) && isTimeIntervalLiteral(gap) => - stream.windowAll(ProcessingTimeSessionWindows.withGap(toTime(gap))) - - case SessionGroupWindow(_, timeField, gap) - if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(gap) => - stream.windowAll(EventTimeSessionWindows.withGap(toTime(gap))) - } - - -} - http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala new file mode 100644 index 0000000..19f90c7 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.java.functions.NullByteKeySelector +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.codegen.CodeGenerator +import org.apache.flink.table.runtime.aggregate._ +import org.apache.flink.table.plan.nodes.CommonAggregate +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.types.Row +import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair + +/** + * + * Flink RelNode for data stream unbounded group aggregate + * + * @param cluster Cluster of the RelNode, represent for an environment of related + * relational expressions during the optimization of a query. + * @param traitSet Trait set of the RelNode + * @param inputNode The input RelNode of aggregation + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param rowRelDataType The type of the rows of the RelNode + * @param inputSchema The type of the rows consumed by this RelNode + * @param schema The type of the rows emitted by this RelNode + * @param groupings The position (in the input Row) of the grouping keys + */ +class DataStreamGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + schema: RowSchema, + inputSchema: RowSchema, + groupings: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) + with CommonAggregate + with DataStreamRel { + + override def deriveRowType() = schema.logicalType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamGroupAggregate( + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + schema, + inputSchema, + groupings) + } + + override def toString: String = { + s"Aggregate(${ + if (!groupings.isEmpty) { + s"groupBy: (${groupingToString(inputSchema.logicalType, groupings)}), " + } else { + "" + } + }select:(${aggregationToString( + inputSchema.logicalType, groupings, getRowType, namedAggregates, Nil)}))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString( + inputSchema.logicalType, groupings), !groupings.isEmpty) + .item("select", aggregationToString( + inputSchema.logicalType, groupings, getRowType, namedAggregates, Nil)) + } + + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + val physicalNamedAggregates = namedAggregates.map { namedAggregate => + new CalcitePair[AggregateCall, String]( + inputSchema.mapAggregateCall(namedAggregate.left), + namedAggregate.right) + } + + val generator = new CodeGenerator( + tableEnv.getConfig, + false, + inputDS.getType) + + val aggString = aggregationToString( + inputSchema.logicalType, + groupings, + getRowType, + namedAggregates, + Nil) + + val keyedAggOpName = s"groupBy: (${groupingToString(inputSchema.logicalType, groupings)}), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"select: ($aggString)" + + val physicalGrouping = groupings.map(inputSchema.mapIndex) + + val processFunction = AggregateUtil.createGroupAggregateFunction( + generator, + physicalNamedAggregates, + inputSchema.logicalType, + inputSchema.physicalFieldTypeInfo, + groupings) + + val result: DataStream[Row] = + // grouped / keyed aggregation + if (physicalGrouping.nonEmpty) { + inputDS + .keyBy(groupings: _*) + .process(processFunction) + .returns(schema.physicalTypeInfo) + .name(keyedAggOpName) + .asInstanceOf[DataStream[Row]] + } + // global / non-keyed aggregation + else { + inputDS + .keyBy(new NullByteKeySelector[Row]) + .process(processFunction) + .setParallelism(1) + .setMaxParallelism(1) + .returns(schema.physicalTypeInfo) + .name(nonKeyedAggOpName) + .asInstanceOf[DataStream[Row]] + } + result + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala new file mode 100644 index 0000000..5aced66 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} +import org.apache.flink.streaming.api.windowing.assigners._ +import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.codegen.CodeGenerator +import org.apache.flink.table.expressions.ExpressionUtils._ +import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.plan.nodes.CommonAggregate +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupWindowAggregate._ +import org.apache.flink.table.runtime.aggregate.AggregateUtil._ +import org.apache.flink.table.runtime.aggregate._ +import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.types.Row + +class DataStreamGroupWindowAggregate( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + schema: RowSchema, + inputSchema: RowSchema, + grouping: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { + + override def deriveRowType(): RelDataType = schema.logicalType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamGroupWindowAggregate( + window, + namedProperties, + cluster, + traitSet, + inputs.get(0), + namedAggregates, + schema, + inputSchema, + grouping) + } + + override def toString: String = { + s"Aggregate(${ + if (!grouping.isEmpty) { + s"groupBy: (${groupingToString(inputSchema.logicalType, grouping)}), " + } else { + "" + } + }window: ($window), " + + s"select: (${ + aggregationToString( + inputSchema.logicalType, + grouping, + getRowType, + namedAggregates, + namedProperties) + }))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputSchema.logicalType, grouping), !grouping.isEmpty) + .item("window", window) + .item( + "select", aggregationToString( + inputSchema.logicalType, + grouping, + schema.logicalType, + namedAggregates, + namedProperties)) + } + + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + val physicalNamedAggregates = namedAggregates.map { namedAggregate => + new CalcitePair[AggregateCall, String]( + inputSchema.mapAggregateCall(namedAggregate.left), + namedAggregate.right) + } + + val aggString = aggregationToString( + inputSchema.logicalType, + grouping, + schema.logicalType, + namedAggregates, + namedProperties) + + val keyedAggOpName = s"groupBy: (${groupingToString(schema.logicalType, grouping)}), " + + s"window: ($window), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" + + val generator = new CodeGenerator( + tableEnv.getConfig, + false, + inputDS.getType) + + val needMerge = window match { + case SessionGroupWindow(_, _, _) => true + case _ => false + } + val physicalGrouping = grouping.map(inputSchema.mapIndex) + + // grouped / keyed aggregation + if (physicalGrouping.length > 0) { + val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( + window, + physicalGrouping.length, + physicalNamedAggregates.size, + schema.physicalArity, + namedProperties) + + val keyedStream = inputDS.keyBy(physicalGrouping: _*) + val windowedStream = + createKeyedWindowedStream(window, keyedStream) + .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamAggregateFunction( + generator, + physicalNamedAggregates, + inputSchema.physicalType, + inputSchema.physicalFieldTypeInfo, + schema.physicalType, + needMerge) + + windowedStream + .aggregate( + aggFunction, + windowFunction, + accumulatorRowType, + aggResultRowType, + schema.physicalTypeInfo) + .name(keyedAggOpName) + } + // global / non-keyed aggregation + else { + val windowFunction = AggregateUtil.createAggregationAllWindowFunction( + window, + schema.physicalArity, + namedProperties) + + val windowedStream = + createNonKeyedWindowedStream(window, inputDS) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamAggregateFunction( + generator, + physicalNamedAggregates, + inputSchema.physicalType, + inputSchema.physicalFieldTypeInfo, + schema.physicalType, + needMerge) + + windowedStream + .aggregate( + aggFunction, + windowFunction, + accumulatorRowType, + aggResultRowType, + schema.physicalTypeInfo) + .name(nonKeyedAggOpName) + } + } +} + +object DataStreamGroupWindowAggregate { + + + private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) + : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { + + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size)=> + stream.window(TumblingProcessingTimeWindows.of(toTime(size))) + + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> + stream.countWindow(toLong(size)) + + case TumblingGroupWindow(_, timeField, size) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) => + stream.window(TumblingEventTimeWindows.of(toTime(size))) + + case TumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(slide) => + stream.window(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide))) + + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isRowCountLiteral(size) => + stream.countWindow(toLong(size), toLong(slide)) + + case SlidingGroupWindow(_, timeField, size, slide) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=> + stream.window(SlidingEventTimeWindows.of(toTime(size), toTime(slide))) + + case SlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case SessionGroupWindow(_, timeField, gap) + if isProctimeAttribute(timeField) => + stream.window(ProcessingTimeSessionWindows.withGap(toTime(gap))) + + case SessionGroupWindow(_, timeField, gap) + if isRowtimeAttribute(timeField) => + stream.window(EventTimeSessionWindows.withGap(toTime(gap))) + } + + private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) + : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { + + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) => + stream.windowAll(TumblingProcessingTimeWindows.of(toTime(size))) + + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> + stream.countWindowAll(toLong(size)) + + case TumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingEventTimeWindows.of(toTime(size))) + + case TumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) => + stream.windowAll(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide))) + + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> + stream.countWindowAll(toLong(size), toLong(slide)) + + case SlidingGroupWindow(_, timeField, size, slide) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=> + stream.windowAll(SlidingEventTimeWindows.of(toTime(size), toTime(slide))) + + case SlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case SessionGroupWindow(_, timeField, gap) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(gap) => + stream.windowAll(ProcessingTimeSessionWindows.withGap(toTime(gap))) + + case SessionGroupWindow(_, timeField, gap) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(gap) => + stream.windowAll(EventTimeSessionWindows.withGap(toTime(gap))) + } + + +} + http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 838cb22..f4de651 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -176,8 +176,9 @@ object FlinkRuleSets { */ val DATASTREAM_OPT_RULES: RuleSet = RuleSets.ofList( // translate to DataStream nodes + DataStreamGroupAggregateRule.INSTANCE, DataStreamOverAggregateRule.INSTANCE, - DataStreamAggregateRule.INSTANCE, + DataStreamGroupWindowAggregateRule.INSTANCE, DataStreamCalcRule.INSTANCE, DataStreamScanRule.INSTANCE, DataStreamUnionRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala deleted file mode 100644 index fc65403..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.rules.datastream - -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.convert.ConverterRule -import org.apache.flink.table.api.TableException -import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate -import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate -import org.apache.flink.table.plan.schema.RowSchema - -import scala.collection.JavaConversions._ - -class DataStreamAggregateRule - extends ConverterRule( - classOf[FlinkLogicalWindowAggregate], - FlinkConventions.LOGICAL, - FlinkConventions.DATASTREAM, - "DataStreamAggregateRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val agg: FlinkLogicalWindowAggregate = call.rel(0).asInstanceOf[FlinkLogicalWindowAggregate] - - // check if we have distinct aggregates - val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - if (distinctAggs) { - throw TableException("DISTINCT aggregates are currently not supported.") - } - - // check if we have grouping sets - val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet - if (groupSets || agg.indicator) { - throw TableException("GROUPING SETS are currently not supported.") - } - - !distinctAggs && !groupSets && !agg.indicator - } - - override def convert(rel: RelNode): RelNode = { - val agg: FlinkLogicalWindowAggregate = rel.asInstanceOf[FlinkLogicalWindowAggregate] - val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) - val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM) - - new DataStreamAggregate( - agg.getWindow, - agg.getNamedProperties, - rel.getCluster, - traitSet, - convInput, - agg.getNamedAggCalls, - new RowSchema(rel.getRowType), - new RowSchema(agg.getInput.getRowType), - agg.getGroupSet.toArray) - } - } - -object DataStreamAggregateRule { - val INSTANCE: RelOptRule = new DataStreamAggregateRule -} http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala new file mode 100644 index 0000000..fd7619c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.logical.LogicalAggregate +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupAggregate +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalAggregate +import org.apache.flink.table.plan.schema.RowSchema + +import scala.collection.JavaConversions._ + +/** + * Rule to convert a [[LogicalAggregate]] into a [[DataStreamGroupAggregate]]. + */ +class DataStreamGroupAggregateRule + extends ConverterRule( + classOf[FlinkLogicalAggregate], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamGroupAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: FlinkLogicalAggregate = call.rel(0).asInstanceOf[FlinkLogicalAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: FlinkLogicalAggregate = rel.asInstanceOf[FlinkLogicalAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM) + + new DataStreamGroupAggregate( + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + new RowSchema(rel.getRowType), + new RowSchema(agg.getInput.getRowType), + agg.getGroupSet.toArray) + } +} + +object DataStreamGroupAggregateRule { + val INSTANCE: RelOptRule = new DataStreamGroupAggregateRule +} + http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala new file mode 100644 index 0000000..3beeb47 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupWindowAggregate +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate +import org.apache.flink.table.plan.schema.RowSchema + +import scala.collection.JavaConversions._ + +class DataStreamGroupWindowAggregateRule + extends ConverterRule( + classOf[FlinkLogicalWindowAggregate], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamGroupWindowAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: FlinkLogicalWindowAggregate = call.rel(0).asInstanceOf[FlinkLogicalWindowAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: FlinkLogicalWindowAggregate = rel.asInstanceOf[FlinkLogicalWindowAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM) + + new DataStreamGroupWindowAggregate( + agg.getWindow, + agg.getNamedProperties, + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + new RowSchema(rel.getRowType), + new RowSchema(agg.getInput.getRowType), + agg.getGroupSet.toArray) + } + } + +object DataStreamGroupWindowAggregateRule { + val INSTANCE: RelOptRule = new DataStreamGroupWindowAggregateRule +} http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index dfed34a..5e9efd0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -87,8 +87,7 @@ object AggregateUtil { inputType, needRetraction = false) - val aggregationStateType: RowTypeInfo = - createDataSetAggregateBufferDataType(Array(), aggregates, inputType) + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) val forwardMapping = (0 until inputType.getFieldCount).toArray val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray @@ -139,7 +138,58 @@ object AggregateUtil { } /** - * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for group (without + * window) aggregate to evaluate final aggregate value. + * + * @param generator code generator instance + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param inputRowType Input row type + * @param inputFieldTypes Types of the physical input fields + * @param groupings the position (in the input Row) of the grouping keys + * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] + */ + private[flink] def createGroupAggregateFunction( + generator: CodeGenerator, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputRowType: RelDataType, + inputFieldTypes: Seq[TypeInformation[_]], + groupings: Array[Int]): ProcessFunction[Row, Row] = { + + val (aggFields, aggregates) = + transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputRowType, + needRetraction = false) + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray + + val outputArity = groupings.length + aggregates.length + + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) + + val genFunction = generator.generateAggregations( + "NonWindowedAggregationHelper", + generator, + inputFieldTypes, + aggregates, + aggFields, + aggMapping, + partialResults = false, + groupings, + None, + None, + outputArity, + needRetract = false, + needMerge = false, + needReset = false + ) + + new GroupAggProcessFunction( + genFunction, + aggregationStateType) + } + + /** + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause * bounded OVER window to evaluate final aggregate value. * * @param generator code generator instance @@ -266,7 +316,7 @@ object AggregateUtil { needRetract) val mapReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -370,7 +420,7 @@ object AggregateUtil { physicalInputRowType, needRetract) - val returnType: RowTypeInfo = createDataSetAggregateBufferDataType( + val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates( groupings, aggregates, physicalInputRowType, @@ -637,7 +687,7 @@ object AggregateUtil { window match { case SessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, physicalInputRowType, @@ -711,7 +761,7 @@ object AggregateUtil { case SessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, physicalInputRowType, @@ -1365,7 +1415,7 @@ object AggregateUtil { aggTypes } - private def createDataSetAggregateBufferDataType( + private def createRowTypeForKeysAndAggregates( groupings: Array[Int], aggregates: Array[TableAggregateFunction[_, _]], inputType: RelDataType, http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala new file mode 100644 index 0000000..81c900c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.slf4j.LoggerFactory + +/** + * Aggregate Function used for the groupby (without window) aggregate + * + * @param genAggregations Generated aggregate helper function + * @param aggregationStateType The row type info of aggregation + */ +class GroupAggProcessFunction( + private val genAggregations: GeneratedAggregationsFunction, + private val aggregationStateType: RowTypeInfo) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + private var output: Row = _ + private var state: ValueState[Row] = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + output = function.createOutputRow() + + val stateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("GroupAggregateState", aggregationStateType) + state = getRuntimeContext.getState(stateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + // get accumulators + var accumulators = state.value() + if (null == accumulators) { + accumulators = function.createAccumulators() + } + + // Set group keys value to the final output + function.setForwardedFields(input, output) + + // accumulate new input row + function.accumulate(accumulators, input) + + // set aggregation results to output + function.setAggregationResults(accumulators, output) + + // update accumulators + state.update(accumulators) + + out.collect(output) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala index 7f87e50..b63eb81 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala @@ -38,7 +38,7 @@ import org.slf4j.LoggerFactory * Process Function used for the aggregate in bounded proc-time OVER window * [[org.apache.flink.streaming.api.datastream.DataStream]] * - * @param genAggregations Generated aggregate helper function + * @param genAggregations Generated aggregate helper function * @param precedingTimeBoundary Is used to indicate the processing time boundaries * @param aggregatesTypeInfo row type info of aggregation * @param inputType row type info of input row http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala index b484293..93e25f8 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala @@ -230,7 +230,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -259,7 +259,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceTest.scala index 7673266..18066c9 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceTest.scala @@ -63,7 +63,7 @@ class TableSourceTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", "StreamTableSourceScan(table=[[rowTimeT]], fields=[id, val, name, addTime])", http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index 6bab4b3..abbcbdd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -47,6 +47,27 @@ class SqlITCase extends StreamingWithStateTestBase { (8L, 8, "Hello World"), (20L, 20, "Hello World")) + /** test unbounded groupby (without window) **/ + @Test + def testUnboundedGroupby(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val sqlQuery = "SELECT b, COUNT(a) FROM MyTable GROUP BY b" + + val t = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("1,1", "2,1", "2,2") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + /** test selection **/ @Test def testSelectExpressionFromTable(): Unit = { http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala index f84ae3d..4c1d6e6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala @@ -60,10 +60,13 @@ class WindowAggregateTest extends TableTestBase { @Test def testPartitionedProcessingTimeBoundedWindow() = { - val sqlQuery = "SELECT a, AVG(c) OVER (PARTITION BY a ORDER BY proctime " + - "RANGE BETWEEN INTERVAL '2' HOUR PRECEDING AND CURRENT ROW) AS avgA " + + val sqlQuery = + "SELECT a, " + + " AVG(c) OVER (PARTITION BY a ORDER BY proctime " + + " RANGE BETWEEN INTERVAL '2' HOUR PRECEDING AND CURRENT ROW) AS avgA " + "FROM MyTable" - val expected = + + val expected = unaryNode( "DataStreamCalc", unaryNode( @@ -85,6 +88,27 @@ class WindowAggregateTest extends TableTestBase { } @Test + def testGroupbyWithoutWindow() = { + val sql = "SELECT COUNT(a) FROM MyTable GROUP BY b" + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a") + ), + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS EXPR$0") + ), + term("select", "EXPR$0") + ) + streamUtil.verifySql(sql, expected) + } + + @Test def testTumbleFunction() = { streamUtil.tEnv.registerFunction("weightedAvg", new WeightedAvgWithMerge) @@ -97,7 +121,7 @@ class WindowAggregateTest extends TableTestBase { "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)" val expected = unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -125,7 +149,7 @@ class WindowAggregateTest extends TableTestBase { "GROUP BY HOP(proctime, INTERVAL '15' MINUTE, INTERVAL '1' HOUR)" val expected = unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -154,7 +178,7 @@ class WindowAggregateTest extends TableTestBase { "GROUP BY SESSION(proctime, INTERVAL '15' MINUTE)" val expected = unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -206,21 +230,6 @@ class WindowAggregateTest extends TableTestBase { streamUtil.verifySql(sql, "n/a") } - @Test(expected = classOf[TableException]) - def testMultiWindow() = { - val sql = "SELECT COUNT(*) FROM MyTable GROUP BY " + - "FLOOR(rowtime TO HOUR), FLOOR(rowtime TO MINUTE)" - val expected = "" - streamUtil.verifySql(sql, expected) - } - - @Test(expected = classOf[TableException]) - def testInvalidWindowExpression() = { - val sql = "SELECT COUNT(*) FROM MyTable GROUP BY FLOOR(localTimestamp TO HOUR)" - val expected = "" - streamUtil.verifySql(sql, expected) - } - @Test(expected = classOf[ValidationException]) def testWindowUdAggInvalidArgs(): Unit = { streamUtil.tEnv.registerFunction("weightedAvg", new WeightedAvgWithMerge) http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala deleted file mode 100644 index 4a6a616..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.api.scala.stream.table - -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.TimeCharacteristic -import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks -import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment -import org.apache.flink.streaming.api.watermark.Watermark -import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMerge} -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.stream.table.AggregationsITCase.TimestampAndWatermarkWithOffset -import org.apache.flink.table.api.scala.stream.utils.StreamITCase -import org.apache.flink.table.functions.aggfunctions.CountAggFunction -import org.apache.flink.types.Row -import org.junit.Assert._ -import org.junit.Test - -import scala.collection.mutable - -/** - * We only test some aggregations until better testing of constructed DataStream - * programs is possible. - */ -class AggregationsITCase extends StreamingMultipleProgramsTestBase { - - val data = List( - (1L, 1, "Hi"), - (2L, 2, "Hello"), - (4L, 2, "Hello"), - (8L, 3, "Hello world"), - (16L, 3, "Hello world")) - - @Test - def testProcessingTimeSlidingGroupWindowOverCount(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setParallelism(1) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env.fromCollection(data) - val table = stream.toTable(tEnv, 'long, 'int, 'string, 'proctime.proctime) - - val countFun = new CountAggFunction - val weightAvgFun = new WeightedAvg - - val windowedTable = table - .window(Slide over 2.rows every 1.rows on 'proctime as 'w) - .groupBy('w, 'string) - .select('string, countFun('int), 'int.avg, - weightAvgFun('long, 'int), weightAvgFun('int, 'int)) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("Hello world,1,3,8,3", "Hello world,2,3,12,3", "Hello,1,2,2,2", - "Hello,2,2,3,2", "Hi,1,1,1,1") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testEventTimeSessionGroupWindowOverTime(): Unit = { - //To verify the "merge" functionality, we create this test with the following characteristics: - // 1. set the Parallelism to 1, and have the test data out of order - // 2. create a waterMark with 10ms offset to delay the window emission by 10ms - val sessionWindowTestdata = List( - (1L, 1, "Hello"), - (2L, 2, "Hello"), - (8L, 8, "Hello"), - (9L, 9, "Hello World"), - (4L, 4, "Hello"), - (16L, 16, "Hello")) - - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - env.setParallelism(1) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val countFun = new CountAggFunction - val weightAvgFun = new WeightedAvgWithMerge - - val stream = env - .fromCollection(sessionWindowTestdata) - .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10L)) - val table = stream.toTable(tEnv, 'long, 'int, 'string, 'rowtime.rowtime) - - val windowedTable = table - .window(Session withGap 5.milli on 'rowtime as 'w) - .groupBy('w, 'string) - .select('string, countFun('int), 'int.avg, - weightAvgFun('long, 'int), weightAvgFun('int, 'int)) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("Hello World,1,9,9,9", "Hello,1,16,16,16", "Hello,4,3,5,5") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setParallelism(1) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env.fromCollection(data) - val table = stream.toTable(tEnv, 'long, 'int, 'string, 'proctime.proctime) - val countFun = new CountAggFunction - val weightAvgFun = new WeightedAvg - - val windowedTable = table - .window(Tumble over 2.rows on 'proctime as 'w) - .groupBy('w) - .select(countFun('string), 'int.avg, - weightAvgFun('long, 'int), weightAvgFun('int, 'int)) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("2,1,1,1", "2,2,6,2") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testEventTimeTumblingWindow(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env - .fromCollection(data) - .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(0L)) - val table = stream.toTable(tEnv, 'long, 'int, 'string, 'rowtime.rowtime) - val countFun = new CountAggFunction - val weightAvgFun = new WeightedAvg - - val windowedTable = table - .window(Tumble over 5.milli on 'rowtime as 'w) - .groupBy('w, 'string) - .select('string, countFun('string), 'int.avg, weightAvgFun('long, 'int), - weightAvgFun('int, 'int), 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq( - "Hello world,1,3,8,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", - "Hello world,1,3,16,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", - "Hello,2,2,3,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", - "Hi,1,1,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } -} - -object AggregationsITCase { - class TimestampAndWatermarkWithOffset( - offset: Long) extends AssignerWithPunctuatedWatermarks[(Long, Int, String)] { - - override def checkAndGetNextWatermark( - lastElement: (Long, Int, String), - extractedTimestamp: Long) - : Watermark = { - new Watermark(extractedTimestamp - offset) - } - - override def extractTimestamp( - element: (Long, Int, String), - previousElementTimestamp: Long): Long = { - element._1 - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala new file mode 100644 index 0000000..271e90b --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.types.Row +import org.junit.Assert.assertEquals +import org.junit.Test + +import scala.collection.mutable + +/** + * Tests of groupby (without window) aggregations + */ +class GroupAggregationsITCase extends StreamingWithStateTestBase { + + @Test + def testNonKeyedGroupAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .select('a.sum, 'b.sum) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", "3,3", "6,5", "10,8", "15,11", "21,14", "28,18", "36,22", "45,26", "55,30", "66,35", + "78,40", "91,45", "105,50", "120,55", "136,61", "153,67", "171,73", "190,79", "210,85", + "231,91") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testGroupAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, 'a.sum) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", "2,2", "2,5", "3,4", "3,9", "3,15", "4,7", "4,15", + "4,24", "4,34", "5,11", "5,23", "5,36", "5,50", "5,65", "6,16", "6,33", "6,51", "6,70", + "6,90", "6,111") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testDoubleGroupAggregation(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('a.sum as 'd, 'b) + .groupBy('b, 'd) + .select('b) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1", + "2", "2", + "3", "3", "3", + "4", "4", "4", "4", + "5", "5", "5", "5", "5", + "6", "6", "6", "6", "6", "6") + + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testGroupAggregateWithExpression(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e, 'b % 3) + .select('c.min, 'e, 'a.avg, 'd.count) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "0,1,1,1", "1,2,2,1", "2,1,2,1", "3,2,3,1", "1,2,2,2", + "5,3,3,1", "3,2,3,2", "7,1,4,1", "2,1,3,2", "3,2,3,3", "7,1,4,2", "5,3,4,2", "12,3,5,1", + "1,2,3,3", "14,2,5,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/8f78824b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala new file mode 100644 index 0000000..520592c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.utils.TableTestBase +import org.junit.Test +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala._ +import org.apache.flink.table.utils.TableTestUtil._ + +class GroupAggregationsTest extends TableTestBase { + + @Test(expected = classOf[ValidationException]) + def testGroupingOnNonExistentField(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val ds = table + // must fail. '_foo is not a valid field + .groupBy('_foo) + .select('a.avg) + } + + @Test(expected = classOf[ValidationException]) + def testGroupingInvalidSelection(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val ds = table + .groupBy('a, 'b) + // must fail. 'c is not a grouping key or aggregation + .select('c) + } + + @Test + def testGroupbyWithoutWindow() = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('a.count) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "b") + ), + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS TMP_0") + ), + term("select", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + + @Test + def testGroupAggregateWithConstant1(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('a, 4 as 'four, 'b) + .groupBy('four, 'a) + .select('four, 'b.sum) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "4 AS four", "b", "a") + ), + term("groupBy", "four", "a"), + term("select", "four", "a", "SUM(b) AS TMP_0") + ), + term("select", "4 AS four", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithConstant2(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('b, 4 as 'four, 'a) + .groupBy('b, 'four) + .select('four, 'a.sum) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "4 AS four", "a", "b") + ), + term("groupBy", "four", "b"), + term("select", "four", "b", "SUM(a) AS TMP_0") + ), + term("select", "4 AS four", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithExpressionInSelect(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('a as 'a, 'b % 3 as 'd, 'c as 'c) + .groupBy('d) + .select('c.min, 'a.avg) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "MOD(b, 3) AS d", "c") + ), + term("groupBy", "d"), + term("select", "d", "MIN(c) AS TMP_0", "AVG(a) AS TMP_1") + ), + term("select", "TMP_0", "TMP_1") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithFilter(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('b, 'a.sum) + .where('b === 2) + + val expected = + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a"), + term("where", "=(b, 2)") + ), + term("groupBy", "b"), + term("select", "b", "SUM(a) AS TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithAverage(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('b, 'a.cast(BasicTypeInfo.DOUBLE_TYPE_INFO).avg) + + val expected = + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a", "CAST(a) AS a0") + ), + term("groupBy", "b"), + term("select", "b", "AVG(a0) AS TMP_0") + ) + + util.verifyTable(resultTable, expected) + } + +}