This is an automated email from the ASF dual-hosted git repository. sarutak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0666f5c [SPARK-36751][SQL][PYTHON][R] Add bit/octet_length APIs to Scala, Python and R 0666f5c is described below commit 0666f5c00393acccecdd82d3794e5a2b88f3210b Author: Leona Yoda <yo...@oss.nttdata.com> AuthorDate: Wed Sep 15 16:27:13 2021 +0900 [SPARK-36751][SQL][PYTHON][R] Add bit/octet_length APIs to Scala, Python and R ### What changes were proposed in this pull request? octet_length: caliculate the byte length of strings bit_length: caliculate the bit length of strings Those two string related functions are only implemented on SparkSQL, not on Scala, Python and R. ### Why are the changes needed? Those functions would be useful for multi-bytes character users, who mainly working with Scala, Python or R. ### Does this PR introduce _any_ user-facing change? Yes. Users can call octet_length/bit_length APIs on Scala(Dataframe), Python, and R. ### How was this patch tested? unit tests Closes #33992 from yoda-mon/add-bit-octet-length. Authored-by: Leona Yoda <yo...@oss.nttdata.com> Signed-off-by: Kousuke Saruta <saru...@oss.nttdata.com> --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 26 +++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 11 +++++ python/docs/source/reference/pyspark.sql.rst | 2 + python/pyspark/sql/functions.py | 52 ++++++++++++++++++++++ python/pyspark/sql/functions.pyi | 2 + python/pyspark/sql/tests/test_functions.py | 14 +++++- .../scala/org/apache/spark/sql/functions.scala | 16 +++++++ .../apache/spark/sql/StringFunctionsSuite.scala | 52 ++++++++++++++++++++++ 10 files changed, 184 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7fa8085..686a49e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -243,6 +243,7 @@ exportMethods("%<=>%", "base64", "between", "bin", + "bit_length", "bitwise_not", "bitwiseNOT", "bround", @@ -364,6 +365,7 @@ exportMethods("%<=>%", "not", "nth_value", "ntile", + "octet_length", "otherwise", "over", "overlay", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 62066da1..f0768c7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -647,6 +647,19 @@ setMethod("bin", }) #' @details +#' \code{bit_length}: Calculates the bit length for the specified string column. +#' +#' @rdname column_string_functions +#' @aliases bit_length bit_length,Column-method +#' @note length since 3.3.0 +setMethod("bit_length", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bit_length", x@jc) + column(jc) + }) + +#' @details #' \code{bitwise_not}: Computes bitwise NOT. #' #' @rdname column_nonaggregate_functions @@ -1570,6 +1583,19 @@ setMethod("negate", }) #' @details +#' \code{octet_length}: Calculates the byte length for the specified string column. +#' +#' @rdname column_string_functions +#' @aliases octet_length octet_length,Column-method +#' @note length since 3.3.0 +setMethod("octet_length", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "octet_length", x@jc) + column(jc) + }) + +#' @details #' \code{overlay}: Overlay the specified portion of \code{x} with \code{replace}, #' starting from byte position \code{pos} of \code{src} and proceeding for #' \code{len} bytes. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9ebea3f..1abde65 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -884,6 +884,10 @@ setGeneric("base64", function(x) { standardGeneric("base64") }) #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) +#' @rdname column_string_functions +#' @name NULL +setGeneric("bit_length", function(x, ...) { standardGeneric("bit_length") }) + #' @rdname column_nonaggregate_functions #' @name NULL setGeneric("bitwise_not", function(x) { standardGeneric("bitwise_not") }) @@ -1232,6 +1236,10 @@ setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @rdname column_string_functions #' @name NULL +setGeneric("octet_length", function(x, ...) { standardGeneric("octet_length") }) + +#' @rdname column_string_functions +#' @name NULL setGeneric("overlay", function(x, replace, pos, ...) { standardGeneric("overlay") }) #' @rdname column_window_functions diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index b97c500..f0cb274 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1988,6 +1988,17 @@ test_that("string operators", { collect(select(df5, repeat_string(df5$a, -1)))[1, 1], "" ) + + l6 <- list(list("cat"), list("\ud83d\udc08")) + df6 <- createDataFrame(l6) + expect_equal( + collect(select(df6, octet_length(df6$"_1")))[, 1], + c(3, 4) + ) + expect_equal( + collect(select(df6, bit_length(df6$"_1")))[, 1], + c(24, 32) + ) }) test_that("date functions on a DataFrame", { diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst index 7653ce4..326b83b 100644 --- a/python/docs/source/reference/pyspark.sql.rst +++ b/python/docs/source/reference/pyspark.sql.rst @@ -367,6 +367,7 @@ Functions avg base64 bin + bit_length bitwise_not bitwiseNOT broadcast @@ -484,6 +485,7 @@ Functions next_day nth_value ntile + octet_length overlay pandas_udf percent_rank diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e418c0d..105727e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -3098,6 +3098,58 @@ def length(col): return Column(sc._jvm.functions.length(_to_java_column(col))) +def octet_length(col): + """ + Calculates the byte length for the specified string column. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + Source column or strings + + Returns + ------- + :class:`~pyspark.sql.Column` + Byte length of the col + + Examples + ------- + >>> from pyspark.sql.functions import octet_length + >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \ + .select(octet_length('cat')).collect() + [Row(octet_length(cat)=3), Row(octet_length(cat)=4)] + """ + return _invoke_function_over_column("octet_length", col) + + +def bit_length(col): + """ + Calculates the bit length for the specified string column. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + Source column or strings + + Returns + ------- + :class:`~pyspark.sql.Column` + Bit length of the col + + Examples + ------- + >>> from pyspark.sql.functions import bit_length + >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \ + .select(bit_length('cat')).collect() + [Row(bit_length(cat)=24), Row(bit_length(cat)=32)] + """ + return _invoke_function_over_column("bit_length", col) + + def translate(srcCol, matching, replace): """A function translate any character in the `srcCol` by a character in `matching`. The characters in `replace` is corresponding to the characters in `matching`. diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 143fa13..1a0a61e 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -174,6 +174,8 @@ def bin(col: ColumnOrName) -> Column: ... def hex(col: ColumnOrName) -> Column: ... def unhex(col: ColumnOrName) -> Column: ... def length(col: ColumnOrName) -> Column: ... +def octet_length(col: ColumnOrName) -> Column: ... +def bit_length(col: ColumnOrName) -> Column: ... def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: ... def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... def create_map(*cols: ColumnOrName) -> Column: ... diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 082d61b..00a2660 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -23,7 +23,7 @@ from py4j.protocol import Py4JJavaError from pyspark.sql import Row, Window from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \ lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \ - shiftright, shiftrightunsigned, shiftRightUnsigned + shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, bit_length from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -197,6 +197,18 @@ class FunctionsTests(ReusedSQLTestCase): df.select(getattr(functions, name)("name")).first()[0], df.select(getattr(functions, name)(col("name"))).first()[0]) + def test_octet_length_function(self): + # SPARK-36751: add octet length api for python + df = self.spark.createDataFrame([('cat',), ('\U0001F408',)], ['cat']) + actual = df.select(octet_length('cat')).collect() + self.assertEqual([Row(3), Row(4)], actual) + + def test_bit_length_function(self): + # SPARK-36751: add bit length api for python + df = self.spark.createDataFrame([('cat',), ('\U0001F408',)], ['cat']) + actual = df.select(bit_length('cat')).collect() + self.assertEqual([Row(24), Row(32)], actual) + def test_array_contains_function(self): from pyspark.sql.functions import array_contains diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 781a2dd..2d12d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2542,6 +2542,14 @@ object functions { def base64(e: Column): Column = withExpr { Base64(e.expr) } /** + * Calculates the bit length for the specified string column. + * + * @group string_funcs + * @since 3.3.0 + */ + def bit_length(e: Column): Column = withExpr { BitLength(e.expr) } + + /** * Concatenates multiple input string columns together into a single string column, * using the given separator. * @@ -2707,6 +2715,14 @@ object functions { } /** + * Calculates the byte length for the specified string column. + * + * @group string_funcs + * @since 3.3.0 + */ + def octet_length(e: Column): Column = withExpr { OctetLength(e.expr) } + + /** * Extract a specific group matched by a Java regex, from the specified string column. * If the regex did not match, or the specified group did not match, an empty string is returned. * if the specified group index exceeds the group count of regex, an IllegalArgumentException diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 00074b0..30a6600 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -486,6 +486,58 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-36751: add octet length api for scala") { + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08")) + .toDF("a", "b", "c", "d", "e", "f") + // string and binary input + checkAnswer( + df.select(octet_length($"a"), octet_length($"b")), + Row(3, 4)) + // string and binary input + checkAnswer( + df.selectExpr("octet_length(a)", "octet_length(b)"), + Row(3, 4)) + // integer, float and double input + checkAnswer( + df.selectExpr("octet_length(c)", "octet_length(d)", "octet_length(e)"), + Row(3, 3, 5) + ) + // multi-byte character input + checkAnswer( + df.selectExpr("octet_length(f)"), + Row(4) + ) + // scalastyle:on + } + + test("SPARK-36751: add bit length api for scala") { + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08")) + .toDF("a", "b", "c", "d", "e", "f") + // string and binary input + checkAnswer( + df.select(bit_length($"a"), bit_length($"b")), + Row(24, 32)) + // string and binary input + checkAnswer( + df.selectExpr("bit_length(a)", "bit_length(b)"), + Row(24, 32)) + // integer, float and double input + checkAnswer( + df.selectExpr("bit_length(c)", "bit_length(d)", "bit_length(e)"), + Row(24, 24, 40) + ) + // multi-byte character input + checkAnswer( + df.selectExpr("bit_length(f)"), + Row(32) + ) + // scalastyle:on + } + test("initcap function") { val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z") checkAnswer( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org