This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bfa0d13f3f6b [SPARK-49083][CONNECT] Allow from_xml and from_json to 
natively work with json schemas
bfa0d13f3f6b is described below

commit bfa0d13f3f6b4b662ad0f355a8db00dd1244a698
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Aug 6 10:54:09 2024 +0900

    [SPARK-49083][CONNECT] Allow from_xml and from_json to natively work with 
json schemas
    
    ### What changes were proposed in this pull request?
    We allow the `JsonToStructs` and `XmlToStructs` expressions to use a json 
schema.
    
    ### Why are the changes needed?
    A couple of reasons:
    - We want to use a reference to the `from_json` and `from_xml` methods in 
the Column API in order to make unification of the Classic and Connect Scala 
clients possible.
    - Reduce the amount of duplication between the Function API and the 
SparkConnectPlanner.
    - Make DataFrame and SQL API behave the same.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47573 from hvanhovell/SPARK-49083.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/sql/types/DataType.scala      |  3 +-
 .../spark/sql/catalyst/expressions/ExprUtils.scala |  5 +++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 32 +-------------------
 .../scala/org/apache/spark/sql/functions.scala     | 22 ++++----------
 .../org/apache/spark/sql/JsonFunctionsSuite.scala  | 34 ++++++++--------------
 5 files changed, 24 insertions(+), 72 deletions(-)

diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 6d783be11277..277d5c9458d6 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -148,7 +148,8 @@ object DataType {
         try {
           fallbackParser(schema)
         } catch {
-          case NonFatal(_) =>
+          case NonFatal(suppressed) =>
+            e.addSuppressed(suppressed)
             if (e.isInstanceOf[SparkThrowable]) {
               throw e
             }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index fde209346087..749152f135e9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -38,7 +38,10 @@ object ExprUtils extends QueryErrorsBase {
     if (exp.foldable) {
       exp.eval() match {
         case s: UTF8String if s != null =>
-          val dataType = DataType.fromDDL(s.toString)
+          val dataType = DataType.parseTypeWithFallback(
+            s.toString,
+            DataType.fromDDL,
+            DataType.fromJson)
           CharVarcharUtils.failIfHasCharVarchar(dataType)
         case _ => throw QueryCompilationErrors.unexpectedSchemaTypeError(exp)
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 7bfacf7cf064..9b6e0e461ec6 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -65,7 +65,7 @@ import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
 import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.MetricGenerator
-import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors}
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -1879,36 +1879,6 @@ class SparkConnectPlanner(
       case "unwrap_udt" if fun.getArgumentsCount == 1 =>
         Some(UnwrapUDT(transformExpression(fun.getArguments(0))))
 
-      case "from_json" if Seq(2, 3).contains(fun.getArgumentsCount) =>
-        // JsonToStructs constructor doesn't accept JSON-formatted schema.
-        extractDataTypeFromJSON(fun.getArguments(1)).map { dataType =>
-          val children = fun.getArgumentsList.asScala.map(transformExpression)
-          val schema = CharVarcharUtils.failIfHasCharVarchar(dataType)
-          var options = Map.empty[String, String]
-          if (children.length == 3) {
-            options = extractMapData(children(2), "Options")
-          }
-          JsonToStructs(schema = schema, options = options, child = 
children.head)
-        }
-
-      case "from_xml" if Seq(2, 3).contains(fun.getArgumentsCount) =>
-        // XmlToStructs constructor doesn't accept JSON-formatted schema.
-        extractDataTypeFromJSON(fun.getArguments(1)).map { dataType =>
-          val children = fun.getArgumentsList.asScala.map(transformExpression)
-          val schema = dataType match {
-            case t: StructType =>
-              CharVarcharUtils
-                .failIfHasCharVarchar(t)
-                .asInstanceOf[StructType]
-            case _ => throw 
DataTypeErrors.failedParsingStructTypeError(dataType.sql)
-          }
-          var options = Map.empty[String, String]
-          if (children.length == 3) {
-            options = extractMapData(children(2), "Options")
-          }
-          XmlToStructs(schema = schema, options = options, child = 
children.head)
-        }
-
       // Avro-specific functions
       case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
         val children = fun.getArgumentsList.asScala.map(transformExpression)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0e62e05900a5..f5945eb32336 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -31,13 +31,11 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, 
ResolvedHint}
-import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors}
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.SparkSqlParser
 import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, 
UserDefinedAggregator, UserDefinedFunction}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.types.DataType.parseTypeWithFallback
 import org.apache.spark.util.Utils
 
 /**
@@ -6531,7 +6529,7 @@ object functions {
    */
   // scalastyle:on line.size.limit
   def from_json(e: Column, schema: DataType, options: java.util.Map[String, 
String]): Column = {
-    from_json(e, CharVarcharUtils.failIfHasCharVarchar(schema), 
options.asScala.toMap)
+    from_json(e, schema, options.asScala.toMap)
   }
 
   /**
@@ -6604,11 +6602,7 @@ object functions {
    */
   // scalastyle:on line.size.limit
   def from_json(e: Column, schema: String, options: Map[String, String]): 
Column = {
-    val dataType = parseTypeWithFallback(
-      schema,
-      DataType.fromJson,
-      fallbackParser = DataType.fromDDL)
-    from_json(e, dataType, options)
+    from_json(e, lit(schema), options.asJava)
   }
 
   /**
@@ -7301,7 +7295,7 @@ object functions {
    */
   // scalastyle:on line.size.limit
   def from_xml(e: Column, schema: StructType, options: java.util.Map[String, 
String]): Column =
-    from_xml(e, lit(CharVarcharUtils.failIfHasCharVarchar(schema).sql), 
options.asScala.iterator)
+    from_xml(e, lit(schema.sql), options.asScala.iterator)
 
   // scalastyle:off line.size.limit
   /**
@@ -7322,13 +7316,7 @@ object functions {
    */
   // scalastyle:on line.size.limit
   def from_xml(e: Column, schema: String, options: java.util.Map[String, 
String]): Column = {
-    val dataType =
-      parseTypeWithFallback(schema, DataType.fromJson, fallbackParser = 
DataType.fromDDL)
-    val structType = dataType match {
-      case t: StructType => t
-      case _ => throw DataTypeErrors.failedParsingStructTypeError(schema)
-    }
-    from_xml(e, structType, options)
+    from_xml(e, lit(schema), options)
   }
 
   // scalastyle:off line.size.limit
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 7dbad885169d..67b4a3e319e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -23,7 +23,7 @@ import java.util.Locale
 
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.{SparkException, SparkIllegalArgumentException, 
SparkRuntimeException}
+import org.apache.spark.{SparkException, SparkRuntimeException}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Literal, StructsToJson}
 import org.apache.spark.sql.catalyst.expressions.Cast._
@@ -1204,42 +1204,32 @@ class JsonFunctionsSuite extends QueryTest with 
SharedSparkSession {
     val df = Seq("""{"a":1}""").toDF("json")
     val invalidJsonSchema = """{"fields": [{"a":123}], "type": "struct"}"""
     checkError(
-      exception = intercept[SparkIllegalArgumentException] {
+      exception = intercept[AnalysisException] {
         df.select(from_json($"json", invalidJsonSchema, Map.empty[String, 
String])).collect()
       },
-      errorClass = "INVALID_JSON_DATA_TYPE",
-      parameters = Map("invalidType" -> """{"a":123}"""))
+      errorClass = "PARSE_SYNTAX_ERROR",
+      parameters = Map("error" -> "'{'", "hint" -> ""),
+      ExpectedContext("from_json", getCurrentClassCallSitePattern)
+    )
 
     val invalidDataType = "MAP<INT, cow>"
-    val invalidDataTypeReason = "Unrecognized token 'MAP': " +
-      "was expecting (JSON String, Number, Array, Object or token 'null', 
'true' or 'false')\n " +
-      "at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` 
disabled); " +
-      "line: 1, column: 4]"
     checkError(
       exception = intercept[AnalysisException] {
         df.select(from_json($"json", invalidDataType, Map.empty[String, 
String])).collect()
       },
-      errorClass = "INVALID_SCHEMA.PARSE_ERROR",
-      parameters = Map(
-        "inputSchema" -> "\"MAP<INT, cow>\"",
-        "reason" -> invalidDataTypeReason
-      )
+      errorClass = "UNSUPPORTED_DATATYPE",
+      parameters = Map("typeName" -> "\"COW\""),
+      ExpectedContext("from_json", getCurrentClassCallSitePattern)
     )
 
     val invalidTableSchema = "x INT, a cow"
-    val invalidTableSchemaReason = "Unrecognized token 'x': " +
-      "was expecting (JSON String, Number, Array, Object or token 'null', 
'true' or 'false')\n" +
-      " at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` 
disabled); " +
-      "line: 1, column: 2]"
     checkError(
       exception = intercept[AnalysisException] {
         df.select(from_json($"json", invalidTableSchema, Map.empty[String, 
String])).collect()
       },
-      errorClass = "INVALID_SCHEMA.PARSE_ERROR",
-      parameters = Map(
-        "inputSchema" -> "\"x INT, a cow\"",
-        "reason" -> invalidTableSchemaReason
-      )
+      errorClass = "PARSE_SYNTAX_ERROR",
+      parameters = Map("error" -> "'INT'", "hint" -> ""),
+      ExpectedContext("from_json", getCurrentClassCallSitePattern)
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to