This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ce7a889ad9fa [SPARK-48461][SQL] Replace NullPointerExceptions with 
error class in AssertNotNull expression
ce7a889ad9fa is described below

commit ce7a889ad9fae55ca6ffdd262d538239f60be1ca
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Fri May 31 08:36:12 2024 +0900

    [SPARK-48461][SQL] Replace NullPointerExceptions with error class in 
AssertNotNull expression
    
    ### What changes were proposed in this pull request?
    
    This PR replaces `NullPointerException`s with a new error class in the 
`AssertNotNull` expression.
    
    ### Why are the changes needed?
    
    We bring the advantages from the Spark error class framework to this case, 
enabling better user experiences and error classification.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR includes unit test coverage.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    GitHub copilot
    
    Closes #46793 from dtenedor/fix-npe.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/resources/error/error-conditions.json |  6 +++
 .../sql/catalyst/expressions/objects/objects.scala | 10 ++---
 .../spark/sql/errors/QueryExecutionErrors.scala    |  9 ++++
 .../catalyst/encoders/EncoderResolutionSuite.scala | 16 +++++--
 .../sql/catalyst/encoders/RowEncoderSuite.scala    |  8 ++--
 .../expressions/NullExpressionsSuite.scala         | 13 +++---
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 51 ++++++++++++++--------
 .../spark/sql/RuntimeNullChecksV2Writes.scala      | 30 ++++++-------
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 14 +++---
 .../sql/connector/MergeIntoTableSuiteBase.scala    | 14 +++---
 .../spark/sql/connector/UpdateTableSuiteBase.scala | 13 +++---
 .../org/apache/spark/sql/sources/InsertSuite.scala |  8 ++--
 12 files changed, 114 insertions(+), 78 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 66708649e564..3914c0f177dc 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -3204,6 +3204,12 @@
     ],
     "sqlState" : "42809"
   },
+  "NOT_NULL_ASSERT_VIOLATION" : {
+    "message" : [
+      "NULL value appeared in non-nullable field: <walkedTypePath>If the 
schema is inferred from a Scala tuple/case class, or a Java bean, please try to 
use scala.Option[_] or other nullable types (such as java.lang.Integer instead 
of int/scala.Int)."
+    ],
+    "sqlState" : "42000"
+  },
   "NOT_NULL_CONSTRAINT_VIOLATION" : {
     "message" : [
       "Assigning a NULL is not allowed here."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 462facd180c4..32d8eebd01ce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1917,16 +1917,12 @@ case class AssertNotNull(child: Expression, 
walkedTypePath: Seq[String] = Nil)
 
   override def flatArguments: Iterator[Any] = Iterator(child)
 
-  private val errMsg = "Null value appeared in non-nullable field:" +
-    walkedTypePath.mkString("\n", "\n", "\n") +
-    "If the schema is inferred from a Scala tuple/case class, or a Java bean, 
" +
-    "please try to use scala.Option[_] or other nullable types " +
-    "(e.g. java.lang.Integer instead of int/scala.Int)."
+  private val errMsg = walkedTypePath.mkString("\n", "\n", "\n")
 
   override def eval(input: InternalRow): Any = {
     val result = child.eval(input)
     if (result == null) {
-      throw new NullPointerException(errMsg)
+      throw QueryExecutionErrors.notNullAssertViolation(errMsg)
     }
     result
   }
@@ -1940,7 +1936,7 @@ case class AssertNotNull(child: Expression, 
walkedTypePath: Seq[String] = Nil)
 
     val code = childGen.code + code"""
       if (${childGen.isNull}) {
-        throw new NullPointerException($errMsgField);
+        throw QueryExecutionErrors.notNullAssertViolation($errMsgField);
       }
      """
     ev.copy(code = code, isNull = FalseLiteral, value = childGen.value)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 1f3283ebed05..f587d87284f3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2773,4 +2773,13 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
       )
     )
   }
+
+  def notNullAssertViolation(walkedTypePath: String): SparkRuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "NOT_NULL_ASSERT_VIOLATION",
+      messageParameters = Map(
+        "walkedTypePath" -> walkedTypePath
+      )
+    )
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 82238de31f9f..9ca990b607db 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.encoders
 
 import scala.reflect.runtime.universe.TypeTag
 
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -169,10 +170,17 @@ class EncoderResolutionSuite extends PlanTest {
     fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
 
     // If there is null value, it should throw runtime exception
-    val e = intercept[RuntimeException] {
-      fromRow(InternalRow(new GenericArrayData(Array(1, null))))
-    }
-    assert(e.getCause.getMessage.contains("Null value appeared in non-nullable 
field"))
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+      },
+      errorClass = "EXPRESSION_DECODING_FAILED",
+      sqlState = "42846",
+      parameters = Map(
+        "expressions" ->
+          ("mapobjects(lambdavariable(MapObject, IntegerType, true, -1), " +
+          "assertnotnull(lambdavariable(MapObject, IntegerType, true, -1)), " +
+          "input[0, array<int>, true], Some(interface 
scala.collection.immutable.Seq))")))
   }
 
   test("the real number of fields doesn't match encoder schema: tuple 
encoder") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index df73d50fdcd6..943499fde84f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
 import scala.collection.mutable
 import scala.util.Random
 
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
@@ -275,9 +276,10 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
   test("RowEncoder should throw RuntimeException if input row object is null") 
{
     val schema = new StructType().add("int", IntegerType)
     val encoder = ExpressionEncoder(schema)
-    val e = intercept[RuntimeException](toRow(encoder, null))
-    assert(e.getCause.getMessage.contains("Null value appeared in non-nullable 
field"))
-    assert(e.getCause.getMessage.contains("top level Product or row object"))
+    // Check the error class only since the parameters may change depending on 
how we are running
+    // this test case.
+    val exception = intercept[SparkRuntimeException](toRow(encoder, null))
+    assert(exception.getErrorClass == "EXPRESSION_ENCODING_FAILED")
   }
 
   test("RowEncoder should validate external type") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index da8e11c0433e..ace017b1cddc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.sql.Timestamp
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
@@ -53,10 +53,13 @@ class NullExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("AssertNotNUll") {
-    val ex = intercept[RuntimeException] {
-      evaluateWithoutCodegen(AssertNotNull(Literal(null)))
-    }.getMessage
-    assert(ex.contains("Null value appeared in non-nullable field"))
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        evaluateWithoutCodegen(AssertNotNull(Literal(null)))
+      },
+      errorClass = "NOT_NULL_ASSERT_VIOLATION",
+      sqlState = "42000",
+      parameters = Map("walkedTypePath" -> "\n\n"))
   }
 
   test("IsNaN") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 16a493b52909..10d6f045db39 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.Assertions._
 import org.scalatest.exceptions.TestFailedException
 import org.scalatest.prop.TableDrivenPropertyChecks._
 
-import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, 
SparkUnsupportedOperationException, TaskContext}
+import org.apache.spark.{SparkConf, SparkRuntimeException, 
SparkUnsupportedOperationException, TaskContext}
 import org.apache.spark.TestUtils.withListener
 import org.apache.spark.internal.config.MAX_RESULT_SIZE
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -1251,11 +1251,10 @@ class DatasetSuite extends QueryTest
     // Shouldn't throw runtime exception when parent object (`ClassData`) is 
null
     assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))
 
-    val message = intercept[RuntimeException] {
+    // Just check the error class here to avoid flakiness due to different 
parameters.
+    assert(intercept[SparkRuntimeException] {
       buildDataset(Row(Row("hello", null))).collect()
-    }.getCause.getMessage
-
-    assert(message.contains("Null value appeared in non-nullable field"))
+    }.getErrorClass == "EXPRESSION_DECODING_FAILED")
   }
 
   test("SPARK-12478: top level null field") {
@@ -1593,9 +1592,8 @@ class DatasetSuite extends QueryTest
   }
 
   test("Dataset should throw RuntimeException if top-level product input 
object is null") {
-    val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
-    assert(e.getCause.getMessage.contains("Null value appeared in non-nullable 
field"))
-    assert(e.getCause.getMessage.contains("top level Product or row object"))
+    val e = intercept[SparkRuntimeException](Seq(ClassData("a", 1), 
null).toDS())
+    assert(e.getErrorClass == "EXPRESSION_ENCODING_FAILED")
   }
 
   test("dropDuplicates") {
@@ -2038,19 +2036,33 @@ class DatasetSuite extends QueryTest
   test("SPARK-22472: add null check for top-level primitive values") {
     // If the primitive values are from Option, we need to do runtime null 
check.
     val ds = Seq(Some(1), None).toDS().as[Int]
-    val e1 = intercept[RuntimeException](ds.collect())
-    assert(e1.getCause.isInstanceOf[NullPointerException])
-    val e2 = intercept[SparkException](ds.map(_ * 2).collect())
-    assert(e2.getCause.isInstanceOf[NullPointerException])
+    val errorClass = "EXPRESSION_DECODING_FAILED"
+    val sqlState = "42846"
+    checkError(
+      exception = intercept[SparkRuntimeException](ds.collect()),
+      errorClass = "EXPRESSION_DECODING_FAILED",
+      sqlState = "42846",
+      parameters = Map("expressions" -> "assertnotnull(input[0, int, true])"))
+    checkError(
+      exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()),
+      errorClass = "NOT_NULL_ASSERT_VIOLATION",
+      sqlState = "42000",
+      parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n"))
 
     withTempPath { path =>
       Seq(Integer.valueOf(1), 
null).toDF("i").write.parquet(path.getCanonicalPath)
       // If the primitive values are from files, we need to do runtime null 
check.
       val ds = spark.read.parquet(path.getCanonicalPath).as[Int]
-      val e1 = intercept[RuntimeException](ds.collect())
-      assert(e1.getCause.isInstanceOf[NullPointerException])
-      val e2 = intercept[SparkException](ds.map(_ * 2).collect())
-      assert(e2.getCause.isInstanceOf[NullPointerException])
+      checkError(
+        exception = intercept[SparkRuntimeException](ds.collect()),
+        errorClass = "EXPRESSION_DECODING_FAILED",
+        sqlState = "42846",
+        parameters = Map("expressions" -> "assertnotnull(input[0, int, 
true])"))
+      checkError(
+        exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()),
+        errorClass = "NOT_NULL_ASSERT_VIOLATION",
+        sqlState = "42000",
+        parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n"))
     }
   }
 
@@ -2068,8 +2080,11 @@ class DatasetSuite extends QueryTest
 
   test("SPARK-23835: null primitive data type should throw 
NullPointerException") {
     val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS()
-    val e = intercept[RuntimeException](ds.as[(Int, Int)].collect())
-    assert(e.getCause.isInstanceOf[NullPointerException])
+    checkError(
+      exception = intercept[SparkRuntimeException](ds.as[(Int, 
Int)].collect()),
+      errorClass = "EXPRESSION_DECODING_FAILED",
+      sqlState = "42846",
+      parameters = Map("expressions" -> "newInstance(class scala.Tuple2)"))
   }
 
   test("SPARK-24569: Option of primitive types are mistakenly mapped to struct 
type") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala
index fbdd1428ba9b..754c46cc5cd3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
 
 import java.util.Collections
 
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkRuntimeException}
 import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Identifier, 
InMemoryTableCatalog}
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.internal.SQLConf
@@ -56,7 +56,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
     withTable("t") {
       sql(s"CREATE TABLE t (s STRING, i INT NOT NULL) USING $FORMAT")
 
-      val e = intercept[SparkException] {
+      val e = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql("SELECT 'txt' AS s, null AS i")
           inputDF.writeTo("t").append()
@@ -64,7 +64,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
           sql("INSERT INTO t VALUES ('txt', null)")
         }
       }
-      assertNotNullException(e, Seq("i"))
+      assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
     }
   }
 
@@ -88,7 +88,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
            |USING $FORMAT
          """.stripMargin)
 
-      val e1 = intercept[SparkException] {
+      val e1 = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -106,7 +106,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       assertNotNullException(e1, Seq("s", "ns"))
 
-      val e2 = intercept[SparkException] {
+      val e2 = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -124,7 +124,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       assertNotNullException(e2, Seq("s", "arr"))
 
-      val e3 = intercept[SparkException] {
+      val e3 = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -177,7 +177,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       checkAnswer(spark.table("t"), Row(1, Row(1, null)))
 
-      val e = intercept[SparkException] {
+      val e = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -224,7 +224,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       checkAnswer(spark.table("t"), Row(1, null))
 
-      val e = intercept[SparkException] {
+      val e = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -279,7 +279,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       checkAnswer(spark.table("t"), Row(1, List(null, Row(1, 1))))
 
-      val e = intercept[SparkException] {
+      val e = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -325,7 +325,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       checkAnswer(spark.table("t"), Row(1, null))
 
-      val e = intercept[SparkException] {
+      val e = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql("SELECT 1 AS i, map(1, null) AS m")
           inputDF.writeTo("t").append()
@@ -364,7 +364,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       checkAnswer(spark.table("t"), Row(1, Map(Row(1, 1) -> null)))
 
-      val e1 = intercept[SparkException] {
+      val e1 = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -382,7 +382,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
       }
       assertNotNullException(e1, Seq("m", "key", "x"))
 
-      val e2 = intercept[SparkException] {
+      val e2 = intercept[SparkRuntimeException] {
         if (byName) {
           val inputDF = sql(
             s"""SELECT
@@ -402,11 +402,9 @@ class RuntimeNullChecksV2Writes extends QueryTest with 
SQLTestUtils with SharedS
     }
   }
 
-  private def assertNotNullException(e: SparkException, colPath: Seq[String]): 
Unit = {
+  private def assertNotNullException(e: SparkRuntimeException, colPath: 
Seq[String]): Unit = {
     e.getCause match {
-      case npe: NullPointerException =>
-        assert(npe.getMessage.contains("Null value appeared in non-nullable 
field"))
-        assert(npe.getMessage.contains(colPath.mkString("\n", "\n", "\n")))
+      case _ if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" =>
       case other =>
         fail(s"Unexpected exception cause: $other")
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index d89c0a2525fd..14b9feb2951a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -24,7 +24,7 @@ import java.util.Locale
 import scala.concurrent.duration.MICROSECONDS
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
+import org.apache.spark.{SparkException, SparkRuntimeException, 
SparkUnsupportedOperationException}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
 import org.apache.spark.sql.catalyst.InternalRow
@@ -814,14 +814,10 @@ class DataSourceV2SQLSuiteV1Filter
         if (nullable) {
           insertNullValueAndCheck()
         } else {
-          // TODO assign a error-classes name
-          checkError(
-            exception = intercept[SparkException] {
-              insertNullValueAndCheck()
-            },
-            errorClass = null,
-            parameters = Map.empty
-          )
+          val exception = intercept[SparkRuntimeException] {
+            insertNullValueAndCheck()
+          }
+          assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
         }
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 0b643ca534e3..9d4e4fc01672 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.connector
 
-import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, 
In, Not}
 import org.apache.spark.sql.catalyst.optimizer.BuildLeft
@@ -1317,7 +1317,7 @@ abstract class MergeIntoTableSuiteBase extends 
RowLevelOperationSuiteBase {
 
       Seq(1, 4).toDF("pk").createOrReplaceTempView("source")
 
-      val e1 = intercept[SparkException] {
+      val e1 = intercept[SparkRuntimeException] {
         sql(
           s"""MERGE INTO $tableNameAsString t
              |USING source s
@@ -1326,9 +1326,9 @@ abstract class MergeIntoTableSuiteBase extends 
RowLevelOperationSuiteBase {
              | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L)
              |""".stripMargin)
       }
-      assert(e1.getCause.getMessage.contains("Null value appeared in 
non-nullable field"))
+      assert(e1.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
 
-      val e2 = intercept[SparkException] {
+      val e2 = intercept[SparkRuntimeException] {
         sql(
           s"""MERGE INTO $tableNameAsString t
              |USING source s
@@ -1337,9 +1337,9 @@ abstract class MergeIntoTableSuiteBase extends 
RowLevelOperationSuiteBase {
              | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L)
              |""".stripMargin)
       }
-      assert(e2.getCause.getMessage.contains("Null value appeared in 
non-nullable field"))
+      assert(e2.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
 
-      val e3 = intercept[SparkException] {
+      val e3 = intercept[SparkRuntimeException] {
         sql(
           s"""MERGE INTO $tableNameAsString t
              |USING source s
@@ -1348,7 +1348,7 @@ abstract class MergeIntoTableSuiteBase extends 
RowLevelOperationSuiteBase {
              | INSERT (pk, s, dep) VALUES (s.pk, named_struct('n_i', null, 
'n_l', -1L), 'invalid')
              |""".stripMargin)
       }
-      assert(e3.getCause.getMessage.contains("Null value appeared in 
non-nullable field"))
+      assert(e3.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
index b43101c2e025..c2ae5f40cfaf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.connector
 
-import org.apache.spark.SparkException
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue}
 import org.apache.spark.sql.connector.expressions.LiteralValue
@@ -575,9 +575,12 @@ abstract class UpdateTableSuiteBase extends 
RowLevelOperationSuiteBase {
         |{ "pk": 3, "s": { "n_i": 3, "n_l": 33 }, "dep": "hr" }
         |""".stripMargin)
 
-    val e = intercept[SparkException] {
-      sql(s"UPDATE $tableNameAsString SET s = named_struct('n_i', null, 'n_l', 
-1L) WHERE pk = 1")
-    }
-    assert(e.getCause.getMessage.contains("Null value appeared in non-nullable 
field"))
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        sql(s"UPDATE $tableNameAsString SET s = named_struct('n_i', null, 
'n_l', -1L) WHERE pk = 1")
+      },
+      errorClass = "NOT_NULL_ASSERT_VIOLATION",
+      sqlState = "42000",
+      parameters = Map("walkedTypePath" -> "\ns\nn_i\n"))
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 93698fdd7bc0..e3e385e9d181 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -23,7 +23,7 @@ import java.time.{Duration, Period}
 
 import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataOutputStream, 
Path, RawLocalFileSystem}
 
-import org.apache.spark.{SparkArithmeticException, SparkException}
+import org.apache.spark.{SparkArithmeticException, SparkException, 
SparkRuntimeException}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, 
CatalogTable, CatalogTableType}
@@ -953,10 +953,10 @@ class InsertSuite extends DataSourceTest with 
SharedSparkSession {
       spark.sessionState.catalog.createTable(newTable, false)
 
       sql("INSERT INTO TABLE test_table SELECT 1, 'a'")
-      val msg = intercept[SparkException] {
+      val msg = intercept[SparkRuntimeException] {
         sql("INSERT INTO TABLE test_table SELECT 2, null")
-      }.getCause.getMessage
-      assert(msg.contains("Null value appeared in non-nullable field"))
+      }
+      assert(msg.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to