[scala] [streaming] Modified aggregations to work on scala tuples

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f7291ea1
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f7291ea1
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f7291ea1

Branch: refs/heads/release-0.8
Commit: f7291ea1c9fa4a0484ab6bc13e4a594ff6d7c2d2
Parents: 40a3b6b
Author: Gyula Fora <[email protected]>
Authored: Sat Dec 20 23:46:35 2014 +0100
Committer: mbalassi <[email protected]>
Committed: Mon Jan 5 17:57:44 2015 +0100

----------------------------------------------------------------------
 .../aggregation/AggregationFunction.java        |   2 +-
 .../aggregation/ComparableAggregator.java       |   8 +-
 .../api/function/aggregation/SumFunction.java   |  12 +-
 .../operator/StreamReduceInvokable.java         |   1 +
 .../streaming/ScalaStreamingAggregator.java     | 111 +++++++++++++++++++
 .../flink/api/scala/streaming/DataStream.scala  |  52 +++++----
 .../scala/streaming/WindowedDataStream.scala    |  64 ++++-------
 7 files changed, 171 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
index d95c37e..1c273d3 100644
--- 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
+++ 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
@@ -22,7 +22,7 @@ import org.apache.flink.api.common.functions.ReduceFunction;
 public abstract class AggregationFunction<T> implements ReduceFunction<T> {
        private static final long serialVersionUID = 1L;
 
-       int position;
+       public int position;
 
        public AggregationFunction(int pos) {
                this.position = pos;

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
----------------------------------------------------------------------
diff --git 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
index 6e2a400..5fb8f62 100644
--- 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
+++ 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
@@ -35,11 +35,11 @@ public abstract class ComparableAggregator<T> extends 
AggregationFunction<T> {
 
        private static final long serialVersionUID = 1L;
 
-       Comparator comparator;
-       boolean byAggregate;
-       boolean first;
+       public Comparator comparator;
+       public boolean byAggregate;
+       public boolean first;
 
-       private ComparableAggregator(int pos, AggregationType aggregationType, 
boolean first) {
+       public ComparableAggregator(int pos, AggregationType aggregationType, 
boolean first) {
                super(pos);
                this.comparator = Comparator.getForAggregation(aggregationType);
                this.byAggregate = (aggregationType == AggregationType.MAXBY)

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
index 1ac236d..2aef19c 100644
--- 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
+++ 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
@@ -45,7 +45,7 @@ public abstract class SumFunction implements Serializable{
                }
        }
 
-       private static class IntSum extends SumFunction {
+       public static class IntSum extends SumFunction {
                private static final long serialVersionUID = 1L;
 
                @Override
@@ -54,7 +54,7 @@ public abstract class SumFunction implements Serializable{
                }
        }
 
-       private static class LongSum extends SumFunction {
+       public static class LongSum extends SumFunction {
                private static final long serialVersionUID = 1L;
 
                @Override
@@ -63,7 +63,7 @@ public abstract class SumFunction implements Serializable{
                }
        }
 
-       private static class DoubleSum extends SumFunction {
+       public static class DoubleSum extends SumFunction {
 
                private static final long serialVersionUID = 1L;
 
@@ -73,7 +73,7 @@ public abstract class SumFunction implements Serializable{
                }
        }
 
-       private static class ShortSum extends SumFunction {
+       public static class ShortSum extends SumFunction {
                private static final long serialVersionUID = 1L;
 
                @Override
@@ -82,7 +82,7 @@ public abstract class SumFunction implements Serializable{
                }
        }
 
-       private static class FloatSum extends SumFunction {
+       public static class FloatSum extends SumFunction {
                private static final long serialVersionUID = 1L;
 
                @Override
@@ -91,7 +91,7 @@ public abstract class SumFunction implements Serializable{
                }
        }
 
-       private static class ByteSum extends SumFunction {
+       public static class ByteSum extends SumFunction {
                private static final long serialVersionUID = 1L;
 
                @Override

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
----------------------------------------------------------------------
diff --git 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
index 4bb78b8..5f5cb12 100644
--- 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
+++ 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
@@ -52,6 +52,7 @@ public class StreamReduceInvokable<IN> extends 
StreamInvokable<IN, IN> {
                        currentValue = reducer.reduce(currentValue, nextValue);
                } else {
                        currentValue = nextValue;
+
                }
                collector.collect(currentValue);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
 
b/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
new file mode 100644
index 0000000..2f587d7
--- /dev/null
+++ 
b/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.streaming;
+
+import java.io.Serializable;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase;
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction;
+import 
org.apache.flink.streaming.api.function.aggregation.ComparableAggregator;
+import org.apache.flink.streaming.api.function.aggregation.SumFunction;
+
+import scala.Product;
+
+public class ScalaStreamingAggregator<IN extends Product> implements 
Serializable {
+
+       private static final long serialVersionUID = 1L;
+
+       TupleSerializerBase<IN> serializer;
+       Object[] fields;
+       int length;
+       int position;
+
+       public ScalaStreamingAggregator(TypeSerializer<IN> serializer, int pos) 
{
+               this.serializer = (TupleSerializerBase<IN>) serializer;
+               this.length = this.serializer.getArity();
+               this.fields = new Object[this.length];
+               this.position = pos;
+       }
+
+       public class Sum extends AggregationFunction<IN> {
+               private static final long serialVersionUID = 1L;
+               SumFunction sumFunction;
+
+               public Sum(SumFunction func) {
+                       super(ScalaStreamingAggregator.this.position);
+                       this.sumFunction = func;
+               }
+
+               @Override
+               public IN reduce(IN value1, IN value2) throws Exception {
+                       for (int i = 0; i < length; i++) {
+                               fields[i] = value2.productElement(i);
+                       }
+
+                       fields[position] = sumFunction.add(fields[position], 
value1.productElement(position));
+
+                       return serializer.createInstance(fields);
+               }
+       }
+
+       public class ProductComparableAggregator extends 
ComparableAggregator<IN> {
+
+               private static final long serialVersionUID = 1L;
+
+               public 
ProductComparableAggregator(AggregationFunction.AggregationType aggregationType,
+                               boolean first) {
+                       super(ScalaStreamingAggregator.this.position, 
aggregationType, first);
+               }
+
+               @SuppressWarnings("unchecked")
+               @Override
+               public IN reduce(IN value1, IN value2) throws Exception {
+                       Object v1 = value1.productElement(position);
+                       Object v2 = value2.productElement(position);
+
+                       int c = comparator.isExtremal((Comparable<Object>) v1, 
v2);
+
+                       if (byAggregate) {
+                               if (c == 1) {
+                                       return value1;
+                               }
+                               if (first) {
+                                       if (c == 0) {
+                                               return value1;
+                                       }
+                               }
+
+                               return value2;
+                       } else {
+                               for (int i = 0; i < length; i++) {
+                                       fields[i] = value2.productElement(i);
+                               }
+
+                               if (c == 1) {
+                                       fields[position] = v1;
+                               }
+
+                               return serializer.createInstance(fields);
+                       }
+               }
+
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
index 42ec709..ecf5615 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
@@ -44,6 +44,11 @@ import 
org.apache.flink.streaming.api.windowing.policy.TriggerPolicy
 import org.apache.flink.streaming.api.collector.OutputSelector
 import scala.collection.JavaConversions._
 import java.util.HashMap
+import org.apache.flink.streaming.api.function.aggregation.SumFunction
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction
+import 
org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType
+import 
com.amazonaws.services.cloudfront_2012_03_15.model.InvalidArgumentException
 
 class DataStream[T](javaStream: JavaStream[T]) {
 
@@ -230,53 +235,52 @@ class DataStream[T](javaStream: JavaStream[T]) {
    * the given position.
    *
    */
-  def max(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.max(field))
-    case field: String => return new DataStream[T](javaStream.max(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, 
position)
 
   /**
    * Applies an aggregation that that gives the current minimum of the data 
stream at
    * the given position.
    *
    */
-  def min(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.min(field))
-    case field: String => return new DataStream[T](javaStream.min(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, 
position)
 
   /**
    * Applies an aggregation that sums the data stream at the given position.
    *
    */
-  def sum(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.sum(field))
-    case field: String => return new DataStream[T](javaStream.sum(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, 
position)
 
   /**
    * Applies an aggregation that that gives the current minimum element of the 
data stream by
    * the given position. When equality, the user can set to get the first or 
last element with the minimal value.
    *
    */
-  def minBy(field: Any, first: Boolean = true): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.minBy(field, first))
-    case field: String => return new DataStream[T](javaStream.minBy(field, 
first))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def minBy(position: Int, first: Boolean = true): DataStream[T] = 
aggregate(AggregationType.MINBY, position, first)
 
   /**
    * Applies an aggregation that that gives the current maximum element of the 
data stream by
    * the given position. When equality, the user can set to get the first or 
last element with the maximal value.
    *
    */
-  def maxBy(field: Any, first: Boolean = true): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.maxBy(field, first))
-    case field: String => return new DataStream[T](javaStream.maxBy(field, 
first))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
+  def maxBy(position: Int, first: Boolean = true): DataStream[T] = 
aggregate(AggregationType.MAXBY, position, first)
+
+  private def aggregate(aggregationType: AggregationType, position: Int, 
first: Boolean = true): DataStream[T] = {
+
+    val jStream = javaStream.asInstanceOf[JavaStream[Product]]
+    val outType = jStream.getType().asInstanceOf[TupleTypeInfoBase[_]]
+
+    val agg = new 
ScalaStreamingAggregator[Product](jStream.getType().createSerializer(), 
position)
+
+    val reducer = aggregationType match {
+      case AggregationType.SUM => new 
agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).getTypeClass()));
+      case _ => new agg.ProductComparableAggregator(aggregationType, first)
+    }
+
+    val invokable = jStream match {
+      case groupedStream: GroupedDataStream[_] => new 
GroupedReduceInvokable(reducer, groupedStream.getKeySelector())
+      case _ => new StreamReduceInvokable(reducer)
+    }
+    new DataStream[Product](jStream.transform("aggregation", 
jStream.getType(), invokable)).asInstanceOf[DataStream[T]]
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
index c686497..c037305 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
@@ -36,6 +36,9 @@ import 
org.apache.flink.streaming.api.windowing.helper.WindowingHelper
 import org.apache.flink.api.common.functions.GroupReduceFunction
 import org.apache.flink.streaming.api.invokable.StreamInvokable
 import scala.collection.JavaConversions._
+import 
org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
+import org.apache.flink.streaming.api.function.aggregation.SumFunction
 
 class WindowedDataStream[T](javaStream: JavaWStream[T]) {
 
@@ -158,75 +161,48 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
    * the given position.
    *
    */
-  def max(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.max(field))
-    case field: String => return new DataStream[T](javaStream.max(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, 
position)
 
   /**
    * Applies an aggregation that that gives the minimum of the elements in the 
window at
    * the given position.
    *
    */
-  def min(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.min(field))
-    case field: String => return new DataStream[T](javaStream.min(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, 
position)
 
   /**
    * Applies an aggregation that sums the elements in the window at the given 
position.
    *
    */
-  def sum(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.sum(field))
-    case field: String => return new DataStream[T](javaStream.sum(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, 
position)
 
   /**
    * Applies an aggregation that that gives the maximum element of the window 
by
    * the given position. When equality, returns the first.
    *
    */
-  def maxBy(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.maxBy(field))
-    case field: String => return new DataStream[T](javaStream.maxBy(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def maxBy(position: Int, first: Boolean = true): DataStream[T] = 
aggregate(AggregationType.MAXBY, position, first)
 
   /**
    * Applies an aggregation that that gives the minimum element of the window 
by
    * the given position. When equality, returns the first.
    *
    */
-  def minBy(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.minBy(field))
-    case field: String => return new DataStream[T](javaStream.minBy(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def minBy(position: Int, first: Boolean = true): DataStream[T] = 
aggregate(AggregationType.MINBY, position, first)
 
-  /**
-   * Applies an aggregation that that gives the minimum element of the window 
by
-   * the given position. When equality, the user can set to get the first or 
last element with the minimal value.
-   *
-   */
-  def minBy(field: Any, first: Boolean): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.minBy(field, first))
-    case field: String => return new DataStream[T](javaStream.minBy(field, 
first))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
-  }
+  def aggregate(aggregationType: AggregationType, position: Int, first: 
Boolean = true): DataStream[T] = {
 
-  /**
-   * Applies an aggregation that that gives the maximum element of the window 
by
-   * the given position. When equality, the user can set to get the first or 
last element with the maximal value.
-   *
-   */
-  def maxBy(field: Any, first: Boolean): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.maxBy(field, first))
-    case field: String => return new DataStream[T](javaStream.maxBy(field, 
first))
-    case _ => throw new IllegalArgumentException("Aggregations are only 
supported by field position (Int) or field expression (String)")
+    val jStream = javaStream.asInstanceOf[JavaWStream[Product]]
+    val outType = jStream.getType().asInstanceOf[TupleTypeInfoBase[_]]
+
+    val agg = new 
ScalaStreamingAggregator[Product](jStream.getType().createSerializer(), 
position)
+
+    val reducer = aggregationType match {
+      case AggregationType.SUM => new 
agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).getTypeClass()));
+      case _ => new agg.ProductComparableAggregator(aggregationType, first)
+    }
+
+    new 
DataStream[Product](jStream.reduce(reducer)).asInstanceOf[DataStream[T]]
   }
 
 }
\ No newline at end of file

Reply via email to