This is an automated email from the ASF dual-hosted git repository.

sarutak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0666f5c  [SPARK-36751][SQL][PYTHON][R] Add bit/octet_length APIs to 
Scala, Python and R
0666f5c is described below

commit 0666f5c00393acccecdd82d3794e5a2b88f3210b
Author: Leona Yoda <yo...@oss.nttdata.com>
AuthorDate: Wed Sep 15 16:27:13 2021 +0900

    [SPARK-36751][SQL][PYTHON][R] Add bit/octet_length APIs to Scala, Python 
and R
    
    ### What changes were proposed in this pull request?
    
    octet_length: caliculate the byte length of strings
    bit_length: caliculate the bit length of strings
    Those two string related functions are only implemented on SparkSQL, not on 
Scala, Python and R.
    
    ### Why are the changes needed?
    
    Those functions would be useful for multi-bytes character users, who mainly 
working with Scala, Python or R.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users can call octet_length/bit_length APIs on Scala(Dataframe), 
Python, and R.
    
    ### How was this patch tested?
    
    unit tests
    
    Closes #33992 from yoda-mon/add-bit-octet-length.
    
    Authored-by: Leona Yoda <yo...@oss.nttdata.com>
    Signed-off-by: Kousuke Saruta <saru...@oss.nttdata.com>
---
 R/pkg/NAMESPACE                                    |  2 +
 R/pkg/R/functions.R                                | 26 +++++++++++
 R/pkg/R/generics.R                                 |  8 ++++
 R/pkg/tests/fulltests/test_sparkSQL.R              | 11 +++++
 python/docs/source/reference/pyspark.sql.rst       |  2 +
 python/pyspark/sql/functions.py                    | 52 ++++++++++++++++++++++
 python/pyspark/sql/functions.pyi                   |  2 +
 python/pyspark/sql/tests/test_functions.py         | 14 +++++-
 .../scala/org/apache/spark/sql/functions.scala     | 16 +++++++
 .../apache/spark/sql/StringFunctionsSuite.scala    | 52 ++++++++++++++++++++++
 10 files changed, 184 insertions(+), 1 deletion(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7fa8085..686a49e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -243,6 +243,7 @@ exportMethods("%<=>%",
               "base64",
               "between",
               "bin",
+              "bit_length",
               "bitwise_not",
               "bitwiseNOT",
               "bround",
@@ -364,6 +365,7 @@ exportMethods("%<=>%",
               "not",
               "nth_value",
               "ntile",
+              "octet_length",
               "otherwise",
               "over",
               "overlay",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 62066da1..f0768c7 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -647,6 +647,19 @@ setMethod("bin",
           })
 
 #' @details
+#' \code{bit_length}: Calculates the bit length for the specified string 
column.
+#'
+#' @rdname column_string_functions
+#' @aliases bit_length bit_length,Column-method
+#' @note length since 3.3.0
+setMethod("bit_length",
+          signature(x = "Column"),
+          function(x) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "bit_length", 
x@jc)
+            column(jc)
+          })
+
+#' @details
 #' \code{bitwise_not}: Computes bitwise NOT.
 #'
 #' @rdname column_nonaggregate_functions
@@ -1570,6 +1583,19 @@ setMethod("negate",
           })
 
 #' @details
+#' \code{octet_length}: Calculates the byte length for the specified string 
column.
+#'
+#' @rdname column_string_functions
+#' @aliases octet_length octet_length,Column-method
+#' @note length since 3.3.0
+setMethod("octet_length",
+          signature(x = "Column"),
+          function(x) {
+            jc <- callJStatic("org.apache.spark.sql.functions", 
"octet_length", x@jc)
+            column(jc)
+          })
+
+#' @details
 #' \code{overlay}: Overlay the specified portion of \code{x} with 
\code{replace},
 #' starting from byte position \code{pos} of \code{src} and proceeding for
 #' \code{len} bytes.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 9ebea3f..1abde65 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -884,6 +884,10 @@ setGeneric("base64", function(x) { 
standardGeneric("base64") })
 #' @name NULL
 setGeneric("bin", function(x) { standardGeneric("bin") })
 
+#' @rdname column_string_functions
+#' @name NULL
+setGeneric("bit_length", function(x, ...) { standardGeneric("bit_length") })
+
 #' @rdname column_nonaggregate_functions
 #' @name NULL
 setGeneric("bitwise_not", function(x) { standardGeneric("bitwise_not") })
@@ -1232,6 +1236,10 @@ setGeneric("n_distinct", function(x, ...) { 
standardGeneric("n_distinct") })
 
 #' @rdname column_string_functions
 #' @name NULL
+setGeneric("octet_length", function(x, ...) { standardGeneric("octet_length") 
})
+
+#' @rdname column_string_functions
+#' @name NULL
 setGeneric("overlay", function(x, replace, pos, ...) { 
standardGeneric("overlay") })
 
 #' @rdname column_window_functions
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R 
b/R/pkg/tests/fulltests/test_sparkSQL.R
index b97c500..f0cb274 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1988,6 +1988,17 @@ test_that("string operators", {
     collect(select(df5, repeat_string(df5$a, -1)))[1, 1],
     ""
   )
+
+  l6 <- list(list("cat"), list("\ud83d\udc08"))
+  df6 <- createDataFrame(l6)
+  expect_equal(
+    collect(select(df6, octet_length(df6$"_1")))[, 1],
+    c(3, 4)
+  )
+  expect_equal(
+    collect(select(df6, bit_length(df6$"_1")))[, 1],
+    c(24, 32)
+  )
 })
 
 test_that("date functions on a DataFrame", {
diff --git a/python/docs/source/reference/pyspark.sql.rst 
b/python/docs/source/reference/pyspark.sql.rst
index 7653ce4..326b83b 100644
--- a/python/docs/source/reference/pyspark.sql.rst
+++ b/python/docs/source/reference/pyspark.sql.rst
@@ -367,6 +367,7 @@ Functions
     avg
     base64
     bin
+    bit_length
     bitwise_not
     bitwiseNOT
     broadcast
@@ -484,6 +485,7 @@ Functions
     next_day
     nth_value
     ntile
+    octet_length
     overlay
     pandas_udf
     percent_rank
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e418c0d..105727e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -3098,6 +3098,58 @@ def length(col):
     return Column(sc._jvm.functions.length(_to_java_column(col)))
 
 
+def octet_length(col):
+    """
+    Calculates the byte length for the specified string column.
+
+    .. versionadded:: 3.3.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        Source column or strings
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        Byte length of the col
+
+    Examples
+    -------
+    >>> from pyspark.sql.functions import octet_length
+    >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \
+            .select(octet_length('cat')).collect()
+        [Row(octet_length(cat)=3), Row(octet_length(cat)=4)]
+    """
+    return _invoke_function_over_column("octet_length", col)
+
+
+def bit_length(col):
+    """
+    Calculates the bit length for the specified string column.
+
+    .. versionadded:: 3.3.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        Source column or strings
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        Bit length of the col
+
+    Examples
+    -------
+    >>> from pyspark.sql.functions import bit_length
+    >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \
+            .select(bit_length('cat')).collect()
+        [Row(bit_length(cat)=24), Row(bit_length(cat)=32)]
+    """
+    return _invoke_function_over_column("bit_length", col)
+
+
 def translate(srcCol, matching, replace):
     """A function translate any character in the `srcCol` by a character in 
`matching`.
     The characters in `replace` is corresponding to the characters in 
`matching`.
diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi
index 143fa13..1a0a61e 100644
--- a/python/pyspark/sql/functions.pyi
+++ b/python/pyspark/sql/functions.pyi
@@ -174,6 +174,8 @@ def bin(col: ColumnOrName) -> Column: ...
 def hex(col: ColumnOrName) -> Column: ...
 def unhex(col: ColumnOrName) -> Column: ...
 def length(col: ColumnOrName) -> Column: ...
+def octet_length(col: ColumnOrName) -> Column: ...
+def bit_length(col: ColumnOrName) -> Column: ...
 def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: ...
 def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
 def create_map(*cols: ColumnOrName) -> Column: ...
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 082d61b..00a2660 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -23,7 +23,7 @@ from py4j.protocol import Py4JJavaError
 from pyspark.sql import Row, Window
 from pyspark.sql.functions import udf, input_file_name, col, 
percentile_approx, \
     lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, 
shiftRight, \
-    shiftright, shiftrightunsigned, shiftRightUnsigned
+    shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, 
bit_length
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -197,6 +197,18 @@ class FunctionsTests(ReusedSQLTestCase):
                 df.select(getattr(functions, name)("name")).first()[0],
                 df.select(getattr(functions, name)(col("name"))).first()[0])
 
+    def test_octet_length_function(self):
+        # SPARK-36751: add octet length api for python
+        df = self.spark.createDataFrame([('cat',), ('\U0001F408',)], ['cat'])
+        actual = df.select(octet_length('cat')).collect()
+        self.assertEqual([Row(3), Row(4)], actual)
+
+    def test_bit_length_function(self):
+        # SPARK-36751: add bit length api for python
+        df = self.spark.createDataFrame([('cat',), ('\U0001F408',)], ['cat'])
+        actual = df.select(bit_length('cat')).collect()
+        self.assertEqual([Row(24), Row(32)], actual)
+
     def test_array_contains_function(self):
         from pyspark.sql.functions import array_contains
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 781a2dd..2d12d5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2542,6 +2542,14 @@ object functions {
   def base64(e: Column): Column = withExpr { Base64(e.expr) }
 
   /**
+   * Calculates the bit length for the specified string column.
+   *
+   * @group string_funcs
+   * @since 3.3.0
+   */
+  def bit_length(e: Column): Column = withExpr { BitLength(e.expr) }
+
+  /**
    * Concatenates multiple input string columns together into a single string 
column,
    * using the given separator.
    *
@@ -2707,6 +2715,14 @@ object functions {
   }
 
   /**
+   * Calculates the byte length for the specified string column.
+   *
+   * @group string_funcs
+   * @since 3.3.0
+   */
+  def octet_length(e: Column): Column = withExpr { OctetLength(e.expr) }
+
+  /**
    * Extract a specific group matched by a Java regex, from the specified 
string column.
    * If the regex did not match, or the specified group did not match, an 
empty string is returned.
    * if the specified group index exceeds the group count of regex, an 
IllegalArgumentException
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 00074b0..30a6600 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -486,6 +486,58 @@ class StringFunctionsSuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
+  test("SPARK-36751: add octet length api for scala") {
+    // scalastyle:off
+    // non ascii characters are not allowed in the code, so we disable the 
scalastyle here.
+    val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, 
"\ud83d\udc08"))
+      .toDF("a", "b", "c", "d", "e", "f")
+    // string and binary input
+    checkAnswer(
+      df.select(octet_length($"a"), octet_length($"b")),
+      Row(3, 4))
+    // string and binary input
+    checkAnswer(
+      df.selectExpr("octet_length(a)", "octet_length(b)"),
+      Row(3, 4))
+    // integer, float and double input
+    checkAnswer(
+      df.selectExpr("octet_length(c)", "octet_length(d)", "octet_length(e)"),
+      Row(3, 3, 5)
+    )
+    // multi-byte character input
+    checkAnswer(
+      df.selectExpr("octet_length(f)"),
+      Row(4)
+    )
+    // scalastyle:on
+  }
+
+  test("SPARK-36751: add bit length api for scala") {
+    // scalastyle:off
+    // non ascii characters are not allowed in the code, so we disable the 
scalastyle here.
+    val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, 
"\ud83d\udc08"))
+      .toDF("a", "b", "c", "d", "e", "f")
+    // string and binary input
+    checkAnswer(
+      df.select(bit_length($"a"), bit_length($"b")),
+      Row(24, 32))
+    // string and binary input
+    checkAnswer(
+      df.selectExpr("bit_length(a)", "bit_length(b)"),
+      Row(24, 32))
+    // integer, float and double input
+    checkAnswer(
+      df.selectExpr("bit_length(c)", "bit_length(d)", "bit_length(e)"),
+      Row(24, 24, 40)
+    )
+    // multi-byte character input
+    checkAnswer(
+      df.selectExpr("bit_length(f)"),
+      Row(32)
+    )
+    // scalastyle:on
+  }
+
   test("initcap function") {
     val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z")
     checkAnswer(

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to