[scala] [streaming] Added groupBy support for case class fields

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/14134842
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/14134842
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/14134842

Branch: refs/heads/release-0.8
Commit: 141348426eeae2ea0e4404a0d8cff8182c601886
Parents: 1f7b6ea
Author: Gyula Fora <[email protected]>
Authored: Sun Dec 21 03:54:30 2014 +0100
Committer: mbalassi <[email protected]>
Committed: Mon Jan 5 17:59:43 2015 +0100

----------------------------------------------------------------------
 .../aggregation/ComparableAggregator.java       |  8 +---
 .../scala/streaming/CaseClassKeySelector.scala  | 45 ++++++++++++++++++++
 .../flink/api/scala/streaming/DataStream.scala  | 19 +++++++--
 .../api/scala/streaming/FieldsKeySelector.scala | 23 +++-------
 .../scala/streaming/StreamJoinOperator.scala    | 19 ++++++---
 .../scala/streaming/WindowedDataStream.scala    | 11 ++++-
 .../api/scala/typeutils/CaseClassTypeInfo.scala |  1 +
 7 files changed, 92 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
----------------------------------------------------------------------
diff --git 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
index 5fb8f62..7ea7ba1 100644
--- 
a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
+++ 
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
@@ -143,8 +143,6 @@ public abstract class ComparableAggregator<T> extends 
AggregationFunction<T> {
                        } else {
                                if (c == 1) {
                                        Array.set(array2, position, v1);
-                               } else {
-                                       Array.set(array2, position, v2);
                                }
 
                                return array2;
@@ -230,10 +228,8 @@ public abstract class ComparableAggregator<T> extends 
AggregationFunction<T> {
                        } else {
                                if (c == 1) {
                                        keyFields[0].set(value2, field1);
-                               } else {
-                                       keyFields[0].set(value2, field2);
-                               }
-
+                               } 
+                               
                                return value2;
                        }
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/CaseClassKeySelector.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/CaseClassKeySelector.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/CaseClassKeySelector.scala
new file mode 100644
index 0000000..63410a9
--- /dev/null
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/CaseClassKeySelector.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.streaming
+
+import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
+import java.util.ArrayList
+import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor
+import org.apache.flink.api.java.functions.KeySelector
+
+class CaseClassKeySelector[T <: Product](@transient val typeInfo: 
CaseClassTypeInfo[T],
+  val keyFields: String*) extends KeySelector[T, Seq[Any]] {
+
+  val numOfKeys: Int = keyFields.length;
+
+  @transient val fieldDescriptors = new ArrayList[FlatFieldDescriptor]();
+  for (field <- keyFields) {
+    typeInfo.getKey(field, 0, fieldDescriptors);
+  }
+
+  val logicalKeyPositions = new Array[Int](numOfKeys)
+  val orders = new Array[Boolean](numOfKeys)
+
+  for (i <- 0 to numOfKeys - 1) {
+    logicalKeyPositions(i) = fieldDescriptors.get(i).getPosition();
+  }
+
+  def getKey(value: T): Seq[Any] = {
+    for (i <- 0 to numOfKeys - 1) yield 
value.productElement(logicalKeyPositions(i))
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
index 0cf4a60..6df4b25 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
@@ -49,6 +49,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
 import org.apache.flink.streaming.api.function.aggregation.AggregationFunction
 import 
org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType
 import 
com.amazonaws.services.cloudfront_2012_03_15.model.InvalidArgumentException
+import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
 
 class DataStream[T](javaStream: JavaStream[T]) {
 
@@ -122,9 +123,14 @@ class DataStream[T](javaStream: JavaStream[T]) {
    * be used with grouped operators like grouped reduce or grouped aggregations
    *
    */
-  def groupBy(firstField: String, otherFields: String*): DataStream[T] =
-    new DataStream[T](javaStream.groupBy(firstField +: otherFields.toArray: 
_*))
-
+  def groupBy(firstField: String, otherFields: String*): DataStream[T] = 
+    javaStream.getType() match {
+      case ccInfo: CaseClassTypeInfo[T] => new 
DataStream[T](javaStream.groupBy(
+          new CaseClassKeySelector[T](ccInfo, firstField +: 
otherFields.toArray: _*)))
+      case _ =>  new DataStream[T](javaStream.groupBy(
+          firstField +: otherFields.toArray: _*))    
+    }
+  
   /**
    * Groups the elements of a DataStream by the given K key to
    * be used with grouped operators like grouped reduce or grouped aggregations
@@ -155,7 +161,12 @@ class DataStream[T](javaStream: JavaStream[T]) {
    *
    */
   def partitionBy(firstField: String, otherFields: String*): DataStream[T] =
-    new DataStream[T](javaStream.partitionBy(firstField +: 
otherFields.toArray: _*))
+    javaStream.getType() match {
+      case ccInfo: CaseClassTypeInfo[T] => new 
DataStream[T](javaStream.partitionBy(
+          new CaseClassKeySelector[T](ccInfo, firstField +: 
otherFields.toArray: _*)))
+      case _ =>  new DataStream[T](javaStream.partitionBy(
+          firstField +: otherFields.toArray: _*))    
+    }
 
   /**
    * Sets the partitioning of the DataStream so that the output is

http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/FieldsKeySelector.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/FieldsKeySelector.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/FieldsKeySelector.scala
index b50d346..bc79fca 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/FieldsKeySelector.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/FieldsKeySelector.scala
@@ -22,25 +22,16 @@ import org.apache.flink.streaming.util.keys.{ 
FieldsKeySelector => JavaSelector
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.api.java.tuple.Tuple
 
-class FieldsKeySelector[IN](fields: Int*) extends KeySelector[IN, Tuple] {
+class FieldsKeySelector[IN](fields: Int*) extends KeySelector[IN, Seq[Any]] {
 
-  val t: Tuple = JavaSelector.tupleClasses(fields.length - 1).newInstance()
-
-  override def getKey(value: IN): Tuple =
+  override def getKey(value: IN): Seq[Any] =
 
     value match {
-      case prod: Product => {
-        for (i <- 0 to fields.length - 1) {
-          t.setField(prod.productElement(fields(i)), i)
-        }
-        t
-      }
-      case tuple: Tuple => {
-        for (i <- 0 to fields.length - 1) {
-          t.setField(tuple.getField(fields(i)), i)
-        }
-        t
-      }
+      case prod: Product => 
+        for (i <- 0 to fields.length - 1) yield prod.productElement(fields(i))
+      case tuple: Tuple => 
+        for (i <- 0 to fields.length - 1) yield tuple.getField(fields(i))
+      
       case _ => throw new RuntimeException("Only tuple types are supported")
     }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/StreamJoinOperator.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/StreamJoinOperator.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/StreamJoinOperator.scala
index 7a39da5..4095645 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/StreamJoinOperator.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/StreamJoinOperator.scala
@@ -61,9 +61,12 @@ object StreamJoinOperator {
      * The resulting incomplete join can be completed by 
JoinPredicate.equalTo()
      * to define the second key.
      */
-    def where(firstField: String, otherFields: String*) = {
-      new JoinPredicate[I1, I2](op, new 
PojoKeySelector[I1](op.input1.getType(),
-        (firstField +: otherFields): _*))
+    def where(firstField: String, otherFields: String*) = 
+      op.input1.getType() match {
+      case ccInfo: CaseClassTypeInfo[I1] => new JoinPredicate[I1, I2](op,
+          new CaseClassKeySelector[I1](ccInfo, firstField +: 
otherFields.toArray: _*))
+      case _ =>  new JoinPredicate[I1, I2](op, new PojoKeySelector[I1](
+          op.input1.getType(), (firstField +: otherFields): _*))  
     }
 
     /**
@@ -104,9 +107,13 @@ object StreamJoinOperator {
      * (first, second)
      * To define a custom wrapping, use JoinedStream.apply(...)
      */
-    def equalTo(firstField: String, otherFields: String*): JoinedStream[I1, 
I2] = {
-      finish(new PojoKeySelector[I2](op.input2.getType(), (firstField +: 
otherFields): _*))
-    }
+    def equalTo(firstField: String, otherFields: String*): JoinedStream[I1, 
I2] = 
+      op.input2.getType() match {
+      case ccInfo: CaseClassTypeInfo[I2] => finish(
+          new CaseClassKeySelector[I2](ccInfo, firstField +: 
otherFields.toArray: _*))
+      case _ => finish(new PojoKeySelector[I2](op.input2.getType(), 
+          (firstField +: otherFields): _*))
+    }    
 
     /**
      * Creates a temporal join transformation by defining the second join key.

http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
index 8c763fc..11f042d 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
@@ -39,6 +39,7 @@ import scala.collection.JavaConversions._
 import 
org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType
 import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
 import org.apache.flink.streaming.api.function.aggregation.SumFunction
+import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
 
 class WindowedDataStream[T](javaStream: JavaWStream[T]) {
 
@@ -77,8 +78,14 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
    *
    */
   def groupBy(firstField: String, otherFields: String*): WindowedDataStream[T] 
=
-    new WindowedDataStream[T](javaStream.groupBy(firstField +: 
otherFields.toArray: _*))
-
+    javaStream.getType() match {
+      case ccInfo: CaseClassTypeInfo[T] => new 
WindowedDataStream[T](javaStream.groupBy(
+          new CaseClassKeySelector[T](ccInfo, firstField +: 
otherFields.toArray: _*)))
+      case _ =>  new WindowedDataStream[T](javaStream.groupBy(
+          firstField +: otherFields.toArray: _*))    
+    }
+    
+    
   /**
    * Groups the elements of the WindowedDataStream using the given
    * KeySelector function. The window sizes (evictions) and slide sizes

http://git-wip-us.apache.org/repos/asf/flink/blob/14134842/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala
index 53d1dea..e0d1155 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala
@@ -15,6 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package org.apache.flink.api.scala.typeutils
 
 import org.apache.flink.api.common.typeinfo.TypeInformation

Reply via email to