Repository: flink
Updated Branches:
  refs/heads/master 37df826e4 -> 16b088218


[FLINK-6226] [table] Add tests for UDFs with Byte, Short, and Float arguments.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6e118d1d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6e118d1d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6e118d1d

Branch: refs/heads/master
Commit: 6e118d1dc97b3a8c0b013d2002fad80219751253
Parents: 37df826
Author: Fabian Hueske <[email protected]>
Authored: Thu Nov 2 21:10:03 2017 +0100
Committer: Fabian Hueske <[email protected]>
Committed: Thu Nov 2 23:10:09 2017 +0100

----------------------------------------------------------------------
 .../UserDefinedScalarFunctionTest.scala         | 28 ++++++++++++++++++--
 .../utils/userDefinedScalarFunctions.scala      |  6 +++++
 .../runtime/batch/table/CorrelateITCase.scala   | 23 +++++++++++++++-
 .../table/utils/UserDefinedTableFunctions.scala | 12 ++++++++-
 4 files changed, 65 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6e118d1d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index 71ff70d..a3b2f07 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -48,6 +48,24 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
       "43")
 
     testAllApis(
+      Func1('f11),
+      "Func1(f11)",
+      "Func1(f11)",
+      "4")
+
+    testAllApis(
+      Func1('f12),
+      "Func1(f12)",
+      "Func1(f12)",
+      "4")
+
+    testAllApis(
+      Func1('f13),
+      "Func1(f13)",
+      "Func1(f13)",
+      "4.0")
+
+    testAllApis(
       Func2('f0, 'f1, 'f3),
       "Func2(f0, f1, f3)",
       "Func2(f0, f1, f3)",
@@ -360,7 +378,7 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
   // 
----------------------------------------------------------------------------------------------
 
   override def testData: Any = {
-    val testData = new Row(11)
+    val testData = new Row(14)
     testData.setField(0, 42)
     testData.setField(1, "Test")
     testData.setField(2, null)
@@ -372,6 +390,9 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
     testData.setField(8, 1000L)
     testData.setField(9, Seq("Hello", "World"))
     testData.setField(10, Array[Integer](1, 2, null))
+    testData.setField(11, 3.toByte)
+    testData.setField(12, 3.toShort)
+    testData.setField(13, 3.toFloat)
     testData
   }
 
@@ -387,7 +408,10 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
       Types.INTERVAL_MONTHS,
       Types.INTERVAL_MILLIS,
       TypeInformation.of(classOf[Seq[String]]),
-      BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO
+      BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO,
+      Types.BYTE,
+      Types.SHORT,
+      Types.FLOAT
     ).asInstanceOf[TypeInformation[Any]]
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6e118d1d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala
index 5285569..9535cdf 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala
@@ -41,6 +41,12 @@ object Func1 extends ScalarFunction {
   def eval(index: Integer): Integer = {
     index + 1
   }
+
+  def eval(b: Byte): Byte = (b + 1).toByte
+
+  def eval(s: Short): Short = (s + 1).toShort
+
+  def eval(f: Float): Float = f + 1
 }
 
 object Func2 extends ScalarFunction {

http://git-wip-us.apache.org/repos/asf/flink/blob/6e118d1d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala
index b109752..79243dd 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.table.api.{TableEnvironment, TableException, 
ValidationException}
+import org.apache.flink.table.api.{TableEnvironment, TableException, Types, 
ValidationException}
 import 
org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, 
RichFunc2}
@@ -231,6 +231,27 @@ class CorrelateITCase(
   }
 
   @Test
+  def testByteShortFloatArguments(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+    val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+    val tFunc = new TableFunc4
+
+    val result = in
+      .select('a.cast(Types.BYTE) as 'a, 'a.cast(Types.SHORT) as 'b, 
'b.cast(Types.FLOAT) as 'c)
+      .join(tFunc('a, 'b, 'c) as ('a2, 'b2, 'c2))
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected = Seq(
+      "1,1,1.0,Byte=1,Short=1,Float=1.0",
+      "2,2,2.0,Byte=2,Short=2,Float=2.0",
+      "3,3,2.0,Byte=3,Short=3,Float=2.0",
+      "4,4,3.0,Byte=4,Short=4,Float=3.0").mkString("\n")
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
   def testUserDefinedTableFunctionWithParameter(): Unit = {
     val env = ExecutionEnvironment.getExecutionEnvironment
     val tEnv = TableEnvironment.getTableEnvironment(env)

http://git-wip-us.apache.org/repos/asf/flink/blob/6e118d1d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
index d0ffade..e1af23b 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
@@ -22,7 +22,7 @@ import java.lang.Boolean
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.java.tuple.Tuple3
 import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.api.{Types, ValidationException}
 import org.apache.flink.table.functions.{FunctionContext, TableFunction}
 import org.apache.flink.types.Row
 import org.junit.Assert
@@ -109,6 +109,16 @@ class TableFunc3(data: String, conf: Map[String, String]) 
extends TableFunction[
   }
 }
 
+class TableFunc4 extends TableFunction[Row] {
+  def eval(b: Byte, s: Short, f: Float): Unit = {
+    collect(Row.of("Byte=" + b, "Short=" + s, "Float=" + f))
+  }
+
+  override def getResultType: TypeInformation[Row] = {
+    new RowTypeInfo(Types.STRING, Types.STRING, Types.STRING)
+  }
+}
+
 class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] 
{
   def eval(user: String) {
     if (user.contains("#")) {

Reply via email to