Repository: spark
Updated Branches:
  refs/heads/branch-2.0 739a333f6 -> fd828e14b


[SPARK-16409][SQL] regexp_extract with optional groups causes NPE

## What changes were proposed in this pull request?

regexp_extract actually returns null when it shouldn't when a regex matches but 
the requested optional group did not. This makes it return an empty string, as 
apparently designed.

## How was this patch tested?

Additional unit test

Author: Sean Owen <so...@cloudera.com>

Closes #14504 from srowen/SPARK-16409.

(cherry picked from commit 8d8725208771a8815a60160a5a30dc6ea87a7e6a)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd828e14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd828e14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd828e14

Branch: refs/heads/branch-2.0
Commit: fd828e14ba968122c22a52c87fc3979dce853157
Parents: 739a333
Author: Sean Owen <so...@cloudera.com>
Authored: Sun Aug 7 12:20:07 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Aug 7 12:20:16 2016 +0100

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                        |  3 +++
 .../sql/catalyst/expressions/regexpExpressions.scala   | 13 +++++++++++--
 .../org/apache/spark/sql/StringFunctionsSuite.scala    |  8 ++++++++
 3 files changed, 22 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd828e14/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e422363..8a01805 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1445,6 +1445,9 @@ def regexp_extract(str, pattern, idx):
     >>> df = spark.createDataFrame([('100-200',)], ['str'])
     >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
     [Row(d=u'100')]
+    >>> df = spark.createDataFrame([('aaaac',)], ['str'])
+    >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()
+    [Row(d=u'')]
     """
     sc = SparkContext._active_spark_context
     jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)

http://git-wip-us.apache.org/repos/asf/spark/blob/fd828e14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index be82b3b..d25da3f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -329,7 +329,12 @@ case class RegExpExtract(subject: Expression, regexp: 
Expression, idx: Expressio
     val m = pattern.matcher(s.toString)
     if (m.find) {
       val mr: MatchResult = m.toMatchResult
-      UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
+      val group = mr.group(r.asInstanceOf[Int])
+      if (group == null) { // Pattern matched, but not optional group
+        UTF8String.EMPTY_UTF8
+      } else {
+        UTF8String.fromString(group)
+      }
     } else {
       UTF8String.EMPTY_UTF8
     }
@@ -367,7 +372,11 @@ case class RegExpExtract(subject: Expression, regexp: 
Expression, idx: Expressio
         ${termPattern}.matcher($subject.toString());
       if (${matcher}.find()) {
         java.util.regex.MatchResult ${matchResult} = 
${matcher}.toMatchResult();
-        ${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
+        if (${matchResult}.group($idx) == null) {
+          ${ev.value} = UTF8String.EMPTY_UTF8;
+        } else {
+          ${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
+        }
         $setEvNotNull
       } else {
         ${ev.value} = UTF8String.EMPTY_UTF8;

http://git-wip-us.apache.org/repos/asf/spark/blob/fd828e14/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 7bea2f6..73bd7e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -92,6 +92,14 @@ class StringFunctionsSuite extends QueryTest with 
SharedSQLContext {
       Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
   }
 
+  test("non-matching optional group") {
+    val df = Seq("aaaac").toDF("s")
+    checkAnswer(
+      df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)),
+      Row("")
+    )
+  }
+
   test("string ascii function") {
     val df = Seq(("abc", "")).toDF("a", "b")
     checkAnswer(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to