Repository: spark Updated Branches: refs/heads/master 30ffb53ca -> 0a9172a05
[SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization ## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido <marcogaid...@gmail.com> Closes #20976 from mgaido91/SPARK-23835. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a9172a0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a9172a0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a9172a0 Branch: refs/heads/master Commit: 0a9172a05e604a4a94adbb9208c8c02362afca00 Parents: 30ffb53 Author: Marco Gaido <marcogaid...@gmail.com> Authored: Tue Apr 17 21:45:20 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Apr 17 21:45:20 2018 +0800 ---------------------------------------------------------------------- .../spark/sql/kafka010/KafkaContinuousSinkSuite.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../apache/spark/sql/catalyst/ScalaReflection.scala | 14 +++++++------- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 5 files changed, 27 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0a9172a0/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala ---------------------------------------------------------------------- diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index fc890a0..ddfc0c1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { http://git-wip-us.apache.org/repos/asf/spark/blob/0a9172a0/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala ---------------------------------------------------------------------- diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 42f8b4c..7079ac6 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { http://git-wip-us.apache.org/repos/asf/spark/blob/0a9172a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1aae3ae..e4274aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { + val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = deserializerFor( + deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor } } http://git-wip-us.apache.org/repos/asf/spark/blob/0a9172a0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 8c3db48..353b834 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("_2", NullType, nullable = true))), nullable = true)) } + + test("SPARK-23835: add null check to non-nullable types in Tuples") { + def numberOfCheckedArguments(deserializer: Expression): Int = { + assert(deserializer.isInstanceOf[NewInstance]) + deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + } + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0a9172a0/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9b745be..e0f4d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val group2 = cached.groupBy("x").agg(min(col("z")) as "value") checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) } + + test("SPARK-23835: null primitive data type should throw NullPointerException") { + val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() + intercept[NullPointerException](ds.as[(Int, Int)].collect()) + } } case class TestDataUnion(x: Int, y: Int, z: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org