Repository: spark
Updated Branches:
  refs/heads/master ab2eafb3c -> 8d54bf79f


[SPARK-26099][SQL] Verification of the corrupt column in from_csv/from_json

## What changes were proposed in this pull request?

The corrupt column specified via JSON/CSV option *columnNameOfCorruptRecord* 
must have the `string` type and be `nullable`. This has been already checked in 
`DataFrameReader`.`csv`/`json` and in `Json`/`CsvFileFormat` but not in 
`from_json`/`from_csv`. The PR adds such checks inside functions as well.

## How was this patch tested?

Added tests to `Json`/`CsvExpressionSuite` for checking type of the corrupt 
column. They don't check the `nullable` property because `schema` is forcibly 
casted to nullable.

Closes #23070 from MaxGekk/verify-corrupt-column-csv-json.

Authored-by: Maxim Gekk <max.g...@gmail.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d54bf79
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d54bf79
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d54bf79

Branch: refs/heads/master
Commit: 8d54bf79f215378fbd95794591a87604a5eaf7a3
Parents: ab2eafb
Author: Maxim Gekk <max.g...@gmail.com>
Authored: Thu Nov 22 10:57:19 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Thu Nov 22 10:57:19 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/ExprUtils.scala    | 16 +++++++++++++++
 .../catalyst/expressions/csvExpressions.scala   |  4 ++++
 .../catalyst/expressions/jsonExpressions.scala  |  1 +
 .../expressions/CsvExpressionsSuite.scala       | 11 ++++++++++
 .../expressions/JsonExpressionsSuite.scala      | 11 ++++++++++
 .../org/apache/spark/sql/DataFrameReader.scala  | 21 +++-----------------
 .../datasources/csv/CSVFileFormat.scala         |  9 ++-------
 .../datasources/json/JsonFileFormat.scala       | 11 +++-------
 8 files changed, 51 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index 040b56c..89e9071 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -67,4 +67,20 @@ object ExprUtils {
     case _ =>
       throw new AnalysisException("Must use a map() function for options")
   }
+
+  /**
+   * A convenient function for schema validation in datasources supporting
+   * `columnNameOfCorruptRecord` as an option.
+   */
+  def verifyColumnNameOfCorruptRecord(
+      schema: StructType,
+      columnNameOfCorruptRecord: String): Unit = {
+    schema.getFieldIndex(columnNameOfCorruptRecord).foreach { 
corruptFieldIndex =>
+      val f = schema(corruptFieldIndex)
+      if (f.dataType != StringType || !f.nullable) {
+        throw new AnalysisException(
+          "The field for corrupt records must be string type and nullable")
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index aff372b..1e4e1c6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -106,6 +106,10 @@ case class CsvToStructs(
       throw new AnalysisException(s"from_csv() doesn't support the 
${mode.name} mode. " +
         s"Acceptable modes are ${PermissiveMode.name} and 
${FailFastMode.name}.")
     }
+    ExprUtils.verifyColumnNameOfCorruptRecord(
+      nullableSchema,
+      parsedOptions.columnNameOfCorruptRecord)
+
     val actualSchema =
       StructType(nullableSchema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
     val rawParser = new UnivocityParser(actualSchema, actualSchema, 
parsedOptions)

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 543c6c4..47304d8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -579,6 +579,7 @@ case class JsonToStructs(
     }
     val (parserSchema, actualSchema) = nullableSchema match {
       case s: StructType =>
+        ExprUtils.verifyColumnNameOfCorruptRecord(s, 
parsedOptions.columnNameOfCorruptRecord)
         (s, StructType(s.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord)))
       case other =>
         (StructType(StructField("value", other) :: Nil), other)

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
index f5aaaec..98c93a4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale}
 import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.PlanTestBase
 import org.apache.spark.sql.catalyst.util._
@@ -226,4 +227,14 @@ class CsvExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper with P
         InternalRow(17836)) // number of days from 1970-01-01
     }
   }
+
+  test("verify corrupt column") {
+    checkExceptionInExpression[AnalysisException](
+      CsvToStructs(
+        schema = StructType.fromDDL("i int, _unparsed boolean"),
+        options = Map("columnNameOfCorruptRecord" -> "_unparsed"),
+        child = Literal.create("a"),
+        timeZoneId = gmtId),
+      expectedErrMsg = "The field for corrupt records must be string type and 
nullable")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 34bd2a9..9b89a27 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale}
 import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.plans.PlanTestBase
@@ -754,4 +755,14 @@ class JsonExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper with
         InternalRow(17836)) // number of days from 1970-01-01
     }
   }
+
+  test("verify corrupt column") {
+    checkExceptionInExpression[AnalysisException](
+      JsonToStructs(
+        schema = StructType.fromDDL("i int, _unparsed boolean"),
+        options = Map("columnNameOfCorruptRecord" -> "_unparsed"),
+        child = Literal.create("""{"i":"a"}"""),
+        timeZoneId = gmtId),
+      expectedErrMsg = "The field for corrupt records must be string type and 
nullable")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 52df13d..f08fd64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, 
UnivocityParser}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
 import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, 
JSONOptions}
 import org.apache.spark.sql.catalyst.util.FailureSafeParser
 import org.apache.spark.sql.execution.command.DDLUtils
@@ -442,7 +443,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
       TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
     }
 
-    verifyColumnNameOfCorruptRecord(schema, 
parsedOptions.columnNameOfCorruptRecord)
+    ExprUtils.verifyColumnNameOfCorruptRecord(schema, 
parsedOptions.columnNameOfCorruptRecord)
     val actualSchema =
       StructType(schema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
 
@@ -504,7 +505,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
         parsedOptions)
     }
 
-    verifyColumnNameOfCorruptRecord(schema, 
parsedOptions.columnNameOfCorruptRecord)
+    ExprUtils.verifyColumnNameOfCorruptRecord(schema, 
parsedOptions.columnNameOfCorruptRecord)
     val actualSchema =
       StructType(schema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
 
@@ -765,22 +766,6 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
     }
   }
 
-  /**
-   * A convenient function for schema validation in datasources supporting
-   * `columnNameOfCorruptRecord` as an option.
-   */
-  private def verifyColumnNameOfCorruptRecord(
-      schema: StructType,
-      columnNameOfCorruptRecord: String): Unit = {
-    schema.getFieldIndex(columnNameOfCorruptRecord).foreach { 
corruptFieldIndex =>
-      val f = schema(corruptFieldIndex)
-      if (f.dataType != StringType || !f.nullable) {
-        throw new AnalysisException(
-          "The field for corrupt records must be string type and nullable")
-      }
-    }
-  }
-
   
///////////////////////////////////////////////////////////////////////////////////////
   // Builder pattern config options
   
///////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 964b56e..ff1911d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, 
UnivocityGenerator, UnivocityParser}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
 import org.apache.spark.sql.catalyst.util.CompressionCodecs
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources._
@@ -110,13 +111,7 @@ class CSVFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
 
     // Check a field requirement for corrupt records here to throw an 
exception in a driver side
-    dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach 
{ corruptFieldIndex =>
-      val f = dataSchema(corruptFieldIndex)
-      if (f.dataType != StringType || !f.nullable) {
-        throw new AnalysisException(
-          "The field for corrupt records must be string type and nullable")
-      }
-    }
+    ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, 
parsedOptions.columnNameOfCorruptRecord)
 
     if (requiredSchema.length == 1 &&
       requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {

http://git-wip-us.apache.org/repos/asf/spark/blob/8d54bf79/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 1f7c9d7..610f0d1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -26,7 +26,8 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, 
JSONOptions, JSONOptionsInRead}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.json._
 import org.apache.spark.sql.catalyst.util.CompressionCodecs
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources._
@@ -107,13 +108,7 @@ class JsonFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
     val actualSchema =
       StructType(requiredSchema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
     // Check a field requirement for corrupt records here to throw an 
exception in a driver side
-    dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach 
{ corruptFieldIndex =>
-      val f = dataSchema(corruptFieldIndex)
-      if (f.dataType != StringType || !f.nullable) {
-        throw new AnalysisException(
-          "The field for corrupt records must be string type and nullable")
-      }
-    }
+    ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, 
parsedOptions.columnNameOfCorruptRecord)
 
     if (requiredSchema.length == 1 &&
       requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to