[scala] [streaming] Modified aggregations to work on scala tuples
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f7291ea1 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f7291ea1 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f7291ea1 Branch: refs/heads/release-0.8 Commit: f7291ea1c9fa4a0484ab6bc13e4a594ff6d7c2d2 Parents: 40a3b6b Author: Gyula Fora <[email protected]> Authored: Sat Dec 20 23:46:35 2014 +0100 Committer: mbalassi <[email protected]> Committed: Mon Jan 5 17:57:44 2015 +0100 ---------------------------------------------------------------------- .../aggregation/AggregationFunction.java | 2 +- .../aggregation/ComparableAggregator.java | 8 +- .../api/function/aggregation/SumFunction.java | 12 +- .../operator/StreamReduceInvokable.java | 1 + .../streaming/ScalaStreamingAggregator.java | 111 +++++++++++++++++++ .../flink/api/scala/streaming/DataStream.scala | 52 +++++---- .../scala/streaming/WindowedDataStream.scala | 64 ++++------- 7 files changed, 171 insertions(+), 79 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java ---------------------------------------------------------------------- diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java index d95c37e..1c273d3 100644 --- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java +++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.functions.ReduceFunction; public abstract class AggregationFunction<T> implements ReduceFunction<T> { private static final long serialVersionUID = 1L; - int position; + public int position; public AggregationFunction(int pos) { this.position = pos; http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java ---------------------------------------------------------------------- diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java index 6e2a400..5fb8f62 100644 --- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java +++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java @@ -35,11 +35,11 @@ public abstract class ComparableAggregator<T> extends AggregationFunction<T> { private static final long serialVersionUID = 1L; - Comparator comparator; - boolean byAggregate; - boolean first; + public Comparator comparator; + public boolean byAggregate; + public boolean first; - private ComparableAggregator(int pos, AggregationType aggregationType, boolean first) { + public ComparableAggregator(int pos, AggregationType aggregationType, boolean first) { super(pos); this.comparator = Comparator.getForAggregation(aggregationType); this.byAggregate = (aggregationType == AggregationType.MAXBY) http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java ---------------------------------------------------------------------- diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java index 1ac236d..2aef19c 100644 --- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java +++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java @@ -45,7 +45,7 @@ public abstract class SumFunction implements Serializable{ } } - private static class IntSum extends SumFunction { + public static class IntSum extends SumFunction { private static final long serialVersionUID = 1L; @Override @@ -54,7 +54,7 @@ public abstract class SumFunction implements Serializable{ } } - private static class LongSum extends SumFunction { + public static class LongSum extends SumFunction { private static final long serialVersionUID = 1L; @Override @@ -63,7 +63,7 @@ public abstract class SumFunction implements Serializable{ } } - private static class DoubleSum extends SumFunction { + public static class DoubleSum extends SumFunction { private static final long serialVersionUID = 1L; @@ -73,7 +73,7 @@ public abstract class SumFunction implements Serializable{ } } - private static class ShortSum extends SumFunction { + public static class ShortSum extends SumFunction { private static final long serialVersionUID = 1L; @Override @@ -82,7 +82,7 @@ public abstract class SumFunction implements Serializable{ } } - private static class FloatSum extends SumFunction { + public static class FloatSum extends SumFunction { private static final long serialVersionUID = 1L; @Override @@ -91,7 +91,7 @@ public abstract class SumFunction implements Serializable{ } } - private static class ByteSum extends SumFunction { + public static class ByteSum extends SumFunction { private static final long serialVersionUID = 1L; @Override http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java ---------------------------------------------------------------------- diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java index 4bb78b8..5f5cb12 100644 --- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java +++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java @@ -52,6 +52,7 @@ public class StreamReduceInvokable<IN> extends StreamInvokable<IN, IN> { currentValue = reducer.reduce(currentValue, nextValue); } else { currentValue = nextValue; + } collector.collect(currentValue); http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java b/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java new file mode 100644 index 0000000..2f587d7 --- /dev/null +++ b/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.streaming; + +import java.io.Serializable; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase; +import org.apache.flink.streaming.api.function.aggregation.AggregationFunction; +import org.apache.flink.streaming.api.function.aggregation.ComparableAggregator; +import org.apache.flink.streaming.api.function.aggregation.SumFunction; + +import scala.Product; + +public class ScalaStreamingAggregator<IN extends Product> implements Serializable { + + private static final long serialVersionUID = 1L; + + TupleSerializerBase<IN> serializer; + Object[] fields; + int length; + int position; + + public ScalaStreamingAggregator(TypeSerializer<IN> serializer, int pos) { + this.serializer = (TupleSerializerBase<IN>) serializer; + this.length = this.serializer.getArity(); + this.fields = new Object[this.length]; + this.position = pos; + } + + public class Sum extends AggregationFunction<IN> { + private static final long serialVersionUID = 1L; + SumFunction sumFunction; + + public Sum(SumFunction func) { + super(ScalaStreamingAggregator.this.position); + this.sumFunction = func; + } + + @Override + public IN reduce(IN value1, IN value2) throws Exception { + for (int i = 0; i < length; i++) { + fields[i] = value2.productElement(i); + } + + fields[position] = sumFunction.add(fields[position], value1.productElement(position)); + + return serializer.createInstance(fields); + } + } + + public class ProductComparableAggregator extends ComparableAggregator<IN> { + + private static final long serialVersionUID = 1L; + + public ProductComparableAggregator(AggregationFunction.AggregationType aggregationType, + boolean first) { + super(ScalaStreamingAggregator.this.position, aggregationType, first); + } + + @SuppressWarnings("unchecked") + @Override + public IN reduce(IN value1, IN value2) throws Exception { + Object v1 = value1.productElement(position); + Object v2 = value2.productElement(position); + + int c = comparator.isExtremal((Comparable<Object>) v1, v2); + + if (byAggregate) { + if (c == 1) { + return value1; + } + if (first) { + if (c == 0) { + return value1; + } + } + + return value2; + } else { + for (int i = 0; i < length; i++) { + fields[i] = value2.productElement(i); + } + + if (c == 1) { + fields[position] = v1; + } + + return serializer.createInstance(fields); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala index 42ec709..ecf5615 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala @@ -44,6 +44,11 @@ import org.apache.flink.streaming.api.windowing.policy.TriggerPolicy import org.apache.flink.streaming.api.collector.OutputSelector import scala.collection.JavaConversions._ import java.util.HashMap +import org.apache.flink.streaming.api.function.aggregation.SumFunction +import org.apache.flink.api.java.typeutils.TupleTypeInfoBase +import org.apache.flink.streaming.api.function.aggregation.AggregationFunction +import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType +import com.amazonaws.services.cloudfront_2012_03_15.model.InvalidArgumentException class DataStream[T](javaStream: JavaStream[T]) { @@ -230,53 +235,52 @@ class DataStream[T](javaStream: JavaStream[T]) { * the given position. * */ - def max(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.max(field)) - case field: String => return new DataStream[T](javaStream.max(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position) /** * Applies an aggregation that that gives the current minimum of the data stream at * the given position. * */ - def min(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.min(field)) - case field: String => return new DataStream[T](javaStream.min(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position) /** * Applies an aggregation that sums the data stream at the given position. * */ - def sum(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.sum(field)) - case field: String => return new DataStream[T](javaStream.sum(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position) /** * Applies an aggregation that that gives the current minimum element of the data stream by * the given position. When equality, the user can set to get the first or last element with the minimal value. * */ - def minBy(field: Any, first: Boolean = true): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.minBy(field, first)) - case field: String => return new DataStream[T](javaStream.minBy(field, first)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def minBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MINBY, position, first) /** * Applies an aggregation that that gives the current maximum element of the data stream by * the given position. When equality, the user can set to get the first or last element with the maximal value. * */ - def maxBy(field: Any, first: Boolean = true): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.maxBy(field, first)) - case field: String => return new DataStream[T](javaStream.maxBy(field, first)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") + def maxBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MAXBY, position, first) + + private def aggregate(aggregationType: AggregationType, position: Int, first: Boolean = true): DataStream[T] = { + + val jStream = javaStream.asInstanceOf[JavaStream[Product]] + val outType = jStream.getType().asInstanceOf[TupleTypeInfoBase[_]] + + val agg = new ScalaStreamingAggregator[Product](jStream.getType().createSerializer(), position) + + val reducer = aggregationType match { + case AggregationType.SUM => new agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).getTypeClass())); + case _ => new agg.ProductComparableAggregator(aggregationType, first) + } + + val invokable = jStream match { + case groupedStream: GroupedDataStream[_] => new GroupedReduceInvokable(reducer, groupedStream.getKeySelector()) + case _ => new StreamReduceInvokable(reducer) + } + new DataStream[Product](jStream.transform("aggregation", jStream.getType(), invokable)).asInstanceOf[DataStream[T]] } /** http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala index c686497..c037305 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala @@ -36,6 +36,9 @@ import org.apache.flink.streaming.api.windowing.helper.WindowingHelper import org.apache.flink.api.common.functions.GroupReduceFunction import org.apache.flink.streaming.api.invokable.StreamInvokable import scala.collection.JavaConversions._ +import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType +import org.apache.flink.api.java.typeutils.TupleTypeInfoBase +import org.apache.flink.streaming.api.function.aggregation.SumFunction class WindowedDataStream[T](javaStream: JavaWStream[T]) { @@ -158,75 +161,48 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) { * the given position. * */ - def max(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.max(field)) - case field: String => return new DataStream[T](javaStream.max(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position) /** * Applies an aggregation that that gives the minimum of the elements in the window at * the given position. * */ - def min(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.min(field)) - case field: String => return new DataStream[T](javaStream.min(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position) /** * Applies an aggregation that sums the elements in the window at the given position. * */ - def sum(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.sum(field)) - case field: String => return new DataStream[T](javaStream.sum(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position) /** * Applies an aggregation that that gives the maximum element of the window by * the given position. When equality, returns the first. * */ - def maxBy(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.maxBy(field)) - case field: String => return new DataStream[T](javaStream.maxBy(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def maxBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MAXBY, position, first) /** * Applies an aggregation that that gives the minimum element of the window by * the given position. When equality, returns the first. * */ - def minBy(field: Any): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.minBy(field)) - case field: String => return new DataStream[T](javaStream.minBy(field)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def minBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MINBY, position, first) - /** - * Applies an aggregation that that gives the minimum element of the window by - * the given position. When equality, the user can set to get the first or last element with the minimal value. - * - */ - def minBy(field: Any, first: Boolean): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.minBy(field, first)) - case field: String => return new DataStream[T](javaStream.minBy(field, first)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") - } + def aggregate(aggregationType: AggregationType, position: Int, first: Boolean = true): DataStream[T] = { - /** - * Applies an aggregation that that gives the maximum element of the window by - * the given position. When equality, the user can set to get the first or last element with the maximal value. - * - */ - def maxBy(field: Any, first: Boolean): DataStream[T] = field match { - case field: Int => return new DataStream[T](javaStream.maxBy(field, first)) - case field: String => return new DataStream[T](javaStream.maxBy(field, first)) - case _ => throw new IllegalArgumentException("Aggregations are only supported by field position (Int) or field expression (String)") + val jStream = javaStream.asInstanceOf[JavaWStream[Product]] + val outType = jStream.getType().asInstanceOf[TupleTypeInfoBase[_]] + + val agg = new ScalaStreamingAggregator[Product](jStream.getType().createSerializer(), position) + + val reducer = aggregationType match { + case AggregationType.SUM => new agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).getTypeClass())); + case _ => new agg.ProductComparableAggregator(aggregationType, first) + } + + new DataStream[Product](jStream.reduce(reducer)).asInstanceOf[DataStream[T]] } } \ No newline at end of file
