This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e2d2ab510632 [SPARK-49552][PYTHON] Add DataFrame API support for new 
'randstr' and 'uniform' SQL functions
e2d2ab510632 is described below

commit e2d2ab510632cc1948cb6b4500e9da49036a96bd
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Wed Sep 25 10:57:44 2024 +0800

    [SPARK-49552][PYTHON] Add DataFrame API support for new 'randstr' and 
'uniform' SQL functions
    
    ### What changes were proposed in this pull request?
    
    In https://github.com/apache/spark/pull/48004 we added new SQL functions 
`randstr` and `uniform`. This PR adds DataFrame API support for them.
    
    For example, in Scala:
    
    ```
    sql("create table t(col int not null) using csv")
    sql("insert into t values (0)")
    val df = sql("select col from t")
    df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x")))
    > 5
    
    df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5")
    > true
    ```
    
    ### Why are the changes needed?
    
    This improves DataFrame parity with the SQL API.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds unit test coverage.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48143 from dtenedor/dataframes-uniform-randstr.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../source/reference/pyspark.sql/functions.rst     |   2 +
 python/pyspark/sql/connect/functions/builtin.py    |  28 ++++++
 python/pyspark/sql/functions/builtin.py            |  92 ++++++++++++++++++
 python/pyspark/sql/tests/test_functions.py         |  21 ++++-
 .../scala/org/apache/spark/sql/functions.scala     |  45 +++++++++
 .../catalyst/expressions/randomExpressions.scala   |  49 ++++++++--
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 104 +++++++++++++++++++++
 7 files changed, 331 insertions(+), 10 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/functions.rst 
b/python/docs/source/reference/pyspark.sql/functions.rst
index 4910a5b59273..6248e7133165 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -148,6 +148,7 @@ Mathematical Functions
     try_multiply
     try_subtract
     unhex
+    uniform
     width_bucket
 
 
@@ -189,6 +190,7 @@ String Functions
     overlay
     position
     printf
+    randstr
     regexp_count
     regexp_extract
     regexp_extract_all
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index 6953230f5b42..27b12fff3c0a 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -1007,6 +1007,22 @@ def unhex(col: "ColumnOrName") -> Column:
 unhex.__doc__ = pysparkfuncs.unhex.__doc__
 
 
+def uniform(
+    min: Union[Column, int, float],
+    max: Union[Column, int, float],
+    seed: Optional[Union[Column, int]] = None,
+) -> Column:
+    if seed is None:
+        return _invoke_function_over_columns(
+            "uniform", lit(min), lit(max), lit(random.randint(0, sys.maxsize))
+        )
+    else:
+        return _invoke_function_over_columns("uniform", lit(min), lit(max), 
lit(seed))
+
+
+uniform.__doc__ = pysparkfuncs.uniform.__doc__
+
+
 def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> 
Column:
     warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", 
FutureWarning)
     return approx_count_distinct(col, rsd)
@@ -2581,6 +2597,18 @@ def regexp_like(str: "ColumnOrName", regexp: 
"ColumnOrName") -> Column:
 regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__
 
 
+def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = 
None) -> Column:
+    if seed is None:
+        return _invoke_function_over_columns(
+            "randstr", lit(length), lit(random.randint(0, sys.maxsize))
+        )
+    else:
+        return _invoke_function_over_columns("randstr", lit(length), lit(seed))
+
+
+randstr.__doc__ = pysparkfuncs.randstr.__doc__
+
+
 def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
     return _invoke_function_over_columns("regexp_count", str, regexp)
 
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index 09a286fe7c94..4ca39562cb20 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -11973,6 +11973,47 @@ def regexp_like(str: "ColumnOrName", regexp: 
"ColumnOrName") -> Column:
     return _invoke_function_over_columns("regexp_like", str, regexp)
 
 
+@_try_remote_functions
+def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = 
None) -> Column:
+    """Returns a string of the specified length whose characters are chosen 
uniformly at random from
+    the following pool of characters: 0-9, a-z, A-Z. The random seed is 
optional. The string length
+    must be a constant two-byte or four-byte integer (SMALLINT or INT, 
respectively).
+
+    .. versionadded:: 4.0.0
+
+    Parameters
+    ----------
+    length : :class:`~pyspark.sql.Column` or int
+        Number of characters in the string to generate.
+    seed : :class:`~pyspark.sql.Column` or int
+        Optional random number seed to use.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        The generated random string with the specified length.
+
+    Examples
+    --------
+    >>> spark.createDataFrame([('3',)], ['a']) \\
+    ...   .select(randstr(lit(5), lit(0)).alias('result')) \\
+    ...   .selectExpr("length(result) > 0").show()
+    +--------------------+
+    |(length(result) > 0)|
+    +--------------------+
+    |                true|
+    +--------------------+
+    """
+    length = _enum_to_value(length)
+    length = lit(length)
+    if seed is None:
+        return _invoke_function_over_columns("randstr", length)
+    else:
+        seed = _enum_to_value(seed)
+        seed = lit(seed)
+        return _invoke_function_over_columns("randstr", length, seed)
+
+
 @_try_remote_functions
 def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
     r"""Returns a count of the number of times that the Java regex pattern 
`regexp` is matched
@@ -12339,6 +12380,57 @@ def unhex(col: "ColumnOrName") -> Column:
     return _invoke_function_over_columns("unhex", col)
 
 
+@_try_remote_functions
+def uniform(
+    min: Union[Column, int, float],
+    max: Union[Column, int, float],
+    seed: Optional[Union[Column, int]] = None,
+) -> Column:
+    """Returns a random value with independent and identically distributed 
(i.i.d.) values with the
+    specified range of numbers. The random seed is optional. The provided 
numbers specifying the
+    minimum and maximum values of the range must be constant. If both of these 
numbers are integers,
+    then the result will also be an integer. Otherwise if one or both of these 
are floating-point
+    numbers, then the result will also be a floating-point number.
+
+    .. versionadded:: 4.0.0
+
+    Parameters
+    ----------
+    min : :class:`~pyspark.sql.Column`, int, or float
+        Minimum value in the range.
+    max : :class:`~pyspark.sql.Column`, int, or float
+        Maximum value in the range.
+    seed : :class:`~pyspark.sql.Column` or int
+        Optional random number seed to use.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        The generated random number within the specified range.
+
+    Examples
+    --------
+    >>> spark.createDataFrame([('3',)], ['a']) \\
+    ...    .select(uniform(lit(0), lit(10), lit(0)).alias('result')) \\
+    ...    .selectExpr("result < 15").show()
+    +-------------+
+    |(result < 15)|
+    +-------------+
+    |         true|
+    +-------------+
+    """
+    min = _enum_to_value(min)
+    min = lit(min)
+    max = _enum_to_value(max)
+    max = lit(max)
+    if seed is None:
+        return _invoke_function_over_columns("uniform", min, max)
+    else:
+        seed = _enum_to_value(seed)
+        seed = lit(seed)
+        return _invoke_function_over_columns("uniform", min, max, seed)
+
+
 @_try_remote_functions
 def length(col: "ColumnOrName") -> Column:
     """Computes the character length of string data or number of bytes of 
binary data.
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index a0ab9bc9c7d4..a51156e895c6 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -29,7 +29,7 @@ from pyspark.errors import PySparkTypeError, 
PySparkValueError, SparkRuntimeExce
 from pyspark.sql import Row, Window, functions as F, types
 from pyspark.sql.avro.functions import from_avro, to_avro
 from pyspark.sql.column import Column
-from pyspark.sql.functions.builtin import nullifzero, zeroifnull
+from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, 
zeroifnull
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
 from pyspark.testing.utils import have_numpy
 
@@ -1610,6 +1610,25 @@ class FunctionsTestsMixin:
         result = df.select(zeroifnull(df.a).alias("r")).collect()
         self.assertEqual([Row(r=0), Row(r=1)], result)
 
+    def test_randstr_uniform(self):
+        df = self.spark.createDataFrame([(0,)], ["a"])
+        result = df.select(randstr(F.lit(5), 
F.lit(0)).alias("x")).selectExpr("length(x)").collect()
+        self.assertEqual([Row(5)], result)
+        # The random seed is optional.
+        result = 
df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
+        self.assertEqual([Row(5)], result)
+
+        df = self.spark.createDataFrame([(0,)], ["a"])
+        result = (
+            df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
+            .selectExpr("x > 5")
+            .collect()
+        )
+        self.assertEqual([Row(True)], result)
+        # The random seed is optional.
+        result = df.select(uniform(F.lit(10), 
F.lit(20)).alias("x")).selectExpr("x > 5").collect()
+        self.assertEqual([Row(True)], result)
+
 
 class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
     pass
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index ab69789c75f5..93bff2262105 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1896,6 +1896,26 @@ object functions {
    */
   def randn(): Column = randn(SparkClassUtils.random.nextLong)
 
+  /**
+   * Returns a string of the specified length whose characters are chosen 
uniformly at random from
+   * the following pool of characters: 0-9, a-z, A-Z. The string length must 
be a constant
+   * two-byte or four-byte integer (SMALLINT or INT, respectively).
+   *
+   * @group string_funcs
+   * @since 4.0.0
+   */
+  def randstr(length: Column): Column = Column.fn("randstr", length)
+
+  /**
+   * Returns a string of the specified length whose characters are chosen 
uniformly at random from
+   * the following pool of characters: 0-9, a-z, A-Z, with the chosen random 
seed. The string
+   * length must be a constant two-byte or four-byte integer (SMALLINT or INT, 
respectively).
+   *
+   * @group string_funcs
+   * @since 4.0.0
+   */
+  def randstr(length: Column, seed: Column): Column = Column.fn("randstr", 
length, seed)
+
   /**
    * Partition ID.
    *
@@ -3740,6 +3760,31 @@ object functions {
    */
   def stack(cols: Column*): Column = Column.fn("stack", cols: _*)
 
+  /**
+   * Returns a random value with independent and identically distributed 
(i.i.d.) values with the
+   * specified range of numbers. The provided numbers specifying the minimum 
and maximum values of
+   * the range must be constant. If both of these numbers are integers, then 
the result will also
+   * be an integer. Otherwise if one or both of these are floating-point 
numbers, then the result
+   * will also be a floating-point number.
+   *
+   * @group math_funcs
+   * @since 4.0.0
+   */
+  def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, 
max)
+
+  /**
+   * Returns a random value with independent and identically distributed 
(i.i.d.) values with the
+   * specified range of numbers, with the chosen random seed. The provided 
numbers specifying the
+   * minimum and maximum values of the range must be constant. If both of 
these numbers are
+   * integers, then the result will also be an integer. Otherwise if one or 
both of these are
+   * floating-point numbers, then the result will also be a floating-point 
number.
+   *
+   * @group math_funcs
+   * @since 4.0.0
+   */
+  def uniform(min: Column, max: Column, seed: Column): Column =
+    Column.fn("uniform", min, max, seed)
+
   /**
    * Returns a random value with independent and identically distributed 
(i.i.d.) uniformly
    * distributed values in [0, 1).
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index f329f8346b0d..ada0a73a6795 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -206,15 +206,18 @@ object Randn {
   """,
   since = "4.0.0",
   group = "math_funcs")
-case class Uniform(min: Expression, max: Expression, seedExpression: 
Expression)
+case class Uniform(min: Expression, max: Expression, seedExpression: 
Expression, hideSeed: Boolean)
   extends RuntimeReplaceable with TernaryLike[Expression] with RDG {
-  def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed)
+  def this(min: Expression, max: Expression) =
+    this(min, max, UnresolvedSeed, hideSeed = true)
+  def this(min: Expression, max: Expression, seedExpression: Expression) =
+    this(min, max, seedExpression, hideSeed = false)
 
   final override lazy val deterministic: Boolean = false
   override val nodePatterns: Seq[TreePattern] =
     Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
 
-  override val dataType: DataType = {
+  override def dataType: DataType = {
     val first = min.dataType
     val second = max.dataType
     (min.dataType, max.dataType) match {
@@ -240,6 +243,10 @@ case class Uniform(min: Expression, max: Expression, 
seedExpression: Expression)
     case _ => false
   }
 
+  override def sql: String = {
+    s"uniform(${min.sql}, ${max.sql}${if (hideSeed) "" else s", 
${seedExpression.sql}"})"
+  }
+
   override def checkInputDataTypes(): TypeCheckResult = {
     var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
     def requiredType = "integer or floating-point"
@@ -277,11 +284,11 @@ case class Uniform(min: Expression, max: Expression, 
seedExpression: Expression)
   override def third: Expression = seedExpression
 
   override def withNewSeed(newSeed: Long): Expression =
-    Uniform(min, max, Literal(newSeed, LongType))
+    Uniform(min, max, Literal(newSeed, LongType), hideSeed)
 
   override def withNewChildrenInternal(
       newFirst: Expression, newSecond: Expression, newThird: Expression): 
Expression =
-    Uniform(newFirst, newSecond, newThird)
+    Uniform(newFirst, newSecond, newThird, hideSeed)
 
   override def replacement: Expression = {
     if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) {
@@ -300,6 +307,13 @@ case class Uniform(min: Expression, max: Expression, 
seedExpression: Expression)
   }
 }
 
+object Uniform {
+  def apply(min: Expression, max: Expression): Uniform =
+    Uniform(min, max, UnresolvedSeed, hideSeed = true)
+  def apply(min: Expression, max: Expression, seedExpression: Expression): 
Uniform =
+    Uniform(min, max, seedExpression, hideSeed = false)
+}
+
 @ExpressionDescription(
   usage = """
     _FUNC_(length[, seed]) - Returns a string of the specified length whose 
characters are chosen
@@ -315,9 +329,13 @@ case class Uniform(min: Expression, max: Expression, 
seedExpression: Expression)
   """,
   since = "4.0.0",
   group = "string_funcs")
-case class RandStr(length: Expression, override val seedExpression: Expression)
+case class RandStr(
+    length: Expression, override val seedExpression: Expression, hideSeed: 
Boolean)
   extends ExpressionWithRandomSeed with BinaryLike[Expression] with 
Nondeterministic {
-  def this(length: Expression) = this(length, UnresolvedSeed)
+  def this(length: Expression) =
+    this(length, UnresolvedSeed, hideSeed = true)
+  def this(length: Expression, seedExpression: Expression) =
+    this(length, seedExpression, hideSeed = false)
 
   override def nullable: Boolean = false
   override def dataType: DataType = StringType
@@ -339,9 +357,14 @@ case class RandStr(length: Expression, override val 
seedExpression: Expression)
     rng = new XORShiftRandom(seed + partitionIndex)
   }
 
-  override def withNewSeed(newSeed: Long): Expression = RandStr(length, 
Literal(newSeed, LongType))
+  override def withNewSeed(newSeed: Long): Expression =
+    RandStr(length, Literal(newSeed, LongType), hideSeed)
   override def withNewChildrenInternal(newFirst: Expression, newSecond: 
Expression): Expression =
-    RandStr(newFirst, newSecond)
+    RandStr(newFirst, newSecond, hideSeed)
+
+  override def sql: String = {
+    s"randstr(${length.sql}${if (hideSeed) "" else s", 
${seedExpression.sql}"})"
+  }
 
   override def checkInputDataTypes(): TypeCheckResult = {
     var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
@@ -422,3 +445,11 @@ case class RandStr(length: Expression, override val 
seedExpression: Expression)
       isNull = FalseLiteral)
   }
 }
+
+object RandStr {
+  def apply(length: Expression): RandStr =
+    RandStr(length, UnresolvedSeed, hideSeed = true)
+  def apply(length: Expression, seedExpression: Expression): RandStr =
+    RandStr(length, seedExpression, hideSeed = false)
+}
+
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0842b92e5d53..016803635ff6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -411,6 +411,110 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null)))
   }
 
+  test("randstr function") {
+    withTable("t") {
+      sql("create table t(col int not null) using csv")
+      sql("insert into t values (0)")
+      val df = sql("select col from t")
+      checkAnswer(
+        df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))),
+        Seq(Row(5)))
+      // The random seed is optional.
+      checkAnswer(
+        df.select(randstr(lit(5)).alias("x")).select(length(col("x"))),
+        Seq(Row(5)))
+    }
+    // Here we exercise some error cases.
+    val df = Seq((0)).toDF("a")
+    var expr = randstr(lit(10), lit("a"))
+    checkError(
+      intercept[AnalysisException](df.select(expr)),
+      condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"randstr(10, a)\"",
+        "paramIndex" -> "second",
+        "inputSql" -> "\"a\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "INT or SMALLINT"),
+      context = ExpectedContext(
+        contextType = QueryContextType.DataFrame,
+        fragment = "randstr",
+        objectType = "",
+        objectName = "",
+        callSitePattern = "",
+        startIndex = 0,
+        stopIndex = 0))
+    expr = randstr(col("a"), lit(10))
+    checkError(
+      intercept[AnalysisException](df.select(expr)),
+      condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+      parameters = Map(
+        "inputName" -> "length",
+        "inputType" -> "INT or SMALLINT",
+        "inputExpr" -> "\"a\"",
+        "sqlExpr" -> "\"randstr(a, 10)\""),
+      context = ExpectedContext(
+        contextType = QueryContextType.DataFrame,
+        fragment = "randstr",
+        objectType = "",
+        objectName = "",
+        callSitePattern = "",
+        startIndex = 0,
+        stopIndex = 0))
+  }
+
+  test("uniform function") {
+    withTable("t") {
+      sql("create table t(col int not null) using csv")
+      sql("insert into t values (0)")
+      val df = sql("select col from t")
+      checkAnswer(
+        df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x 
> 5"),
+        Seq(Row(true)))
+      // The random seed is optional.
+      checkAnswer(
+        df.select(uniform(lit(10), lit(20)).alias("x")).selectExpr("x > 5"),
+        Seq(Row(true)))
+    }
+    // Here we exercise some error cases.
+    val df = Seq((0)).toDF("a")
+    var expr = uniform(lit(10), lit("a"))
+    checkError(
+      intercept[AnalysisException](df.select(expr)),
+      condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"uniform(10, a)\"",
+        "paramIndex" -> "second",
+        "inputSql" -> "\"a\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "integer or floating-point"),
+      context = ExpectedContext(
+        contextType = QueryContextType.DataFrame,
+        fragment = "uniform",
+        objectType = "",
+        objectName = "",
+        callSitePattern = "",
+        startIndex = 0,
+        stopIndex = 0))
+    expr = uniform(col("a"), lit(10))
+    checkError(
+      intercept[AnalysisException](df.select(expr)),
+      condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+      parameters = Map(
+        "inputName" -> "min",
+        "inputType" -> "integer or floating-point",
+        "inputExpr" -> "\"a\"",
+        "sqlExpr" -> "\"uniform(a, 10)\""),
+      context = ExpectedContext(
+        contextType = QueryContextType.DataFrame,
+        fragment = "uniform",
+        objectType = "",
+        objectName = "",
+        callSitePattern = "",
+        startIndex = 0,
+        stopIndex = 0))
+  }
+
   test("zeroifnull function") {
     withTable("t") {
       // Here we exercise a non-nullable, non-foldable column.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to