This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 75a38b9024a [SPARK-45616][CORE] Avoid ParVector, which does not 
propagate ThreadLocals or SparkSession
75a38b9024a is described below

commit 75a38b9024af3c9cfd85e916c46359f7e7315c87
Author: Ankur Dave <ankurd...@gmail.com>
AuthorDate: Mon Oct 23 10:47:42 2023 +0800

    [SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals 
or SparkSession
    
    ### What changes were proposed in this pull request?
    `CastSuiteBase` and `ExpressionInfoSuite` use `ParVector.foreach()` to run 
Spark SQL queries in parallel. They incorrectly assume that each parallel 
operation will inherit the main thread’s active SparkSession. This is only true 
when these parallel operations run in freshly-created threads. However, when 
other code has already run some parallel operations before Spark was started, 
then there may be existing threads that do not have an active SparkSession. In 
that case, these tests fai [...]
    
    The fix is to use the existing method `ThreadUtils.parmap()`. This method 
creates fresh threads that inherit the current active SparkSession, and it 
propagates the Spark ThreadLocals.
    
    This PR also adds a scalastyle warning against use of ParVector.
    
    ### Why are the changes needed?
    This change makes `CastSuiteBase` and `ExpressionInfoSuite` less brittle to 
future changes that may run parallel operations during test startup.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Reproduced the test failures by running a ParVector operation before Spark 
starts. Verified that this PR fixes the test failures in this condition.
    
    ```scala
      protected override def beforeAll(): Unit = {
        // Run a ParVector operation before initializing the SparkSession. This 
starts some Scala
        // execution context threads that have no active SparkSession. These 
threads will be reused for
        // later ParVector operations, reproducing SPARK-45616.
        new ParVector((0 until 100).toVector).foreach { _ => }
    
        super.beforeAll()
      }
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #43466 from ankurdave/SPARK-45616.
    
    Authored-by: Ankur Dave <ankurd...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 376de8a502fca6b46d7f21560a60024d643144ea)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala      |  2 ++
 core/src/main/scala/org/apache/spark/util/ThreadUtils.scala  |  4 ++++
 scalastyle-config.xml                                        | 12 ++++++++++++
 .../spark/sql/catalyst/expressions/CastSuiteBase.scala       |  9 ++++++---
 .../scala/org/apache/spark/sql/execution/command/ddl.scala   |  2 ++
 .../apache/spark/sql/expressions/ExpressionInfoSuite.scala   | 11 ++++++-----
 .../main/scala/org/apache/spark/streaming/DStreamGraph.scala |  4 ++++
 .../apache/spark/streaming/util/FileBasedWriteAheadLog.scala |  2 ++
 8 files changed, 38 insertions(+), 8 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0a930234437..3c1451a0185 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag](
 
   override def getPartitions: Array[Partition] = {
     val parRDDs = if (isPartitionListingParallel) {
+      // scalastyle:off parvector
       val parArray = new ParVector(rdds.toVector)
       parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
+      // scalastyle:on parvector
       parArray
     } else {
       rdds
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala 
b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 16d7de56c39..2d3d6ec89ff 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -363,6 +363,10 @@ private[spark] object ThreadUtils {
    * Comparing to the map() method of Scala parallel collections, this method 
can be interrupted
    * at any time. This is useful on canceling of task execution, for example.
    *
+   * Functions are guaranteed to be executed in freshly-created threads that 
inherit the calling
+   * thread's Spark thread-local variables. These threads also inherit the 
calling thread's active
+   * SparkSession.
+   *
    * @param in - the input collection which should be transformed in parallel.
    * @param prefix - the prefix assigned to the underlying thread pool.
    * @param maxThreads - maximum number of thread can be created during 
execution.
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 74e8480deaf..0ccd937e72e 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -227,6 +227,18 @@ This file is divided into 3 sections:
     ]]></customMessage>
   </check>
 
+  <check customId="parvector" level="error" 
class="org.scalastyle.file.RegexChecker" enabled="true">
+    <parameters><parameter name="regex">new.*ParVector</parameter></parameters>
+    <customMessage><![CDATA[
+      Are you sure you want to create a ParVector? It will not automatically 
propagate Spark ThreadLocals or the
+      active SparkSession for the submitted tasks. In most cases, you should 
use ThreadUtils.parmap instead.
+      If you must use ParVector, then wrap your creation of the ParVector with
+      // scalastyle:off parvector
+      ...ParVector...
+      // scalastyle:on parvector
+    ]]></customMessage>
+  </check>
+
   <check customId="caselocale" level="error" 
class="org.scalastyle.file.RegexChecker" enabled="true">
     <parameters><parameter 
name="regex">(\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\)))</parameter></parameters>
     <customMessage><![CDATA[
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 0172fd9b3e4..1ce311a5544 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period}
 import java.time.temporal.ChronoUnit
 import java.util.{Calendar, Locale, TimeZone}
 
-import scala.collection.parallel.immutable.ParVector
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, 
HOUR, MINUTE, SECOND
 import org.apache.spark.sql.types.UpCastRule.numericPrecedence
 import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
 import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ThreadUtils
 
 /**
  * Common test suite for [[Cast]] with ansi mode on and off. It only includes 
test cases that work
@@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("cast string to timestamp") {
-    new ParVector(ALL_TIMEZONES.toVector).foreach { zid =>
+    ThreadUtils.parmap(
+      ALL_TIMEZONES,
+      prefix = "CastSuiteBase-cast-string-to-timestamp",
+      maxThreads = Runtime.getRuntime.availableProcessors
+    ) { zid =>
       def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = 
{
         checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)), 
expected)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index a8f7cdb2600..bb8fea71019 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -755,8 +755,10 @@ case class RepairTableCommand(
     val statusPar: Seq[FileStatus] =
       if (partitionNames.length > 1 && statuses.length > threshold || 
partitionNames.length > 2) {
         // parallelize the list of partitions here, then we can have better 
parallelism later.
+        // scalastyle:off parvector
         val parArray = new ParVector(statuses.toVector)
         parArray.tasksupport = evalTaskSupport
+        // scalastyle:on parvector
         parArray.seq
       } else {
         statuses
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 4dd93983e87..a02137a56aa 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.expressions
 
-import scala.collection.parallel.immutable.ParVector
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +24,7 @@ import 
org.apache.spark.sql.execution.HiveResult.hiveResultString
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.tags.SlowSQLTest
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 @SlowSQLTest
 class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
@@ -197,8 +195,11 @@ class ExpressionInfoSuite extends SparkFunSuite with 
SharedSparkSession {
       // The encrypt expression includes a random initialization vector to its 
encrypted result
       classOf[AesEncrypt].getName)
 
-    val parFuncs = new 
ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
-    parFuncs.foreach { funcId =>
+    ThreadUtils.parmap(
+      spark.sessionState.functionRegistry.listFunction(),
+      prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples",
+      maxThreads = Runtime.getRuntime.availableProcessors
+    ) { funcId =>
       // Examples can change settings. We clone the session to prevent tests 
clashing.
       val clonedSpark = spark.cloneSession()
       // Coalescing partitions can change result order, so disable it.
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 43aaa7e1eea..a8f55c8b4d6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends 
Serializable with Logging {
       outputStreams.foreach(_.validateAtStart())
       numReceivers = 
inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
       inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq
+      // scalastyle:off parvector
       new ParVector(inputStreams.toVector).foreach(_.start())
+      // scalastyle:on parvector
     }
   }
 
@@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends 
Serializable with Logging {
 
   def stop(): Unit = {
     this.synchronized {
+      // scalastyle:off parvector
       new ParVector(inputStreams.toVector).foreach(_.stop())
+      // scalastyle:on parvector
     }
   }
 
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index d1f9dfb7913..4e65bc75e43 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog {
     val groupSize = taskSupport.parallelismLevel.max(8)
 
     source.grouped(groupSize).flatMap { group =>
+      // scalastyle:off parvector
       val parallelCollection = new ParVector(group.toVector)
       parallelCollection.tasksupport = taskSupport
+      // scalastyle:on parvector
       parallelCollection.map(handler)
     }.flatten
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to