This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch release-1.13 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.13 by this push: new 8f5b012 [FLINK-22788][table-planner-blink] Support equalisers for many fields 8f5b012 is described below commit 8f5b0126d8f0521f01a731143bd5b293b3939b92 Author: Ingo Bürk <ingo.bu...@tngtech.com> AuthorDate: Fri Jun 11 12:43:09 2021 +0200 [FLINK-22788][table-planner-blink] Support equalisers for many fields When working with hundreds of fields, equalisers can fail to compile because the method body grows beyond 64kb. With this change, instead of generating all code into one method, we generate a dedicated method per field and then call all of those methods. This doesn't entirely remove the problem, but supports roughly a factor of 10 more fields and is currently deemed sufficient. This closes #16213. --- .../planner/codegen/EqualiserCodeGenerator.scala | 162 ++++++++++++--------- .../codegen/EqualiserCodeGeneratorTest.java | 34 +++++ 2 files changed, 127 insertions(+), 69 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala index be8d480..79d8098 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala @@ -44,73 +44,16 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { // ignore time zone val ctx = CodeGeneratorContext(new TableConfig) val className = newName(name) - val header = - s""" - |if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) { - | return false; - |} - """.stripMargin - - val codes = for (i <- fieldTypes.indices) yield { - val fieldType = fieldTypes(i) - val fieldTypeTerm = primitiveTypeTermForType(fieldType) - val result = s"cmp$i" - val leftNullTerm = "leftIsNull$" + i - val rightNullTerm = "rightIsNull$" + i - val leftFieldTerm = "leftField$" + i - val rightFieldTerm = "rightField$" + i - - // TODO merge ScalarOperatorGens.generateEquals. - val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) { - ("", s"$leftFieldTerm == $rightFieldTerm") - } else if (isCompositeType(fieldType)) { - val equaliserGenerator = new EqualiserCodeGenerator( - getFieldTypes(fieldType).asScala.toArray) - val generatedEqualiser = equaliserGenerator - .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") - val generatedEqualiserTerm = ctx.addReusableObject( - generatedEqualiser, "field$" + i + "GeneratedEqualiser") - val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName - val equaliserTerm = newName("equaliser") - ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") - ctx.addReusableInitStatement( - s""" - |$equaliserTerm = ($equaliserTypeTerm) - | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); - |""".stripMargin) - ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)") - } else { - val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType) - val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType) - val gen = generateEquals(ctx, left, right) - (gen.code, gen.resultTerm) - } - val leftReadCode = rowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType) - val rightReadCode = rowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType) - s""" - |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i); - |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i); - |boolean $result; - |if ($leftNullTerm && $rightNullTerm) { - | $result = true; - |} else if ($leftNullTerm|| $rightNullTerm) { - | $result = false; - |} else { - | $fieldTypeTerm $leftFieldTerm = $leftReadCode; - | $fieldTypeTerm $rightFieldTerm = $rightReadCode; - | $equalsCode - | $result = $equalsResult; - |} - |if (!$result) { - | return false; - |} - """.stripMargin + + val equalsMethodCodes = for (idx <- fieldTypes.indices) yield generateEqualsMethod(ctx, idx) + val equalsMethodCalls = for (idx <- fieldTypes.indices) yield { + val methodName = getEqualsMethodName(idx) + s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);""" } - val functionCode = + val classCode = j""" public final class $className implements $RECORD_EQUALISER { - ${ctx.reuseMemberCode()} public $className(Object[] references) throws Exception { @@ -121,17 +64,98 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { public boolean equals($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) { if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { return $LEFT_INPUT.equals($RIGHT_INPUT); - } else { - $header - ${ctx.reuseLocalVariableCode()} - ${codes.mkString("\n")} - return true; } + + if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) { + return false; + } + + boolean result = true; + ${equalsMethodCalls.mkString("\n")} + return result; } + + ${equalsMethodCodes.mkString("\n")} } """.stripMargin - new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray) + new GeneratedRecordEqualiser(className, classCode, ctx.references.toArray) + } + + private def getEqualsMethodName(idx: Int) = s"""equalsAtIndex$idx""" + + private def generateEqualsMethod(ctx: CodeGeneratorContext, idx: Int): String = { + val methodName = getEqualsMethodName(idx) + ctx.startNewLocalVariableStatement(methodName) + + val Seq(leftNullTerm, rightNullTerm) = ctx.addReusableLocalVariables( + ("boolean", "isNullLeft"), + ("boolean", "isNullRight") + ) + + val fieldType = fieldTypes(idx) + val fieldTypeTerm = primitiveTypeTermForType(fieldType) + val Seq(leftFieldTerm, rightFieldTerm) = ctx.addReusableLocalVariables( + (fieldTypeTerm, "leftField"), + (fieldTypeTerm, "rightField") + ) + + val leftReadCode = rowFieldReadAccess(ctx, idx, LEFT_INPUT, fieldType) + val rightReadCode = rowFieldReadAccess(ctx, idx, RIGHT_INPUT, fieldType) + + val (equalsCode, equalsResult) = generateEqualsCode(ctx, fieldType, + leftFieldTerm, rightFieldTerm, leftNullTerm, rightNullTerm) + + s""" + |private boolean $methodName($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) { + | ${ctx.reuseLocalVariableCode(methodName)} + | + | $leftNullTerm = $LEFT_INPUT.isNullAt($idx); + | $rightNullTerm = $RIGHT_INPUT.isNullAt($idx); + | if ($leftNullTerm && $rightNullTerm) { + | return true; + | } + | + | if ($leftNullTerm || $rightNullTerm) { + | return false; + | } + | + | $leftFieldTerm = $leftReadCode; + | $rightFieldTerm = $rightReadCode; + | $equalsCode + | + | return $equalsResult; + |} + """.stripMargin + } + + private def generateEqualsCode(ctx: CodeGeneratorContext, fieldType: LogicalType, + leftFieldTerm: String, rightFieldTerm: String, + leftNullTerm: String, rightNullTerm: String) = { + // TODO merge ScalarOperatorGens.generateEquals. + if (isInternalPrimitive(fieldType)) { + ("", s"$leftFieldTerm == $rightFieldTerm") + } else if (isCompositeType(fieldType)) { + val equaliserGenerator = new EqualiserCodeGenerator( + getFieldTypes(fieldType).asScala.toArray) + val generatedEqualiser = equaliserGenerator.generateRecordEqualiser("fieldGeneratedEqualiser") + val generatedEqualiserTerm = ctx.addReusableObject( + generatedEqualiser, "fieldGeneratedEqualiser") + val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName + val equaliserTerm = newName("equaliser") + ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") + ctx.addReusableInitStatement( + s""" + |$equaliserTerm = ($equaliserTypeTerm) + | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); + |""".stripMargin) + ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)") + } else { + val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType) + val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType) + val gen = generateEquals(ctx, left, right) + (gen.code, gen.resultTerm) + } } @tailrec diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java index fe0e9a6..9ba9d00 100644 --- a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.StringData; import org.apache.flink.table.data.TimestampData; import org.apache.flink.table.data.binary.BinaryRowData; import org.apache.flink.table.data.writer.BinaryRowWriter; @@ -30,13 +31,16 @@ import org.apache.flink.table.runtime.typeutils.RawValueDataSerializer; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.TimestampType; import org.apache.flink.table.types.logical.TypeInformationRawType; +import org.apache.flink.table.types.logical.VarCharType; import org.junit.Assert; import org.junit.Test; import java.util.function.Function; +import java.util.stream.IntStream; import static org.apache.flink.table.data.TimestampData.fromEpochMillis; +import static org.junit.Assert.assertTrue; /** Test for {@link EqualiserCodeGenerator}. */ public class EqualiserCodeGeneratorTest { @@ -81,6 +85,36 @@ public class EqualiserCodeGeneratorTest { assertBoolean(equaliser, func, fromEpochMillis(1024), fromEpochMillis(1025), false); } + @Test + public void testManyFields() { + final LogicalType[] fieldTypes = + IntStream.range(0, 999) + .mapToObj(i -> new VarCharType()) + .toArray(LogicalType[]::new); + + RecordEqualiser equaliser; + try { + equaliser = + new EqualiserCodeGenerator(fieldTypes) + .generateRecordEqualiser("ManyFields") + .newInstance(Thread.currentThread().getContextClassLoader()); + } catch (Exception e) { + Assert.fail("Expected compilation to succeed"); + + // Unreachable + throw e; + } + + final StringData[] fields = + IntStream.range(0, 999) + .mapToObj(i -> StringData.fromString("Entry " + i)) + .toArray(StringData[]::new); + assertTrue( + equaliser.equals( + GenericRowData.of((Object[]) fields), + GenericRowData.of((Object[]) fields))); + } + private static <T> void assertBoolean( RecordEqualiser equaliser, Function<T, BinaryRowData> toBinaryRow,