This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 940946515bd [SPARK-41440][CONNECT][PYTHON] Implement `DataFrame.randomSplit` 940946515bd is described below commit 940946515bd199930051be89f9fd557a35f2af0d Author: Jiaan Geng <belie...@163.com> AuthorDate: Wed Dec 21 11:31:00 2022 +0900 [SPARK-41440][CONNECT][PYTHON] Implement `DataFrame.randomSplit` ### What changes were proposed in this pull request? Implement `DataFrame.randomSplit` with a proto message Implement `DataFrame.randomSplit` for scala API Implement `DataFrame.randomSplit` for python API ### Why are the changes needed? for Connect API coverage ### Does this PR introduce _any_ user-facing change? 'No'. New API ### How was this patch tested? New test cases. Closes #39017 from beliefer/SPARK-41440. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 4 ++ .../org/apache/spark/sql/connect/dsl/package.scala | 34 ++++++++++ .../sql/connect/planner/SparkConnectPlanner.scala | 24 +++++++- .../connect/planner/SparkConnectProtoSuite.scala | 16 +++++ python/pyspark/sql/connect/dataframe.py | 57 +++++++++++++++++ python/pyspark/sql/connect/plan.py | 3 + python/pyspark/sql/connect/proto/relations_pb2.py | 72 +++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 18 ++++++ .../sql/tests/connect/test_connect_basic.py | 15 +++++ .../sql/tests/connect/test_connect_plan_only.py | 39 ++++++++++++ 10 files changed, 244 insertions(+), 38 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 2f83db1176a..42471821634 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -317,6 +317,10 @@ message Sample { // (Optional) The random seed. optional int64 seed = 5; + + // (Optional) Explicitly sort the underlying plan to make the ordering deterministic. + // This flag is only used to randomly splits DataFrame with the provided weights. + optional bool force_stable_sort = 6; } // Relation of type [[Range]] that generates a sequence of integers. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 8211dc21bde..bce8d390fcb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -27,6 +27,7 @@ import org.apache.spark.connect.proto.SetOperation.SetOpType import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connect.planner.DataTypeProtoConverter import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue +import org.apache.spark.util.Utils /** * A collection of implicit conversions that create a DSL for constructing connect protos. @@ -775,6 +776,39 @@ package object dsl { valueColumnName: String): Relation = unpivot(ids, variableColumnName, valueColumnName) + def randomSplit(weights: Array[Double], seed: Long): Array[Relation] = { + require( + weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require( + weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + + val sum = weights.toSeq.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights + .sliding(2) + .map { x => + Relation + .newBuilder() + .setSample( + Sample + .newBuilder() + .setInput(logicalPlan) + .setLowerBound(x(0)) + .setUpperBound(x(1)) + .setWithReplacement(false) + .setSeed(seed) + .setForceStableSort(true) + .build()) + .build() + } + .toArray + } + + def randomSplit(weights: Array[Double]): Array[Relation] = + randomSplit(weights, Utils.random.nextLong) + private def createSetOperation( left: Relation, right: Relation, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 9fe9acd354d..cad6c3c5c61 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, Unpivot, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -132,12 +132,32 @@ class SparkConnectPlanner(session: SparkSession) { * wrap such fields into proto messages. */ private def transformSample(rel: proto.Sample): LogicalPlan = { + val input = Dataset.ofRows(session, transformRelation(rel.getInput)) + val plan = if (rel.getForceStableSort) { + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = input.logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, input.logicalPlan) + } else { + input.logicalPlan + } + } else { + input.cache() + input.logicalPlan + } + Sample( rel.getLowerBound, rel.getUpperBound, rel.getWithReplacement, if (rel.hasSeed) rel.getSeed else Utils.random.nextLong, - transformRelation(rel.getInput)) + plan) } private def transformRepartition(rel: proto.Repartition): LogicalPlan = { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index fcd0ac49fb2..354b7693dab 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -598,6 +598,22 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan1, sparkPlan1) } + test("Test RandomSplit") { + val splitRelations0 = connectTestRelation.randomSplit(Array[Double](1, 2, 3), 1) + val splits0 = sparkTestRelation.randomSplit(Array[Double](1, 2, 3), 1) + assert(splitRelations0.length == splits0.length) + splitRelations0.zip(splits0).foreach { case (connectPlan, sparkPlan) => + comparePlans(connectPlan, sparkPlan) + } + + val splitRelations1 = connectTestRelation.randomSplit(Array[Double](1, 2, 3)) + val splits1 = sparkTestRelation.randomSplit(Array[Double](1, 2, 3)) + assert(splitRelations1.length == splits1.length) + splitRelations1.zip(splits1).foreach { case (connectPlan, sparkPlan) => + comparePlans(connectPlan, sparkPlan) + } + } + private def createLocalRelationProtoByAttributeReferences( attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 547ed1d110a..8e19b8892b2 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -30,6 +30,8 @@ from typing import ( Type, ) +import sys +import random import pandas import warnings from collections.abc import Iterable @@ -918,6 +920,61 @@ class DataFrame(object): session=self._session, ) + def randomSplit( + self, + weights: List[float], + seed: Optional[int] = None, + ) -> List["DataFrame"]: + """Randomly splits this :class:`DataFrame` with the provided weights. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + weights : list + list of doubles as weights with which to split the :class:`DataFrame`. + Weights will be normalized if they don't sum up to 1.0. + seed : int, optional + The seed for sampling. + + Returns + ------- + list + List of DataFrames. + """ + for w in weights: + if w < 0.0: + raise ValueError("Weights must be positive. Found weight value: %s" % w) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + total = sum(weights) + if total <= 0: + raise ValueError("Sum of weights must be positive, but got: %s" % w) + proportions = list(map(lambda x: x / total, weights)) + normalizedCumWeights = [0.0] + for v in proportions: + normalizedCumWeights.append(normalizedCumWeights[-1] + v) + j = 1 + length = len(normalizedCumWeights) + splits = [] + while j < length: + lowerBound = normalizedCumWeights[j - 1] + upperBound = normalizedCumWeights[j] + samplePlan = DataFrame.withPlan( + plan.Sample( + child=self._plan, + lower_bound=lowerBound, + upper_bound=upperBound, + with_replacement=False, + seed=int(seed), + force_stable_sort=True, + ), + session=self._session, + ) + splits.append(samplePlan) + j += 1 + + return splits + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: """ Prints the first ``n`` rows to the console. diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 27c8461491c..320ed07425f 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -642,12 +642,14 @@ class Sample(LogicalPlan): upper_bound: float, with_replacement: bool, seed: Optional[int], + force_stable_sort: bool = False, ) -> None: super().__init__(child) self.lower_bound = lower_bound self.upper_bound = upper_bound self.with_replacement = with_replacement self.seed = seed + self.force_stable_sort = force_stable_sort def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None @@ -658,6 +660,7 @@ class Sample(LogicalPlan): plan.sample.with_replacement = self.with_replacement if self.seed is not None: plan.sample.seed = self.seed + plan.sample.force_stable_sort = self.force_stable_sort return plan def print(self, indent: int = 0) -> str: diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 0ddda63aa42..f7ae79c3dfb 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...] ) @@ -546,39 +546,39 @@ if _descriptor._USE_C_DESCRIPTORS == False: _LOCALRELATION._serialized_start = 4481 _LOCALRELATION._serialized_end = 4618 _SAMPLE._serialized_start = 4621 - _SAMPLE._serialized_end = 4845 - _RANGE._serialized_start = 4848 - _RANGE._serialized_end = 4993 - _SUBQUERYALIAS._serialized_start = 4995 - _SUBQUERYALIAS._serialized_end = 5109 - _REPARTITION._serialized_start = 5112 - _REPARTITION._serialized_end = 5254 - _SHOWSTRING._serialized_start = 5257 - _SHOWSTRING._serialized_end = 5398 - _STATSUMMARY._serialized_start = 5400 - _STATSUMMARY._serialized_end = 5492 - _STATDESCRIBE._serialized_start = 5494 - _STATDESCRIBE._serialized_end = 5575 - _STATCROSSTAB._serialized_start = 5577 - _STATCROSSTAB._serialized_end = 5678 - _NAFILL._serialized_start = 5681 - _NAFILL._serialized_end = 5815 - _NADROP._serialized_start = 5818 - _NADROP._serialized_end = 5952 - _NAREPLACE._serialized_start = 5955 - _NAREPLACE._serialized_end = 6251 - _NAREPLACE_REPLACEMENT._serialized_start = 6110 - _NAREPLACE_REPLACEMENT._serialized_end = 6251 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6253 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6367 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6370 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6629 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6562 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6629 - _WITHCOLUMNS._serialized_start = 6632 - _WITHCOLUMNS._serialized_end = 6763 - _HINT._serialized_start = 6766 - _HINT._serialized_end = 6906 - _UNPIVOT._serialized_start = 6909 - _UNPIVOT._serialized_end = 7155 + _SAMPLE._serialized_end = 4916 + _RANGE._serialized_start = 4919 + _RANGE._serialized_end = 5064 + _SUBQUERYALIAS._serialized_start = 5066 + _SUBQUERYALIAS._serialized_end = 5180 + _REPARTITION._serialized_start = 5183 + _REPARTITION._serialized_end = 5325 + _SHOWSTRING._serialized_start = 5328 + _SHOWSTRING._serialized_end = 5469 + _STATSUMMARY._serialized_start = 5471 + _STATSUMMARY._serialized_end = 5563 + _STATDESCRIBE._serialized_start = 5565 + _STATDESCRIBE._serialized_end = 5646 + _STATCROSSTAB._serialized_start = 5648 + _STATCROSSTAB._serialized_end = 5749 + _NAFILL._serialized_start = 5752 + _NAFILL._serialized_end = 5886 + _NADROP._serialized_start = 5889 + _NADROP._serialized_end = 6023 + _NAREPLACE._serialized_start = 6026 + _NAREPLACE._serialized_end = 6322 + _NAREPLACE_REPLACEMENT._serialized_start = 6181 + _NAREPLACE_REPLACEMENT._serialized_end = 6322 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6324 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6438 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6441 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6700 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6633 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6700 + _WITHCOLUMNS._serialized_start = 6703 + _WITHCOLUMNS._serialized_end = 6834 + _HINT._serialized_start = 6837 + _HINT._serialized_end = 6977 + _UNPIVOT._serialized_start = 6980 + _UNPIVOT._serialized_end = 7226 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index a2a6538b32d..99a584ab7b9 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1161,6 +1161,7 @@ class Sample(google.protobuf.message.Message): UPPER_BOUND_FIELD_NUMBER: builtins.int WITH_REPLACEMENT_FIELD_NUMBER: builtins.int SEED_FIELD_NUMBER: builtins.int + FORCE_STABLE_SORT_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: """(Required) Input relation for a Sample.""" @@ -1172,6 +1173,10 @@ class Sample(google.protobuf.message.Message): """(Optional) Whether to sample with replacement.""" seed: builtins.int """(Optional) The random seed.""" + force_stable_sort: builtins.bool + """(Optional) Explicitly sort the underlying plan to make the ordering deterministic. + This flag is only used to randomly splits DataFrame with the provided weights. + """ def __init__( self, *, @@ -1180,14 +1185,19 @@ class Sample(google.protobuf.message.Message): upper_bound: builtins.float = ..., with_replacement: builtins.bool | None = ..., seed: builtins.int | None = ..., + force_stable_sort: builtins.bool | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ + "_force_stable_sort", + b"_force_stable_sort", "_seed", b"_seed", "_with_replacement", b"_with_replacement", + "force_stable_sort", + b"force_stable_sort", "input", b"input", "seed", @@ -1199,10 +1209,14 @@ class Sample(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ + "_force_stable_sort", + b"_force_stable_sort", "_seed", b"_seed", "_with_replacement", b"_with_replacement", + "force_stable_sort", + b"force_stable_sort", "input", b"input", "lower_bound", @@ -1216,6 +1230,10 @@ class Sample(google.protobuf.message.Message): ], ) -> None: ... @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_force_stable_sort", b"_force_stable_sort"] + ) -> typing_extensions.Literal["force_stable_sort"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_seed", b"_seed"] ) -> typing_extensions.Literal["seed"] | None: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 3da0566a4e9..f08cb38edf3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -816,6 +816,21 @@ class SparkConnectTests(SparkConnectSQLTestCase): .toPandas(), ) + def test_random_split(self): + # SPARK-41440: test randomSplit(weights, seed). + relations = ( + self.connect.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2) + ) + datasets = ( + self.spark.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2) + ) + + self.assertTrue(len(relations) == len(datasets)) + i = 0 + while i < len(relations): + self.assert_eq(relations[i].toPandas(), datasets[i].toPandas()) + i += 1 + def test_with_columns(self): # SPARK-41256: test withColumn(s). self.assert_eq( diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 703e27c46d6..66106fb7f2a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -25,6 +25,7 @@ from pyspark.testing.connectutils import ( if should_test_connect: import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import Column + from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import WriteOperation from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.function_builder import UserDefinedFunction, udf @@ -228,6 +229,42 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.unpivot.variable_column_name, "variable") self.assertEqual(plan.root.unpivot.value_column_name, "value") + def test_random_split(self): + # SPARK-41440: test randomSplit(weights, seed). + from typing import List + + df = self.connect.readTable(table_name=self.tbl_name) + + def checkRelations(relations: List["DataFrame"]): + self.assertTrue(len(relations) == 3) + + plan = relations[0]._plan.to_proto(self.connect) + self.assertEqual(plan.root.sample.lower_bound, 0.0) + self.assertEqual(plan.root.sample.upper_bound, 0.16666666666666666) + self.assertEqual(plan.root.sample.with_replacement, False) + self.assertEqual(plan.root.sample.HasField("seed"), True) + self.assertEqual(plan.root.sample.force_stable_sort, True) + + plan = relations[1]._plan.to_proto(self.connect) + self.assertEqual(plan.root.sample.lower_bound, 0.16666666666666666) + self.assertEqual(plan.root.sample.upper_bound, 0.5) + self.assertEqual(plan.root.sample.with_replacement, False) + self.assertEqual(plan.root.sample.HasField("seed"), True) + self.assertEqual(plan.root.sample.force_stable_sort, True) + + plan = relations[2]._plan.to_proto(self.connect) + self.assertEqual(plan.root.sample.lower_bound, 0.5) + self.assertEqual(plan.root.sample.upper_bound, 1.0) + self.assertEqual(plan.root.sample.with_replacement, False) + self.assertEqual(plan.root.sample.HasField("seed"), True) + self.assertEqual(plan.root.sample.force_stable_sort, True) + + relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0], 1) + checkRelations(relations) + + relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0]) + checkRelations(relations) + def test_summary(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect) @@ -281,6 +318,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.sample.upper_bound, 0.3) self.assertEqual(plan.root.sample.with_replacement, False) self.assertEqual(plan.root.sample.HasField("seed"), False) + self.assertEqual(plan.root.sample.force_stable_sort, False) plan = ( df.filter(df.col_name > 3) @@ -291,6 +329,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.sample.upper_bound, 0.4) self.assertEqual(plan.root.sample.with_replacement, True) self.assertEqual(plan.root.sample.seed, -1) + self.assertEqual(plan.root.sample.force_stable_sort, False) def test_sort(self): df = self.connect.readTable(table_name=self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org