This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a3b918eb94c1 [SPARK-49540][PS] Unify the usage of 
`distributed_sequence_id`
a3b918eb94c1 is described below

commit a3b918eb94c1ad49bf8bdfddf31d40a346e0fafb
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Sep 9 12:26:23 2024 +0800

    [SPARK-49540][PS] Unify the usage of `distributed_sequence_id`
    
    ### What changes were proposed in this pull request?
    
    in PySpark Classic, it was used via a dataframe method 
`withSequenceColumn`, while in PySpark Connect, it was used as an internal 
function
    
    This PR unifies the usage of `distributed_sequence_id`
    
    ### Why are the changes needed?
    code refactoring
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48028 from zhengruifeng/func_withSequenceColumn.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/internal.py                      | 18 +++++-------------
 python/pyspark/pandas/spark/functions.py               | 12 ++++++++++++
 .../src/main/scala/org/apache/spark/sql/Dataset.scala  |  8 --------
 .../apache/spark/sql/api/python/PythonSQLUtils.scala   |  7 +++++++
 .../org/apache/spark/sql/DataFrameSelfJoinSuite.scala  |  5 +++--
 .../scala/org/apache/spark/sql/DataFrameSuite.scala    |  4 +++-
 6 files changed, 30 insertions(+), 24 deletions(-)

diff --git a/python/pyspark/pandas/internal.py 
b/python/pyspark/pandas/internal.py
index 92d4a3357319..4be345201ba6 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -43,6 +43,7 @@ from pyspark.sql.types import (  # noqa: F401
 )
 from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
 from pyspark import pandas as ps
+from pyspark.pandas.spark import functions as SF
 from pyspark.pandas._typing import Label
 from pyspark.pandas.spark.utils import as_nullable_spark_type, 
force_decimal_precision_scale
 from pyspark.pandas.data_type_ops.base import DataTypeOps
@@ -938,19 +939,10 @@ class InternalFrame:
         +--------+---+
         """
         if len(sdf.columns) > 0:
-            if is_remote():
-                from pyspark.sql.connect.column import Column as ConnectColumn
-                from pyspark.sql.connect.expressions import 
DistributedSequenceID
-
-                return sdf.select(
-                    ConnectColumn(DistributedSequenceID()).alias(column_name),
-                    "*",
-                )
-            else:
-                return PySparkDataFrame(
-                    sdf._jdf.toDF().withSequenceColumn(column_name),
-                    sdf.sparkSession,
-                )
+            return sdf.select(
+                SF.distributed_sequence_id().alias(column_name),
+                "*",
+            )
         else:
             cnt = sdf.count()
             if cnt > 0:
diff --git a/python/pyspark/pandas/spark/functions.py 
b/python/pyspark/pandas/spark/functions.py
index 6aaa63956c14..4bcf07f6f650 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -174,6 +174,18 @@ def null_index(col: Column) -> Column:
         return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
 
 
+def distributed_sequence_id() -> Column:
+    if is_remote():
+        from pyspark.sql.connect.functions.builtin import _invoke_function
+
+        return _invoke_function("distributed_sequence_id")
+    else:
+        from pyspark import SparkContext
+
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())
+
+
 def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
     if is_remote():
         from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 870571b533d0..0fab60a94842 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2010,14 +2010,6 @@ class Dataset[T] private[sql](
   // For Python API
   ////////////////////////////////////////////////////////////////////////////
 
-  /**
-   * It adds a new long column with the name `name` that increases one by one.
-   * This is for 'distributed-sequence' default index in pandas API on Spark.
-   */
-  private[sql] def withSequenceColumn(name: String) = {
-    select(column(DistributedSequenceID()).alias(name), col("*"))
-  }
-
   /**
    * Converts a JavaRDD to a PythonRDD.
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 7dbc586f6473..93082740cca6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging {
   def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
     Column.internalFn("pandas_covar", col1, col2, lit(ddof))
 
+  /**
+   * A long column that increases one by one.
+   * This is for 'distributed-sequence' default index in pandas API on Spark.
+   */
+  def distributed_sequence_id(): Column =
+    Column.internalFn("distributed_sequence_id")
+
   def unresolvedNamedLambdaVariable(name: String): Column =
     Column(internal.UnresolvedNamedLambdaVariable.apply(name))
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 310b5a62c908..d888b09d76ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -18,10 +18,11 @@
 package org.apache.spark.sql
 
 import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
 import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, 
AttributeReference, PythonUDF, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, 
ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
 import org.apache.spark.sql.expressions.Window
-import org.apache.spark.sql.functions.{count, explode, sum, year}
+import org.apache.spark.sql.functions.{col, count, explode, sum, year}
 import org.apache.spark.sql.internal.ExpressionUtils.column
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with 
SharedSparkSession {
     assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
 
     // Test for AttachDistributedSequence
-    val df13 = df1.withSequenceColumn("seq")
+    val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*"))
     val df14 = df13.filter($"value" === "A2")
     assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
     assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b1c41033fd76..9bfbdda33c36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._
 
 import org.apache.spark.SparkException
 import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest
   }
 
   test("SPARK-36338: DataFrame.withSequenceColumn should append unique 
sequence IDs") {
-    val ids = 
spark.range(10).repartition(5).withSequenceColumn("default_index")
+    val ids = spark.range(10).repartition(5).select(
+      distributed_sequence_id().alias("default_index"), col("id"))
     assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
     assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to