Repository: spark Updated Branches: refs/heads/branch-2.0 739a333f6 -> fd828e14b
[SPARK-16409][SQL] regexp_extract with optional groups causes NPE ## What changes were proposed in this pull request? regexp_extract actually returns null when it shouldn't when a regex matches but the requested optional group did not. This makes it return an empty string, as apparently designed. ## How was this patch tested? Additional unit test Author: Sean Owen <so...@cloudera.com> Closes #14504 from srowen/SPARK-16409. (cherry picked from commit 8d8725208771a8815a60160a5a30dc6ea87a7e6a) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd828e14 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd828e14 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd828e14 Branch: refs/heads/branch-2.0 Commit: fd828e14ba968122c22a52c87fc3979dce853157 Parents: 739a333 Author: Sean Owen <so...@cloudera.com> Authored: Sun Aug 7 12:20:07 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Sun Aug 7 12:20:16 2016 +0100 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 3 +++ .../sql/catalyst/expressions/regexpExpressions.scala | 13 +++++++++++-- .../org/apache/spark/sql/StringFunctionsSuite.scala | 8 ++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fd828e14/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e422363..8a01805 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1445,6 +1445,9 @@ def regexp_extract(str, pattern, idx): >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] + >>> df = spark.createDataFrame([('aaaac',)], ['str']) + >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() + [Row(d=u'')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) http://git-wip-us.apache.org/repos/asf/spark/blob/fd828e14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index be82b3b..d25da3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -329,7 +329,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val m = pattern.matcher(s.toString) if (m.find) { val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + val group = mr.group(r.asInstanceOf[Int]) + if (group == null) { // Pattern matched, but not optional group + UTF8String.EMPTY_UTF8 + } else { + UTF8String.fromString(group) + } } else { UTF8String.EMPTY_UTF8 } @@ -367,7 +372,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termPattern}.matcher($subject.toString()); if (${matcher}.find()) { java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); - ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + if (${matchResult}.group($idx) == null) { + ${ev.value} = UTF8String.EMPTY_UTF8; + } else { + ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + } $setEvNotNull } else { ${ev.value} = UTF8String.EMPTY_UTF8; http://git-wip-us.apache.org/repos/asf/spark/blob/fd828e14/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 7bea2f6..73bd7e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -92,6 +92,14 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } + test("non-matching optional group") { + val df = Seq("aaaac").toDF("s") + checkAnswer( + df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)), + Row("") + ) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org