This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.13 by this push:
     new 8f5b012  [FLINK-22788][table-planner-blink] Support equalisers for 
many fields
8f5b012 is described below

commit 8f5b0126d8f0521f01a731143bd5b293b3939b92
Author: Ingo Bürk <ingo.bu...@tngtech.com>
AuthorDate: Fri Jun 11 12:43:09 2021 +0200

    [FLINK-22788][table-planner-blink] Support equalisers for many fields
    
    When working with hundreds of fields, equalisers can fail to compile
    because the method body grows beyond 64kb. With this change, instead of
    generating all code into one method, we generate a dedicated method per
    field and then call all of those methods. This doesn't entirely remove
    the problem, but supports roughly a factor of 10 more fields and is
    currently deemed sufficient.
    
    This closes #16213.
---
 .../planner/codegen/EqualiserCodeGenerator.scala   | 162 ++++++++++++---------
 .../codegen/EqualiserCodeGeneratorTest.java        |  34 +++++
 2 files changed, 127 insertions(+), 69 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
index be8d480..79d8098 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
@@ -44,73 +44,16 @@ class EqualiserCodeGenerator(fieldTypes: 
Array[LogicalType]) {
     // ignore time zone
     val ctx = CodeGeneratorContext(new TableConfig)
     val className = newName(name)
-    val header =
-      s"""
-         |if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) {
-         |  return false;
-         |}
-       """.stripMargin
-
-    val codes = for (i <- fieldTypes.indices) yield {
-      val fieldType = fieldTypes(i)
-      val fieldTypeTerm = primitiveTypeTermForType(fieldType)
-      val result = s"cmp$i"
-      val leftNullTerm = "leftIsNull$" + i
-      val rightNullTerm = "rightIsNull$" + i
-      val leftFieldTerm = "leftField$" + i
-      val rightFieldTerm = "rightField$" + i
-
-      // TODO merge ScalarOperatorGens.generateEquals.
-      val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) {
-        ("", s"$leftFieldTerm == $rightFieldTerm")
-      } else if (isCompositeType(fieldType)) {
-        val equaliserGenerator = new EqualiserCodeGenerator(
-          getFieldTypes(fieldType).asScala.toArray)
-        val generatedEqualiser = equaliserGenerator
-          .generateRecordEqualiser("field$" + i + "GeneratedEqualiser")
-        val generatedEqualiserTerm = ctx.addReusableObject(
-          generatedEqualiser, "field$" + i + "GeneratedEqualiser")
-        val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
-        val equaliserTerm = newName("equaliser")
-        ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = 
null;")
-        ctx.addReusableInitStatement(
-          s"""
-             |$equaliserTerm = ($equaliserTypeTerm)
-             |  
$generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
-             |""".stripMargin)
-        ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)")
-      } else {
-        val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", 
fieldType)
-        val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", 
fieldType)
-        val gen = generateEquals(ctx, left, right)
-        (gen.code, gen.resultTerm)
-      }
-      val leftReadCode = rowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType)
-      val rightReadCode = rowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType)
-      s"""
-         |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i);
-         |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i);
-         |boolean $result;
-         |if ($leftNullTerm && $rightNullTerm) {
-         |  $result = true;
-         |} else if ($leftNullTerm|| $rightNullTerm) {
-         |  $result = false;
-         |} else {
-         |  $fieldTypeTerm $leftFieldTerm = $leftReadCode;
-         |  $fieldTypeTerm $rightFieldTerm = $rightReadCode;
-         |  $equalsCode
-         |  $result = $equalsResult;
-         |}
-         |if (!$result) {
-         |  return false;
-         |}
-      """.stripMargin
+
+    val equalsMethodCodes = for (idx <- fieldTypes.indices) yield 
generateEqualsMethod(ctx, idx)
+    val equalsMethodCalls = for (idx <- fieldTypes.indices) yield {
+      val methodName = getEqualsMethodName(idx)
+      s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);"""
     }
 
-    val functionCode =
+    val classCode =
       j"""
         public final class $className implements $RECORD_EQUALISER {
-
           ${ctx.reuseMemberCode()}
 
           public $className(Object[] references) throws Exception {
@@ -121,17 +64,98 @@ class EqualiserCodeGenerator(fieldTypes: 
Array[LogicalType]) {
           public boolean equals($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) 
{
             if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof 
$BINARY_ROW) {
               return $LEFT_INPUT.equals($RIGHT_INPUT);
-            } else {
-              $header
-              ${ctx.reuseLocalVariableCode()}
-              ${codes.mkString("\n")}
-              return true;
             }
+
+            if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) {
+              return false;
+            }
+
+            boolean result = true;
+            ${equalsMethodCalls.mkString("\n")}
+            return result;
           }
+
+          ${equalsMethodCodes.mkString("\n")}
         }
       """.stripMargin
 
-    new GeneratedRecordEqualiser(className, functionCode, 
ctx.references.toArray)
+    new GeneratedRecordEqualiser(className, classCode, ctx.references.toArray)
+  }
+
+  private def getEqualsMethodName(idx: Int) = s"""equalsAtIndex$idx"""
+
+  private def generateEqualsMethod(ctx: CodeGeneratorContext, idx: Int): 
String = {
+    val methodName = getEqualsMethodName(idx)
+    ctx.startNewLocalVariableStatement(methodName)
+
+    val Seq(leftNullTerm, rightNullTerm) = ctx.addReusableLocalVariables(
+      ("boolean", "isNullLeft"),
+      ("boolean", "isNullRight")
+    )
+
+    val fieldType = fieldTypes(idx)
+    val fieldTypeTerm = primitiveTypeTermForType(fieldType)
+    val Seq(leftFieldTerm, rightFieldTerm) = ctx.addReusableLocalVariables(
+      (fieldTypeTerm, "leftField"),
+      (fieldTypeTerm, "rightField")
+    )
+
+    val leftReadCode = rowFieldReadAccess(ctx, idx, LEFT_INPUT, fieldType)
+    val rightReadCode = rowFieldReadAccess(ctx, idx, RIGHT_INPUT, fieldType)
+
+    val (equalsCode, equalsResult) = generateEqualsCode(ctx, fieldType,
+      leftFieldTerm, rightFieldTerm, leftNullTerm, rightNullTerm)
+
+    s"""
+       |private boolean $methodName($ROW_DATA $LEFT_INPUT, $ROW_DATA 
$RIGHT_INPUT) {
+       |  ${ctx.reuseLocalVariableCode(methodName)}
+       |
+       |  $leftNullTerm = $LEFT_INPUT.isNullAt($idx);
+       |  $rightNullTerm = $RIGHT_INPUT.isNullAt($idx);
+       |  if ($leftNullTerm && $rightNullTerm) {
+       |    return true;
+       |  }
+       |
+       |  if ($leftNullTerm || $rightNullTerm) {
+       |    return false;
+       |  }
+       |
+       |  $leftFieldTerm = $leftReadCode;
+       |  $rightFieldTerm = $rightReadCode;
+       |  $equalsCode
+       |
+       |  return $equalsResult;
+       |}
+      """.stripMargin
+  }
+
+  private def generateEqualsCode(ctx: CodeGeneratorContext, fieldType: 
LogicalType,
+                  leftFieldTerm: String, rightFieldTerm: String,
+                  leftNullTerm: String, rightNullTerm: String) = {
+    // TODO merge ScalarOperatorGens.generateEquals.
+    if (isInternalPrimitive(fieldType)) {
+      ("", s"$leftFieldTerm == $rightFieldTerm")
+    } else if (isCompositeType(fieldType)) {
+      val equaliserGenerator = new EqualiserCodeGenerator(
+        getFieldTypes(fieldType).asScala.toArray)
+      val generatedEqualiser = 
equaliserGenerator.generateRecordEqualiser("fieldGeneratedEqualiser")
+      val generatedEqualiserTerm = ctx.addReusableObject(
+        generatedEqualiser, "fieldGeneratedEqualiser")
+      val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
+      val equaliserTerm = newName("equaliser")
+      ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = 
null;")
+      ctx.addReusableInitStatement(
+        s"""
+           |$equaliserTerm = ($equaliserTypeTerm)
+           |  
$generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
+           |""".stripMargin)
+      ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)")
+    } else {
+      val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", 
fieldType)
+      val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", 
fieldType)
+      val gen = generateEquals(ctx, left, right)
+      (gen.code, gen.resultTerm)
+    }
   }
 
   @tailrec
diff --git 
a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
 
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
index fe0e9a6..9ba9d00 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
+++ 
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RawValueData;
+import org.apache.flink.table.data.StringData;
 import org.apache.flink.table.data.TimestampData;
 import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.data.writer.BinaryRowWriter;
@@ -30,13 +31,16 @@ import 
org.apache.flink.table.runtime.typeutils.RawValueDataSerializer;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.TimestampType;
 import org.apache.flink.table.types.logical.TypeInformationRawType;
+import org.apache.flink.table.types.logical.VarCharType;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.function.Function;
+import java.util.stream.IntStream;
 
 import static org.apache.flink.table.data.TimestampData.fromEpochMillis;
+import static org.junit.Assert.assertTrue;
 
 /** Test for {@link EqualiserCodeGenerator}. */
 public class EqualiserCodeGeneratorTest {
@@ -81,6 +85,36 @@ public class EqualiserCodeGeneratorTest {
         assertBoolean(equaliser, func, fromEpochMillis(1024), 
fromEpochMillis(1025), false);
     }
 
+    @Test
+    public void testManyFields() {
+        final LogicalType[] fieldTypes =
+                IntStream.range(0, 999)
+                        .mapToObj(i -> new VarCharType())
+                        .toArray(LogicalType[]::new);
+
+        RecordEqualiser equaliser;
+        try {
+            equaliser =
+                    new EqualiserCodeGenerator(fieldTypes)
+                            .generateRecordEqualiser("ManyFields")
+                            
.newInstance(Thread.currentThread().getContextClassLoader());
+        } catch (Exception e) {
+            Assert.fail("Expected compilation to succeed");
+
+            // Unreachable
+            throw e;
+        }
+
+        final StringData[] fields =
+                IntStream.range(0, 999)
+                        .mapToObj(i -> StringData.fromString("Entry " + i))
+                        .toArray(StringData[]::new);
+        assertTrue(
+                equaliser.equals(
+                        GenericRowData.of((Object[]) fields),
+                        GenericRowData.of((Object[]) fields)));
+    }
+
     private static <T> void assertBoolean(
             RecordEqualiser equaliser,
             Function<T, BinaryRowData> toBinaryRow,

Reply via email to