This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 940946515bd [SPARK-41440][CONNECT][PYTHON] Implement 
`DataFrame.randomSplit`
940946515bd is described below

commit 940946515bd199930051be89f9fd557a35f2af0d
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Wed Dec 21 11:31:00 2022 +0900

    [SPARK-41440][CONNECT][PYTHON] Implement `DataFrame.randomSplit`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.randomSplit` with a proto message
    
    Implement `DataFrame.randomSplit` for scala API
    Implement `DataFrame.randomSplit` for python API
    
    ### Why are the changes needed?
    for Connect API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    'No'. New API
    
    ### How was this patch tested?
    New test cases.
    
    Closes #39017 from beliefer/SPARK-41440.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  4 ++
 .../org/apache/spark/sql/connect/dsl/package.scala | 34 ++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 24 +++++++-
 .../connect/planner/SparkConnectProtoSuite.scala   | 16 +++++
 python/pyspark/sql/connect/dataframe.py            | 57 +++++++++++++++++
 python/pyspark/sql/connect/plan.py                 |  3 +
 python/pyspark/sql/connect/proto/relations_pb2.py  | 72 +++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 18 ++++++
 .../sql/tests/connect/test_connect_basic.py        | 15 +++++
 .../sql/tests/connect/test_connect_plan_only.py    | 39 ++++++++++++
 10 files changed, 244 insertions(+), 38 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 2f83db1176a..42471821634 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -317,6 +317,10 @@ message Sample {
 
   // (Optional) The random seed.
   optional int64 seed = 5;
+
+  // (Optional) Explicitly sort the underlying plan to make the ordering 
deterministic.
+  // This flag is only used to randomly splits DataFrame with the provided 
weights.
+  optional bool force_stable_sort = 6;
 }
 
 // Relation of type [[Range]] that generates a sequence of integers.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 8211dc21bde..bce8d390fcb 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -27,6 +27,7 @@ import org.apache.spark.connect.proto.SetOperation.SetOpType
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.connect.planner.DataTypeProtoConverter
 import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
+import org.apache.spark.util.Utils
 
 /**
  * A collection of implicit conversions that create a DSL for constructing 
connect protos.
@@ -775,6 +776,39 @@ package object dsl {
           valueColumnName: String): Relation =
         unpivot(ids, variableColumnName, valueColumnName)
 
+      def randomSplit(weights: Array[Double], seed: Long): Array[Relation] = {
+        require(
+          weights.forall(_ >= 0),
+          s"Weights must be nonnegative, but got ${weights.mkString("[", ",", 
"]")}")
+        require(
+          weights.sum > 0,
+          s"Sum of weights must be positive, but got ${weights.mkString("[", 
",", "]")}")
+
+        val sum = weights.toSeq.sum
+        val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+        normalizedCumWeights
+          .sliding(2)
+          .map { x =>
+            Relation
+              .newBuilder()
+              .setSample(
+                Sample
+                  .newBuilder()
+                  .setInput(logicalPlan)
+                  .setLowerBound(x(0))
+                  .setUpperBound(x(1))
+                  .setWithReplacement(false)
+                  .setSeed(seed)
+                  .setForceStableSort(true)
+                  .build())
+              .build()
+          }
+          .toArray
+      }
+
+      def randomSplit(weights: Array[Double]): Array[Relation] =
+        randomSplit(weights, Utils.random.nextLong)
+
       private def createSetOperation(
           left: Relation,
           right: Relation,
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9fe9acd354d..cad6c3c5c61 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, 
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
-import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, 
Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, Unpivot, 
UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, 
Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, 
Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression,
 toCatalystValue}
 import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -132,12 +132,32 @@ class SparkConnectPlanner(session: SparkSession) {
    * wrap such fields into proto messages.
    */
   private def transformSample(rel: proto.Sample): LogicalPlan = {
+    val input = Dataset.ofRows(session, transformRelation(rel.getInput))
+    val plan = if (rel.getForceStableSort) {
+      // It is possible that the underlying dataframe doesn't guarantee the 
ordering of rows in its
+      // constituent partitions each time a split is materialized which could 
result in
+      // overlapping splits. To prevent this, we explicitly sort each input 
partition to make the
+      // ordering deterministic. Note that MapTypes cannot be sorted and are 
explicitly pruned out
+      // from the sort order.
+      val sortOrder = input.logicalPlan.output
+        .filter(attr => RowOrdering.isOrderable(attr.dataType))
+        .map(SortOrder(_, Ascending))
+      if (sortOrder.nonEmpty) {
+        Sort(sortOrder, global = false, input.logicalPlan)
+      } else {
+        input.logicalPlan
+      }
+    } else {
+      input.cache()
+      input.logicalPlan
+    }
+
     Sample(
       rel.getLowerBound,
       rel.getUpperBound,
       rel.getWithReplacement,
       if (rel.hasSeed) rel.getSeed else Utils.random.nextLong,
-      transformRelation(rel.getInput))
+      plan)
   }
 
   private def transformRepartition(rel: proto.Repartition): LogicalPlan = {
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index fcd0ac49fb2..354b7693dab 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -598,6 +598,22 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     comparePlans(connectPlan1, sparkPlan1)
   }
 
+  test("Test RandomSplit") {
+    val splitRelations0 = connectTestRelation.randomSplit(Array[Double](1, 2, 
3), 1)
+    val splits0 = sparkTestRelation.randomSplit(Array[Double](1, 2, 3), 1)
+    assert(splitRelations0.length == splits0.length)
+    splitRelations0.zip(splits0).foreach { case (connectPlan, sparkPlan) =>
+      comparePlans(connectPlan, sparkPlan)
+    }
+
+    val splitRelations1 = connectTestRelation.randomSplit(Array[Double](1, 2, 
3))
+    val splits1 = sparkTestRelation.randomSplit(Array[Double](1, 2, 3))
+    assert(splitRelations1.length == splits1.length)
+    splitRelations1.zip(splits1).foreach { case (connectPlan, sparkPlan) =>
+      comparePlans(connectPlan, sparkPlan)
+    }
+  }
+
   private def createLocalRelationProtoByAttributeReferences(
       attrs: Seq[AttributeReference]): proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 547ed1d110a..8e19b8892b2 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -30,6 +30,8 @@ from typing import (
     Type,
 )
 
+import sys
+import random
 import pandas
 import warnings
 from collections.abc import Iterable
@@ -918,6 +920,61 @@ class DataFrame(object):
             session=self._session,
         )
 
+    def randomSplit(
+        self,
+        weights: List[float],
+        seed: Optional[int] = None,
+    ) -> List["DataFrame"]:
+        """Randomly splits this :class:`DataFrame` with the provided weights.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        weights : list
+            list of doubles as weights with which to split the 
:class:`DataFrame`.
+            Weights will be normalized if they don't sum up to 1.0.
+        seed : int, optional
+            The seed for sampling.
+
+        Returns
+        -------
+        list
+            List of DataFrames.
+        """
+        for w in weights:
+            if w < 0.0:
+                raise ValueError("Weights must be positive. Found weight 
value: %s" % w)
+        seed = seed if seed is not None else random.randint(0, sys.maxsize)
+        total = sum(weights)
+        if total <= 0:
+            raise ValueError("Sum of weights must be positive, but got: %s" % 
w)
+        proportions = list(map(lambda x: x / total, weights))
+        normalizedCumWeights = [0.0]
+        for v in proportions:
+            normalizedCumWeights.append(normalizedCumWeights[-1] + v)
+        j = 1
+        length = len(normalizedCumWeights)
+        splits = []
+        while j < length:
+            lowerBound = normalizedCumWeights[j - 1]
+            upperBound = normalizedCumWeights[j]
+            samplePlan = DataFrame.withPlan(
+                plan.Sample(
+                    child=self._plan,
+                    lower_bound=lowerBound,
+                    upper_bound=upperBound,
+                    with_replacement=False,
+                    seed=int(seed),
+                    force_stable_sort=True,
+                ),
+                session=self._session,
+            )
+            splits.append(samplePlan)
+            j += 1
+
+        return splits
+
     def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: 
bool = False) -> None:
         """
         Prints the first ``n`` rows to the console.
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 27c8461491c..320ed07425f 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -642,12 +642,14 @@ class Sample(LogicalPlan):
         upper_bound: float,
         with_replacement: bool,
         seed: Optional[int],
+        force_stable_sort: bool = False,
     ) -> None:
         super().__init__(child)
         self.lower_bound = lower_bound
         self.upper_bound = upper_bound
         self.with_replacement = with_replacement
         self.seed = seed
+        self.force_stable_sort = force_stable_sort
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
@@ -658,6 +660,7 @@ class Sample(LogicalPlan):
         plan.sample.with_replacement = self.with_replacement
         if self.seed is not None:
             plan.sample.seed = self.seed
+        plan.sample.force_stable_sort = self.force_stable_sort
         return plan
 
     def print(self, indent: int = 0) -> str:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 0ddda63aa42..f7ae79c3dfb 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...]
 )
 
 
@@ -546,39 +546,39 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _LOCALRELATION._serialized_start = 4481
     _LOCALRELATION._serialized_end = 4618
     _SAMPLE._serialized_start = 4621
-    _SAMPLE._serialized_end = 4845
-    _RANGE._serialized_start = 4848
-    _RANGE._serialized_end = 4993
-    _SUBQUERYALIAS._serialized_start = 4995
-    _SUBQUERYALIAS._serialized_end = 5109
-    _REPARTITION._serialized_start = 5112
-    _REPARTITION._serialized_end = 5254
-    _SHOWSTRING._serialized_start = 5257
-    _SHOWSTRING._serialized_end = 5398
-    _STATSUMMARY._serialized_start = 5400
-    _STATSUMMARY._serialized_end = 5492
-    _STATDESCRIBE._serialized_start = 5494
-    _STATDESCRIBE._serialized_end = 5575
-    _STATCROSSTAB._serialized_start = 5577
-    _STATCROSSTAB._serialized_end = 5678
-    _NAFILL._serialized_start = 5681
-    _NAFILL._serialized_end = 5815
-    _NADROP._serialized_start = 5818
-    _NADROP._serialized_end = 5952
-    _NAREPLACE._serialized_start = 5955
-    _NAREPLACE._serialized_end = 6251
-    _NAREPLACE_REPLACEMENT._serialized_start = 6110
-    _NAREPLACE_REPLACEMENT._serialized_end = 6251
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6253
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6367
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6370
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6629
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
6562
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6629
-    _WITHCOLUMNS._serialized_start = 6632
-    _WITHCOLUMNS._serialized_end = 6763
-    _HINT._serialized_start = 6766
-    _HINT._serialized_end = 6906
-    _UNPIVOT._serialized_start = 6909
-    _UNPIVOT._serialized_end = 7155
+    _SAMPLE._serialized_end = 4916
+    _RANGE._serialized_start = 4919
+    _RANGE._serialized_end = 5064
+    _SUBQUERYALIAS._serialized_start = 5066
+    _SUBQUERYALIAS._serialized_end = 5180
+    _REPARTITION._serialized_start = 5183
+    _REPARTITION._serialized_end = 5325
+    _SHOWSTRING._serialized_start = 5328
+    _SHOWSTRING._serialized_end = 5469
+    _STATSUMMARY._serialized_start = 5471
+    _STATSUMMARY._serialized_end = 5563
+    _STATDESCRIBE._serialized_start = 5565
+    _STATDESCRIBE._serialized_end = 5646
+    _STATCROSSTAB._serialized_start = 5648
+    _STATCROSSTAB._serialized_end = 5749
+    _NAFILL._serialized_start = 5752
+    _NAFILL._serialized_end = 5886
+    _NADROP._serialized_start = 5889
+    _NADROP._serialized_end = 6023
+    _NAREPLACE._serialized_start = 6026
+    _NAREPLACE._serialized_end = 6322
+    _NAREPLACE_REPLACEMENT._serialized_start = 6181
+    _NAREPLACE_REPLACEMENT._serialized_end = 6322
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6324
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6438
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6441
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6700
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
6633
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6700
+    _WITHCOLUMNS._serialized_start = 6703
+    _WITHCOLUMNS._serialized_end = 6834
+    _HINT._serialized_start = 6837
+    _HINT._serialized_end = 6977
+    _UNPIVOT._serialized_start = 6980
+    _UNPIVOT._serialized_end = 7226
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index a2a6538b32d..99a584ab7b9 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1161,6 +1161,7 @@ class Sample(google.protobuf.message.Message):
     UPPER_BOUND_FIELD_NUMBER: builtins.int
     WITH_REPLACEMENT_FIELD_NUMBER: builtins.int
     SEED_FIELD_NUMBER: builtins.int
+    FORCE_STABLE_SORT_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> global___Relation:
         """(Required) Input relation for a Sample."""
@@ -1172,6 +1173,10 @@ class Sample(google.protobuf.message.Message):
     """(Optional) Whether to sample with replacement."""
     seed: builtins.int
     """(Optional) The random seed."""
+    force_stable_sort: builtins.bool
+    """(Optional) Explicitly sort the underlying plan to make the ordering 
deterministic.
+    This flag is only used to randomly splits DataFrame with the provided 
weights.
+    """
     def __init__(
         self,
         *,
@@ -1180,14 +1185,19 @@ class Sample(google.protobuf.message.Message):
         upper_bound: builtins.float = ...,
         with_replacement: builtins.bool | None = ...,
         seed: builtins.int | None = ...,
+        force_stable_sort: builtins.bool | None = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
+            "_force_stable_sort",
+            b"_force_stable_sort",
             "_seed",
             b"_seed",
             "_with_replacement",
             b"_with_replacement",
+            "force_stable_sort",
+            b"force_stable_sort",
             "input",
             b"input",
             "seed",
@@ -1199,10 +1209,14 @@ class Sample(google.protobuf.message.Message):
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_force_stable_sort",
+            b"_force_stable_sort",
             "_seed",
             b"_seed",
             "_with_replacement",
             b"_with_replacement",
+            "force_stable_sort",
+            b"force_stable_sort",
             "input",
             b"input",
             "lower_bound",
@@ -1216,6 +1230,10 @@ class Sample(google.protobuf.message.Message):
         ],
     ) -> None: ...
     @typing.overload
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_force_stable_sort", 
b"_force_stable_sort"]
+    ) -> typing_extensions.Literal["force_stable_sort"] | None: ...
+    @typing.overload
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
     ) -> typing_extensions.Literal["seed"] | None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 3da0566a4e9..f08cb38edf3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -816,6 +816,21 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             .toPandas(),
         )
 
+    def test_random_split(self):
+        # SPARK-41440: test randomSplit(weights, seed).
+        relations = (
+            self.connect.read.table(self.tbl_name).filter("id > 
3").randomSplit([1.0, 2.0, 3.0], 2)
+        )
+        datasets = (
+            self.spark.read.table(self.tbl_name).filter("id > 
3").randomSplit([1.0, 2.0, 3.0], 2)
+        )
+
+        self.assertTrue(len(relations) == len(datasets))
+        i = 0
+        while i < len(relations):
+            self.assert_eq(relations[i].toPandas(), datasets[i].toPandas())
+            i += 1
+
     def test_with_columns(self):
         # SPARK-41256: test withColumn(s).
         self.assert_eq(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 703e27c46d6..66106fb7f2a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -25,6 +25,7 @@ from pyspark.testing.connectutils import (
 if should_test_connect:
     import pyspark.sql.connect.proto as proto
     from pyspark.sql.connect.column import Column
+    from pyspark.sql.connect.dataframe import DataFrame
     from pyspark.sql.connect.plan import WriteOperation
     from pyspark.sql.connect.readwriter import DataFrameReader
     from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
@@ -228,6 +229,42 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
         self.assertEqual(plan.root.unpivot.value_column_name, "value")
 
+    def test_random_split(self):
+        # SPARK-41440: test randomSplit(weights, seed).
+        from typing import List
+
+        df = self.connect.readTable(table_name=self.tbl_name)
+
+        def checkRelations(relations: List["DataFrame"]):
+            self.assertTrue(len(relations) == 3)
+
+            plan = relations[0]._plan.to_proto(self.connect)
+            self.assertEqual(plan.root.sample.lower_bound, 0.0)
+            self.assertEqual(plan.root.sample.upper_bound, 0.16666666666666666)
+            self.assertEqual(plan.root.sample.with_replacement, False)
+            self.assertEqual(plan.root.sample.HasField("seed"), True)
+            self.assertEqual(plan.root.sample.force_stable_sort, True)
+
+            plan = relations[1]._plan.to_proto(self.connect)
+            self.assertEqual(plan.root.sample.lower_bound, 0.16666666666666666)
+            self.assertEqual(plan.root.sample.upper_bound, 0.5)
+            self.assertEqual(plan.root.sample.with_replacement, False)
+            self.assertEqual(plan.root.sample.HasField("seed"), True)
+            self.assertEqual(plan.root.sample.force_stable_sort, True)
+
+            plan = relations[2]._plan.to_proto(self.connect)
+            self.assertEqual(plan.root.sample.lower_bound, 0.5)
+            self.assertEqual(plan.root.sample.upper_bound, 1.0)
+            self.assertEqual(plan.root.sample.with_replacement, False)
+            self.assertEqual(plan.root.sample.HasField("seed"), True)
+            self.assertEqual(plan.root.sample.force_stable_sort, True)
+
+        relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0], 1)
+        checkRelations(relations)
+
+        relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0])
+        checkRelations(relations)
+
     def test_summary(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.filter(df.col_name > 
3).summary()._plan.to_proto(self.connect)
@@ -281,6 +318,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.upper_bound, 0.3)
         self.assertEqual(plan.root.sample.with_replacement, False)
         self.assertEqual(plan.root.sample.HasField("seed"), False)
+        self.assertEqual(plan.root.sample.force_stable_sort, False)
 
         plan = (
             df.filter(df.col_name > 3)
@@ -291,6 +329,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.upper_bound, 0.4)
         self.assertEqual(plan.root.sample.with_replacement, True)
         self.assertEqual(plan.root.sample.seed, -1)
+        self.assertEqual(plan.root.sample.force_stable_sort, False)
 
     def test_sort(self):
         df = self.connect.readTable(table_name=self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to