http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala new file mode 100644 index 0000000..a614783 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.types.Row + +/** + * The interface for all Flink aggregate functions, which expressed in terms of initiate(), + * prepare(), merge() and evaluate(). The aggregate functions would be executed in 2 phases: + * -- In Map phase, use prepare() to transform aggregate field value into intermediate + * aggregate value. + * -- In GroupReduce phase, use merge() to merge grouped intermediate aggregate values + * into aggregate buffer. Then use evaluate() to calculate the final aggregated value. + * For associative decomposable aggregate functions, they support partial aggregate. To optimize + * the performance, a Combine phase would be added between Map phase and GroupReduce phase, + * -- In Combine phase, use merge() to merge sub-grouped intermediate aggregate values + * into aggregate buffer. + * + * The intermediate aggregate value is stored inside Row, aggOffsetInRow is used as the start + * field index in Row, so different aggregate functions could share the same Row as intermediate + * aggregate value/aggregate buffer, as their aggregate values could be stored in distinct fields + * of Row with no conflict. The intermediate aggregate value is required to be a sequence of JVM + * primitives, and Flink use intermediateDataType() to get its data types in SQL side. + * + * @tparam T Aggregated value type. + */ +trait Aggregate[T] extends Serializable { + + /** + * Transform the aggregate field value into intermediate aggregate data. + * + * @param value The value to insert into the intermediate aggregate row. + * @param intermediate The intermediate aggregate row into which the value is inserted. + */ + def prepare(value: Any, intermediate: Row): Unit + + /** + * Initiate the intermediate aggregate value in Row. + * + * @param intermediate The intermediate aggregate row to initiate. + */ + def initiate(intermediate: Row): Unit + + /** + * Merge intermediate aggregate data into aggregate buffer. + * + * @param intermediate The intermediate aggregate row to merge. + * @param buffer The aggregate buffer into which the intermedidate is merged. + */ + def merge(intermediate: Row, buffer: Row): Unit + + /** + * Calculate the final aggregated result based on aggregate buffer. + * + * @param buffer The aggregate buffer from which the final aggregate is computed. + * @return The final result of the aggregate. + */ + def evaluate(buffer: Row): T + + /** + * Intermediate aggregate value types. + * + * @return The types of the intermediate fields of this aggregate. + */ + def intermediateDataType: Array[TypeInformation[_]] + + /** + * Set the aggregate data offset in Row. + * + * @param aggOffset The offset of this aggregate in the intermediate aggregate rows. + */ + def setAggOffsetInRow(aggOffset: Int) + + /** + * Whether aggregate function support partial aggregate. + * + * @return True if the aggregate supports partial aggregation, False otherwise. + */ + def supportPartial: Boolean = false +}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala new file mode 100644 index 0000000..234ecfb --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.Collector + +class AggregateAllTimeWindowFunction( + groupReduceFunction: RichGroupReduceFunction[Row, Row], + windowStartPos: Option[Int], + windowEndPos: Option[Int]) + extends AggregateAllWindowFunction[TimeWindow](groupReduceFunction) { + + private var collector: TimeWindowPropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) + } + + override def apply(window: TimeWindow, input: Iterable[Row], out: Collector[Row]): Unit = { + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = window + + // call wrapped reduce function with property collector + super.apply(window, input, collector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala new file mode 100644 index 0000000..10a06da --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.Collector + +class AggregateAllWindowFunction[W <: Window]( + groupReduceFunction: RichGroupReduceFunction[Row, Row]) + extends RichAllWindowFunction[Row, Row, W] { + + override def open(parameters: Configuration): Unit = { + groupReduceFunction.open(parameters) + } + + override def apply(window: W, input: Iterable[Row], out: Collector[Row]): Unit = { + groupReduceFunction.reduce(input, out) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala new file mode 100644 index 0000000..21a96e0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.util.Preconditions + +class AggregateMapFunction[IN, OUT]( + private val aggregates: Array[Aggregate[_]], + private val aggFields: Array[Int], + private val groupingKeys: Array[Int], + @transient private val returnType: TypeInformation[OUT]) + extends RichMapFunction[IN, OUT] + with ResultTypeQueryable[OUT] { + + private var output: Row = _ + + override def open(config: Configuration) { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkArgument(aggregates.size == aggFields.size) + val partialRowLength = groupingKeys.length + + aggregates.map(_.intermediateDataType.length).sum + output = new Row(partialRowLength) + } + + override def map(value: IN): OUT = { + + val input = value.asInstanceOf[Row] + for (i <- 0 until aggregates.length) { + val fieldValue = input.getField(aggFields(i)) + aggregates(i).prepare(fieldValue, output) + } + for (i <- 0 until groupingKeys.length) { + output.setField(i, input.getField(groupingKeys(i))) + } + output.asInstanceOf[OUT] + } + + override def getProducedType: TypeInformation[OUT] = { + returnType + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala new file mode 100644 index 0000000..31b85cd --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.CombineFunction +import org.apache.flink.types.Row + +import scala.collection.JavaConversions._ + + +/** + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and + * [[org.apache.flink.api.java.operators.GroupCombineOperator]] + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + */ +class AggregateReduceCombineFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val intermediateRowArity: Int, + private val finalRowArity: Int) + extends AggregateReduceGroupFunction( + aggregates, + groupKeysMapping, + aggregateMapping, + intermediateRowArity, + finalRowArity) + with CombineFunction[Row, Row] { + + /** + * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * + * @param records Sub-grouped intermediate aggregate Rows iterator. + * @return Combined intermediate aggregate Row. + * + */ + override def combine(records: Iterable[Row]): Row = { + + // Initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + + // Merge intermediate aggregate value to buffer. + var last: Row = null + records.foreach((record) => { + aggregates.foreach(_.merge(record, aggregateBuffer)) + last = record + }) + + // Set group keys to aggregateBuffer. + for (i <- groupKeysMapping.indices) { + aggregateBuffer.setField(i, last.getField(i)) + } + + aggregateBuffer + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala new file mode 100644 index 0000000..c1efebb --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.util.{Collector, Preconditions} + +import scala.collection.JavaConversions._ + +/** + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + */ +class AggregateReduceGroupFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val intermediateRowArity: Int, + private val finalRowArity: Int) + extends RichGroupReduceFunction[Row, Row] { + + protected var aggregateBuffer: Row = _ + private var output: Row = _ + + override def open(config: Configuration) { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + aggregateBuffer = new Row(intermediateRowArity) + output = new Row(finalRowArity) + } + + /** + * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate data and output data. + * + * @param records Grouped intermediate aggregate Rows iterator. + * @param out The collector to hand results to. + * + */ + override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { + + // Initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + + // Merge intermediate aggregate value to buffer. + var last: Row = null + records.foreach((record) => { + aggregates.foreach(_.merge(record, aggregateBuffer)) + last = record + }) + + // Set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, last.getField(previous)) + } + + // Evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + } + + out.collect(output) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala new file mode 100644 index 0000000..b7419dd --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.Collector + +class AggregateTimeWindowFunction( + groupReduceFunction: RichGroupReduceFunction[Row, Row], + windowStartPos: Option[Int], + windowEndPos: Option[Int]) + extends AggregateWindowFunction[TimeWindow](groupReduceFunction) { + + private var collector: TimeWindowPropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) + } + + override def apply( + key: Tuple, + window: TimeWindow, + input: Iterable[Row], + out: Collector[Row]): Unit = { + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = window + + // call wrapped reduce function with property collector + super.apply(key, window, input, collector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala new file mode 100644 index 0000000..282e6c0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -0,0 +1,595 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util + +import org.apache.calcite.rel.`type`._ +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.sql.{SqlAggFunction, SqlKind} +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName} +import org.apache.calcite.sql.fun._ +import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory} +import FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.expressions.{WindowEnd, WindowStart} +import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.typeutils.TypeCheckUtils._ +import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} +import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} +import org.apache.flink.table.api.TableException +import org.apache.flink.types.Row + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +object AggregateUtil { + + type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R] + type JavaList[T] = java.util.List[T] + + /** + * Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates. + * The function returns intermediate aggregate values of all aggregate function which are + * organized by the following format: + * + * {{{ + * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 + * | | + * v v + * +---------+---------+--------+--------+--------+--------+ + * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | + * +---------+---------+--------+--------+--------+--------+ + * ^ + * | + * sum(y) aggOffsetInRow = 4 + * }}} + * + */ + private[flink] def createPrepareMapFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + groupings: Array[Int], + inputType: RelDataType): MapFunction[Any, Row] = { + + val (aggFieldIndexes,aggregates) = transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + groupings.length) + + val mapReturnType: RowTypeInfo = + createAggregateBufferDataType(groupings, aggregates, inputType) + + val mapFunction = new AggregateMapFunction[Row, Row]( + aggregates, + aggFieldIndexes, + groupings, + mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, Row]] + + mapFunction + } + + /** + * Create a [[org.apache.flink.api.common.functions.GroupReduceFunction]] to compute aggregates. + * If all aggregates support partial aggregation, the + * [[org.apache.flink.api.common.functions.GroupReduceFunction]] implements + * [[org.apache.flink.api.common.functions.CombineFunction]] as well. + * + */ + private[flink] def createAggregateGroupReduceFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int]): RichGroupReduceFunction[Row, Row] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + groupings.length)._2 + + val (groupingOffsetMapping, aggOffsetMapping) = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings) + + val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial) + + val intermediateRowArity = groupings.length + + aggregates.map(_.intermediateDataType.length).sum + + val groupReduceFunction = + if (allPartialAggregate) { + new AggregateReduceCombineFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + intermediateRowArity, + outputType.getFieldCount) + } + else { + new AggregateReduceGroupFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + intermediateRowArity, + outputType.getFieldCount) + } + groupReduceFunction + } + + /** + * Create a [[org.apache.flink.api.common.functions.ReduceFunction]] for incremental window + * aggregation. + * + */ + private[flink] def createIncrementalAggregateReduceFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int]): IncrementalAggregateReduceFunction = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey),inputType,groupings.length)._2 + + val groupingOffsetMapping = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings)._1 + + val intermediateRowArity = groupings.length + aggregates.map(_.intermediateDataType.length).sum + val reduceFunction = new IncrementalAggregateReduceFunction( + aggregates, + groupingOffsetMapping, + intermediateRowArity) + reduceFunction + } + + /** + * Create an [[AllWindowFunction]] to compute non-partitioned group window aggregates. + */ + private[flink] def createAllWindowAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]) + : AllWindowFunction[Row, Row, DataStreamWindow] = { + + val aggFunction = + createAggregateGroupReduceFunction( + namedAggregates, + inputType, + outputType, + groupings) + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos) + .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] + } else { + new AggregateAllWindowFunction(aggFunction) + } + } + + /** + * Create a [[WindowFunction]] to compute partitioned group window aggregates. + * + */ + private[flink] def createWindowAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]) + : WindowFunction[Row, Row, Tuple, DataStreamWindow] = { + + val aggFunction = + createAggregateGroupReduceFunction( + namedAggregates, + inputType, + outputType, + groupings) + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new AggregateTimeWindowFunction(aggFunction, startPos, endPos) + .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] + } else { + new AggregateWindowFunction(aggFunction) + } + } + + /** + * Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned + * window aggreagtes. + */ + private[flink] def createAllWindowIncrementalAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]): AllWindowFunction[Row, Row, DataStreamWindow] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey),inputType,groupings.length)._2 + + val (groupingOffsetMapping, aggOffsetMapping) = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings) + + val finalRowArity = outputType.getFieldCount + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new IncrementalAggregateAllTimeWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity, + startPos, + endPos) + .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] + } else { + new IncrementalAggregateAllWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity) + } + } + + /** + * Create a [[WindowFunction]] to finalize incrementally pre-computed window aggregates. + */ + private[flink] def createWindowIncrementalAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]): WindowFunction[Row, Row, Tuple, DataStreamWindow] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey),inputType,groupings.length)._2 + + val (groupingOffsetMapping, aggOffsetMapping) = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings) + + val finalRowArity = outputType.getFieldCount + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new IncrementalAggregateTimeWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity, + startPos, + endPos) + .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] + } else { + new IncrementalAggregateWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity) + } + } + + /** + * Return true if all aggregates can be partially computed. False otherwise. + */ + private[flink] def doAllSupportPartialAggregation( + aggregateCalls: Seq[AggregateCall], + inputType: RelDataType, + groupKeysCount: Int): Boolean = { + transformToAggregateFunctions( + aggregateCalls, + inputType, + groupKeysCount)._2.forall(_.supportPartial) + } + + /** + * @return groupingOffsetMapping (mapping relation between field index of intermediate + * aggregate Row and output Row.) + * and aggOffsetMapping (the mapping relation between aggregate function index in list + * and its corresponding field index in output Row.) + */ + private def getGroupingOffsetAndAggOffsetMapping( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int]): (Array[(Int, Int)], Array[(Int, Int)]) = { + + // the mapping relation between field index of intermediate aggregate Row and output Row. + val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) + + // the mapping relation between aggregate function index in list and its corresponding + // field index in output Row. + val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType) + + if (groupingOffsetMapping.length != groupings.length || + aggOffsetMapping.length != namedAggregates.length) { + throw new TableException( + "Could not find output field in input data type " + + "or aggregate functions.") + } + (groupingOffsetMapping, aggOffsetMapping) + } + + private def isTimeWindow(window: LogicalWindow) = { + window match { + case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType) + case ProcessingTimeSlidingGroupWindow(_, size, _) => isTimeInterval(size.resultType) + case ProcessingTimeSessionGroupWindow(_, _) => true + case EventTimeTumblingGroupWindow(_, _, size) => isTimeInterval(size.resultType) + case EventTimeSlidingGroupWindow(_, _, size, _) => isTimeInterval(size.resultType) + case EventTimeSessionGroupWindow(_, _, _) => true + } + } + + private def computeWindowStartEndPropertyPos( + properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = { + + val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) { + (p, x) => p match { + case NamedWindowProperty(name, prop) => + prop match { + case WindowStart(_) if x._1.isDefined => + throw new TableException("Duplicate WindowStart property encountered. This is a bug.") + case WindowStart(_) => + (Some(x._3), x._2, x._3 - 1) + case WindowEnd(_) if x._2.isDefined => + throw new TableException("Duplicate WindowEnd property encountered. This is a bug.") + case WindowEnd(_) => + (x._1, Some(x._3), x._3 - 1) + } + } + } + (propPos._1, propPos._2) + } + + private def transformToAggregateFunctions( + aggregateCalls: Seq[AggregateCall], + inputType: RelDataType, + groupKeysCount: Int): (Array[Int], Array[Aggregate[_ <: Any]]) = { + + // store the aggregate fields of each aggregate function, by the same order of aggregates. + val aggFieldIndexes = new Array[Int](aggregateCalls.size) + val aggregates = new Array[Aggregate[_ <: Any]](aggregateCalls.size) + + // set the start offset of aggregate buffer value to group keys' length, + // as all the group keys would be moved to the start fields of intermediate + // aggregate data. + var aggOffset = groupKeysCount + + // create aggregate function instances by function type and aggregate field data type. + aggregateCalls.zipWithIndex.foreach { case (aggregateCall, index) => + val argList: util.List[Integer] = aggregateCall.getArgList + if (argList.isEmpty) { + if (aggregateCall.getAggregation.isInstanceOf[SqlCountAggFunction]) { + aggFieldIndexes(index) = 0 + } else { + throw new TableException("Aggregate fields should not be empty.") + } + } else { + if (argList.size() > 1) { + throw new TableException("Currently, do not support aggregate on multi fields.") + } + aggFieldIndexes(index) = argList.get(0) + } + val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)).getType.getSqlTypeName + aggregateCall.getAggregation match { + case _: SqlSumAggFunction | _: SqlSumEmptyIsZeroAggFunction => { + aggregates(index) = sqlTypeName match { + case TINYINT => + new ByteSumAggregate + case SMALLINT => + new ShortSumAggregate + case INTEGER => + new IntSumAggregate + case BIGINT => + new LongSumAggregate + case FLOAT => + new FloatSumAggregate + case DOUBLE => + new DoubleSumAggregate + case DECIMAL => + new DecimalSumAggregate + case sqlType: SqlTypeName => + throw new TableException("Sum aggregate does no support type:" + sqlType) + } + } + case _: SqlAvgAggFunction => { + aggregates(index) = sqlTypeName match { + case TINYINT => + new ByteAvgAggregate + case SMALLINT => + new ShortAvgAggregate + case INTEGER => + new IntAvgAggregate + case BIGINT => + new LongAvgAggregate + case FLOAT => + new FloatAvgAggregate + case DOUBLE => + new DoubleAvgAggregate + case DECIMAL => + new DecimalAvgAggregate + case sqlType: SqlTypeName => + throw new TableException("Avg aggregate does no support type:" + sqlType) + } + } + case sqlMinMaxFunction: SqlMinMaxAggFunction => { + aggregates(index) = if (sqlMinMaxFunction.getKind == SqlKind.MIN) { + sqlTypeName match { + case TINYINT => + new ByteMinAggregate + case SMALLINT => + new ShortMinAggregate + case INTEGER => + new IntMinAggregate + case BIGINT => + new LongMinAggregate + case FLOAT => + new FloatMinAggregate + case DOUBLE => + new DoubleMinAggregate + case DECIMAL => + new DecimalMinAggregate + case BOOLEAN => + new BooleanMinAggregate + case sqlType: SqlTypeName => + throw new TableException("Min aggregate does no support type:" + sqlType) + } + } else { + sqlTypeName match { + case TINYINT => + new ByteMaxAggregate + case SMALLINT => + new ShortMaxAggregate + case INTEGER => + new IntMaxAggregate + case BIGINT => + new LongMaxAggregate + case FLOAT => + new FloatMaxAggregate + case DOUBLE => + new DoubleMaxAggregate + case DECIMAL => + new DecimalMaxAggregate + case BOOLEAN => + new BooleanMaxAggregate + case sqlType: SqlTypeName => + throw new TableException("Max aggregate does no support type:" + sqlType) + } + } + } + case _: SqlCountAggFunction => + aggregates(index) = new CountAggregate + case unSupported: SqlAggFunction => + throw new TableException("unsupported Function: " + unSupported.getName) + } + setAggregateDataOffset(index) + } + + // set the aggregate intermediate data start index in Row, and update current value. + def setAggregateDataOffset(index: Int): Unit = { + aggregates(index).setAggOffsetInRow(aggOffset) + aggOffset += aggregates(index).intermediateDataType.length + } + + (aggFieldIndexes, aggregates) + } + + private def createAggregateBufferDataType( + groupings: Array[Int], + aggregates: Array[Aggregate[_]], + inputType: RelDataType): RowTypeInfo = { + + // get the field data types of group keys. + val groupingTypes: Seq[TypeInformation[_]] = groupings + .map(inputType.getFieldList.get(_).getType) + .map(FlinkTypeFactory.toTypeInfo) + + val aggPartialNameSuffix = "agg_buffer_" + val factory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT) + + // get all field data types of all intermediate aggregates + val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType) + + // concat group key types and aggregation types + val allFieldTypes = groupingTypes ++: aggTypes + val partialType = new RowTypeInfo(allFieldTypes: _*) + partialType + } + + // Find the mapping between the index of aggregate list and aggregated value index in output Row. + private def getAggregateMapping( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + outputType: RelDataType): Array[(Int, Int)] = { + + // the mapping relation between aggregate function index in list and its corresponding + // field index in output Row. + var aggOffsetMapping = ArrayBuffer[(Int, Int)]() + + outputType.getFieldList.zipWithIndex.foreach{ + case (outputFieldType, outputIndex) => + namedAggregates.zipWithIndex.foreach { + case (namedAggCall, aggregateIndex) => + if (namedAggCall.getValue.equals(outputFieldType.getName) && + namedAggCall.getKey.getType.equals(outputFieldType.getType)) { + aggOffsetMapping += ((outputIndex, aggregateIndex)) + } + } + } + + aggOffsetMapping.toArray + } + + // Find the mapping between the index of group key in intermediate aggregate Row and its index + // in output Row. + private def getGroupKeysMapping( + inputDatType: RelDataType, + outputType: RelDataType, + groupKeys: Array[Int]): Array[(Int, Int)] = { + + // the mapping relation between field index of intermediate aggregate Row and output Row. + var groupingOffsetMapping = ArrayBuffer[(Int, Int)]() + + outputType.getFieldList.zipWithIndex.foreach { + case (outputFieldType, outputIndex) => + inputDatType.getFieldList.zipWithIndex.foreach { + // find the field index in input data type. + case (inputFieldType, inputIndex) => + if (outputFieldType.getName.equals(inputFieldType.getName) && + outputFieldType.getType.equals(inputFieldType.getType)) { + // as aggregated field in output data type would not have a matched field in + // input data, so if inputIndex is not -1, it must be a group key. Then we can + // find the field index in buffer data by the group keys index mapping between + // input data and buffer data. + for (i <- groupKeys.indices) { + if (inputIndex == groupKeys(i)) { + groupingOffsetMapping += ((outputIndex, i)) + } + } + } + } + } + + groupingOffsetMapping.toArray + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala new file mode 100644 index 0000000..5491b1d --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.Collector + +class AggregateWindowFunction[W <: Window](groupReduceFunction: RichGroupReduceFunction[Row, Row]) + extends RichWindowFunction[Row, Row, Tuple, W] { + + override def open(parameters: Configuration): Unit = { + groupReduceFunction.open(parameters) + } + + override def apply( + key: Tuple, + window: W, + input: Iterable[Row], + out: Collector[Row]): Unit = { + + groupReduceFunction.reduce(input, out) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala new file mode 100644 index 0000000..cb94ca1 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import com.google.common.math.LongMath +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.types.Row +import java.math.BigDecimal +import java.math.BigInteger + +abstract class AvgAggregate[T] extends Aggregate[T] { + protected var partialSumIndex: Int = _ + protected var partialCountIndex: Int = _ + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggOffset: Int): Unit = { + partialSumIndex = aggOffset + partialCountIndex = aggOffset + 1 + } +} + +abstract class IntegralAvgAggregate[T] extends AvgAggregate[T] { + + override def initiate(partial: Row): Unit = { + partial.setField(partialSumIndex, 0L) + partial.setField(partialCountIndex, 0L) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + partial.setField(partialSumIndex, 0L) + partial.setField(partialCountIndex, 0L) + } else { + doPrepare(value, partial) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialSum = partial.getField(partialSumIndex).asInstanceOf[Long] + val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long] + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + buffer.setField(partialSumIndex, LongMath.checkedAdd(partialSum, bufferSum)) + buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount)) + } + + override def evaluate(buffer : Row): T = { + doEvaluate(buffer).asInstanceOf[T] + } + + override def intermediateDataType = Array( + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + + def doPrepare(value: Any, partial: Row): Unit + + def doEvaluate(buffer: Row): Any +} + +class ByteAvgAggregate extends IntegralAvgAggregate[Byte] { + override def doPrepare(value: Any, partial: Row): Unit = { + val input = value.asInstanceOf[Byte] + partial.setField(partialSumIndex, input.toLong) + partial.setField(partialCountIndex, 1L) + } + + override def doEvaluate(buffer: Row): Any = { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount == 0L) { + null + } else { + (bufferSum / bufferCount).toByte + } + } +} + +class ShortAvgAggregate extends IntegralAvgAggregate[Short] { + + override def doPrepare(value: Any, partial: Row): Unit = { + val input = value.asInstanceOf[Short] + partial.setField(partialSumIndex, input.toLong) + partial.setField(partialCountIndex, 1L) + } + + override def doEvaluate(buffer: Row): Any = { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount == 0L) { + null + } else { + (bufferSum / bufferCount).toShort + } + } +} + +class IntAvgAggregate extends IntegralAvgAggregate[Int] { + + override def doPrepare(value: Any, partial: Row): Unit = { + val input = value.asInstanceOf[Int] + partial.setField(partialSumIndex, input.toLong) + partial.setField(partialCountIndex, 1L) + } + + override def doEvaluate(buffer: Row): Any = { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount == 0L) { + null + } else { + (bufferSum / bufferCount).toInt + } + } +} + +class LongAvgAggregate extends IntegralAvgAggregate[Long] { + + override def intermediateDataType = Array( + BasicTypeInfo.BIG_INT_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + + override def initiate(partial: Row): Unit = { + partial.setField(partialSumIndex, BigInteger.ZERO) + partial.setField(partialCountIndex, 0L) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + partial.setField(partialSumIndex, BigInteger.ZERO) + partial.setField(partialCountIndex, 0L) + } else { + doPrepare(value, partial) + } + } + + override def doPrepare(value: Any, partial: Row): Unit = { + val input = value.asInstanceOf[Long] + partial.setField(partialSumIndex, BigInteger.valueOf(input)) + partial.setField(partialCountIndex, 1L) + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialSum = partial.getField(partialSumIndex).asInstanceOf[BigInteger] + val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long] + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigInteger] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + buffer.setField(partialSumIndex, partialSum.add(bufferSum)) + buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount)) + } + + override def doEvaluate(buffer: Row): Any = { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigInteger] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount == 0L) { + null + } else { + bufferSum.divide(BigInteger.valueOf(bufferCount)).longValue() + } + } +} + +abstract class FloatingAvgAggregate[T: Numeric] extends AvgAggregate[T] { + + override def initiate(partial: Row): Unit = { + partial.setField(partialSumIndex, 0D) + partial.setField(partialCountIndex, 0L) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + partial.setField(partialSumIndex, 0D) + partial.setField(partialCountIndex, 0L) + } else { + doPrepare(value, partial) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialSum = partial.getField(partialSumIndex).asInstanceOf[Double] + val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long] + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Double] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + + buffer.setField(partialSumIndex, partialSum + bufferSum) + buffer.setField(partialCountIndex, partialCount + bufferCount) + } + + override def evaluate(buffer : Row): T = { + doEvaluate(buffer).asInstanceOf[T] + } + + override def intermediateDataType = Array( + BasicTypeInfo.DOUBLE_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + + def doPrepare(value: Any, partial: Row): Unit + + def doEvaluate(buffer: Row): Any +} + +class FloatAvgAggregate extends FloatingAvgAggregate[Float] { + + override def doPrepare(value: Any, partial: Row): Unit = { + val input = value.asInstanceOf[Float] + partial.setField(partialSumIndex, input.toDouble) + partial.setField(partialCountIndex, 1L) + } + + + override def doEvaluate(buffer: Row): Any = { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Double] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount == 0L) { + null + } else { + (bufferSum / bufferCount).toFloat + } + } +} + +class DoubleAvgAggregate extends FloatingAvgAggregate[Double] { + + override def doPrepare(value: Any, partial: Row): Unit = { + val input = value.asInstanceOf[Double] + partial.setField(partialSumIndex, input) + partial.setField(partialCountIndex, 1L) + } + + override def doEvaluate(buffer: Row): Any = { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Double] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount == 0L) { + null + } else { + (bufferSum / bufferCount) + } + } +} + +class DecimalAvgAggregate extends AvgAggregate[BigDecimal] { + + override def intermediateDataType = Array( + BasicTypeInfo.BIG_DEC_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + + override def initiate(partial: Row): Unit = { + partial.setField(partialSumIndex, BigDecimal.ZERO) + partial.setField(partialCountIndex, 0L) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + initiate(partial) + } else { + val input = value.asInstanceOf[BigDecimal] + partial.setField(partialSumIndex, input) + partial.setField(partialCountIndex, 1L) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialSum = partial.getField(partialSumIndex).asInstanceOf[BigDecimal] + val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long] + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigDecimal] + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + buffer.setField(partialSumIndex, partialSum.add(bufferSum)) + buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount)) + } + + override def evaluate(buffer: Row): BigDecimal = { + val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long] + if (bufferCount != 0) { + val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigDecimal] + bufferSum.divide(BigDecimal.valueOf(bufferCount)) + } else { + null.asInstanceOf[BigDecimal] + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala new file mode 100644 index 0000000..ea8e1d8 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.types.Row + +class CountAggregate extends Aggregate[Long] { + private var countIndex: Int = _ + + override def initiate(intermediate: Row): Unit = { + intermediate.setField(countIndex, 0L) + } + + override def merge(intermediate: Row, buffer: Row): Unit = { + val partialCount = intermediate.getField(countIndex).asInstanceOf[Long] + val bufferCount = buffer.getField(countIndex).asInstanceOf[Long] + buffer.setField(countIndex, partialCount + bufferCount) + } + + override def evaluate(buffer: Row): Long = { + buffer.getField(countIndex).asInstanceOf[Long] + } + + override def prepare(value: Any, intermediate: Row): Unit = { + if (value == null) { + intermediate.setField(countIndex, 0L) + } else { + intermediate.setField(countIndex, 1L) + } + } + + override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO) + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggIndex: Int): Unit = { + countIndex = aggIndex + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala new file mode 100644 index 0000000..5d7a94b --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.Collector +/** + * + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateAllTimeWindowFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int, + private val windowStartPos: Option[Int], + private val windowEndPos: Option[Int]) + extends IncrementalAggregateAllWindowFunction[TimeWindow]( + aggregates, + groupKeysMapping, + aggregateMapping, + finalRowArity) { + + private var collector: TimeWindowPropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) + } + + override def apply( + window: TimeWindow, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = window + + super.apply(window, records, collector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala new file mode 100644 index 0000000..3c41a62 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.{Collector, Preconditions} + +/** + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateAllWindowFunction[W <: Window]( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int) + extends RichAllWindowFunction[Row, Row, W] { + + private var output: Row = _ + + override def open(parameters: Configuration): Unit = { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + output = new Row(finalRowArity) + } + + /** + * Calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate data and output data. + */ + override def apply( + window: W, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + val iterator = records.iterator + + if (iterator.hasNext) { + val record = iterator.next() + // Set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, record.getField(previous)) + } + // Evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(record)) + } + out.collect(output) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala new file mode 100644 index 0000000..14b44e8 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.common.functions.ReduceFunction +import org.apache.flink.types.Row +import org.apache.flink.util.Preconditions + +/** + * Incrementally computes group window aggregates. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + */ +class IncrementalAggregateReduceFunction( + private val aggregates: Array[Aggregate[_]], + private val groupKeysMapping: Array[(Int, Int)], + private val intermediateRowArity: Int) + extends ReduceFunction[Row] { + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + + /** + * For Incremental intermediate aggregate Rows, merge value1 and value2 + * into aggregate buffer, return aggregate buffer. + * + * @param value1 The first value to combined. + * @param value2 The second value to combined. + * @return accumulatorRow A resulting row that combines two input values. + * + */ + override def reduce(value1: Row, value2: Row): Row = { + + // TODO: once FLINK-5105 is solved, we can avoid creating a new row for each invocation + // and directly merge value1 and value2. + val accumulatorRow = new Row(intermediateRowArity) + + // copy all fields of value1 into accumulatorRow + (0 until intermediateRowArity) + .foreach(i => accumulatorRow.setField(i, value1.getField(i))) + // merge value2 to accumulatorRow + aggregates.foreach(_.merge(value2, accumulatorRow)) + + accumulatorRow + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala new file mode 100644 index 0000000..a96ce7a --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.Collector + +/** + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateTimeWindowFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int, + private val windowStartPos: Option[Int], + private val windowEndPos: Option[Int]) + extends IncrementalAggregateWindowFunction[TimeWindow]( + aggregates, + groupKeysMapping, + aggregateMapping, finalRowArity) { + + private var collector: TimeWindowPropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) + } + + override def apply( + key: Tuple, + window: TimeWindow, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = window + + super.apply(key, window, records, collector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala new file mode 100644 index 0000000..30f7a7b --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.{Collector, Preconditions} + +/** + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateWindowFunction[W <: Window]( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int) + extends RichWindowFunction[Row, Row, Tuple, W] { + + private var output: Row = _ + + override def open(parameters: Configuration): Unit = { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + output = new Row(finalRowArity) + } + + /** + * Calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate data and output data. + */ + override def apply( + key: Tuple, + window: W, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + val iterator = records.iterator + + if (iterator.hasNext) { + val record = iterator.next() + // Set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, record.getField(previous)) + } + // Evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(record)) + } + out.collect(output) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala new file mode 100644 index 0000000..34b25e0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.math.BigDecimal + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.types.Row + +abstract class MaxAggregate[T](implicit ord: Ordering[T]) extends Aggregate[T] { + + protected var maxIndex = -1 + + /** + * Initiate the intermediate aggregate value in Row. + * + * @param intermediate The intermediate aggregate row to initiate. + */ + override def initiate(intermediate: Row): Unit = { + intermediate.setField(maxIndex, null) + } + + /** + * Accessed in MapFunction, prepare the input of partial aggregate. + * + * @param value + * @param intermediate + */ + override def prepare(value: Any, intermediate: Row): Unit = { + if (value == null) { + initiate(intermediate) + } else { + intermediate.setField(maxIndex, value) + } + } + + /** + * Accessed in CombineFunction and GroupReduceFunction, merge partial + * aggregate result into aggregate buffer. + * + * @param intermediate + * @param buffer + */ + override def merge(intermediate: Row, buffer: Row): Unit = { + val partialValue = intermediate.getField(maxIndex).asInstanceOf[T] + if (partialValue != null) { + val bufferValue = buffer.getField(maxIndex).asInstanceOf[T] + if (bufferValue != null) { + val max : T = if (ord.compare(partialValue, bufferValue) > 0) partialValue else bufferValue + buffer.setField(maxIndex, max) + } else { + buffer.setField(maxIndex, partialValue) + } + } + } + + /** + * Return the final aggregated result based on aggregate buffer. + * + * @param buffer + * @return + */ + override def evaluate(buffer: Row): T = { + buffer.getField(maxIndex).asInstanceOf[T] + } + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggOffset: Int): Unit = { + maxIndex = aggOffset + } +} + +class ByteMaxAggregate extends MaxAggregate[Byte] { + + override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO) + +} + +class ShortMaxAggregate extends MaxAggregate[Short] { + + override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO) + +} + +class IntMaxAggregate extends MaxAggregate[Int] { + + override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO) + +} + +class LongMaxAggregate extends MaxAggregate[Long] { + + override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO) + +} + +class FloatMaxAggregate extends MaxAggregate[Float] { + + override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO) + +} + +class DoubleMaxAggregate extends MaxAggregate[Double] { + + override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO) + +} + +class BooleanMaxAggregate extends MaxAggregate[Boolean] { + + override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO) + +} + +class DecimalMaxAggregate extends Aggregate[BigDecimal] { + + protected var minIndex: Int = _ + + override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO) + + override def initiate(intermediate: Row): Unit = { + intermediate.setField(minIndex, null) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + initiate(partial) + } else { + partial.setField(minIndex, value) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialValue = partial.getField(minIndex).asInstanceOf[BigDecimal] + if (partialValue != null) { + val bufferValue = buffer.getField(minIndex).asInstanceOf[BigDecimal] + if (bufferValue != null) { + val min = if (partialValue.compareTo(bufferValue) > 0) partialValue else bufferValue + buffer.setField(minIndex, min) + } else { + buffer.setField(minIndex, partialValue) + } + } + } + + override def evaluate(buffer: Row): BigDecimal = { + buffer.getField(minIndex).asInstanceOf[BigDecimal] + } + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggOffset: Int): Unit = { + minIndex = aggOffset + } +}