This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 072968d [SPARK-38063][SQL] Support split_part Function 072968d is described below commit 072968d730863e89635c903999a397fc0233ea87 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Thu Mar 24 14:28:32 2022 +0800 [SPARK-38063][SQL] Support split_part Function ### What changes were proposed in this pull request? `split_part()` is a commonly supported function by other systems such as Postgres and some other systems. The Spark equivalent is `element_at(split(arg, delim), part)` ### Why are the changes needed? Adding new SQL function. ### Does this PR introduce _any_ user-facing change? Yes. This PR adds a new function so there is no previous behavior. The following demonstrates more about the new function: syntax: `split_part(str, delimiter, partNum)` This function splits `str` by `delimiter` and return requested part of the split (1-based). If any input is null, returns null. If the index is out of range of split parts, returns empty string. If index is 0, throws an ArrayIndexOutOfBoundsException. `str` and `delimiter` are the same type as `string`. `partNum` is `integer` type Examples: ``` > SELECT _FUNC_('11.12.13', '.', 3); 13 > SELECT _FUNC_(NULL, '.', 3); NULL ``` ### How was this patch tested? Unit Test Closes #35352 from amaliujia/splitpart. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 3858bf0fbd02e3d8fd18e967f3841c50b9294414) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/unsafe/types/UTF8String.java | 21 +++++- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 22 ++++-- .../catalyst/expressions/stringExpressions.scala | 75 ++++++++++++++++++- .../sql-functions/sql-expression-schema.md | 3 +- .../sql-tests/inputs/string-functions.sql | 12 ++++ .../results/ansi/string-functions.sql.out | 83 +++++++++++++++++++++- .../sql-tests/results/string-functions.sql.out | 83 +++++++++++++++++++++- 8 files changed, 291 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 98c61cf..0f9d653 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -23,6 +23,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Map; +import java.util.regex.Pattern; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoSerializable; @@ -999,13 +1000,31 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, } public UTF8String[] split(UTF8String pattern, int limit) { + return split(pattern.toString(), limit); + } + + public UTF8String[] splitSQL(UTF8String delimiter, int limit) { + // if delimiter is empty string, skip the regex based splitting directly as regex + // treats empty string as matching anything, thus use the input directly. + if (delimiter.numBytes() == 0) { + return new UTF8String[]{this}; + } else { + // we do not treat delimiter as a regex but consider the whole string of delimiter + // as the separator to split string. Java String's split, however, only accept + // regex as the pattern to split, thus we can quote the delimiter to escape special + // characters in the string. + return split(Pattern.quote(delimiter.toString()), limit); + } + } + + private UTF8String[] split(String delimiter, int limit) { // Java String's split method supports "ignore empty string" behavior when the limit is 0 // whereas other languages do not. To avoid this java specific behavior, we fall back to // -1 when the limit is 0. if (limit == 0) { limit = -1; } - String[] splits = toString().split(pattern.toString(), limit); + String[] splits = toString().split(delimiter, limit); UTF8String[] res = new UTF8String[splits.length]; for (int i = 0; i < res.length; i++) { res[i] = fromString(splits[i]); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a37d4b2..a06112a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -549,6 +549,7 @@ object FunctionRegistry { expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), + expression[SplitPart]("split_part"), expression[Substring]("substr", true), expression[Substring]("substring"), expression[Left]("left"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 363c531..ca00839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2095,10 +2095,12 @@ case class ArrayPosition(left: Expression, right: Expression) case class ElementAt( left: Expression, right: Expression, + // The value to return if index is out of bound + defaultValueOutOfBound: Option[Literal] = None, failOnError: Boolean = SQLConf.get.ansiEnabled) extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = this(left, right, None, SQLConf.get.ansiEnabled) @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @@ -2179,7 +2181,10 @@ case class ElementAt( if (failOnError) { throw QueryExecutionErrors.invalidElementAtIndexError(index, array.numElements()) } else { - null + defaultValueOutOfBound match { + case Some(value) => value.eval() + case None => null + } } } else { val idx = if (index == 0) { @@ -2218,7 +2223,16 @@ case class ElementAt( val indexOutOfBoundBranch = if (failOnError) { s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements());" } else { - s"${ev.isNull} = true;" + defaultValueOutOfBound match { + case Some(value) => + val defaultValueEval = value.genCode(ctx) + s""" + ${defaultValueEval.code} + ${ev.isNull} = ${defaultValueEval.isNull} + ${ev.value} = ${defaultValueEval.value} + """.stripMargin + case None => s"${ev.isNull} = true;" + } } s""" @@ -2278,7 +2292,7 @@ case class ElementAt( case class TryElementAt(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = - this(left, right, ElementAt(left, right, failOnError = false)) + this(left, right, ElementAt(left, right, None, failOnError = false)) override def prettyName: String = "try_element_at" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index fc73216..a08ab84 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{StringType, _} import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -2943,3 +2943,76 @@ case class Sentences( copy(str = newFirst, language = newSecond, country = newThird) } + +/** + * Splits a given string by a specified delimiter and return splits into a + * GenericArrayData. This expression is different from `split` function as + * `split` takes regex expression as the pattern to split strings while this + * expression take delimiter (a string without carrying special meaning on its + * characters, thus is not treated as regex) to split strings. + */ +case class StringSplitSQL( + str: Expression, + delimiter: Expression) extends BinaryExpression with NullIntolerant { + override def dataType: DataType = ArrayType(StringType, containsNull = false) + override def left: Expression = str + override def right: Expression = delimiter + + override def nullSafeEval(string: Any, delimiter: Any): Any = { + val strings = string.asInstanceOf[UTF8String].splitSQL( + delimiter.asInstanceOf[UTF8String], -1); + new GenericArrayData(strings.asInstanceOf[Array[Any]]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayClass = classOf[GenericArrayData].getName + nullSafeCodeGen(ctx, ev, (str, delimiter) => { + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"${ev.value} = new $arrayClass($str.splitSQL($delimiter,-1));" + }) + } + + override def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression): StringSplitSQL = + copy(str = newFirst, delimiter = newSecond) +} + +/** + * Splits a given string by a specified delimiter and returns the requested part. + * If any input is null, returns null. + * If index is out of range of split parts, return empty string. + * If index is 0, throws an ArrayIndexOutOfBoundsException. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(str, delimiter, partNum) - Splits `str` by delimiter and return + requested part of the split (1-based). If any input is null, returns null. + if `partNum` is out of range of split parts, returns empty string. If `partNum` is 0, + throws an error. If `partNum` is negative, the parts are counted backward from the + end of the string. If the `delimiter` is an empty string, the `str` is not split. + """, + examples = + """ + Examples: + > SELECT _FUNC_('11.12.13', '.', 3); + 13 + """, + since = "3.3.0", + group = "string_funcs") +case class SplitPart ( + str: Expression, + delimiter: Expression, + partNum: Expression) + extends RuntimeReplaceable with ImplicitCastInputTypes { + override lazy val replacement: Expression = + ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", StringType)), + false) + override def nodeName: String = "split_part" + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + def children: Seq[Expression] = Seq(str, delimiter, partNum) + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { + copy(str = newChildren.apply(0), delimiter = newChildren.apply(1), + partNum = newChildren.apply(2)) + } +} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 1afba46..166c761 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ <!-- Automatically generated by ExpressionsSchemaSuite --> ## Summary - - Number of queries: 383 + - Number of queries: 384 - Number of expressions that missing example: 12 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint ## Schema of Built-in Functions @@ -275,6 +275,7 @@ | org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct<soundex(Miller):string> | | org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct<SPARK_PARTITION_ID():int> | | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct<version():string> | +| org.apache.spark.sql.catalyst.expressions.SplitPart | split_part | SELECT split_part('11.12.13', '.', 3) | struct<split_part(11.12.13, ., 3):string> | | org.apache.spark.sql.catalyst.expressions.Sqrt | sqrt | SELECT sqrt(4) | struct<SQRT(4):double> | | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct<col0:int,col1:int> | | org.apache.spark.sql.catalyst.expressions.StartsWithExpressionBuilder$ | startswith | SELECT startswith('Spark SQL', 'Spark') | struct<startswith(Spark SQL, Spark):boolean> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index e7c01a6..7d22e79 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -27,6 +27,18 @@ select right("abcd", -2), right("abcd", 0), right("abcd", 'a'); SELECT split('aa1cc2ee3', '[1-9]+'); SELECT split('aa1cc2ee3', '[1-9]+', 2); +-- split_part function +SELECT split_part('11.12.13', '.', 2); +SELECT split_part('11.12.13', '.', -1); +SELECT split_part('11.12.13', '.', -3); +SELECT split_part('11.12.13', '', 1); +SELECT split_part('11ab12ab13', 'ab', 1); +SELECT split_part('11.12.13', '.', 0); +SELECT split_part('11.12.13', '.', 4); +SELECT split_part('11.12.13', '.', 5); +SELECT split_part('11.12.13', '.', -5); +SELECT split_part(null, '.', 1); + -- substring function SELECT substr('Spark SQL', 5); SELECT substr('Spark SQL', -3); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index b182b5c..01213bd 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 131 +-- Number of queries: 141 -- !query @@ -127,6 +127,87 @@ struct<split(aa1cc2ee3, [1-9]+, 2):array<string>> -- !query +SELECT split_part('11.12.13', '.', 2) +-- !query schema +struct<split_part(11.12.13, ., 2):string> +-- !query output +12 + + +-- !query +SELECT split_part('11.12.13', '.', -1) +-- !query schema +struct<split_part(11.12.13, ., -1):string> +-- !query output +13 + + +-- !query +SELECT split_part('11.12.13', '.', -3) +-- !query schema +struct<split_part(11.12.13, ., -3):string> +-- !query output +11 + + +-- !query +SELECT split_part('11.12.13', '', 1) +-- !query schema +struct<split_part(11.12.13, , 1):string> +-- !query output +11.12.13 + + +-- !query +SELECT split_part('11ab12ab13', 'ab', 1) +-- !query schema +struct<split_part(11ab12ab13, ab, 1):string> +-- !query output +11 + + +-- !query +SELECT split_part('11.12.13', '.', 0) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +SQL array indices start at 1 + + +-- !query +SELECT split_part('11.12.13', '.', 4) +-- !query schema +struct<split_part(11.12.13, ., 4):string> +-- !query output + + + +-- !query +SELECT split_part('11.12.13', '.', 5) +-- !query schema +struct<split_part(11.12.13, ., 5):string> +-- !query output + + + +-- !query +SELECT split_part('11.12.13', '.', -5) +-- !query schema +struct<split_part(11.12.13, ., -5):string> +-- !query output + + + +-- !query +SELECT split_part(null, '.', 1) +-- !query schema +struct<split_part(NULL, ., 1):string> +-- !query output +NULL + + +-- !query SELECT substr('Spark SQL', 5) -- !query schema struct<substr(Spark SQL, 5, 2147483647):string> diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 4307df7..3a7f197 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 131 +-- Number of queries: 141 -- !query @@ -125,6 +125,87 @@ struct<split(aa1cc2ee3, [1-9]+, 2):array<string>> -- !query +SELECT split_part('11.12.13', '.', 2) +-- !query schema +struct<split_part(11.12.13, ., 2):string> +-- !query output +12 + + +-- !query +SELECT split_part('11.12.13', '.', -1) +-- !query schema +struct<split_part(11.12.13, ., -1):string> +-- !query output +13 + + +-- !query +SELECT split_part('11.12.13', '.', -3) +-- !query schema +struct<split_part(11.12.13, ., -3):string> +-- !query output +11 + + +-- !query +SELECT split_part('11.12.13', '', 1) +-- !query schema +struct<split_part(11.12.13, , 1):string> +-- !query output +11.12.13 + + +-- !query +SELECT split_part('11ab12ab13', 'ab', 1) +-- !query schema +struct<split_part(11ab12ab13, ab, 1):string> +-- !query output +11 + + +-- !query +SELECT split_part('11.12.13', '.', 0) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +SQL array indices start at 1 + + +-- !query +SELECT split_part('11.12.13', '.', 4) +-- !query schema +struct<split_part(11.12.13, ., 4):string> +-- !query output + + + +-- !query +SELECT split_part('11.12.13', '.', 5) +-- !query schema +struct<split_part(11.12.13, ., 5):string> +-- !query output + + + +-- !query +SELECT split_part('11.12.13', '.', -5) +-- !query schema +struct<split_part(11.12.13, ., -5):string> +-- !query output + + + +-- !query +SELECT split_part(null, '.', 1) +-- !query schema +struct<split_part(NULL, ., 1):string> +-- !query output +NULL + + +-- !query SELECT substr('Spark SQL', 5) -- !query schema struct<substr(Spark SQL, 5, 2147483647):string> --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org