This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 376de8a502f [SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals or SparkSession 376de8a502f is described below commit 376de8a502fca6b46d7f21560a60024d643144ea Author: Ankur Dave <ankurd...@gmail.com> AuthorDate: Mon Oct 23 10:47:42 2023 +0800 [SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals or SparkSession ### What changes were proposed in this pull request? `CastSuiteBase` and `ExpressionInfoSuite` use `ParVector.foreach()` to run Spark SQL queries in parallel. They incorrectly assume that each parallel operation will inherit the main thread’s active SparkSession. This is only true when these parallel operations run in freshly-created threads. However, when other code has already run some parallel operations before Spark was started, then there may be existing threads that do not have an active SparkSession. In that case, these tests fai [...] The fix is to use the existing method `ThreadUtils.parmap()`. This method creates fresh threads that inherit the current active SparkSession, and it propagates the Spark ThreadLocals. This PR also adds a scalastyle warning against use of ParVector. ### Why are the changes needed? This change makes `CastSuiteBase` and `ExpressionInfoSuite` less brittle to future changes that may run parallel operations during test startup. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Reproduced the test failures by running a ParVector operation before Spark starts. Verified that this PR fixes the test failures in this condition. ```scala protected override def beforeAll(): Unit = { // Run a ParVector operation before initializing the SparkSession. This starts some Scala // execution context threads that have no active SparkSession. These threads will be reused for // later ParVector operations, reproducing SPARK-45616. new ParVector((0 until 100).toVector).foreach { _ => } super.beforeAll() } ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43466 from ankurdave/SPARK-45616. Authored-by: Ankur Dave <ankurd...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala | 2 ++ core/src/main/scala/org/apache/spark/util/ThreadUtils.scala | 4 ++++ scalastyle-config.xml | 12 ++++++++++++ .../spark/sql/catalyst/expressions/CastSuiteBase.scala | 9 ++++++--- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 2 ++ .../apache/spark/sql/expressions/ExpressionInfoSuite.scala | 11 ++++++----- .../main/scala/org/apache/spark/streaming/DStreamGraph.scala | 4 ++++ .../apache/spark/streaming/util/FileBasedWriteAheadLog.scala | 2 ++ 8 files changed, 38 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 0a930234437..3c1451a0185 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { val parRDDs = if (isPartitionListingParallel) { + // scalastyle:off parvector val parArray = new ParVector(rdds.toVector) parArray.tasksupport = UnionRDD.partitionEvalTaskSupport + // scalastyle:on parvector parArray } else { rdds diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 16d7de56c39..2d3d6ec89ff 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -363,6 +363,10 @@ private[spark] object ThreadUtils { * Comparing to the map() method of Scala parallel collections, this method can be interrupted * at any time. This is useful on canceling of task execution, for example. * + * Functions are guaranteed to be executed in freshly-created threads that inherit the calling + * thread's Spark thread-local variables. These threads also inherit the calling thread's active + * SparkSession. + * * @param in - the input collection which should be transformed in parallel. * @param prefix - the prefix assigned to the underlying thread pool. * @param maxThreads - maximum number of thread can be created during execution. diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 987b4235c19..2077769c71d 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -227,6 +227,18 @@ This file is divided into 3 sections: ]]></customMessage> </check> + <check customId="parvector" level="error" class="org.scalastyle.file.RegexChecker" enabled="true"> + <parameters><parameter name="regex">new.*ParVector</parameter></parameters> + <customMessage><![CDATA[ + Are you sure you want to create a ParVector? It will not automatically propagate Spark ThreadLocals or the + active SparkSession for the submitted tasks. In most cases, you should use ThreadUtils.parmap instead. + If you must use ParVector, then wrap your creation of the ParVector with + // scalastyle:off parvector + ...ParVector... + // scalastyle:on parvector + ]]></customMessage> + </check> + <check customId="caselocale" level="error" class="org.scalastyle.file.RegexChecker" enabled="true"> <parameters><parameter name="regex">(\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\)))</parameter></parameters> <customMessage><![CDATA[ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 0172fd9b3e4..1ce311a5544 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period} import java.time.temporal.ChronoUnit import java.util.{Calendar, Locale, TimeZone} -import scala.collection.parallel.immutable.ParVector - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND import org.apache.spark.sql.types.UpCastRule.numericPrecedence import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ThreadUtils /** * Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work @@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - new ParVector(ALL_TIMEZONES.toVector).foreach { zid => + ThreadUtils.parmap( + ALL_TIMEZONES, + prefix = "CastSuiteBase-cast-string-to-timestamp", + maxThreads = Runtime.getRuntime.availableProcessors + ) { zid => def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = { checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)), expected) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 1465e32924a..a30734abfa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -759,8 +759,10 @@ case class RepairTableCommand( val statusPar: Seq[FileStatus] = if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { // parallelize the list of partitions here, then we can have better parallelism later. + // scalastyle:off parvector val parArray = new ParVector(statuses.toVector) parArray.tasksupport = evalTaskSupport + // scalastyle:on parvector parArray.seq } else { statuses diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index fd6f0adccf7..f8dde124b31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.expressions -import scala.collection.parallel.immutable.ParVector - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +24,7 @@ import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.tags.SlowSQLTest -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} @SlowSQLTest class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { @@ -201,8 +199,11 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // The encrypt expression includes a random initialization vector to its encrypted result classOf[AesEncrypt].getName) - val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector) - parFuncs.foreach { funcId => + ThreadUtils.parmap( + spark.sessionState.functionRegistry.listFunction(), + prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples", + maxThreads = Runtime.getRuntime.availableProcessors + ) { funcId => // Examples can change settings. We clone the session to prevent tests clashing. val clonedSpark = spark.cloneSession() // Coalescing partitions can change result order, so disable it. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 43aaa7e1eea..a8f55c8b4d6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { outputStreams.foreach(_.validateAtStart()) numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]]) inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq + // scalastyle:off parvector new ParVector(inputStreams.toVector).foreach(_.start()) + // scalastyle:on parvector } } @@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def stop(): Unit = { this.synchronized { + // scalastyle:off parvector new ParVector(inputStreams.toVector).foreach(_.stop()) + // scalastyle:on parvector } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index c3f2a04d1f0..908d155908f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog { val groupSize = taskSupport.parallelismLevel.max(8) source.grouped(groupSize).flatMap { group => + // scalastyle:off parvector val parallelCollection = new ParVector(group.toVector) parallelCollection.tasksupport = taskSupport + // scalastyle:on parvector parallelCollection.map(handler) }.flatten } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org