This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ed1b4d  [SPARK-26637][SQL] Makes GetArrayItem nullability more precise
1ed1b4d is described below

commit 1ed1b4d8e1a5b9ca0ec8b15f36542d7a63eebf94
Author: Takeshi Yamamuro <yamam...@apache.org>
AuthorDate: Wed Jan 23 15:33:02 2019 +0800

    [SPARK-26637][SQL] Makes GetArrayItem nullability more precise
    
    ## What changes were proposed in this pull request?
    In the master, GetArrayItem nullable is always true;
    
https://github.com/apache/spark/blob/cf133e611020ed178f90358464a1b88cdd9b7889/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala#L236
    
    But, If input array size is constant and ordinal is foldable, we could make 
GetArrayItem nullability more precise. This pr added code to make 
`GetArrayItem` nullability more precise.
    
    ## How was this patch tested?
    Added tests in `ComplexTypeSuite`.
    
    Closes #23566 from maropu/GetArrayItemNullability.
    
    Authored-by: Takeshi Yamamuro <yamam...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/complexTypeExtractors.scala        | 15 +++++++++-
 .../catalyst/expressions/ComplexTypeSuite.scala    | 33 ++++++++++++++++++++++
 2 files changed, 47 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 8994eef..104ad98 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -233,7 +233,20 @@ case class GetArrayItem(child: Expression, ordinal: 
Expression)
   override def right: Expression = ordinal
 
   /** `Null` is returned for invalid ordinals. */
-  override def nullable: Boolean = true
+  override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
+    val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
+    child match {
+      case CreateArray(ar) if intOrdinal < ar.length =>
+        ar(intOrdinal).nullable
+      case GetArrayStructFields(CreateArray(elements), field, _, _, _)
+          if intOrdinal < elements.length =>
+        elements(intOrdinal).nullable || field.nullable
+      case _ =>
+        true
+    }
+  } else {
+    true
+  }
 
   override def dataType: DataType = 
child.dataType.asInstanceOf[ArrayType].elementType
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index dc60464..d8d6571 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -59,6 +59,39 @@ class ComplexTypeSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
   }
 
+  test("SPARK-26637 handles GetArrayItem nullability correctly when input 
array size is constant") {
+    // CreateArray case
+    val a = AttributeReference("a", IntegerType, nullable = false)()
+    val b = AttributeReference("b", IntegerType, nullable = true)()
+    val array = CreateArray(a :: b :: Nil)
+    assert(!GetArrayItem(array, Literal(0)).nullable)
+    assert(GetArrayItem(array, Literal(1)).nullable)
+    assert(!GetArrayItem(array, Subtract(Literal(2), Literal(2))).nullable)
+    assert(GetArrayItem(array, AttributeReference("ordinal", 
IntegerType)()).nullable)
+
+    // GetArrayStructFields case
+    val f1 = StructField("a", IntegerType, nullable = false)
+    val f2 = StructField("b", IntegerType, nullable = true)
+    val structType = StructType(f1 :: f2 :: Nil)
+    val c = AttributeReference("c", structType, nullable = false)()
+    val inputArray1 = CreateArray(c :: Nil)
+    val inputArray1ContainsNull = c.nullable
+    val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, 
inputArray1ContainsNull)
+    assert(!GetArrayItem(stArray1, Literal(0)).nullable)
+    val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, 
inputArray1ContainsNull)
+    assert(GetArrayItem(stArray2, Literal(0)).nullable)
+
+    val d = AttributeReference("d", structType, nullable = true)()
+    val inputArray2 = CreateArray(c :: d :: Nil)
+    val inputArray2ContainsNull = c.nullable || d.nullable
+    val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, 
inputArray2ContainsNull)
+    assert(!GetArrayItem(stArray3, Literal(0)).nullable)
+    assert(GetArrayItem(stArray3, Literal(1)).nullable)
+    val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, 
inputArray2ContainsNull)
+    assert(GetArrayItem(stArray4, Literal(0)).nullable)
+    assert(GetArrayItem(stArray4, Literal(1)).nullable)
+  }
+
   test("GetMapValue") {
     val typeM = MapType(StringType, StringType)
     val map = Literal.create(Map("a" -> "b"), typeM)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to