srowen closed pull request #22723: [SPARK-25729][CORE]It is better to replace 
`minPartitions` with `defaultParallelism` , when `minPartitions` is less than 
`defaultParallelism`
URL: https://github.com/apache/spark/pull/22723
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala 
b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 04c5c4b90e8a1..9400879f27048 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -46,13 +46,15 @@ private[spark] class WholeTextFileInputFormat
 
   /**
    * Allow minPartitions set by end-user in order to keep compatibility with 
old Hadoop API,
-   * which is set through setMaxSplitSize
+   * which is set through setMaxSplitSize. But when minPartitions is less than 
defaultParallelism,
+   * it is better to replace minPartitions with defaultParallelism, because 
this can improve
+   * parallelism.
    */
-  def setMinPartitions(context: JobContext, minPartitions: Int) {
+  def setMinPartitions(defaultParallelism: Int, context: JobContext, 
minPartitions: Int) {
     val files = listStatus(context).asScala
     val totalLen = files.map(file => if (file.isDirectory) 0L else 
file.getLen).sum
-    val maxSplitSize = Math.ceil(totalLen * 1.0 /
-      (if (minPartitions == 0) 1 else minPartitions)).toLong
+    val minPartNum = Math.max(defaultParallelism, minPartitions)
+    val maxSplitSize = Math.ceil(totalLen * 1.0 / minPartNum).toLong
 
     // For small files we need to ensure the min split size per node & rack <= 
maxSplitSize
     val config = context.getConfiguration
diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala
index 9f3d0745c33c9..6377b677ed10c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala
@@ -30,7 +30,7 @@ import org.apache.spark.input.WholeTextFileInputFormat
  * An RDD that reads a bunch of text files in, and each text file becomes one 
record.
  */
 private[spark] class WholeTextFileRDD(
-    sc : SparkContext,
+    @transient private val sc: SparkContext,
     inputFormatClass: Class[_ <: WholeTextFileInputFormat],
     keyClass: Class[Text],
     valueClass: Class[Text],
@@ -51,7 +51,7 @@ private[spark] class WholeTextFileRDD(
       case _ =>
     }
     val jobContext = new JobContextImpl(conf, jobId)
-    inputFormat.setMinPartitions(jobContext, minPartitions)
+    inputFormat.setMinPartitions(sc.defaultParallelism, jobContext, 
minPartitions)
     val rawSplits = inputFormat.getSplits(jobContext).toArray
     val result = new Array[Partition](rawSplits.size)
     for (i <- 0 until rawSplits.size) {
diff --git 
a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
 
b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
index 817dc082b7d38..531ac936a4d5d 100644
--- 
a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
@@ -38,7 +38,7 @@ class WholeTextFileInputFormatSuite extends SparkFunSuite 
with BeforeAndAfterAll
   override def beforeAll() {
     super.beforeAll()
     val conf = new SparkConf()
-    sc = new SparkContext("local", "test", conf)
+    sc = new SparkContext("local[2]", "test", conf)
   }
 
   override def afterAll() {
@@ -79,6 +79,22 @@ class WholeTextFileInputFormatSuite extends SparkFunSuite 
with BeforeAndAfterAll
       Utils.deleteRecursively(dir)
     }
   }
+
+  test("Test the number of partitions for WholeTextFileRDD") {
+    var dir: File = null
+    try {
+      dir = Utils.createTempDir()
+      WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) 
=>
+        createNativeFile(dir, filename, contents, true)
+      }
+      // set `minPartitions = 1`
+      val rdd = sc.wholeTextFiles(dir.toString, 1)
+      // The number of partitions is equal to 2, not equal to 1, because the 
defaultParallelism is 2
+      assert(rdd.getNumPartitions === 2)
+    } finally {
+      Utils.deleteRecursively(dir)
+    }
+  }
 }
 
 /**
@@ -88,7 +104,7 @@ object WholeTextFileInputFormatSuite {
   private val testWords: IndexedSeq[Byte] = "Spark is easy to 
use.\n".map(_.toByte)
 
   private val fileNames = Array("part-00000", "part-00001", "part-00002")
-  private val fileLengths = Array(10, 100, 1000)
+  private val fileLengths = Array(10, 100, 100)
 
   private val files = fileLengths.zip(fileNames).map { case (upperBound, 
filename) =>
     filename -> 
Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to