This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch release-1.10
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.10 by this push:
     new ce72e50  [FLINK-15631][table-planner-blink] Fix equals code generation 
for RAW and TIMESTAMP type
ce72e50 is described below

commit ce72e50451cfd5ab7e134b5e3d1ac29228d8a763
Author: Jingsong Lee <lzljs3620...@aliyun.com>
AuthorDate: Tue Jan 21 11:20:57 2020 +0800

    [FLINK-15631][table-planner-blink] Fix equals code generation for RAW and 
TIMESTAMP type
    
    This fixes generic types can't be used as the result of an 
AggregateFunction in Blink planner.
    
    This closes #10896
---
 .../planner/codegen/EqualiserCodeGenerator.scala   | 28 +++---
 .../planner/codegen/calls/ScalarOperatorGens.scala | 20 +++++
 .../codegen/EqualiserCodeGeneratorTest.java        | 99 ++++++++++++++++++++++
 .../planner/expressions/ScalarFunctionsTest.scala  | 11 +++
 .../expressions/utils/ScalarTypesTestBase.scala    | 12 ++-
 .../runtime/stream/sql/AggregateITCase.scala       | 17 +++-
 .../utils/UserDefinedFunctionTestUtils.scala       | 28 +++++-
 .../flink/table/dataformat/LazyBinaryFormat.java   |  4 +
 8 files changed, 203 insertions(+), 16 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
index 92a7820..456d376 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
@@ -20,14 +20,15 @@ package org.apache.flink.table.planner.codegen
 import org.apache.flink.table.api.TableConfig
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.Indenter.toISC
+import 
org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.generateEquals
 import org.apache.flink.table.runtime.generated.{GeneratedRecordEqualiser, 
RecordEqualiser}
 import org.apache.flink.table.runtime.types.PlannerTypeUtils
 import org.apache.flink.table.types.logical.LogicalTypeRoot._
 import org.apache.flink.table.types.logical.{LogicalType, RowType}
 
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
 
-class EqualiserCodeGenerator(fieldTypes: Seq[LogicalType]) {
+class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) {
 
   private val RECORD_EQUALISER = className[RecordEqualiser]
   private val LEFT_INPUT = "left"
@@ -52,11 +53,13 @@ class EqualiserCodeGenerator(fieldTypes: Seq[LogicalType]) {
       val rightNullTerm = "rightIsNull$" + i
       val leftFieldTerm = "leftField$" + i
       val rightFieldTerm = "rightField$" + i
-      val equalsCode = if (isInternalPrimitive(fieldType)) {
-        s"$leftFieldTerm == $rightFieldTerm"
+
+      // TODO merge ScalarOperatorGens.generateEquals.
+      val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) {
+        ("", s"$leftFieldTerm == $rightFieldTerm")
       } else if (isBaseRow(fieldType)) {
-        val equaliserGenerator =
-          new 
EqualiserCodeGenerator(fieldType.asInstanceOf[RowType].getChildren)
+        val equaliserGenerator = new EqualiserCodeGenerator(
+          fieldType.asInstanceOf[RowType].getChildren.asScala.toArray)
         val generatedEqualiser = equaliserGenerator
           .generateRecordEqualiser("field$" + i + "GeneratedEqualiser")
         val generatedEqualiserTerm = ctx.addReusableObject(
@@ -69,9 +72,12 @@ class EqualiserCodeGenerator(fieldTypes: Seq[LogicalType]) {
              |$equaliserTerm = ($equaliserTypeTerm)
              |  
$generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
              |""".stripMargin)
-        s"$equaliserTerm.equalsWithoutHeader($leftFieldTerm, $rightFieldTerm)"
+        ("", s"$equaliserTerm.equalsWithoutHeader($leftFieldTerm, 
$rightFieldTerm)")
       } else {
-        s"$leftFieldTerm.equals($rightFieldTerm)"
+        val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", 
fieldType)
+        val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", 
fieldType)
+        val gen = generateEquals(ctx, left, right)
+        (gen.code, gen.resultTerm)
       }
       val leftReadCode = baseRowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType)
       val rightReadCode = baseRowFieldReadAccess(ctx, i, RIGHT_INPUT, 
fieldType)
@@ -86,7 +92,8 @@ class EqualiserCodeGenerator(fieldTypes: Seq[LogicalType]) {
          |} else {
          |  $fieldTypeTerm $leftFieldTerm = $leftReadCode;
          |  $fieldTypeTerm $rightFieldTerm = $rightReadCode;
-         |  $result = $equalsCode;
+         |  $equalsCode
+         |  $result = $equalsResult;
          |}
          |if (!$result) {
          |  return false;
@@ -135,8 +142,7 @@ class EqualiserCodeGenerator(fieldTypes: Seq[LogicalType]) {
   private def isInternalPrimitive(t: LogicalType): Boolean = t.getTypeRoot 
match {
     case _ if PlannerTypeUtils.isPrimitive(t) => true
 
-    case DATE | TIME_WITHOUT_TIME_ZONE | TIMESTAMP_WITHOUT_TIME_ZONE |
-         TIMESTAMP_WITH_LOCAL_TIME_ZONE | INTERVAL_YEAR_MONTH 
|INTERVAL_DAY_TIME => true
+    case DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH 
|INTERVAL_DAY_TIME => true
     case _ => false
   }
 
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
index 01acc1f..e165164 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
@@ -388,6 +388,26 @@ object ScalarOperatorGens {
     else if (isComparable(left.resultType) && canEqual) {
       generateComparison(ctx, "==", left, right)
     }
+    // generic types of same type
+    else if (isRaw(left.resultType) && canEqual) {
+      val Seq(resultTerm, nullTerm) = newNames("result", "isNull")
+      val genericSer = ctx.addReusableTypeSerializer(left.resultType)
+      val ser = s"$genericSer.getInnerSerializer()"
+      val resultType = new BooleanType()
+      val code = s"""
+         |${left.code}
+         |${right.code}
+         |boolean $nullTerm = ${left.nullTerm} || ${right.nullTerm};
+         |boolean $resultTerm = ${primitiveDefaultValue(resultType)};
+         |if (!$nullTerm) {
+         |  ${left.resultTerm}.ensureMaterialized($ser);
+         |  ${right.resultTerm}.ensureMaterialized($ser);
+         |  $resultTerm =
+         |    
${left.resultTerm}.getBinarySection().equals(${right.resultTerm}.getBinarySection());
+         |}
+         |""".stripMargin
+      GeneratedExpression(resultTerm, nullTerm, code, resultType)
+    }
     // support date/time/timestamp equalTo string.
     // for performance, we cast literal string to literal time.
     else if (isTimePoint(left.resultType) && 
isCharacterString(right.resultType)) {
diff --git 
a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
 
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
new file mode 100644
index 0000000..f4f2c70
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.codegen;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.table.dataformat.BinaryGeneric;
+import org.apache.flink.table.dataformat.BinaryRow;
+import org.apache.flink.table.dataformat.BinaryRowWriter;
+import org.apache.flink.table.dataformat.GenericRow;
+import org.apache.flink.table.dataformat.SqlTimestamp;
+import org.apache.flink.table.runtime.generated.RecordEqualiser;
+import org.apache.flink.table.runtime.typeutils.BinaryGenericSerializer;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.TimestampType;
+import org.apache.flink.table.types.logical.TypeInformationRawType;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.function.Function;
+
+import static org.apache.flink.table.dataformat.SqlTimestamp.fromEpochMillis;
+
+/**
+ * Test for {@link EqualiserCodeGenerator}.
+ */
+public class EqualiserCodeGeneratorTest {
+
+       @Test
+       public void testRaw() {
+               RecordEqualiser equaliser = new EqualiserCodeGenerator(
+                               new LogicalType[]{new 
TypeInformationRawType<>(Types.INT)})
+                               .generateRecordEqualiser("RAW")
+                               
.newInstance(Thread.currentThread().getContextClassLoader());
+               Function<BinaryGeneric, BinaryRow> func = o -> {
+                       BinaryRow row = new BinaryRow(1);
+                       BinaryRowWriter writer = new BinaryRowWriter(row);
+                       writer.writeGeneric(0, o, new 
BinaryGenericSerializer<>(IntSerializer.INSTANCE));
+                       writer.complete();
+                       return row;
+               };
+               assertBoolean(equaliser, func, new BinaryGeneric<>(1), new 
BinaryGeneric<>(1), true);
+               assertBoolean(equaliser, func, new BinaryGeneric<>(1), new 
BinaryGeneric<>(2), false);
+       }
+
+       @Test
+       public void testTimestamp() {
+               RecordEqualiser equaliser = new EqualiserCodeGenerator(
+                               new LogicalType[]{new TimestampType()})
+                               .generateRecordEqualiser("TIMESTAMP")
+                               
.newInstance(Thread.currentThread().getContextClassLoader());
+               Function<SqlTimestamp, BinaryRow> func = o -> {
+                       BinaryRow row = new BinaryRow(1);
+                       BinaryRowWriter writer = new BinaryRowWriter(row);
+                       writer.writeTimestamp(0, o, 9);
+                       writer.complete();
+                       return row;
+               };
+               assertBoolean(equaliser, func, fromEpochMillis(1024), 
fromEpochMillis(1024), true);
+               assertBoolean(equaliser, func, fromEpochMillis(1024), 
fromEpochMillis(1025), false);
+       }
+
+       private static <T> void assertBoolean(
+                       RecordEqualiser equaliser,
+                       Function<T, BinaryRow> toBinaryRow,
+                       T o1,
+                       T o2,
+                       boolean bool) {
+               Assert.assertEquals(bool, equaliser.equalsWithoutHeader(
+                               GenericRow.of(o1),
+                               GenericRow.of(o2)));
+               Assert.assertEquals(bool, equaliser.equalsWithoutHeader(
+                               toBinaryRow.apply(o1),
+                               GenericRow.of(o2)));
+               Assert.assertEquals(bool, equaliser.equalsWithoutHeader(
+                               GenericRow.of(o1),
+                               toBinaryRow.apply(o2)));
+               Assert.assertEquals(bool, equaliser.equalsWithoutHeader(
+                               toBinaryRow.apply(o1),
+                               toBinaryRow.apply(o2)));
+       }
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
index f4c6dd5..f38a6d2 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
@@ -4159,4 +4159,15 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
       "IS_ALPHA(f33)",
       "false")
   }
+
+  @Test
+  def testRawTypeEquality(): Unit = {
+    testSqlApi(
+      "f55=f56",
+      "false")
+
+    testSqlApi(
+      "f55=f57",
+      "true")
+  }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/utils/ScalarTypesTestBase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/utils/ScalarTypesTestBase.scala
index b296cb6..de4b021 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/utils/ScalarTypesTestBase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/utils/ScalarTypesTestBase.scala
@@ -19,7 +19,7 @@
 package org.apache.flink.table.planner.expressions.utils
 
 import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, Types}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo}
 import org.apache.flink.table.dataformat.Decimal
 import org.apache.flink.table.planner.utils.DateTimeTestUtil._
 import org.apache.flink.table.runtime.typeutils.DecimalTypeInfo
@@ -31,7 +31,7 @@ import java.nio.charset.StandardCharsets
 abstract class ScalarTypesTestBase extends ExpressionTestBase {
 
   override def testData: Row = {
-    val testData = new Row(55)
+    val testData = new Row(58)
     testData.setField(0, "This is a test String.")
     testData.setField(1, true)
     testData.setField(2, 42.toByte)
@@ -87,6 +87,9 @@ abstract class ScalarTypesTestBase extends ExpressionTestBase 
{
     testData.setField(52, localDateTime("1997-11-11 09:44:55.333"))
     testData.setField(53, "hello world".getBytes(StandardCharsets.UTF_8))
     testData.setField(54, "This is a testing 
string.".getBytes(StandardCharsets.UTF_8))
+    testData.setField(55, 1)
+    testData.setField(56, 2)
+    testData.setField(57, 1)
     testData
   }
 
@@ -146,6 +149,9 @@ abstract class ScalarTypesTestBase extends 
ExpressionTestBase {
       /* 51 */ Types.LOCAL_TIME,
       /* 52 */ Types.LOCAL_DATE_TIME,
       /* 53 */ Types.PRIMITIVE_ARRAY(Types.BYTE),
-      /* 54 */ Types.PRIMITIVE_ARRAY(Types.BYTE))
+      /* 54 */ Types.PRIMITIVE_ARRAY(Types.BYTE),
+      /* 55 */ new GenericTypeInfo[Integer](classOf[Integer]),
+      /* 56 */ new GenericTypeInfo[Integer](classOf[Integer]),
+      /* 57 */ new GenericTypeInfo[Integer](classOf[Integer]))
   }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 46bda45..005aee7 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -17,6 +17,7 @@
  */
 package org.apache.flink.table.planner.runtime.stream.sql
 
+import org.apache.flink.api.common.time.Time
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.api.scala._
 import org.apache.flink.streaming.api.TimeCharacteristic
@@ -31,7 +32,7 @@ import 
org.apache.flink.table.planner.runtime.utils.StreamingWithMiniBatchTestBa
 import 
org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
 import 
org.apache.flink.table.planner.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset
 import 
org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils._
-import org.apache.flink.table.planner.runtime.utils.{StreamingWithAggTestBase, 
TestData, TestingRetractSink}
+import org.apache.flink.table.planner.runtime.utils.{GenericAggregateFunction, 
StreamingWithAggTestBase, TestData, TestingRetractSink}
 import org.apache.flink.table.planner.utils.DateTimeTestUtil.{localDate, 
localDateTime, localTime => mLocalTime}
 import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
 import org.apache.flink.types.Row
@@ -1208,4 +1209,18 @@ class AggregateITCase(
     assertEquals(expected, sink.getRetractResults)
   }
 
+  @Test
+  def testGenericTypesWithoutStateClean(): Unit = {
+    // because we don't provide a way to disable state cleanup.
+    // TODO verify all tests with state cleanup closed.
+    tEnv.getConfig.setIdleStateRetentionTime(Time.days(0), Time.days(0))
+    val t = failingDataSource(Seq(1, 2, 3)).toTable(tEnv, 'a)
+    val results = t
+        .select(new GenericAggregateFunction()('a))
+        .toRetractStream[Row]
+
+    val sink = new TestingRetractSink
+    results.addSink(sink).setParallelism(1)
+    env.execute()
+  }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
index 231f2e5..08be4c4 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.runtime.utils
 
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.tuple.{Tuple1, Tuple2}
-import org.apache.flink.api.java.typeutils.{ListTypeInfo, PojoField, 
PojoTypeInfo, RowTypeInfo}
+import org.apache.flink.api.java.typeutils.{GenericTypeInfo, ListTypeInfo, 
PojoField, PojoTypeInfo, RowTypeInfo}
 import org.apache.flink.api.scala.ExecutionEnvironment
 import org.apache.flink.api.scala.typeutils.Types
 import org.apache.flink.configuration.Configuration
@@ -443,3 +443,29 @@ object UserDefinedFunctionTestUtils {
     tempFile.getAbsolutePath
   }
 }
+
+class RandomClass(var i: Int)
+
+class GenericAggregateFunction extends AggregateFunction[java.lang.Integer, 
RandomClass] {
+  override def getValue(accumulator: RandomClass): java.lang.Integer = 
accumulator.i
+
+  override def createAccumulator(): RandomClass = new RandomClass(0)
+
+  override def getResultType: TypeInformation[java.lang.Integer] =
+    new GenericTypeInfo[Integer](classOf[Integer])
+
+  override def getAccumulatorType: TypeInformation[RandomClass] = new 
GenericTypeInfo[RandomClass](
+    classOf[RandomClass])
+
+  def accumulate(acc: RandomClass, value: Int): Unit = {
+    acc.i = value
+  }
+
+  def retract(acc: RandomClass, value: Int): Unit = {
+    acc.i = value
+  }
+
+  def resetAccumulator(acc: RandomClass): Unit = {
+    acc.i = 0
+  }
+}
diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java
index 5f91820..2003c18 100644
--- 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java
@@ -71,6 +71,10 @@ public abstract class LazyBinaryFormat<T> implements 
BinaryFormat {
                return javaObject;
        }
 
+       public BinarySection getBinarySection() {
+               return binarySection;
+       }
+
        /**
         * Must be public as it is used during code generation.
         */

Reply via email to