This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f9f5db9f363 [SPARK-47689][SQL][FOLLOWUP] More accurate file path in 
TASK_WRITE_FAILED error
5f9f5db9f363 is described below

commit 5f9f5db9f3631b9ba4cde7b1b4d9b74674baeba1
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu Apr 4 22:05:09 2024 +0800

    [SPARK-47689][SQL][FOLLOWUP] More accurate file path in TASK_WRITE_FAILED 
error
    
    ### What changes were proposed in this pull request?
    
    This is a follow-up of https://github.com/apache/spark/pull/45797 . Instead 
of detecting query execution errors and not wrapping them, it's better to do 
the error wrapping only in the data writer, which has more context. We can 
provide the specific file path when the error happened, instead of the 
destination directory name.
    
    ### Why are the changes needed?
    
    better error message
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45844 from cloud-fan/write-error.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../datasources/FileFormatDataWriter.scala         | 27 +++++++--
 .../execution/datasources/FileFormatWriter.scala   | 65 +++++-----------------
 .../sql/execution/datasources/csv/CSVSuite.scala   |  4 +-
 .../sql/execution/datasources/json/JsonSuite.scala |  4 +-
 .../sql/execution/datasources/xml/XmlSuite.scala   | 11 ++--
 5 files changed, 44 insertions(+), 67 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index b9e8475e4859..1dbb6ce26f69 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -18,17 +18,20 @@ package org.apache.spark.sql.execution.datasources
 
 import scala.collection.mutable
 
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
 import org.apache.hadoop.mapreduce.TaskAttemptContext
 
+import org.apache.spark.TaskOutputFileAlreadyExistException
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec}
 import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import 
org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec
 import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
 import org.apache.spark.sql.internal.SQLConf
@@ -76,6 +79,18 @@ abstract class FileFormatDataWriter(
     releaseCurrentWriter()
   }
 
+  private def enrichWriteError[T](path: => String)(f: => T): T = try {
+    f
+  } catch {
+    case e: FetchFailedException =>
+      throw e
+    case f: FileAlreadyExistsException if SQLConf.get.fastFailFileFormatOutput 
=>
+      // If any output file to write already exists, it does not make sense to 
re-run this task.
+      // We throw the exception and let Executor throw ExceptionFailure to 
abort the job.
+      throw new TaskOutputFileAlreadyExistException(f)
+    case t: Throwable => throw 
QueryExecutionErrors.taskFailedWhileWritingRowsError(path, t)
+  }
+
   /** Writes a record. */
   def write(record: InternalRow): Unit
 
@@ -83,7 +98,9 @@ abstract class FileFormatDataWriter(
     if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
       CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, 
customMetrics)
     }
-    write(record)
+    
enrichWriteError(Option(currentWriter).map(_.path()).getOrElse(description.path))
 {
+      write(record)
+    }
   }
 
   /** Write an iterator of records. */
@@ -102,7 +119,7 @@ abstract class FileFormatDataWriter(
    * to the driver and used to update the catalog. Other information will be 
sent back to the
    * driver too and used to e.g. update the metrics in UI.
    */
-  override def commit(): WriteTaskResult = {
+  final override def commit(): WriteTaskResult = 
enrichWriteError(description.path) {
     releaseResources()
     val (taskCommitMessage, taskCommitTime) = Utils.timeTakenMs {
       committer.commitTask(taskAttemptContext)
@@ -113,7 +130,7 @@ abstract class FileFormatDataWriter(
     WriteTaskResult(taskCommitMessage, summary)
   }
 
-  def abort(): Unit = {
+  final def abort(): Unit = enrichWriteError(description.path) {
     try {
       releaseResources()
     } finally {
@@ -121,7 +138,7 @@ abstract class FileFormatDataWriter(
     }
   }
 
-  override def close(): Unit = {}
+  final override def close(): Unit = {}
 }
 
 /** FileFormatWriteTask for empty partitions */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 1df63aa14b4b..3bfa3413f679 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources
 import java.util.{Date, UUID}
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
+import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
@@ -28,7 +28,6 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
 import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.io.{FileCommitProtocol, 
SparkHadoopWriterUtils}
-import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
@@ -37,11 +36,9 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.connector.write.WriterCommitMessage
-import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, 
SQLExecution, UnsafeExternalRowSorter}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.{NextIterator, SerializableConfiguration, Utils}
+import org.apache.spark.util.{SerializableConfiguration, Utils}
 import org.apache.spark.util.ArrayImplicits._
 
 
@@ -400,31 +397,17 @@ object FileFormatWriter extends Logging {
         }
       }
 
-    try {
-      val queryFailureCapturedIterator = new 
QueryFailureCapturedIterator(iterator)
-      Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
-        // Execute the task to write rows out and commit the task.
-        dataWriter.writeWithIterator(queryFailureCapturedIterator)
-        dataWriter.commit()
-      })(catchBlock = {
-        // If there is an error, abort the task
-        dataWriter.abort()
-        logError(s"Job $jobId aborted.")
-      }, finallyBlock = {
-        dataWriter.close()
-      })
-    } catch {
-      case e: QueryFailureDuringWrite =>
-        throw e.queryFailure
-      case e: FetchFailedException =>
-        throw e
-      case f: FileAlreadyExistsException if 
SQLConf.get.fastFailFileFormatOutput =>
-        // If any output file to write already exists, it does not make sense 
to re-run this task.
-        // We throw the exception and let Executor throw ExceptionFailure to 
abort the job.
-        throw new TaskOutputFileAlreadyExistException(f)
-      case t: Throwable =>
-        throw 
QueryExecutionErrors.taskFailedWhileWritingRowsError(description.path, t)
-    }
+    Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
+      // Execute the task to write rows out and commit the task.
+      dataWriter.writeWithIterator(iterator)
+      dataWriter.commit()
+    })(catchBlock = {
+      // If there is an error, abort the task
+      dataWriter.abort()
+      logError(s"Job $jobId aborted.")
+    }, finallyBlock = {
+      dataWriter.close()
+    })
   }
 
   /**
@@ -455,25 +438,3 @@ object FileFormatWriter extends Logging {
     }
   }
 }
-
-// A exception wrapper to indicate that the error was thrown when executing 
the query, not writing
-// the data
-private class QueryFailureDuringWrite(val queryFailure: Throwable) extends 
Throwable
-
-// An iterator wrapper to rethrow any error from the given iterator with 
`QueryFailureDuringWrite`.
-private class QueryFailureCapturedIterator(data: Iterator[InternalRow])
-  extends NextIterator[InternalRow] {
-
-  override protected def getNext(): InternalRow = try {
-    if (data.hasNext) {
-      data.next()
-    } else {
-      finished = true
-      null
-    }
-  } catch {
-    case t: Throwable => throw new QueryFailureDuringWrite(t)
-  }
-
-  override protected def close(): Unit = {}
-}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 964d1ec85e15..22ea133ee19a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -1250,10 +1250,10 @@ abstract class CSVSuite
         val ex = intercept[SparkException] {
           exp.write.format("csv").option("timestampNTZFormat", 
pattern).save(path.getAbsolutePath)
         }
-        checkError(
+        checkErrorMatchPVals(
           exception = ex,
           errorClass = "TASK_WRITE_FAILED",
-          parameters = Map("path" -> actualPath))
+          parameters = Map("path" -> s"$actualPath.*"))
         val msg = ex.getCause.getMessage
         assert(
           msg.contains("Unsupported field: OffsetSeconds") ||
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 9af7511ca913..5c96df98dd23 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -3043,10 +3043,10 @@ abstract class JsonSuite
         val err = intercept[SparkException] {
           exp.write.option("timestampNTZFormat", 
pattern).json(path.getAbsolutePath)
         }
-        checkError(
+        checkErrorMatchPVals(
           exception = err,
           errorClass = "TASK_WRITE_FAILED",
-          parameters = Map("path" -> actualPath))
+          parameters = Map("path" -> s"$actualPath.*"))
 
         val msg = err.getCause.getMessage
         assert(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 9dedd5795370..ddb49657144d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -30,7 +30,6 @@ import scala.collection.mutable
 import scala.io.Source
 import scala.jdk.CollectionConverters._
 
-import org.apache.commons.lang3.StringUtils
 import org.apache.commons.lang3.exception.ExceptionUtils
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.FSDataInputStream
@@ -2451,10 +2450,10 @@ class XmlSuite
           exp.write.option("timestampNTZFormat", pattern)
             .option("rowTag", "ROW").xml(path.getAbsolutePath)
         }
-        checkError(
+        checkErrorMatchPVals(
           exception = err,
           errorClass = "TASK_WRITE_FAILED",
-          parameters = Map("path" -> actualPath))
+          parameters = Map("path" -> s"$actualPath.*"))
         val msg = err.getCause.getMessage
         assert(
           msg.contains("Unsupported field: OffsetSeconds") ||
@@ -2948,11 +2947,11 @@ class XmlSuite
                 .mode(SaveMode.Overwrite)
                 .xml(path)
             }
-            val actualPath = Path.of(dir.getAbsolutePath).toUri.toURL.toString
-            checkError(
+            val actualPath = 
Path.of(dir.getAbsolutePath).toUri.toURL.toString.stripSuffix("/")
+            checkErrorMatchPVals(
               exception = e,
               errorClass = "TASK_WRITE_FAILED",
-              parameters = Map("path" -> StringUtils.removeEnd(actualPath, 
"/")))
+              parameters = Map("path" -> s"$actualPath.*"))
             assert(e.getCause.isInstanceOf[XMLStreamException])
             assert(e.getCause.getMessage.contains(errorMsg))
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to