This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 56ba88bf0b4a9fb9956ab81505eefcd9f7074b77 Author: sunxia <[email protected]> AuthorDate: Thu Oct 31 11:53:59 2024 +0800 [FLINK-36634][table] Introduce AdaptiveJoin to support dynamic generate join operator at runtime. --- .../adaptive/AdaptiveJoinOperatorGenerator.java | 158 ++++++++++ .../planner/plan/utils/HashJoinOperatorUtil.java | 195 ++++++++++++ .../plan/utils/SorMergeJoinOperatorUtil.java | 43 ++- .../flink/table/planner/plan/utils/SortUtil.scala | 4 +- .../AdaptiveJoinOperatorGeneratorTest.java | 333 +++++++++++++++++++++ .../operators/join/adaptive/AdaptiveJoin.java | 56 ++++ .../join/adaptive/AdaptiveJoinOperatorFactory.java | 134 +++++++++ .../join/Int2HashJoinOperatorTestBase.java | 2 +- 8 files changed, 920 insertions(+), 5 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGenerator.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGenerator.java new file mode 100644 index 00000000000..04fbdc2d32d --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGenerator.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.adaptive; + +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.table.planner.plan.utils.HashJoinOperatorUtil; +import org.apache.flink.table.planner.plan.utils.OperatorType; +import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil; +import org.apache.flink.table.runtime.generated.GeneratedJoinCondition; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; +import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin; +import org.apache.flink.table.types.logical.RowType; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Implementation class for {@link AdaptiveJoin}. It can selectively generate broadcast hash join, + * shuffle hash join or shuffle merge join operator based on actual conditions. + */ +public class AdaptiveJoinOperatorGenerator implements AdaptiveJoin { + + private final int[] leftKeys; + + private final int[] rightKeys; + + private final FlinkJoinType joinType; + + private final boolean[] filterNulls; + + private final RowType leftType; + + private final RowType rightType; + + private final GeneratedJoinCondition condFunc; + + private final int leftRowSize; + + private final long leftRowCount; + + private final int rightRowSize; + + private final long rightRowCount; + + private final boolean tryDistinctBuildRow; + + private final long managedMemory; + + private final OperatorType originalJoin; + + private boolean leftIsBuild; + + private boolean isBroadcastJoin; + + public AdaptiveJoinOperatorGenerator( + int[] leftKeys, + int[] rightKeys, + FlinkJoinType joinType, + boolean[] filterNulls, + RowType leftType, + RowType rightType, + GeneratedJoinCondition condFunc, + int leftRowSize, + int rightRowSize, + long leftRowCount, + long rightRowCount, + boolean tryDistinctBuildRow, + long managedMemory, + boolean leftIsBuild, + OperatorType originalJoin) { + this.leftKeys = leftKeys; + this.rightKeys = rightKeys; + this.joinType = joinType; + this.filterNulls = filterNulls; + this.leftType = leftType; + this.rightType = rightType; + this.condFunc = condFunc; + this.leftRowSize = leftRowSize; + this.rightRowSize = rightRowSize; + this.leftRowCount = leftRowCount; + this.rightRowCount = rightRowCount; + this.tryDistinctBuildRow = tryDistinctBuildRow; + this.managedMemory = managedMemory; + checkState( + originalJoin == OperatorType.ShuffleHashJoin + || originalJoin == OperatorType.SortMergeJoin, + String.format( + "Adaptive join " + + "currently only supports adaptive optimization for ShuffleHashJoin and " + + "SortMergeJoin, not including %s.", + originalJoin.toString())); + this.leftIsBuild = leftIsBuild; + this.originalJoin = originalJoin; + } + + @Override + public StreamOperatorFactory<?> genOperatorFactory( + ClassLoader classLoader, ReadableConfig config) { + if (isBroadcastJoin || originalJoin == OperatorType.ShuffleHashJoin) { + return HashJoinOperatorUtil.generateOperatorFactory( + leftKeys, + rightKeys, + joinType, + filterNulls, + leftType, + rightType, + condFunc, + leftIsBuild, + leftRowSize, + rightRowSize, + leftRowCount, + rightRowCount, + tryDistinctBuildRow, + managedMemory, + config, + classLoader); + } else { + return SorMergeJoinOperatorUtil.generateOperatorFactory( + condFunc, + leftType, + rightType, + leftKeys, + rightKeys, + joinType, + config, + leftIsBuild, + filterNulls, + managedMemory, + classLoader); + } + } + + @Override + public FlinkJoinType getJoinType() { + return joinType; + } + + @Override + public void markAsBroadcastJoin(boolean canBroadcast, boolean leftIsBuild) { + this.isBroadcastJoin = canBroadcast; + this.leftIsBuild = leftIsBuild; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/HashJoinOperatorUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/HashJoinOperatorUtil.java new file mode 100644 index 00000000000..cd72e9a7223 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/HashJoinOperatorUtil.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.utils; + +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.LongHashJoinGenerator; +import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator; +import org.apache.flink.table.runtime.generated.GeneratedJoinCondition; +import org.apache.flink.table.runtime.generated.GeneratedProjection; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; +import org.apache.flink.table.runtime.operators.join.HashJoinOperator; +import org.apache.flink.table.runtime.operators.join.HashJoinType; +import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import java.util.stream.IntStream; + +/** + * Utility for generating hash join operator factory, including {@link HashJoinOperator} and the + * codegen version of LongHashJoinOperator generated by {@link LongHashJoinGenerator}. + */ +public class HashJoinOperatorUtil { + + public static StreamOperatorFactory<RowData> generateOperatorFactory( + int[] leftKeys, + int[] rightKeys, + FlinkJoinType joinType, + boolean[] filterNulls, + RowType leftType, + RowType rightType, + GeneratedJoinCondition condFunc, + boolean leftIsBuild, + int estimatedLeftAvgRowSize, + int estimatedRightAvgRowSize, + long estimatedLeftRowCount, + long estimatedRightRowCount, + boolean tryDistinctBuildRow, + long managedMemory, + ReadableConfig config, + ClassLoader classLoader) { + LogicalType[] keyFieldTypes = + IntStream.of(leftKeys).mapToObj(leftType::getTypeAt).toArray(LogicalType[]::new); + RowType keyType = RowType.of(keyFieldTypes); + + // projection for equals + GeneratedProjection leftProj = + ProjectionCodeGenerator.generateProjection( + new CodeGeneratorContext(config, classLoader), + "HashJoinLeftProjection", + leftType, + keyType, + leftKeys); + GeneratedProjection rightProj = + ProjectionCodeGenerator.generateProjection( + new CodeGeneratorContext(config, classLoader), + "HashJoinRightProjection", + rightType, + keyType, + rightKeys); + + GeneratedProjection buildProj; + GeneratedProjection probeProj; + int[] buildKeys; + int[] probeKeys; + RowType buildType; + RowType probeType; + int buildRowSize; + long buildRowCount; + long probeRowCount; + boolean reverseJoin = !leftIsBuild; + if (leftIsBuild) { + buildProj = leftProj; + buildType = leftType; + buildRowSize = estimatedLeftAvgRowSize; + buildRowCount = estimatedLeftRowCount; + buildKeys = leftKeys; + + probeProj = rightProj; + probeType = rightType; + probeRowCount = estimatedRightRowCount; + probeKeys = rightKeys; + } else { + buildProj = rightProj; + buildType = rightType; + buildRowSize = estimatedRightAvgRowSize; + buildRowCount = estimatedRightRowCount; + buildKeys = rightKeys; + + probeProj = leftProj; + probeType = leftType; + probeRowCount = estimatedLeftRowCount; + probeKeys = leftKeys; + } + + // operator + StreamOperatorFactory<RowData> operator; + HashJoinType hashJoinType = + HashJoinType.of( + leftIsBuild, + joinType.isLeftOuter(), + joinType.isRightOuter(), + joinType == FlinkJoinType.SEMI, + joinType == FlinkJoinType.ANTI); + + long externalBufferMemory = + config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY) + .getBytes(); + + // sort merge join function + SortMergeJoinFunction sortMergeJoinFunction = + SorMergeJoinOperatorUtil.getSortMergeJoinFunction( + classLoader, + config, + joinType, + leftType, + rightType, + leftKeys, + rightKeys, + keyType, + leftIsBuild, + filterNulls, + condFunc, + 1.0 * externalBufferMemory / managedMemory); + + boolean compressionEnabled = + config.get(ExecutionConfigOptions.TABLE_EXEC_SPILL_COMPRESSION_ENABLED); + int compressionBlockSize = + (int) + config.get(ExecutionConfigOptions.TABLE_EXEC_SPILL_COMPRESSION_BLOCK_SIZE) + .getBytes(); + if (LongHashJoinGenerator.support(hashJoinType, keyType, filterNulls)) { + operator = + LongHashJoinGenerator.gen( + config, + classLoader, + hashJoinType, + keyType, + buildType, + probeType, + buildKeys, + probeKeys, + buildRowSize, + buildRowCount, + reverseJoin, + condFunc, + leftIsBuild, + compressionEnabled, + compressionBlockSize, + sortMergeJoinFunction); + } else { + operator = + SimpleOperatorFactory.of( + HashJoinOperator.newHashJoinOperator( + hashJoinType, + leftIsBuild, + compressionEnabled, + compressionBlockSize, + condFunc, + reverseJoin, + filterNulls, + buildProj, + probeProj, + tryDistinctBuildRow, + buildRowSize, + buildRowCount, + probeRowCount, + keyType, + sortMergeJoinFunction)); + } + + return operator; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java index 891aa185786..a88b8e89d1f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java @@ -18,15 +18,18 @@ package org.apache.flink.table.planner.plan.utils; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.data.RowData; import org.apache.flink.table.planner.codegen.CodeGeneratorContext; import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator; import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator; -import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig; import org.apache.flink.table.runtime.generated.GeneratedJoinCondition; import org.apache.flink.table.runtime.operators.join.FlinkJoinType; import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction; import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator; +import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; import java.util.stream.IntStream; @@ -36,7 +39,7 @@ public class SorMergeJoinOperatorUtil { public static SortMergeJoinFunction getSortMergeJoinFunction( ClassLoader classLoader, - ExecNodeConfig config, + ReadableConfig config, FlinkJoinType joinType, RowType leftType, RowType rightType, @@ -95,5 +98,41 @@ public class SorMergeJoinOperatorUtil { filterNulls); } + public static SimpleOperatorFactory<RowData> generateOperatorFactory( + GeneratedJoinCondition condFunc, + RowType leftType, + RowType rightType, + int[] leftKeys, + int[] rightKeys, + FlinkJoinType joinType, + ReadableConfig config, + boolean leftIsSmaller, + boolean[] filterNulls, + long managedMemory, + ClassLoader classLoader) { + LogicalType[] keyFieldTypes = + IntStream.of(leftKeys).mapToObj(leftType::getTypeAt).toArray(LogicalType[]::new); + RowType keyType = RowType.of(keyFieldTypes); + long externalBufferMemory = + config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY) + .getBytes(); + + SortMergeJoinFunction sortMergeJoinFunction = + SorMergeJoinOperatorUtil.getSortMergeJoinFunction( + classLoader, + config, + joinType, + leftType, + rightType, + leftKeys, + rightKeys, + keyType, + leftIsSmaller, + filterNulls, + condFunc, + 1.0 * externalBufferMemory / managedMemory); + return SimpleOperatorFactory.of(new SortMergeJoinOperator(sortMergeJoinFunction)); + } + private SorMergeJoinOperatorUtil() {} } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala index e3c7cfc5e86..da6c1e9fbaf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala @@ -18,10 +18,10 @@ package org.apache.flink.table.planner.plan.utils import org.apache.flink.api.common.operators.Order +import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.TableException import org.apache.flink.table.planner.calcite.FlinkPlannerImpl import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator -import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec import org.apache.flink.table.types.logical.RowType @@ -123,7 +123,7 @@ object SortUtil { } def newSortGen( - config: ExecNodeConfig, + config: ReadableConfig, classLoader: ClassLoader, originalKeys: Array[Int], inputType: RowType): SortCodeGenerator = { diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGeneratorTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGeneratorTest.java new file mode 100644 index 00000000000..3baa5030cbe --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGeneratorTest.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * imitations under the License. + */ + +package org.apache.flink.table.planner.adaptive; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.planner.plan.utils.OperatorType; +import org.apache.flink.table.runtime.generated.GeneratedJoinCondition; +import org.apache.flink.table.runtime.generated.JoinCondition; +import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; +import org.apache.flink.table.runtime.operators.join.HashJoinOperator; +import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase; +import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator; +import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin; +import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.MutableObjectIterator; + +import org.junit.jupiter.api.Test; + +import static org.apache.flink.table.api.config.ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY; +import static org.apache.flink.table.planner.plan.utils.OperatorType.BroadcastHashJoin; +import static org.apache.flink.table.planner.plan.utils.OperatorType.ShuffleHashJoin; +import static org.apache.flink.table.planner.plan.utils.OperatorType.SortMergeJoin; +import static org.apache.flink.table.runtime.util.JoinUtil.getJoinType; +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link AdaptiveJoinOperatorGenerator}. */ +class AdaptiveJoinOperatorGeneratorTest extends Int2HashJoinOperatorTestBase { + + @Test + void testShuffleHashJoinTransformationCorrectness() throws Exception { + + // all cases to ShuffleHashJoin + testInnerJoin(true, ShuffleHashJoin, false, ShuffleHashJoin); + testInnerJoin(false, ShuffleHashJoin, false, ShuffleHashJoin); + + testLeftOutJoin(true, ShuffleHashJoin, false, ShuffleHashJoin); + testLeftOutJoin(false, ShuffleHashJoin, false, ShuffleHashJoin); + + testRightOutJoin(true, ShuffleHashJoin, false, ShuffleHashJoin); + testRightOutJoin(false, ShuffleHashJoin, false, ShuffleHashJoin); + + testSemiJoin(ShuffleHashJoin, false, ShuffleHashJoin); + + testAntiJoin(ShuffleHashJoin, false, ShuffleHashJoin); + + // all cases to BroadcastHashJoin + testInnerJoin(true, ShuffleHashJoin, true, BroadcastHashJoin); + testInnerJoin(false, ShuffleHashJoin, true, BroadcastHashJoin); + + testLeftOutJoin(false, ShuffleHashJoin, true, BroadcastHashJoin); + + testRightOutJoin(true, ShuffleHashJoin, true, BroadcastHashJoin); + + testSemiJoin(ShuffleHashJoin, true, BroadcastHashJoin); + + testAntiJoin(ShuffleHashJoin, true, BroadcastHashJoin); + } + + @Test + void testSortMergeJoinTransformationCorrectness() throws Exception { + // all cases to SortMergeJoin + testInnerJoin(true, SortMergeJoin, false, SortMergeJoin); + testInnerJoin(false, SortMergeJoin, false, SortMergeJoin); + + testLeftOutJoin(true, SortMergeJoin, false, SortMergeJoin); + + testRightOutJoin(true, SortMergeJoin, false, SortMergeJoin); + + testAntiJoin(SortMergeJoin, false, SortMergeJoin); + + testAntiJoin(SortMergeJoin, false, SortMergeJoin); + + // all cases to BroadcastHashJoin + testInnerJoin(true, SortMergeJoin, true, BroadcastHashJoin); + testInnerJoin(false, SortMergeJoin, true, BroadcastHashJoin); + + testLeftOutJoin(false, SortMergeJoin, true, BroadcastHashJoin); + + testRightOutJoin(true, SortMergeJoin, true, BroadcastHashJoin); + + testSemiJoin(SortMergeJoin, true, BroadcastHashJoin); + + testAntiJoin(SortMergeJoin, true, BroadcastHashJoin); + } + + private void testInnerJoin( + boolean isBuildLeft, + OperatorType originalJoinType, + boolean isBroadcast, + OperatorType expectedOperatorType) + throws Exception { + int numKeys = 100; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + MutableObjectIterator<BinaryRowData> buildInput = + new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false); + MutableObjectIterator<BinaryRowData> probeInput = + new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true); + + buildJoin( + buildInput, + probeInput, + originalJoinType, + expectedOperatorType, + false, + false, + isBuildLeft, + isBroadcast, + numKeys * buildValsPerKey * probeValsPerKey, + numKeys, + 165); + } + + private void testLeftOutJoin( + boolean isBuildLeft, + OperatorType originalJoinType, + boolean isBroadcast, + OperatorType expectedOperatorType) + throws Exception { + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + MutableObjectIterator<BinaryRowData> buildInput = + new UniformBinaryRowGenerator( + isBuildLeft ? numKeys1 : numKeys2, buildValsPerKey, true); + MutableObjectIterator<BinaryRowData> probeInput = + new UniformBinaryRowGenerator( + isBuildLeft ? numKeys2 : numKeys1, probeValsPerKey, true); + + buildJoin( + buildInput, + probeInput, + originalJoinType, + expectedOperatorType, + true, + false, + isBuildLeft, + isBroadcast, + numKeys1 * buildValsPerKey * probeValsPerKey, + numKeys1, + 165); + } + + private void testRightOutJoin( + boolean isBuildLeft, + OperatorType originalJoinType, + boolean isBroadcast, + OperatorType expectedOperatorType) + throws Exception { + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + MutableObjectIterator<BinaryRowData> buildInput = + new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + MutableObjectIterator<BinaryRowData> probeInput = + new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin( + buildInput, + probeInput, + originalJoinType, + expectedOperatorType, + false, + true, + isBuildLeft, + isBroadcast, + isBuildLeft ? 280 : 270, + numKeys2, + -1); + } + + private void testSemiJoin( + OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) + throws Exception { + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + if (originalJoinType == SortMergeJoin && !isBroadcast) { + numKeys1 = 10; + numKeys2 = 9; + buildValsPerKey = 10; + probeValsPerKey = 3; + } + MutableObjectIterator<BinaryRowData> buildInput = + new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + MutableObjectIterator<BinaryRowData> probeInput = + new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + Object operator = newOperator(FlinkJoinType.SEMI, false, isBroadcast, originalJoinType); + assertOperatorType(operator, expectedOperatorType); + joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true); + } + + private void testAntiJoin( + OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) + throws Exception { + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + if (originalJoinType == SortMergeJoin && !isBroadcast) { + numKeys1 = 10; + numKeys2 = 9; + buildValsPerKey = 10; + probeValsPerKey = 3; + } + MutableObjectIterator<BinaryRowData> buildInput = + new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + MutableObjectIterator<BinaryRowData> probeInput = + new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + Object operator = newOperator(FlinkJoinType.ANTI, false, isBroadcast, originalJoinType); + assertOperatorType(operator, expectedOperatorType); + joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true); + } + + public void buildJoin( + MutableObjectIterator<BinaryRowData> buildInput, + MutableObjectIterator<BinaryRowData> probeInput, + OperatorType originalJoinType, + OperatorType expectedOperatorType, + boolean leftOut, + boolean rightOut, + boolean buildLeft, + boolean isBroadcast, + int expectOutSize, + int expectOutKeySize, + int expectOutVal) + throws Exception { + FlinkJoinType flinkJoinType = getJoinType(leftOut, rightOut); + Object operator = newOperator(flinkJoinType, buildLeft, isBroadcast, originalJoinType); + assertOperatorType(operator, expectedOperatorType); + joinAndAssert( + operator, + buildInput, + probeInput, + expectOutSize, + expectOutKeySize, + expectOutVal, + false); + } + + public Object newOperator( + FlinkJoinType flinkJoinType, + boolean buildLeft, + boolean isBroadcast, + OperatorType operatorType) { + AdaptiveJoin adaptiveJoin = genAdaptiveJoin(flinkJoinType, operatorType); + adaptiveJoin.markAsBroadcastJoin(isBroadcast, buildLeft); + + return adaptiveJoin.genOperatorFactory(getClass().getClassLoader(), new Configuration()); + } + + public void assertOperatorType(Object operator, OperatorType expectedOperatorType) { + switch (expectedOperatorType) { + case BroadcastHashJoin: + case ShuffleHashJoin: + if (operator instanceof CodeGenOperatorFactory) { + assertThat( + ((CodeGenOperatorFactory<?>) operator) + .getGeneratedClass() + .getClassName()) + .contains("LongHashJoinOperator"); + } else { + assertThat(operator).isInstanceOf(SimpleOperatorFactory.class); + assertThat(((SimpleOperatorFactory<?>) operator).getOperator()) + .isInstanceOf(HashJoinOperator.class); + } + break; + case SortMergeJoin: + assertThat(operator).isInstanceOf(SimpleOperatorFactory.class); + assertThat(((SimpleOperatorFactory<?>) operator).getOperator()) + .isInstanceOf(SortMergeJoinOperator.class); + break; + default: + throw new IllegalArgumentException( + String.format("Unexpected operator type %s.", expectedOperatorType)); + } + } + + public AdaptiveJoin genAdaptiveJoin(FlinkJoinType flinkJoinType, OperatorType operatorType) { + GeneratedJoinCondition condFuncCode = + new GeneratedJoinCondition( + Int2HashJoinOperatorTestBase.MyJoinCondition.class.getCanonicalName(), + "", + new Object[0]) { + @Override + public JoinCondition newInstance(ClassLoader classLoader) { + return new Int2HashJoinOperatorTestBase.MyJoinCondition(new Object[0]); + } + }; + + return new AdaptiveJoinOperatorGenerator( + new int[] {0}, + new int[] {0}, + flinkJoinType, + new boolean[] {true}, + RowType.of(new IntType(), new IntType()), + RowType.of(new IntType(), new IntType()), + condFuncCode, + 20, + 10000, + 20, + 10000, + false, + TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY.defaultValue().getBytes(), + true, + operatorType); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoin.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoin.java new file mode 100755 index 00000000000..01a44788b01 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoin.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.adaptive; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; + +import java.io.Serializable; + +/** Interface for implementing an adaptive join operator. */ +@Internal +public interface AdaptiveJoin extends Serializable { + + /** + * Generates a StreamOperatorFactory for this join operator using the provided ClassLoader and + * config. + * + * @param classLoader the ClassLoader to be used for loading classes. + * @param config the configuration to be applied for creating the operator factory. + * @return a StreamOperatorFactory instance. + */ + StreamOperatorFactory<?> genOperatorFactory(ClassLoader classLoader, ReadableConfig config); + + /** + * Get the join type of the join operator. + * + * @return the join type. + */ + FlinkJoinType getJoinType(); + + /** + * Determine whether the adaptive join operator can be optimized as broadcast hash join and + * decide which input side is the build side or a smaller side. + * + * @param canBeBroadcast whether the join operator can be optimized to broadcast hash join. + * @param leftIsBuild whether the left input side is the build side. + */ + void markAsBroadcastJoin(boolean canBeBroadcast, boolean leftIsBuild); +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoinOperatorFactory.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoinOperatorFactory.java new file mode 100644 index 00000000000..35c47ad68cc --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoinOperatorFactory.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.adaptive; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.table.planner.loader.PlannerModule; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; +import org.apache.flink.util.InstantiationUtil; + +import javax.annotation.Nullable; + +import java.io.IOException; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Adaptive join factory. + * + * <p>Note: This class will hold an {@link AdaptiveJoin} and serve as a proxy class to provide an + * interface externally. Due to runtime access visibility constraints with the table-planner module, + * the {@link AdaptiveJoin} object will be serialized during the Table Planner phase and will only + * be lazily deserialized before the dynamic generation of the JobGraph. + * + * @param <OUT> The output type of the operator + */ +@Internal +public class AdaptiveJoinOperatorFactory<OUT> extends AbstractStreamOperatorFactory<OUT> + implements AdaptiveJoin { + private static final long serialVersionUID = 1L; + + private final byte[] adaptiveJoinSerialized; + + @Nullable private transient AdaptiveJoin adaptiveJoin; + + @Nullable private StreamOperatorFactory<OUT> finalFactory; + + public AdaptiveJoinOperatorFactory(byte[] adaptiveJoinSerialized) { + this.adaptiveJoinSerialized = checkNotNull(adaptiveJoinSerialized); + } + + @Override + public StreamOperatorFactory<?> genOperatorFactory( + ClassLoader classLoader, ReadableConfig config) { + checkAndLazyInitialize(); + this.finalFactory = + (StreamOperatorFactory<OUT>) adaptiveJoin.genOperatorFactory(classLoader, config); + return this.finalFactory; + } + + @Override + public FlinkJoinType getJoinType() { + checkAndLazyInitialize(); + return adaptiveJoin.getJoinType(); + } + + @Override + public void markAsBroadcastJoin(boolean canBeBroadcast, boolean leftIsBuild) { + checkAndLazyInitialize(); + adaptiveJoin.markAsBroadcastJoin(canBeBroadcast, leftIsBuild); + } + + private void checkAndLazyInitialize() { + if (this.adaptiveJoin == null) { + lazyInitialize(); + } + } + + @Override + public <T extends StreamOperator<OUT>> T createStreamOperator( + StreamOperatorParameters<OUT> parameters) { + checkNotNull( + finalFactory, + String.format( + "The OperatorFactory of task [%s] have not been initialized.", + parameters.getContainingTask())); + if (finalFactory instanceof AbstractStreamOperatorFactory) { + ((AbstractStreamOperatorFactory<OUT>) finalFactory) + .setProcessingTimeService(processingTimeService); + } + StreamOperator<OUT> operator = finalFactory.createStreamOperator(parameters); + return (T) operator; + } + + @Override + public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader classLoader) { + throw new UnsupportedOperationException( + "The method should not be invoked in the " + + "adaptive join operator for batch jobs."); + } + + private void lazyInitialize() { + if (!tryInitializeAdaptiveJoin(Thread.currentThread().getContextClassLoader())) { + boolean isSuccess = + tryInitializeAdaptiveJoin( + PlannerModule.getInstance().getSubmoduleClassLoader()); + if (!isSuccess) { + throw new RuntimeException( + "Failed to deserialize AdaptiveJoin instance. " + + "Please check whether the flink-table-planner-loader.jar is in the classpath."); + } + } + } + + private boolean tryInitializeAdaptiveJoin(ClassLoader classLoader) { + try { + this.adaptiveJoin = + InstantiationUtil.deserializeObject(adaptiveJoinSerialized, classLoader); + } catch (ClassNotFoundException | IOException e) { + return false; + } + + return true; + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java index f74cf86db23..53a8dab724f 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java @@ -58,7 +58,7 @@ import static org.apache.flink.table.runtime.util.JoinUtil.getJoinType; import static org.assertj.core.api.Assertions.assertThat; /** Base test class for {@link HashJoinOperator}. */ -abstract class Int2HashJoinOperatorTestBase implements Serializable { +public abstract class Int2HashJoinOperatorTestBase implements Serializable { public void buildJoin( MutableObjectIterator<BinaryRowData> buildInput,
