Repository: flink
Updated Branches:
  refs/heads/master e13a7f80e -> 1cc1bb41e


[FLINK-6834] [table] Support scalar functions on Over Window

This closes #4070.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1cc1bb41
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1cc1bb41
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1cc1bb41

Branch: refs/heads/master
Commit: 1cc1bb41e94b585200d9f7179bdeaa0bec0dcc5d
Parents: 8ae4f2b
Author: Jark Wu <wuchong...@alibaba-inc.com>
Authored: Sat Jun 3 23:31:00 2017 +0800
Committer: zentol <ches...@apache.org>
Committed: Wed Jun 7 23:06:08 2017 +0200

----------------------------------------------------------------------
 .../flink/table/api/scala/expressionDsl.scala   |  2 +-
 .../flink/table/plan/ProjectionTranslator.scala | 81 +++++++++++++++-----
 .../scala/stream/table/OverWindowITCase.scala   | 33 ++++----
 .../api/scala/stream/table/OverWindowTest.scala | 44 +++++++++++
 .../OverWindowStringExpressionTest.scala        | 35 +++++++++
 5 files changed, 160 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index b87bb6d..7b424b2 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -456,7 +456,7 @@ trait ImplicitExpressionOperations {
     *   .window(Over partitionBy 'c orderBy 'rowtime preceding 2.rows 
following CURRENT_ROW as 'w)
     *   .select('c, 'a, 'a.count over 'w, 'a.sum over 'w)
     */
-  def over(alias: Expression) = {
+  def over(alias: Expression): Expression = {
     expr match {
       case _: Aggregation => UnresolvedOverCall(
         expr.asInstanceOf[Aggregation],

http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
index 69b437a..b3799d1 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
@@ -226,30 +226,69 @@ object ProjectionTranslator {
       overWindows: Array[OverWindow],
       tEnv: TableEnvironment): Seq[Expression] = {
 
-    def resolveOverWindow(unresolvedCall: UnresolvedOverCall): Expression = {
-
-      val overWindow = overWindows.find(_.alias.equals(unresolvedCall.alias))
-      if (overWindow.isDefined) {
-        OverCall(
-          unresolvedCall.agg,
-          overWindow.get.partitionBy,
-          overWindow.get.orderBy,
-          overWindow.get.preceding,
-          overWindow.get.following)
-      } else {
-        unresolvedCall
-      }
-    }
+    exprs.map(e => replaceOverCall(e, overWindows, tEnv))
+  }
 
-    val projectList = new ListBuffer[Expression]
-    exprs.foreach {
-      case Alias(u: UnresolvedOverCall, name, _) =>
-        projectList += Alias(resolveOverWindow(u), name)
+  /**
+    * Find and replace UnresolvedOverCall with OverCall
+    *
+    * @param expr    the expression to check
+    * @return an expression with correct resolved OverCall
+    */
+  private def replaceOverCall(
+    expr: Expression,
+    overWindows: Array[OverWindow],
+    tableEnv: TableEnvironment): Expression = {
+
+    expr match {
       case u: UnresolvedOverCall =>
-        projectList += resolveOverWindow(u)
-      case e: Expression => projectList += e
+        val overWindow = overWindows.find(_.alias.equals(u.alias))
+        if (overWindow.isDefined) {
+          OverCall(
+            u.agg,
+            overWindow.get.partitionBy,
+            overWindow.get.orderBy,
+            overWindow.get.preceding,
+            overWindow.get.following)
+        } else {
+          u
+        }
+
+      case u: UnaryExpression =>
+        val c = replaceOverCall(u.child, overWindows, tableEnv)
+        u.makeCopy(Array(c))
+
+      case b: BinaryExpression =>
+        val l = replaceOverCall(b.left, overWindows, tableEnv)
+        val r = replaceOverCall(b.right, overWindows, tableEnv)
+        b.makeCopy(Array(l, r))
+
+      // Functions calls
+      case c @ Call(name, args: Seq[Expression]) =>
+        val newArgs =
+          args.map(
+            (exp: Expression) =>
+              replaceOverCall(exp, overWindows, tableEnv))
+        c.makeCopy(Array(name, newArgs))
+
+      // Scala functions
+      case sfc @ ScalarFunctionCall(clazz, args: Seq[Expression]) =>
+        val newArgs: Seq[Expression] =
+          args.map(
+            (exp: Expression) =>
+              replaceOverCall(exp, overWindows, tableEnv))
+        sfc.makeCopy(Array(clazz, newArgs))
+
+      // Array constructor
+      case c @ ArrayConstructor(args) =>
+        val newArgs =
+          c.elements
+            .map((exp: Expression) => replaceOverCall(exp, overWindows, 
tableEnv))
+        c.makeCopy(Array(newArgs))
+
+      // Other expressions
+      case e: Expression => e
     }
-    projectList
   }
 
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala
index dc7d5dc..133328e 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala
@@ -26,6 +26,7 @@ import 
org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.table.api.TableEnvironment
 import 
org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvg
+import 
org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.JavaFunc0
 import org.apache.flink.table.api.scala._
 import 
org.apache.flink.table.api.scala.stream.table.OverWindowITCase.RowTimeSourceFunction
 import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, 
StreamingWithStateTestBase}
@@ -110,6 +111,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvg
+    val plusOne = new JavaFunc0
 
     val windowedTable = table
       .window(Over partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_RANGE 
following
@@ -117,10 +119,15 @@ class OverWindowITCase extends StreamingWithStateTestBase 
{
       .select(
         'a, 'b, 'c,
         'b.sum over 'w,
+        "SUM:".toExpr + ('b.sum over 'w),
         countFun('b) over 'w,
+        (countFun('b) over 'w) + 1,
+        plusOne(countFun('b) over 'w),
+        array('b.avg over 'w, 'b.max over 'w),
         'b.avg over 'w,
         'b.max over 'w,
         'b.min over 'w,
+        ('b.min over 'w).abs(),
         weightAvgFun('b, 'a) over 'w)
 
     val result = windowedTable.toAppendStream[Row]
@@ -128,19 +135,19 @@ class OverWindowITCase extends StreamingWithStateTestBase 
{
     env.execute()
 
     val expected = mutable.MutableList(
-      "1,1,Hello,6,3,2,3,1,2",
-      "1,2,Hello,6,3,2,3,1,2",
-      "1,3,Hello world,6,3,2,3,1,2",
-      "1,1,Hi,7,4,1,3,1,1",
-      "2,1,Hello,1,1,1,1,1,1",
-      "2,2,Hello world,6,3,2,3,1,2",
-      "2,3,Hello world,6,3,2,3,1,2",
-      "1,4,Hello world,11,5,2,4,1,2",
-      "1,5,Hello world,29,8,3,7,1,3",
-      "1,6,Hello world,29,8,3,7,1,3",
-      "1,7,Hello world,29,8,3,7,1,3",
-      "2,4,Hello world,15,5,3,5,1,3",
-      "2,5,Hello world,15,5,3,5,1,3"
+      "1,1,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
+      "1,2,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
+      "1,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
+      "1,1,Hi,7,SUM:7,4,5,5,[1, 3],1,3,1,1,1",
+      "2,1,Hello,1,SUM:1,1,2,2,[1, 1],1,1,1,1,1",
+      "2,2,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
+      "2,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
+      "1,4,Hello world,11,SUM:11,5,6,6,[2, 4],2,4,1,1,2",
+      "1,5,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3",
+      "1,6,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3",
+      "1,7,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3",
+      "2,4,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3",
+      "2,5,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3"
     )
 
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)

http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala
index 96e5eb5..49a210c 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala
@@ -21,6 +21,7 @@ import org.apache.flink.api.scala._
 import 
org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithRetract
 import org.apache.flink.table.api.{Table, ValidationException}
 import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils.Func1
 import org.apache.flink.table.utils.TableTestUtil._
 import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase}
 import org.junit.Test
@@ -106,6 +107,49 @@ class OverWindowTest extends TableTestBase {
     streamUtil.tEnv.optimize(result.getRelNode, updatesAsRetraction = true)
   }
 
+
+  @Test
+  def testScalarFunctionsOnOverWindow() = {
+    val weightedAvg = new WeightedAvgWithRetract
+    val plusOne = Func1
+
+    val result = table
+      .window(Over partitionBy 'b orderBy 'proctime preceding UNBOUNDED_ROW as 
'w)
+      .select(
+        plusOne('a.sum over 'w as 'wsum) as 'd,
+        ('a.count over 'w).exp(),
+        (weightedAvg('c, 'a) over 'w) + 1,
+        "AVG:".toExpr + (weightedAvg('c, 'a) over 'w),
+        array(weightedAvg('c, 'a) over 'w, 'a.count over 'w))
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamOverAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(0),
+            term("select", "a", "b", "c", "proctime")
+          ),
+          term("partitionBy", "b"),
+          term("orderBy", "proctime"),
+          term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"),
+          term("select", "a", "b", "c", "proctime",
+               "SUM(a) AS w0$o0",
+               "COUNT(a) AS w0$o1",
+               "WeightedAvgWithRetract(c, a) AS w0$o2")
+        ),
+        term("select",
+             s"${plusOne.functionIdentifier}(w0$$o0) AS d",
+             "EXP(CAST(w0$o1)) AS _c1",
+             "+(w0$o2, 1) AS _c2",
+             "||('AVG:', CAST(w0$o2)) AS _c3",
+             "ARRAY(w0$o2, w0$o1) AS _c4")
+      )
+    streamUtil.verifyTable(result, expected)
+  }
+
   @Test
   def testProcTimeBoundedPartitionedRowsOver() = {
     val weightedAvg = new WeightedAvgWithRetract

http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala
index 04016f1..4c95916 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala
@@ -19,8 +19,10 @@
 package org.apache.flink.table.api.scala.stream.table.stringexpr
 
 import org.apache.flink.api.scala._
+import 
org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithRetract
 import org.apache.flink.table.api.java.{Over => JOver}
 import org.apache.flink.table.api.scala.{Over => SOver, _}
+import org.apache.flink.table.expressions.utils.Func1
 import org.apache.flink.table.utils.TableTestBase
 import org.junit.Test
 
@@ -147,5 +149,38 @@ class OverWindowStringExpressionTest extends TableTestBase 
{
     verifyTableEquals(resScala, resJava)
   }
 
+  @Test
+  def testScalarFunctionsOnOverWindow(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Int, String, Int, Long)]('a, 'b, 'c, 'd, 'e, 
'rowtime.rowtime)
+
+    val weightedAvg = new WeightedAvgWithRetract
+    val plusOne = Func1
+    util.addFunction("plusOne", plusOne)
+    util.addFunction("weightedAvg", weightedAvg)
+
+    val resScala = t
+      .window(SOver partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_ROW as 
'w)
+      .select(
+        array('a.sum over 'w, 'a.count over 'w),
+        plusOne('b.sum over 'w as 'wsum) as 'd,
+        ('a.count over 'w).exp(),
+        (weightedAvg('a, 'b) over 'w) + 1,
+        "AVG:".toExpr + (weightedAvg('a, 'b) over 'w))
+
+    val resJava = t
+      
.window(JOver.partitionBy("a").orderBy("rowtime").preceding("unbounded_row").as("w"))
+      .select(
+        s"""
+           |ARRAY(SUM(a) OVER w, COUNT(a) OVER w),
+           |plusOne(SUM(b) OVER w AS wsum) AS d,
+           |EXP(COUNT(a) OVER w),
+           |(weightedAvg(a, b) OVER w) + 1,
+           |'AVG:' + (weightedAvg(a, b) OVER w)
+         """.stripMargin)
+
+    verifyTableEquals(resScala, resJava)
+  }
+
 
 }

Reply via email to