This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1ed1b4d [SPARK-26637][SQL] Makes GetArrayItem nullability more precise 1ed1b4d is described below commit 1ed1b4d8e1a5b9ca0ec8b15f36542d7a63eebf94 Author: Takeshi Yamamuro <yamam...@apache.org> AuthorDate: Wed Jan 23 15:33:02 2019 +0800 [SPARK-26637][SQL] Makes GetArrayItem nullability more precise ## What changes were proposed in this pull request? In the master, GetArrayItem nullable is always true; https://github.com/apache/spark/blob/cf133e611020ed178f90358464a1b88cdd9b7889/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala#L236 But, If input array size is constant and ordinal is foldable, we could make GetArrayItem nullability more precise. This pr added code to make `GetArrayItem` nullability more precise. ## How was this patch tested? Added tests in `ComplexTypeSuite`. Closes #23566 from maropu/GetArrayItemNullability. Authored-by: Takeshi Yamamuro <yamam...@apache.org> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/complexTypeExtractors.scala | 15 +++++++++- .../catalyst/expressions/ComplexTypeSuite.scala | 33 ++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 8994eef..104ad98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -233,7 +233,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true + override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) { + val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() + child match { + case CreateArray(ar) if intOrdinal < ar.length => + ar(intOrdinal).nullable + case GetArrayStructFields(CreateArray(elements), field, _, _, _) + if intOrdinal < elements.length => + elements(intOrdinal).nullable || field.nullable + case _ => + true + } + } else { + true + } override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index dc60464..d8d6571 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -59,6 +59,39 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } + test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") { + // CreateArray case + val a = AttributeReference("a", IntegerType, nullable = false)() + val b = AttributeReference("b", IntegerType, nullable = true)() + val array = CreateArray(a :: b :: Nil) + assert(!GetArrayItem(array, Literal(0)).nullable) + assert(GetArrayItem(array, Literal(1)).nullable) + assert(!GetArrayItem(array, Subtract(Literal(2), Literal(2))).nullable) + assert(GetArrayItem(array, AttributeReference("ordinal", IntegerType)()).nullable) + + // GetArrayStructFields case + val f1 = StructField("a", IntegerType, nullable = false) + val f2 = StructField("b", IntegerType, nullable = true) + val structType = StructType(f1 :: f2 :: Nil) + val c = AttributeReference("c", structType, nullable = false)() + val inputArray1 = CreateArray(c :: Nil) + val inputArray1ContainsNull = c.nullable + val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) + assert(!GetArrayItem(stArray1, Literal(0)).nullable) + val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) + assert(GetArrayItem(stArray2, Literal(0)).nullable) + + val d = AttributeReference("d", structType, nullable = true)() + val inputArray2 = CreateArray(c :: d :: Nil) + val inputArray2ContainsNull = c.nullable || d.nullable + val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) + assert(!GetArrayItem(stArray3, Literal(0)).nullable) + assert(GetArrayItem(stArray3, Literal(1)).nullable) + val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) + assert(GetArrayItem(stArray4, Literal(0)).nullable) + assert(GetArrayItem(stArray4, Literal(1)).nullable) + } + test("GetMapValue") { val typeM = MapType(StringType, StringType) val map = Literal.create(Map("a" -> "b"), typeM) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org