This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 56ba88bf0b4a9fb9956ab81505eefcd9f7074b77
Author: sunxia <[email protected]>
AuthorDate: Thu Oct 31 11:53:59 2024 +0800

    [FLINK-36634][table] Introduce AdaptiveJoin to support dynamic generate 
join operator at runtime.
---
 .../adaptive/AdaptiveJoinOperatorGenerator.java    | 158 ++++++++++
 .../planner/plan/utils/HashJoinOperatorUtil.java   | 195 ++++++++++++
 .../plan/utils/SorMergeJoinOperatorUtil.java       |  43 ++-
 .../flink/table/planner/plan/utils/SortUtil.scala  |   4 +-
 .../AdaptiveJoinOperatorGeneratorTest.java         | 333 +++++++++++++++++++++
 .../operators/join/adaptive/AdaptiveJoin.java      |  56 ++++
 .../join/adaptive/AdaptiveJoinOperatorFactory.java | 134 +++++++++
 .../join/Int2HashJoinOperatorTestBase.java         |   2 +-
 8 files changed, 920 insertions(+), 5 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGenerator.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGenerator.java
new file mode 100644
index 00000000000..04fbdc2d32d
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGenerator.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.adaptive;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.table.planner.plan.utils.HashJoinOperatorUtil;
+import org.apache.flink.table.planner.plan.utils.OperatorType;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
+import org.apache.flink.table.types.logical.RowType;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Implementation class for {@link AdaptiveJoin}. It can selectively generate 
broadcast hash join,
+ * shuffle hash join or shuffle merge join operator based on actual conditions.
+ */
+public class AdaptiveJoinOperatorGenerator implements AdaptiveJoin {
+
+    private final int[] leftKeys;
+
+    private final int[] rightKeys;
+
+    private final FlinkJoinType joinType;
+
+    private final boolean[] filterNulls;
+
+    private final RowType leftType;
+
+    private final RowType rightType;
+
+    private final GeneratedJoinCondition condFunc;
+
+    private final int leftRowSize;
+
+    private final long leftRowCount;
+
+    private final int rightRowSize;
+
+    private final long rightRowCount;
+
+    private final boolean tryDistinctBuildRow;
+
+    private final long managedMemory;
+
+    private final OperatorType originalJoin;
+
+    private boolean leftIsBuild;
+
+    private boolean isBroadcastJoin;
+
+    public AdaptiveJoinOperatorGenerator(
+            int[] leftKeys,
+            int[] rightKeys,
+            FlinkJoinType joinType,
+            boolean[] filterNulls,
+            RowType leftType,
+            RowType rightType,
+            GeneratedJoinCondition condFunc,
+            int leftRowSize,
+            int rightRowSize,
+            long leftRowCount,
+            long rightRowCount,
+            boolean tryDistinctBuildRow,
+            long managedMemory,
+            boolean leftIsBuild,
+            OperatorType originalJoin) {
+        this.leftKeys = leftKeys;
+        this.rightKeys = rightKeys;
+        this.joinType = joinType;
+        this.filterNulls = filterNulls;
+        this.leftType = leftType;
+        this.rightType = rightType;
+        this.condFunc = condFunc;
+        this.leftRowSize = leftRowSize;
+        this.rightRowSize = rightRowSize;
+        this.leftRowCount = leftRowCount;
+        this.rightRowCount = rightRowCount;
+        this.tryDistinctBuildRow = tryDistinctBuildRow;
+        this.managedMemory = managedMemory;
+        checkState(
+                originalJoin == OperatorType.ShuffleHashJoin
+                        || originalJoin == OperatorType.SortMergeJoin,
+                String.format(
+                        "Adaptive join "
+                                + "currently only supports adaptive 
optimization for ShuffleHashJoin and "
+                                + "SortMergeJoin, not including %s.",
+                        originalJoin.toString()));
+        this.leftIsBuild = leftIsBuild;
+        this.originalJoin = originalJoin;
+    }
+
+    @Override
+    public StreamOperatorFactory<?> genOperatorFactory(
+            ClassLoader classLoader, ReadableConfig config) {
+        if (isBroadcastJoin || originalJoin == OperatorType.ShuffleHashJoin) {
+            return HashJoinOperatorUtil.generateOperatorFactory(
+                    leftKeys,
+                    rightKeys,
+                    joinType,
+                    filterNulls,
+                    leftType,
+                    rightType,
+                    condFunc,
+                    leftIsBuild,
+                    leftRowSize,
+                    rightRowSize,
+                    leftRowCount,
+                    rightRowCount,
+                    tryDistinctBuildRow,
+                    managedMemory,
+                    config,
+                    classLoader);
+        } else {
+            return SorMergeJoinOperatorUtil.generateOperatorFactory(
+                    condFunc,
+                    leftType,
+                    rightType,
+                    leftKeys,
+                    rightKeys,
+                    joinType,
+                    config,
+                    leftIsBuild,
+                    filterNulls,
+                    managedMemory,
+                    classLoader);
+        }
+    }
+
+    @Override
+    public FlinkJoinType getJoinType() {
+        return joinType;
+    }
+
+    @Override
+    public void markAsBroadcastJoin(boolean canBroadcast, boolean leftIsBuild) 
{
+        this.isBroadcastJoin = canBroadcast;
+        this.leftIsBuild = leftIsBuild;
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/HashJoinOperatorUtil.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/HashJoinOperatorUtil.java
new file mode 100644
index 00000000000..cd72e9a7223
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/HashJoinOperatorUtil.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.utils;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.LongHashJoinGenerator;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.HashJoinOperator;
+import org.apache.flink.table.runtime.operators.join.HashJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+
+import java.util.stream.IntStream;
+
+/**
+ * Utility for generating hash join operator factory, including {@link 
HashJoinOperator} and the
+ * codegen version of LongHashJoinOperator generated by {@link 
LongHashJoinGenerator}.
+ */
+public class HashJoinOperatorUtil {
+
+    public static StreamOperatorFactory<RowData> generateOperatorFactory(
+            int[] leftKeys,
+            int[] rightKeys,
+            FlinkJoinType joinType,
+            boolean[] filterNulls,
+            RowType leftType,
+            RowType rightType,
+            GeneratedJoinCondition condFunc,
+            boolean leftIsBuild,
+            int estimatedLeftAvgRowSize,
+            int estimatedRightAvgRowSize,
+            long estimatedLeftRowCount,
+            long estimatedRightRowCount,
+            boolean tryDistinctBuildRow,
+            long managedMemory,
+            ReadableConfig config,
+            ClassLoader classLoader) {
+        LogicalType[] keyFieldTypes =
+                
IntStream.of(leftKeys).mapToObj(leftType::getTypeAt).toArray(LogicalType[]::new);
+        RowType keyType = RowType.of(keyFieldTypes);
+
+        // projection for equals
+        GeneratedProjection leftProj =
+                ProjectionCodeGenerator.generateProjection(
+                        new CodeGeneratorContext(config, classLoader),
+                        "HashJoinLeftProjection",
+                        leftType,
+                        keyType,
+                        leftKeys);
+        GeneratedProjection rightProj =
+                ProjectionCodeGenerator.generateProjection(
+                        new CodeGeneratorContext(config, classLoader),
+                        "HashJoinRightProjection",
+                        rightType,
+                        keyType,
+                        rightKeys);
+
+        GeneratedProjection buildProj;
+        GeneratedProjection probeProj;
+        int[] buildKeys;
+        int[] probeKeys;
+        RowType buildType;
+        RowType probeType;
+        int buildRowSize;
+        long buildRowCount;
+        long probeRowCount;
+        boolean reverseJoin = !leftIsBuild;
+        if (leftIsBuild) {
+            buildProj = leftProj;
+            buildType = leftType;
+            buildRowSize = estimatedLeftAvgRowSize;
+            buildRowCount = estimatedLeftRowCount;
+            buildKeys = leftKeys;
+
+            probeProj = rightProj;
+            probeType = rightType;
+            probeRowCount = estimatedRightRowCount;
+            probeKeys = rightKeys;
+        } else {
+            buildProj = rightProj;
+            buildType = rightType;
+            buildRowSize = estimatedRightAvgRowSize;
+            buildRowCount = estimatedRightRowCount;
+            buildKeys = rightKeys;
+
+            probeProj = leftProj;
+            probeType = leftType;
+            probeRowCount = estimatedLeftRowCount;
+            probeKeys = leftKeys;
+        }
+
+        // operator
+        StreamOperatorFactory<RowData> operator;
+        HashJoinType hashJoinType =
+                HashJoinType.of(
+                        leftIsBuild,
+                        joinType.isLeftOuter(),
+                        joinType.isRightOuter(),
+                        joinType == FlinkJoinType.SEMI,
+                        joinType == FlinkJoinType.ANTI);
+
+        long externalBufferMemory =
+                
config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)
+                        .getBytes();
+
+        // sort merge join function
+        SortMergeJoinFunction sortMergeJoinFunction =
+                SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+                        classLoader,
+                        config,
+                        joinType,
+                        leftType,
+                        rightType,
+                        leftKeys,
+                        rightKeys,
+                        keyType,
+                        leftIsBuild,
+                        filterNulls,
+                        condFunc,
+                        1.0 * externalBufferMemory / managedMemory);
+
+        boolean compressionEnabled =
+                
config.get(ExecutionConfigOptions.TABLE_EXEC_SPILL_COMPRESSION_ENABLED);
+        int compressionBlockSize =
+                (int)
+                        
config.get(ExecutionConfigOptions.TABLE_EXEC_SPILL_COMPRESSION_BLOCK_SIZE)
+                                .getBytes();
+        if (LongHashJoinGenerator.support(hashJoinType, keyType, filterNulls)) 
{
+            operator =
+                    LongHashJoinGenerator.gen(
+                            config,
+                            classLoader,
+                            hashJoinType,
+                            keyType,
+                            buildType,
+                            probeType,
+                            buildKeys,
+                            probeKeys,
+                            buildRowSize,
+                            buildRowCount,
+                            reverseJoin,
+                            condFunc,
+                            leftIsBuild,
+                            compressionEnabled,
+                            compressionBlockSize,
+                            sortMergeJoinFunction);
+        } else {
+            operator =
+                    SimpleOperatorFactory.of(
+                            HashJoinOperator.newHashJoinOperator(
+                                    hashJoinType,
+                                    leftIsBuild,
+                                    compressionEnabled,
+                                    compressionBlockSize,
+                                    condFunc,
+                                    reverseJoin,
+                                    filterNulls,
+                                    buildProj,
+                                    probeProj,
+                                    tryDistinctBuildRow,
+                                    buildRowSize,
+                                    buildRowCount,
+                                    probeRowCount,
+                                    keyType,
+                                    sortMergeJoinFunction));
+        }
+
+        return operator;
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
index 891aa185786..a88b8e89d1f 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
@@ -18,15 +18,18 @@
 
 package org.apache.flink.table.planner.plan.utils;
 
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
 import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
 import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator;
-import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
 import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
 import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
 import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
 import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
+import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.RowType;
 
 import java.util.stream.IntStream;
@@ -36,7 +39,7 @@ public class SorMergeJoinOperatorUtil {
 
     public static SortMergeJoinFunction getSortMergeJoinFunction(
             ClassLoader classLoader,
-            ExecNodeConfig config,
+            ReadableConfig config,
             FlinkJoinType joinType,
             RowType leftType,
             RowType rightType,
@@ -95,5 +98,41 @@ public class SorMergeJoinOperatorUtil {
                 filterNulls);
     }
 
+    public static SimpleOperatorFactory<RowData> generateOperatorFactory(
+            GeneratedJoinCondition condFunc,
+            RowType leftType,
+            RowType rightType,
+            int[] leftKeys,
+            int[] rightKeys,
+            FlinkJoinType joinType,
+            ReadableConfig config,
+            boolean leftIsSmaller,
+            boolean[] filterNulls,
+            long managedMemory,
+            ClassLoader classLoader) {
+        LogicalType[] keyFieldTypes =
+                
IntStream.of(leftKeys).mapToObj(leftType::getTypeAt).toArray(LogicalType[]::new);
+        RowType keyType = RowType.of(keyFieldTypes);
+        long externalBufferMemory =
+                
config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)
+                        .getBytes();
+
+        SortMergeJoinFunction sortMergeJoinFunction =
+                SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+                        classLoader,
+                        config,
+                        joinType,
+                        leftType,
+                        rightType,
+                        leftKeys,
+                        rightKeys,
+                        keyType,
+                        leftIsSmaller,
+                        filterNulls,
+                        condFunc,
+                        1.0 * externalBufferMemory / managedMemory);
+        return SimpleOperatorFactory.of(new 
SortMergeJoinOperator(sortMergeJoinFunction));
+    }
+
     private SorMergeJoinOperatorUtil() {}
 }
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
index e3c7cfc5e86..da6c1e9fbaf 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
@@ -18,10 +18,10 @@
 package org.apache.flink.table.planner.plan.utils
 
 import org.apache.flink.api.common.operators.Order
+import org.apache.flink.configuration.ReadableConfig
 import org.apache.flink.table.api.TableException
 import org.apache.flink.table.planner.calcite.FlinkPlannerImpl
 import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator
-import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig
 import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec
 import org.apache.flink.table.types.logical.RowType
 
@@ -123,7 +123,7 @@ object SortUtil {
   }
 
   def newSortGen(
-      config: ExecNodeConfig,
+      config: ReadableConfig,
       classLoader: ClassLoader,
       originalKeys: Array[Int],
       inputType: RowType): SortCodeGenerator = {
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGeneratorTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGeneratorTest.java
new file mode 100644
index 00000000000..3baa5030cbe
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/adaptive/AdaptiveJoinOperatorGeneratorTest.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * imitations under the License.
+ */
+
+package org.apache.flink.table.planner.adaptive;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.planner.plan.utils.OperatorType;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.HashJoinOperator;
+import 
org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
+import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
+import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.MutableObjectIterator;
+
+import org.junit.jupiter.api.Test;
+
+import static 
org.apache.flink.table.api.config.ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY;
+import static 
org.apache.flink.table.planner.plan.utils.OperatorType.BroadcastHashJoin;
+import static 
org.apache.flink.table.planner.plan.utils.OperatorType.ShuffleHashJoin;
+import static 
org.apache.flink.table.planner.plan.utils.OperatorType.SortMergeJoin;
+import static org.apache.flink.table.runtime.util.JoinUtil.getJoinType;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link AdaptiveJoinOperatorGenerator}. */
+class AdaptiveJoinOperatorGeneratorTest extends Int2HashJoinOperatorTestBase {
+
+    @Test
+    void testShuffleHashJoinTransformationCorrectness() throws Exception {
+
+        // all cases to ShuffleHashJoin
+        testInnerJoin(true, ShuffleHashJoin, false, ShuffleHashJoin);
+        testInnerJoin(false, ShuffleHashJoin, false, ShuffleHashJoin);
+
+        testLeftOutJoin(true, ShuffleHashJoin, false, ShuffleHashJoin);
+        testLeftOutJoin(false, ShuffleHashJoin, false, ShuffleHashJoin);
+
+        testRightOutJoin(true, ShuffleHashJoin, false, ShuffleHashJoin);
+        testRightOutJoin(false, ShuffleHashJoin, false, ShuffleHashJoin);
+
+        testSemiJoin(ShuffleHashJoin, false, ShuffleHashJoin);
+
+        testAntiJoin(ShuffleHashJoin, false, ShuffleHashJoin);
+
+        // all cases to BroadcastHashJoin
+        testInnerJoin(true, ShuffleHashJoin, true, BroadcastHashJoin);
+        testInnerJoin(false, ShuffleHashJoin, true, BroadcastHashJoin);
+
+        testLeftOutJoin(false, ShuffleHashJoin, true, BroadcastHashJoin);
+
+        testRightOutJoin(true, ShuffleHashJoin, true, BroadcastHashJoin);
+
+        testSemiJoin(ShuffleHashJoin, true, BroadcastHashJoin);
+
+        testAntiJoin(ShuffleHashJoin, true, BroadcastHashJoin);
+    }
+
+    @Test
+    void testSortMergeJoinTransformationCorrectness() throws Exception {
+        // all cases to SortMergeJoin
+        testInnerJoin(true, SortMergeJoin, false, SortMergeJoin);
+        testInnerJoin(false, SortMergeJoin, false, SortMergeJoin);
+
+        testLeftOutJoin(true, SortMergeJoin, false, SortMergeJoin);
+
+        testRightOutJoin(true, SortMergeJoin, false, SortMergeJoin);
+
+        testAntiJoin(SortMergeJoin, false, SortMergeJoin);
+
+        testAntiJoin(SortMergeJoin, false, SortMergeJoin);
+
+        // all cases to BroadcastHashJoin
+        testInnerJoin(true, SortMergeJoin, true, BroadcastHashJoin);
+        testInnerJoin(false, SortMergeJoin, true, BroadcastHashJoin);
+
+        testLeftOutJoin(false, SortMergeJoin, true, BroadcastHashJoin);
+
+        testRightOutJoin(true, SortMergeJoin, true, BroadcastHashJoin);
+
+        testSemiJoin(SortMergeJoin, true, BroadcastHashJoin);
+
+        testAntiJoin(SortMergeJoin, true, BroadcastHashJoin);
+    }
+
+    private void testInnerJoin(
+            boolean isBuildLeft,
+            OperatorType originalJoinType,
+            boolean isBroadcast,
+            OperatorType expectedOperatorType)
+            throws Exception {
+        int numKeys = 100;
+        int buildValsPerKey = 3;
+        int probeValsPerKey = 10;
+        MutableObjectIterator<BinaryRowData> buildInput =
+                new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true);
+
+        buildJoin(
+                buildInput,
+                probeInput,
+                originalJoinType,
+                expectedOperatorType,
+                false,
+                false,
+                isBuildLeft,
+                isBroadcast,
+                numKeys * buildValsPerKey * probeValsPerKey,
+                numKeys,
+                165);
+    }
+
+    private void testLeftOutJoin(
+            boolean isBuildLeft,
+            OperatorType originalJoinType,
+            boolean isBroadcast,
+            OperatorType expectedOperatorType)
+            throws Exception {
+        int numKeys1 = 9;
+        int numKeys2 = 10;
+        int buildValsPerKey = 3;
+        int probeValsPerKey = 10;
+        MutableObjectIterator<BinaryRowData> buildInput =
+                new UniformBinaryRowGenerator(
+                        isBuildLeft ? numKeys1 : numKeys2, buildValsPerKey, 
true);
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(
+                        isBuildLeft ? numKeys2 : numKeys1, probeValsPerKey, 
true);
+
+        buildJoin(
+                buildInput,
+                probeInput,
+                originalJoinType,
+                expectedOperatorType,
+                true,
+                false,
+                isBuildLeft,
+                isBroadcast,
+                numKeys1 * buildValsPerKey * probeValsPerKey,
+                numKeys1,
+                165);
+    }
+
+    private void testRightOutJoin(
+            boolean isBuildLeft,
+            OperatorType originalJoinType,
+            boolean isBroadcast,
+            OperatorType expectedOperatorType)
+            throws Exception {
+        int numKeys1 = 9;
+        int numKeys2 = 10;
+        int buildValsPerKey = 3;
+        int probeValsPerKey = 10;
+        MutableObjectIterator<BinaryRowData> buildInput =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        buildJoin(
+                buildInput,
+                probeInput,
+                originalJoinType,
+                expectedOperatorType,
+                false,
+                true,
+                isBuildLeft,
+                isBroadcast,
+                isBuildLeft ? 280 : 270,
+                numKeys2,
+                -1);
+    }
+
+    private void testSemiJoin(
+            OperatorType originalJoinType, boolean isBroadcast, OperatorType 
expectedOperatorType)
+            throws Exception {
+        int numKeys1 = 9;
+        int numKeys2 = 10;
+        int buildValsPerKey = 3;
+        int probeValsPerKey = 10;
+        if (originalJoinType == SortMergeJoin && !isBroadcast) {
+            numKeys1 = 10;
+            numKeys2 = 9;
+            buildValsPerKey = 10;
+            probeValsPerKey = 3;
+        }
+        MutableObjectIterator<BinaryRowData> buildInput =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        Object operator = newOperator(FlinkJoinType.SEMI, false, isBroadcast, 
originalJoinType);
+        assertOperatorType(operator, expectedOperatorType);
+        joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
+    }
+
+    private void testAntiJoin(
+            OperatorType originalJoinType, boolean isBroadcast, OperatorType 
expectedOperatorType)
+            throws Exception {
+        int numKeys1 = 9;
+        int numKeys2 = 10;
+        int buildValsPerKey = 3;
+        int probeValsPerKey = 10;
+        if (originalJoinType == SortMergeJoin && !isBroadcast) {
+            numKeys1 = 10;
+            numKeys2 = 9;
+            buildValsPerKey = 10;
+            probeValsPerKey = 3;
+        }
+        MutableObjectIterator<BinaryRowData> buildInput =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        Object operator = newOperator(FlinkJoinType.ANTI, false, isBroadcast, 
originalJoinType);
+        assertOperatorType(operator, expectedOperatorType);
+        joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
+    }
+
+    public void buildJoin(
+            MutableObjectIterator<BinaryRowData> buildInput,
+            MutableObjectIterator<BinaryRowData> probeInput,
+            OperatorType originalJoinType,
+            OperatorType expectedOperatorType,
+            boolean leftOut,
+            boolean rightOut,
+            boolean buildLeft,
+            boolean isBroadcast,
+            int expectOutSize,
+            int expectOutKeySize,
+            int expectOutVal)
+            throws Exception {
+        FlinkJoinType flinkJoinType = getJoinType(leftOut, rightOut);
+        Object operator = newOperator(flinkJoinType, buildLeft, isBroadcast, 
originalJoinType);
+        assertOperatorType(operator, expectedOperatorType);
+        joinAndAssert(
+                operator,
+                buildInput,
+                probeInput,
+                expectOutSize,
+                expectOutKeySize,
+                expectOutVal,
+                false);
+    }
+
+    public Object newOperator(
+            FlinkJoinType flinkJoinType,
+            boolean buildLeft,
+            boolean isBroadcast,
+            OperatorType operatorType) {
+        AdaptiveJoin adaptiveJoin = genAdaptiveJoin(flinkJoinType, 
operatorType);
+        adaptiveJoin.markAsBroadcastJoin(isBroadcast, buildLeft);
+
+        return adaptiveJoin.genOperatorFactory(getClass().getClassLoader(), 
new Configuration());
+    }
+
+    public void assertOperatorType(Object operator, OperatorType 
expectedOperatorType) {
+        switch (expectedOperatorType) {
+            case BroadcastHashJoin:
+            case ShuffleHashJoin:
+                if (operator instanceof CodeGenOperatorFactory) {
+                    assertThat(
+                                    ((CodeGenOperatorFactory<?>) operator)
+                                            .getGeneratedClass()
+                                            .getClassName())
+                            .contains("LongHashJoinOperator");
+                } else {
+                    
assertThat(operator).isInstanceOf(SimpleOperatorFactory.class);
+                    assertThat(((SimpleOperatorFactory<?>) 
operator).getOperator())
+                            .isInstanceOf(HashJoinOperator.class);
+                }
+                break;
+            case SortMergeJoin:
+                assertThat(operator).isInstanceOf(SimpleOperatorFactory.class);
+                assertThat(((SimpleOperatorFactory<?>) operator).getOperator())
+                        .isInstanceOf(SortMergeJoinOperator.class);
+                break;
+            default:
+                throw new IllegalArgumentException(
+                        String.format("Unexpected operator type %s.", 
expectedOperatorType));
+        }
+    }
+
+    public AdaptiveJoin genAdaptiveJoin(FlinkJoinType flinkJoinType, 
OperatorType operatorType) {
+        GeneratedJoinCondition condFuncCode =
+                new GeneratedJoinCondition(
+                        
Int2HashJoinOperatorTestBase.MyJoinCondition.class.getCanonicalName(),
+                        "",
+                        new Object[0]) {
+                    @Override
+                    public JoinCondition newInstance(ClassLoader classLoader) {
+                        return new 
Int2HashJoinOperatorTestBase.MyJoinCondition(new Object[0]);
+                    }
+                };
+
+        return new AdaptiveJoinOperatorGenerator(
+                new int[] {0},
+                new int[] {0},
+                flinkJoinType,
+                new boolean[] {true},
+                RowType.of(new IntType(), new IntType()),
+                RowType.of(new IntType(), new IntType()),
+                condFuncCode,
+                20,
+                10000,
+                20,
+                10000,
+                false,
+                TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY.defaultValue().getBytes(),
+                true,
+                operatorType);
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoin.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoin.java
new file mode 100755
index 00000000000..01a44788b01
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoin.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.join.adaptive;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+
+import java.io.Serializable;
+
+/** Interface for implementing an adaptive join operator. */
+@Internal
+public interface AdaptiveJoin extends Serializable {
+
+    /**
+     * Generates a StreamOperatorFactory for this join operator using the 
provided ClassLoader and
+     * config.
+     *
+     * @param classLoader the ClassLoader to be used for loading classes.
+     * @param config the configuration to be applied for creating the operator 
factory.
+     * @return a StreamOperatorFactory instance.
+     */
+    StreamOperatorFactory<?> genOperatorFactory(ClassLoader classLoader, 
ReadableConfig config);
+
+    /**
+     * Get the join type of the join operator.
+     *
+     * @return the join type.
+     */
+    FlinkJoinType getJoinType();
+
+    /**
+     * Determine whether the adaptive join operator can be optimized as 
broadcast hash join and
+     * decide which input side is the build side or a smaller side.
+     *
+     * @param canBeBroadcast whether the join operator can be optimized to 
broadcast hash join.
+     * @param leftIsBuild whether the left input side is the build side.
+     */
+    void markAsBroadcastJoin(boolean canBeBroadcast, boolean leftIsBuild);
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoinOperatorFactory.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoinOperatorFactory.java
new file mode 100644
index 00000000000..35c47ad68cc
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/adaptive/AdaptiveJoinOperatorFactory.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.join.adaptive;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.table.planner.loader.PlannerModule;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.util.InstantiationUtil;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Adaptive join factory.
+ *
+ * <p>Note: This class will hold an {@link AdaptiveJoin} and serve as a proxy 
class to provide an
+ * interface externally. Due to runtime access visibility constraints with the 
table-planner module,
+ * the {@link AdaptiveJoin} object will be serialized during the Table Planner 
phase and will only
+ * be lazily deserialized before the dynamic generation of the JobGraph.
+ *
+ * @param <OUT> The output type of the operator
+ */
+@Internal
+public class AdaptiveJoinOperatorFactory<OUT> extends 
AbstractStreamOperatorFactory<OUT>
+        implements AdaptiveJoin {
+    private static final long serialVersionUID = 1L;
+
+    private final byte[] adaptiveJoinSerialized;
+
+    @Nullable private transient AdaptiveJoin adaptiveJoin;
+
+    @Nullable private StreamOperatorFactory<OUT> finalFactory;
+
+    public AdaptiveJoinOperatorFactory(byte[] adaptiveJoinSerialized) {
+        this.adaptiveJoinSerialized = checkNotNull(adaptiveJoinSerialized);
+    }
+
+    @Override
+    public StreamOperatorFactory<?> genOperatorFactory(
+            ClassLoader classLoader, ReadableConfig config) {
+        checkAndLazyInitialize();
+        this.finalFactory =
+                (StreamOperatorFactory<OUT>) 
adaptiveJoin.genOperatorFactory(classLoader, config);
+        return this.finalFactory;
+    }
+
+    @Override
+    public FlinkJoinType getJoinType() {
+        checkAndLazyInitialize();
+        return adaptiveJoin.getJoinType();
+    }
+
+    @Override
+    public void markAsBroadcastJoin(boolean canBeBroadcast, boolean 
leftIsBuild) {
+        checkAndLazyInitialize();
+        adaptiveJoin.markAsBroadcastJoin(canBeBroadcast, leftIsBuild);
+    }
+
+    private void checkAndLazyInitialize() {
+        if (this.adaptiveJoin == null) {
+            lazyInitialize();
+        }
+    }
+
+    @Override
+    public <T extends StreamOperator<OUT>> T createStreamOperator(
+            StreamOperatorParameters<OUT> parameters) {
+        checkNotNull(
+                finalFactory,
+                String.format(
+                        "The OperatorFactory of task [%s] have not been 
initialized.",
+                        parameters.getContainingTask()));
+        if (finalFactory instanceof AbstractStreamOperatorFactory) {
+            ((AbstractStreamOperatorFactory<OUT>) finalFactory)
+                    .setProcessingTimeService(processingTimeService);
+        }
+        StreamOperator<OUT> operator = 
finalFactory.createStreamOperator(parameters);
+        return (T) operator;
+    }
+
+    @Override
+    public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader 
classLoader) {
+        throw new UnsupportedOperationException(
+                "The method should not be invoked in the "
+                        + "adaptive join operator for batch jobs.");
+    }
+
+    private void lazyInitialize() {
+        if 
(!tryInitializeAdaptiveJoin(Thread.currentThread().getContextClassLoader())) {
+            boolean isSuccess =
+                    tryInitializeAdaptiveJoin(
+                            
PlannerModule.getInstance().getSubmoduleClassLoader());
+            if (!isSuccess) {
+                throw new RuntimeException(
+                        "Failed to deserialize AdaptiveJoin instance. "
+                                + "Please check whether the 
flink-table-planner-loader.jar is in the classpath.");
+            }
+        }
+    }
+
+    private boolean tryInitializeAdaptiveJoin(ClassLoader classLoader) {
+        try {
+            this.adaptiveJoin =
+                    
InstantiationUtil.deserializeObject(adaptiveJoinSerialized, classLoader);
+        } catch (ClassNotFoundException | IOException e) {
+            return false;
+        }
+
+        return true;
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
index f74cf86db23..53a8dab724f 100644
--- 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
+++ 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
@@ -58,7 +58,7 @@ import static 
org.apache.flink.table.runtime.util.JoinUtil.getJoinType;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Base test class for {@link HashJoinOperator}. */
-abstract class Int2HashJoinOperatorTestBase implements Serializable {
+public abstract class Int2HashJoinOperatorTestBase implements Serializable {
 
     public void buildJoin(
             MutableObjectIterator<BinaryRowData> buildInput,


Reply via email to