This is an automated email from the ASF dual-hosted git repository.

danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 51c9c0e226a [HUDI-7906] Improve the parallelism deduce in rdd write 
(#11470)
51c9c0e226a is described below

commit 51c9c0e226ab158556de87dc0e5c3e6530b6b8c1
Author: KnightChess <981159...@qq.com>
AuthorDate: Sat Jun 22 12:29:35 2024 +0800

    [HUDI-7906] Improve the parallelism deduce in rdd write (#11470)
---
 .../org/apache/hudi/config/HoodieIndexConfig.java  |  2 +-
 .../hudi/index/simple/HoodieGlobalSimpleIndex.java |  5 +++-
 .../hudi/index/simple/HoodieSimpleIndex.java       | 10 ++++----
 .../table/action/commit/HoodieDeleteHelper.java    |  2 +-
 .../table/action/commit/HoodieWriteHelper.java     |  2 +-
 .../table/action/commit/TestWriterHelperBase.java  | 19 ---------------
 .../org/apache/hudi/data/HoodieJavaPairRDD.java    | 23 ++++++++++++++++++
 .../java/org/apache/hudi/data/HoodieJavaRDD.java   | 23 ++++++++++++++++++
 .../index/bloom/SparkHoodieBloomIndexHelper.java   |  2 +-
 .../scala/org/apache/hudi/HoodieSparkUtils.scala   |  4 ++--
 .../org/apache/hudi/data/TestHoodieJavaRDD.java    | 28 ++++++++++++++++++++++
 .../table/action/commit/TestSparkWriteHelper.java  | 23 ++++++++++++++++++
 .../org/apache/hudi/common/data/HoodieData.java    |  5 ++++
 .../apache/hudi/common/data/HoodieListData.java    |  5 ++++
 .../hudi/common/data/HoodieListPairData.java       |  5 ++++
 .../apache/hudi/common/data/HoodiePairData.java    |  5 ++++
 .../spark/sql/hudi/dml/TestInsertTable.scala       |  2 ++
 17 files changed, 134 insertions(+), 31 deletions(-)

diff --git 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/config/HoodieIndexConfig.java
 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/config/HoodieIndexConfig.java
index c80c5a2de8a..385532917c4 100644
--- 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/config/HoodieIndexConfig.java
+++ 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/config/HoodieIndexConfig.java
@@ -168,7 +168,7 @@ public class HoodieIndexConfig extends HoodieConfig {
 
   public static final ConfigProperty<String> GLOBAL_SIMPLE_INDEX_PARALLELISM = 
ConfigProperty
       .key("hoodie.global.simple.index.parallelism")
-      .defaultValue("100")
+      .defaultValue("0")
       .markAdvanced()
       .withDocumentation("Only applies if index type is GLOBAL_SIMPLE. "
           + "This limits the parallelism of fetching records from the base 
files of all table "
diff --git 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieGlobalSimpleIndex.java
 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieGlobalSimpleIndex.java
index 7432d606839..3c76ff17935 100644
--- 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieGlobalSimpleIndex.java
+++ 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieGlobalSimpleIndex.java
@@ -69,8 +69,11 @@ public class HoodieGlobalSimpleIndex extends 
HoodieSimpleIndex {
       HoodieData<HoodieRecord<R>> inputRecords, HoodieEngineContext context,
       HoodieTable hoodieTable) {
     List<Pair<String, HoodieBaseFile>> latestBaseFiles = 
getAllBaseFilesInTable(context, hoodieTable);
+    int configuredSimpleIndexParallelism = 
config.getGlobalSimpleIndexParallelism();
+    int fetchParallelism =
+        configuredSimpleIndexParallelism > 0 ? 
configuredSimpleIndexParallelism : inputRecords.deduceNumPartitions();
     HoodiePairData<String, HoodieRecordGlobalLocation> allKeysAndLocations =
-        fetchRecordGlobalLocations(context, hoodieTable, 
config.getGlobalSimpleIndexParallelism(), latestBaseFiles);
+        fetchRecordGlobalLocations(context, hoodieTable, fetchParallelism, 
latestBaseFiles);
     boolean mayContainDuplicateLookup = 
hoodieTable.getMetaClient().getTableType() == MERGE_ON_READ;
     boolean shouldUpdatePartitionPath = 
config.getGlobalSimpleIndexUpdatePartitionPath() && hoodieTable.isPartitioned();
     return tagGlobalLocationBackToRecords(inputRecords, allKeysAndLocations,
diff --git 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieSimpleIndex.java
 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieSimpleIndex.java
index cca7a43d1f9..99ffc1b47e6 100644
--- 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieSimpleIndex.java
+++ 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/index/simple/HoodieSimpleIndex.java
@@ -107,16 +107,16 @@ public class HoodieSimpleIndex
           
.getString(HoodieIndexConfig.SIMPLE_INDEX_INPUT_STORAGE_LEVEL_VALUE));
     }
 
-    int inputParallelism = inputRecords.getNumPartitions();
+    int deduceNumParallelism = inputRecords.deduceNumPartitions();
     int configuredSimpleIndexParallelism = config.getSimpleIndexParallelism();
     // NOTE: Target parallelism could be overridden by the config
-    int targetParallelism =
-        configuredSimpleIndexParallelism > 0 ? 
configuredSimpleIndexParallelism : inputParallelism;
+    int fetchParallelism =
+        configuredSimpleIndexParallelism > 0 ? 
configuredSimpleIndexParallelism : deduceNumParallelism;
     HoodiePairData<HoodieKey, HoodieRecord<R>> keyedInputRecords =
         inputRecords.mapToPair(record -> new ImmutablePair<>(record.getKey(), 
record));
     HoodiePairData<HoodieKey, HoodieRecordLocation> existingLocationsOnTable =
         fetchRecordLocationsForAffectedPartitions(keyedInputRecords.keys(), 
context, hoodieTable,
-            targetParallelism);
+            fetchParallelism);
 
     HoodieData<HoodieRecord<R>> taggedRecords =
         keyedInputRecords.leftOuterJoin(existingLocationsOnTable).map(entry -> 
{
@@ -144,7 +144,7 @@ public class HoodieSimpleIndex
       HoodieData<HoodieKey> hoodieKeys, HoodieEngineContext context, 
HoodieTable hoodieTable,
       int parallelism) {
     List<String> affectedPartitionPathList =
-        hoodieKeys.map(HoodieKey::getPartitionPath).distinct().collectAsList();
+        
hoodieKeys.map(HoodieKey::getPartitionPath).distinct(hoodieKeys.deduceNumPartitions()).collectAsList();
     List<Pair<String, HoodieBaseFile>> latestBaseFiles =
         getLatestBaseFilesForAllPartitions(affectedPartitionPathList, context, 
hoodieTable);
     return fetchRecordLocations(context, hoodieTable, parallelism, 
latestBaseFiles);
diff --git 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieDeleteHelper.java
 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieDeleteHelper.java
index 63899a4e40b..17dd4282e14 100644
--- 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieDeleteHelper.java
+++ 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieDeleteHelper.java
@@ -50,7 +50,7 @@ public class HoodieDeleteHelper<T, R> extends
     BaseDeleteHelper<T, HoodieData<HoodieRecord<T>>, HoodieData<HoodieKey>, 
HoodieData<WriteStatus>, R> {
 
   private HoodieDeleteHelper() {
-    super(HoodieData::getNumPartitions);
+    super(HoodieData::deduceNumPartitions);
   }
 
   private static class DeleteHelperHolder {
diff --git 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieWriteHelper.java
 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieWriteHelper.java
index b56ac08e16f..37bb5b64e3b 100644
--- 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieWriteHelper.java
+++ 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/action/commit/HoodieWriteHelper.java
@@ -38,7 +38,7 @@ public class HoodieWriteHelper<T, R> extends 
BaseWriteHelper<T, HoodieData<Hoodi
     HoodieData<HoodieKey>, HoodieData<WriteStatus>, R> {
 
   private HoodieWriteHelper() {
-    super(HoodieData::getNumPartitions);
+    super(HoodieData::deduceNumPartitions);
   }
 
   private static class WriteHelperHolder {
diff --git 
a/hudi-client/hudi-client-common/src/test/java/org/apache/hudi/table/action/commit/TestWriterHelperBase.java
 
b/hudi-client/hudi-client-common/src/test/java/org/apache/hudi/table/action/commit/TestWriterHelperBase.java
index 2d43b414608..fe2fbb6800e 100644
--- 
a/hudi-client/hudi-client-common/src/test/java/org/apache/hudi/table/action/commit/TestWriterHelperBase.java
+++ 
b/hudi-client/hudi-client-common/src/test/java/org/apache/hudi/table/action/commit/TestWriterHelperBase.java
@@ -19,7 +19,6 @@
 
 package org.apache.hudi.table.action.commit;
 
-import org.apache.hudi.common.data.HoodieData;
 import org.apache.hudi.common.engine.HoodieEngineContext;
 import org.apache.hudi.common.model.HoodieRecord;
 import org.apache.hudi.common.testutils.HoodieCommonTestHarness;
@@ -27,13 +26,10 @@ import org.apache.hudi.table.HoodieTable;
 
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.CsvSource;
 
 import java.io.IOException;
 import java.util.List;
 
-import static org.junit.jupiter.api.Assertions.assertEquals;
 
 /**
  * Tests for write helpers
@@ -61,21 +57,6 @@ public abstract class TestWriterHelperBase<I> extends 
HoodieCommonTestHarness {
     cleanupResources();
   }
 
-  @ParameterizedTest
-  @CsvSource({"true,0", "true,50", "false,0", "false,50"})
-  public void testCombineParallelism(boolean shouldCombine, int 
configuredShuffleParallelism) {
-    int inputParallelism = 5;
-    inputRecords = getInputRecords(
-        dataGen.generateInserts("20230915000000000", 10), inputParallelism);
-    HoodieData<HoodieRecord> outputRecords = (HoodieData<HoodieRecord>) 
writeHelper.combineOnCondition(
-        shouldCombine, inputRecords, configuredShuffleParallelism, table);
-    if (!shouldCombine || configuredShuffleParallelism == 0) {
-      assertEquals(inputParallelism, outputRecords.getNumPartitions());
-    } else {
-      assertEquals(configuredShuffleParallelism, 
outputRecords.getNumPartitions());
-    }
-  }
-
   private void initResources() throws IOException {
     initPath("dataset" + runNo);
     runNo++;
diff --git 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaPairRDD.java
 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaPairRDD.java
index 9019fb43ff0..5b422b7fe8a 100644
--- 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaPairRDD.java
+++ 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaPairRDD.java
@@ -28,7 +28,10 @@ import org.apache.hudi.common.util.Option;
 import org.apache.hudi.common.util.collection.ImmutablePair;
 import org.apache.hudi.common.util.collection.Pair;
 
+import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.sql.internal.SQLConf;
 import org.apache.spark.storage.StorageLevel;
 
 import java.util.List;
@@ -146,4 +149,24 @@ public class HoodieJavaPairRDD<K, V> implements 
HoodiePairData<K, V> {
   public List<Pair<K, V>> collectAsList() {
     return pairRDDData.map(t -> Pair.of(t._1, t._2)).collect();
   }
+
+  @Override
+  public int deduceNumPartitions() {
+    // for source rdd, the partitioner is None
+    final Optional<Partitioner> partitioner = pairRDDData.partitioner();
+    if (partitioner.isPresent()) {
+      int partPartitions = partitioner.get().numPartitions();
+      if (partPartitions > 0) {
+        return partPartitions;
+      }
+    }
+
+    if (SQLConf.get().contains(SQLConf.SHUFFLE_PARTITIONS().key())) {
+      return 
Integer.parseInt(SQLConf.get().getConfString(SQLConf.SHUFFLE_PARTITIONS().key()));
+    } else if 
(pairRDDData.context().conf().contains("spark.default.parallelism")) {
+      return pairRDDData.context().defaultParallelism();
+    } else {
+      return pairRDDData.getNumPartitions();
+    }
+  }
 }
diff --git 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaRDD.java
 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaRDD.java
index a712ee0640e..faec42368ca 100644
--- 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaRDD.java
+++ 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/data/HoodieJavaRDD.java
@@ -28,8 +28,11 @@ import 
org.apache.hudi.common.function.SerializablePairFunction;
 import org.apache.hudi.common.util.collection.MappingIterator;
 import org.apache.hudi.common.util.collection.Pair;
 
+import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.sql.internal.SQLConf;
 import org.apache.spark.storage.StorageLevel;
 
 import java.util.Iterator;
@@ -120,6 +123,26 @@ public class HoodieJavaRDD<T> implements HoodieData<T> {
     return rddData.getNumPartitions();
   }
 
+  @Override
+  public int deduceNumPartitions() {
+    // for source rdd, the partitioner is None
+    final Optional<Partitioner> partitioner = rddData.partitioner();
+    if (partitioner.isPresent()) {
+      int partPartitions = partitioner.get().numPartitions();
+      if (partPartitions > 0) {
+        return partPartitions;
+      }
+    }
+
+    if (SQLConf.get().contains(SQLConf.SHUFFLE_PARTITIONS().key())) {
+      return 
Integer.parseInt(SQLConf.get().getConfString(SQLConf.SHUFFLE_PARTITIONS().key()));
+    } else if (rddData.context().conf().contains("spark.default.parallelism")) 
{
+      return rddData.context().defaultParallelism();
+    } else {
+      return rddData.getNumPartitions();
+    }
+  }
+
   @Override
   public <O> HoodieData<O> map(SerializableFunction<T, O> func) {
     return HoodieJavaRDD.of(rddData.map(func::apply));
diff --git 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/bloom/SparkHoodieBloomIndexHelper.java
 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/bloom/SparkHoodieBloomIndexHelper.java
index 5f17a78bad8..b0c1e284465 100644
--- 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/bloom/SparkHoodieBloomIndexHelper.java
+++ 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/bloom/SparkHoodieBloomIndexHelper.java
@@ -88,7 +88,7 @@ public class SparkHoodieBloomIndexHelper extends 
BaseHoodieBloomIndexHelper {
       Map<String, List<BloomIndexFileInfo>> partitionToFileInfo,
       Map<String, Long> recordsPerPartition) {
 
-    int inputParallelism = 
HoodieJavaPairRDD.getJavaPairRDD(partitionRecordKeyPairs).getNumPartitions();
+    int inputParallelism = partitionRecordKeyPairs.deduceNumPartitions();
     int configuredBloomIndexParallelism = config.getBloomIndexParallelism();
 
     // NOTE: Target parallelism could be overridden by the config
diff --git 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
index ac78b77097e..82e4f218f65 100644
--- 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
+++ 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
@@ -108,7 +108,7 @@ object HoodieSparkUtils extends SparkAdapterSupport with 
SparkVersionsSupport wi
     //       Additionally, we have to explicitly wrap around resulting [[RDD]] 
into the one
     //       injecting [[SQLConf]], which by default isn't propagated by Spark 
to the executor(s).
     //       [[SQLConf]] is required by [[AvroSerializer]]
-    injectSQLConf(df.queryExecution.toRdd.mapPartitions { rows =>
+    injectSQLConf(df.queryExecution.toRdd.mapPartitions (rows => {
       if (rows.isEmpty) {
         Iterator.empty
       } else {
@@ -126,7 +126,7 @@ object HoodieSparkUtils extends SparkAdapterSupport with 
SparkVersionsSupport wi
 
         rows.map { ir => transform(convert(ir)) }
       }
-    }, SQLConf.get)
+    }, preservesPartitioning = true), SQLConf.get)
   }
 
   def injectSQLConf[T: ClassTag](rdd: RDD[T], conf: SQLConf): RDD[T] =
diff --git 
a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/data/TestHoodieJavaRDD.java
 
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/data/TestHoodieJavaRDD.java
index 75958883048..a2617b592d6 100644
--- 
a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/data/TestHoodieJavaRDD.java
+++ 
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/data/TestHoodieJavaRDD.java
@@ -20,8 +20,11 @@
 package org.apache.hudi.data;
 
 import org.apache.hudi.common.data.HoodieData;
+import org.apache.hudi.common.data.HoodiePairData;
+import org.apache.hudi.common.util.collection.Pair;
 import org.apache.hudi.testutils.HoodieClientTestBase;
 
+import org.apache.spark.sql.internal.SQLConf;
 import org.junit.jupiter.api.Test;
 
 import java.util.stream.Collectors;
@@ -37,4 +40,29 @@ public class TestHoodieJavaRDD extends HoodieClientTestBase {
         IntStream.rangeClosed(0, 100).boxed().collect(Collectors.toList()), 
numPartitions));
     assertEquals(numPartitions, rddData.getNumPartitions());
   }
+
+  @Test
+  public void testDeduceNumPartitions() {
+    int numPartitions = 100;
+    jsc.sc().conf().remove("spark.default.parallelism");
+    SQLConf.get().unsetConf("spark.sql.shuffle.partitions");
+
+    // rdd parallelize
+    SQLConf.get().setConfString("spark.sql.shuffle.partitions", "5");
+    HoodieData<Integer> rddData = HoodieJavaRDD.of(jsc.parallelize(
+        IntStream.rangeClosed(0, 100).boxed().collect(Collectors.toList()), 
numPartitions));
+    assertEquals(5, rddData.deduceNumPartitions());
+
+    // sql parallelize
+    SQLConf.get().unsetConf("spark.sql.shuffle.partitions");
+    jsc.sc().conf().set("spark.default.parallelism", "6");
+    rddData = HoodieJavaRDD.of(jsc.parallelize(
+        IntStream.rangeClosed(0, 100).boxed().collect(Collectors.toList()), 
numPartitions));
+    assertEquals(6, rddData.deduceNumPartitions());
+
+    // use partitioner num
+    HoodiePairData<Integer, Integer> shuffleRDD = rddData.mapToPair(key -> 
Pair.of(key, 1))
+        .reduceByKey((p1, p2) -> p1, 11);
+    assertEquals(11, shuffleRDD.deduceNumPartitions());
+  }
 }
diff --git 
a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/table/action/commit/TestSparkWriteHelper.java
 
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/table/action/commit/TestSparkWriteHelper.java
index 5689de996eb..54f0558f5c6 100644
--- 
a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/table/action/commit/TestSparkWriteHelper.java
+++ 
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/table/action/commit/TestSparkWriteHelper.java
@@ -30,6 +30,8 @@ import org.apache.hudi.testutils.HoodieClientTestUtils;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
 
 import java.util.List;
 
@@ -73,4 +75,25 @@ public class TestSparkWriteHelper extends 
TestWriterHelperBase<HoodieData<Hoodie
     }
     this.context = null;
   }
+
+  @ParameterizedTest
+  @CsvSource({"true,0", "true,50", "false,0", "false,50"})
+  public void testCombineParallelism(boolean shouldCombine, int 
configuredShuffleParallelism) {
+    int inputParallelism = 5;
+    int expectDefaultParallelism = 4;
+    inputRecords = getInputRecords(
+        dataGen.generateInserts("20230915000000000", 10), inputParallelism);
+    HoodieData<HoodieRecord> outputRecords = (HoodieData<HoodieRecord>) 
writeHelper.combineOnCondition(
+        shouldCombine, inputRecords, configuredShuffleParallelism, table);
+
+    if (shouldCombine) {
+      if (configuredShuffleParallelism == 0) {
+        assertEquals(expectDefaultParallelism, 
outputRecords.getNumPartitions());
+      } else {
+        assertEquals(configuredShuffleParallelism, 
outputRecords.getNumPartitions());
+      }
+    } else {
+      assertEquals(inputParallelism, outputRecords.getNumPartitions());
+    }
+  }
 }
diff --git 
a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieData.java 
b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieData.java
index 60820d5a0ce..e65a8f426bd 100644
--- a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieData.java
+++ b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieData.java
@@ -93,6 +93,11 @@ public interface HoodieData<T> extends Serializable {
    */
   int getNumPartitions();
 
+  /**
+   * @return the deduce number of shuffle partitions
+   */
+  int deduceNumPartitions();
+
   /**
    * Maps every element in the collection using provided mapping {@code func}.
    * <p>
diff --git 
a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListData.java 
b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListData.java
index 4d9980a3575..690ab71c090 100644
--- a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListData.java
+++ b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListData.java
@@ -201,6 +201,11 @@ public class HoodieListData<T> extends 
HoodieBaseListData<T> implements HoodieDa
     return 1;
   }
 
+  @Override
+  public int deduceNumPartitions() {
+    return 1;
+  }
+
   @Override
   public List<T> collectAsList() {
     return super.collectAsList();
diff --git 
a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListPairData.java 
b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListPairData.java
index 39ce1411575..b55d2f5be98 100644
--- 
a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListPairData.java
+++ 
b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodieListPairData.java
@@ -201,6 +201,11 @@ public class HoodieListPairData<K, V> extends 
HoodieBaseListData<Pair<K, V>> imp
     return super.collectAsList();
   }
 
+  @Override
+  public int deduceNumPartitions() {
+    return 1;
+  }
+
   public static <K, V> HoodieListPairData<K, V> lazy(List<Pair<K, V>> data) {
     return new HoodieListPairData<>(data, true);
   }
diff --git 
a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodiePairData.java 
b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodiePairData.java
index 1d3622786fd..d9815063b86 100644
--- a/hudi-common/src/main/java/org/apache/hudi/common/data/HoodiePairData.java
+++ b/hudi-common/src/main/java/org/apache/hudi/common/data/HoodiePairData.java
@@ -129,4 +129,9 @@ public interface HoodiePairData<K, V> extends Serializable {
    * This is a terminal operation
    */
   List<Pair<K, V>> collectAsList();
+
+  /**
+   * @return the deduce number of shuffle partitions
+   */
+  int deduceNumPartitions();
 }
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
index 3994fc1bcca..ec9e90be7c1 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
@@ -69,6 +69,8 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
              location '${tablePath}'
              """.stripMargin)
 
+      spark.sql("set spark.sql.shuffle.partitions = 11")
+
       spark.sql(
         s"""
            |insert into ${targetTable}

Reply via email to