This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new cf21e88  [SPARK-34233][SQL] FIX NPE for char padding in binary 
comparison
cf21e88 is described below

commit cf21e8898ab484a833b6696d0cf4bb0c871e7ff6
Author: Kent Yao <y...@apache.org>
AuthorDate: Wed Jan 27 14:59:53 2021 +0800

    [SPARK-34233][SQL] FIX NPE for char padding in binary comparison
    
    ### What changes were proposed in this pull request?
    
    we need to check whether the `lit` is null  before calling `numChars`
    
    ### Why are the changes needed?
    
    fix an obvious NPE bug
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new tests
    
    Closes #31336 from yaooqinn/SPARK-34233.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 764582c07a263ae0bef4a080a84a66be60d1aab9)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 22 +++++++----
 .../apache/spark/sql/CharVarcharTestSuite.scala    | 43 +++++++++++++++++++++-
 2 files changed, 55 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fb95323a..6fd6901 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3888,13 +3888,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
           if attr.dataType == StringType && list.forall(_.foldable) =>
           CharVarcharUtils.getRawType(attr.metadata).flatMap {
             case CharType(length) =>
-              val literalCharLengths = 
list.map(_.eval().asInstanceOf[UTF8String].numChars())
+              val (nulls, literalChars) =
+                list.map(_.eval().asInstanceOf[UTF8String]).partition(_ == 
null)
+              val literalCharLengths = literalChars.map(_.numChars())
               val targetLen = (length +: literalCharLengths).max
               Some(i.copy(
                 value = addPadding(attr, length, targetLen),
                 list = list.zip(literalCharLengths).map {
                   case (lit, charLength) => addPadding(lit, charLength, 
targetLen)
-                }))
+                } ++ nulls.map(Literal.create(_, StringType))))
             case _ => None
           }.getOrElse(i)
 
@@ -3915,13 +3917,17 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
       CharVarcharUtils.getRawType(attr.metadata).flatMap {
         case CharType(length) =>
           val str = lit.eval().asInstanceOf[UTF8String]
-          val stringLitLen = str.numChars()
-          if (length < stringLitLen) {
-            Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit))
-          } else if (length > stringLitLen) {
-            Some(Seq(attr, StringRPad(lit, Literal(length))))
-          } else {
+          if (str == null) {
             None
+          } else {
+            val stringLitLen = str.numChars()
+            if (length < stringLitLen) {
+              Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit))
+            } else if (length > stringLitLen) {
+              Some(Seq(attr, StringRPad(lit, Literal(length))))
+            } else {
+              None
+            }
           }
         case _ => None
       }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index ff8820a..744757b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -152,6 +152,22 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
     }
   }
 
+  test("SPARK-34233: char/varchar with null value for partitioned columns") {
+    Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+      withTable("t") {
+        sql(s"CREATE TABLE t(i STRING, c $typ) USING $format PARTITIONED BY 
(c)")
+        sql("INSERT INTO t VALUES ('1', null)")
+        checkPlainResult(spark.table("t"), typ, null)
+        sql("INSERT OVERWRITE t VALUES ('1', null)")
+        checkPlainResult(spark.table("t"), typ, null)
+        sql("INSERT OVERWRITE t PARTITION (c=null) VALUES ('1')")
+        checkPlainResult(spark.table("t"), typ, null)
+        sql("ALTER TABLE t DROP PARTITION(c=null)")
+        checkAnswer(spark.table("t"), Nil)
+      }
+    }
+  }
+
   test("char/varchar type values length check: partitioned columns of other 
types") {
     // DSV2 doesn't support DROP PARTITION yet.
     assume(format != "foo")
@@ -435,7 +451,8 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
         ("c1 IN ('a', 'b')", true),
         ("c1 = c2", true),
         ("c1 < c2", false),
-        ("c1 IN (c2)", true)))
+        ("c1 IN (c2)", true),
+        ("c1 <=> null", false)))
     }
   }
 
@@ -451,7 +468,29 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
         ("c1 IN ('a', 'b')", true),
         ("c1 = c2", true),
         ("c1 < c2", false),
-        ("c1 IN (c2)", true)))
+        ("c1 IN (c2)", true),
+        ("c1 <=> null", false)))
+    }
+  }
+
+  private def testNullConditions(df: DataFrame, conditions: Seq[String]): Unit 
= {
+    conditions.foreach { cond =>
+      checkAnswer(df.selectExpr(cond), Row(null))
+    }
+  }
+
+  test("SPARK-34233: char type comparison with null values") {
+    val conditions = Seq("c = null", "c IN ('e', null)", "c IN (null)")
+    withTable("t") {
+      sql(s"CREATE TABLE t(c CHAR(2)) USING $format")
+      sql("INSERT INTO t VALUES ('a')")
+      testNullConditions(spark.table("t"), conditions)
+    }
+
+    withTable("t") {
+      sql(s"CREATE TABLE t(i INT, c CHAR(2)) USING $format PARTITIONED BY (c)")
+      sql("INSERT INTO t VALUES (1, 'a')")
+      testNullConditions(spark.table("t"), conditions)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to