This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new a762f3175fcd [SPARK-48184][PYTHON][CONNECT] Always set the seed of 
`Dataframe.sample` in Client side
a762f3175fcd is described below

commit a762f3175fcdb7b069faa0c2bfce93d295cb1f10
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 8 07:44:22 2024 -0700

    [SPARK-48184][PYTHON][CONNECT] Always set the seed of `Dataframe.sample` in 
Client side
    
    ### What changes were proposed in this pull request?
    Always set the seed of `Dataframe.sample` in Client side
    
    ### Why are the changes needed?
    Bug fix
    
    If the seed is not set in Client, it will be set in server side with a 
random int
    
    
https://github.com/apache/spark/blob/c4df12cc884cddefcfcf8324b4d7b9349fb4f6a0/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L386
    
    which cause inconsistent results in multiple executions
    
    In Spark Classic:
    ```
    In [1]: df = spark.range(10000).sample(0.1)
    
    In [2]: [df.count() for i in range(10)]
    Out[2]: [1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006]
    ```
    
    In Spark Connect:
    
    before:
    ```
    In [1]: df = spark.range(10000).sample(0.1)
    
    In [2]: [df.count() for i in range(10)]
    Out[2]: [969, 1005, 958, 996, 987, 1026, 991, 1020, 1012, 979]
    ```
    
    after:
    ```
    In [1]: df = spark.range(10000).sample(0.1)
    
    In [2]: [df.count() for i in range(10)]
    Out[2]: [1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes, bug fix
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46456 from zhengruifeng/py_connect_sample_seed.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 47afe77242abf639a1d6966ce60cfd170a9d7d20)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/connect/dataframe.py               | 2 +-
 python/pyspark/sql/tests/connect/test_connect_plan.py | 2 +-
 python/pyspark/sql/tests/test_dataframe.py            | 5 +++++
 3 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index ff6191642025..6f23a15fb4ad 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -687,7 +687,7 @@ class DataFrame:
         if withReplacement is None:
             withReplacement = False
 
-        seed = int(seed) if seed is not None else None
+        seed = int(seed) if seed is not None else random.randint(0, 
sys.maxsize)
 
         return DataFrame.withPlan(
             plan.Sample(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py 
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index c39fb6be24cd..88ef37511a66 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -430,7 +430,7 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.lower_bound, 0.0)
         self.assertEqual(plan.root.sample.upper_bound, 0.3)
         self.assertEqual(plan.root.sample.with_replacement, False)
-        self.assertEqual(plan.root.sample.HasField("seed"), False)
+        self.assertEqual(plan.root.sample.HasField("seed"), True)
         self.assertEqual(plan.root.sample.deterministic_order, False)
 
         plan = (
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 5907c8c09fb4..887648018cf3 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1045,6 +1045,11 @@ class DataFrameTestsMixin:
             IllegalArgumentException, lambda: 
self.spark.range(1).sample(-1.0).count()
         )
 
+    def test_sample_with_random_seed(self):
+        df = self.spark.range(10000).sample(0.1)
+        cnts = [df.count() for i in range(10)]
+        self.assertEqual(1, len(set(cnts)))
+
     def test_toDF_with_string(self):
         df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 
28)])
         data = [("John", 30), ("Alice", 25), ("Bob", 28)]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to