Repository: flink
Updated Branches:
  refs/heads/master 9215b7242 -> 20fe2af8b


[FLINK-3087] [Table API] support multi count in aggregation.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/20fe2af8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/20fe2af8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/20fe2af8

Branch: refs/heads/master
Commit: 20fe2af8b3037d6284033b7dccd3471c45c78c35
Parents: 9215b72
Author: chengxiang li <chengxiang...@intel.com>
Authored: Tue Dec 1 10:40:47 2015 +0800
Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com>
Committed: Wed Dec 2 10:20:11 2015 +0100

----------------------------------------------------------------------
 .../table/codegen/ExpressionCodeGenerator.scala | 15 ++++++++-----
 .../api/table/plan/ExpandAggregations.scala     |  7 ++++--
 .../api/java/table/test/AggregationsITCase.java | 23 ++++++++++++++++++++
 .../scala/table/test/AggregationsITCase.scala   | 11 ++++++++++
 4 files changed, 48 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/20fe2af8/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
 
b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
index 42dec0f..10f5859 100644
--- 
a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
+++ 
b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
@@ -123,14 +123,17 @@ abstract class ExpressionCodeGenerator[R](
       }
     }
 
-    val cleanedExpr = expr match {
-      case expressions.Naming(namedExpr, _) => namedExpr
-      case _ => expr
+    def cleanedExpr(e: Expression): Expression =  {
+      e match {
+        case expressions.Naming(namedExpr, _) => cleanedExpr(namedExpr)
+        case _ => e
+      }
     }
-    
-    val resultTpe = typeTermForTypeInfo(cleanedExpr.typeInfo)
 
-    val code: String = cleanedExpr match {
+    val cleanedExpression = cleanedExpr(expr)
+    val resultTpe = typeTermForTypeInfo(cleanedExpression.typeInfo)
+
+    val code: String = cleanedExpression match {
 
       case expressions.Literal(null, typeInfo) =>
         if (nullCheck) {

http://git-wip-us.apache.org/repos/asf/flink/blob/20fe2af8/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/plan/ExpandAggregations.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/plan/ExpandAggregations.scala
 
b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/plan/ExpandAggregations.scala
index 65728c2..2e09f39 100644
--- 
a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/plan/ExpandAggregations.scala
+++ 
b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/plan/ExpandAggregations.scala
@@ -51,20 +51,23 @@ object ExpandAggregations {
       val aggregationIntermediates = mutable.HashMap[Aggregation, 
Seq[Expression]]()
 
       var intermediateCount = 0
+      var resultCount = 0
       selection foreach {  f =>
         f.transformPre {
           case agg: Aggregation =>
             val intermediateReferences = 
agg.getIntermediateFields.zip(agg.getAggregations) map {
               case (expr, basicAgg) =>
+                resultCount += 1
+                val resultName = s"result.$resultCount"
                 aggregations.get((expr, basicAgg)) match {
                   case Some(intermediateName) =>
-                    ResolvedFieldReference(intermediateName, expr.typeInfo)
+                    Naming(ResolvedFieldReference(intermediateName, 
expr.typeInfo), resultName)
                   case None =>
                     intermediateCount = intermediateCount + 1
                     val intermediateName = s"intermediate.$intermediateCount"
                     intermediateFields += Naming(expr, intermediateName)
                     aggregations((expr, basicAgg)) = intermediateName
-                    ResolvedFieldReference(intermediateName, expr.typeInfo)
+                    Naming(ResolvedFieldReference(intermediateName, 
expr.typeInfo), resultName)
                 }
             }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/20fe2af8/flink-staging/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
 
b/flink-staging/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
index 3e0147c..bdebfb1 100644
--- 
a/flink-staging/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
+++ 
b/flink-staging/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
@@ -137,6 +137,29 @@ public class AggregationsITCase extends 
MultipleProgramsTestBase {
                compareResultAsText(results, expected);
        }
 
+       @Test
+       public void testAggregationWithTwoCount() throws Exception {
+               ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               TableEnvironment tableEnv = new TableEnvironment();
+
+               DataSource<Tuple2<Float, String>> input =
+                       env.fromElements(
+                               new Tuple2<>(1f, "Hello"),
+                               new Tuple2<>(2f, "Ciao"));
+
+               Table table =
+                       tableEnv.fromDataSet(input);
+
+               Table result =
+                       table.select("f0.count, f1.count");
+
+
+               DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
+               List<Row> results = ds.collect();
+               String expected = "2,2";
+               compareResultAsText(results, expected);
+       }
+
        @Test(expected = ExpressionException.class)
        public void testNonWorkingDataTypes() throws Exception {
                ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();

http://git-wip-us.apache.org/repos/asf/flink/blob/20fe2af8/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
 
b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
index 08ad1f4..ee5d9e8 100644
--- 
a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
+++ 
b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
@@ -80,6 +80,17 @@ class AggregationsITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBa
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
+  @Test
+  def testAggregationWithTwoCount(): Unit = {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds = env.fromElements((1f, "Hello"), (2f, "Ciao")).toTable
+      .select('_1.count, '_2.count).toDataSet[Row]
+    val expected = "2,2"
+    val results = ds.collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
   @Test(expected = classOf[ExpressionException])
   def testNonWorkingAggregationDataTypes(): Unit = {
 

Reply via email to