This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f096dba9d2 [SPARK-40852][CONNECT][PYTHON] Introduce `StatFunction` in 
proto and implement `DataFrame.summary`
4f096dba9d2 is described below

commit 4f096dba9d2c28cfd8595ac58417025fdb2d7073
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Nov 9 09:19:50 2022 +0800

    [SPARK-40852][CONNECT][PYTHON] Introduce `StatFunction` in proto and 
implement `DataFrame.summary`
    
    ### What changes were proposed in this pull request?
     Implement `DataFrame.summary`
    
    there is a set of DataFrame APIs implemented in 
[`StatFunctions`](https://github.com/apache/spark/blob/9cae423075145d3dd81d53f4b82d4f2af6fe7c15/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala),
  
[`DataFrameStatFunctions`](https://github.com/apache/spark/blob/b69c26833c99337bb17922f21dd72ee3a12e0c0a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala)
 and 
[`DataFrameNaFunctions`](https://github.com/apache/spark/blob/5d74ace648422e7a 
[...]
    
    1. depend on Catalyst's analysis (most of them);
    ~~2. implemented in RDD operations (like `summary`,`approxQuantile`);~~ 
(resolved by reimpl)
    ~~3. internally trigger jobs (like `summary`);~~ (resolved by reimpl)
    
    This PR introduced a new proto `StatFunction`  to support  `StatFunctions` 
method
    
    ### Why are the changes needed?
    for Connect API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new API
    
    ### How was this patch tested?
    added UT
    
    Closes #38318 from zhengruifeng/connect_df_summary.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  20 ++++
 .../org/apache/spark/sql/connect/dsl/package.scala |  16 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  18 +++-
 .../connect/planner/SparkConnectProtoSuite.scala   |   6 ++
 python/pyspark/sql/connect/dataframe.py            |  10 ++
 python/pyspark/sql/connect/plan.py                 |  38 +++++++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 120 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  68 ++++++++++++
 .../sql/tests/connect/test_connect_plan_only.py    |  15 +++
 9 files changed, 252 insertions(+), 59 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 36113e2a30c..dd03bd86940 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -48,6 +48,8 @@ message Relation {
     SubqueryAlias subquery_alias = 16;
     Repartition repartition = 17;
 
+    StatFunction stat_function = 100;
+
     Unknown unknown = 999;
   }
 }
@@ -254,3 +256,21 @@ message Repartition {
   // Optional. Default value is false.
   bool shuffle = 3;
 }
+
+// StatFunction
+message StatFunction {
+  // Required. The input relation.
+  Relation input = 1;
+  // Required. The function and its parameters.
+  oneof function {
+    Summary summary = 2;
+
+    Unknown unknown = 999;
+  }
+
+  // StatFunctions.summary
+  message Summary {
+    repeated string statistics = 1;
+  }
+}
+
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 2755727de11..3e68b101057 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -441,6 +441,22 @@ package object dsl {
             
Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
           .build()
 
+      def summary(statistics: String*): Relation = {
+        Relation
+          .newBuilder()
+          .setStatFunction(
+            proto.StatFunction
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setSummary(
+                proto.StatFunction.Summary
+                  .newBuilder()
+                  .addAllStatistics(statistics.toSeq.asJava)
+                  .build())
+              .build())
+          .build()
+      }
+
       private def createSetOperation(
           left: Relation,
           right: Relation,
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 1615fc56ab6..6a5808bc77f 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -21,7 +21,7 @@ import scala.annotation.elidable.byName
 import scala.collection.JavaConverters._
 
 import org.apache.spark.connect.proto
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.AliasIdentifier
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, 
UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.expressions
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.{logical, 
FullOuter, Inner, JoinType,
 import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, 
Intersect, LogicalPlan, Sample, SubqueryAlias, Union}
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.stat.StatFunctions
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -73,6 +74,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
       case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
         transformSubqueryAlias(rel.getSubqueryAlias)
       case proto.Relation.RelTypeCase.REPARTITION => 
transformRepartition(rel.getRepartition)
+      case proto.Relation.RelTypeCase.STAT_FUNCTION =>
+        transformStatFunction(rel.getStatFunction)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but 
is empty.")
       case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -124,6 +127,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
     logical.Range(start, end, step, numPartitions)
   }
 
+  private def transformStatFunction(rel: proto.StatFunction): LogicalPlan = {
+    val child = transformRelation(rel.getInput)
+
+    rel.getFunctionCase match {
+      case proto.StatFunction.FunctionCase.SUMMARY =>
+        StatFunctions
+          .summary(Dataset.ofRows(session, child), 
rel.getSummary.getStatisticsList.asScala.toSeq)
+          .logicalPlan
+
+      case _ => throw InvalidPlanInput(s"StatFunction ${rel.getUnknown} not 
supported.")
+    }
+  }
+
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
     if (!rel.hasInput) {
       throw InvalidPlanInput("Deduplicate needs a plan input")
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 72dae674721..c5b6f4fc0ee 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -261,6 +261,12 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     comparePlans(connectPlan2, sparkPlan2)
   }
 
+  test("Test summary") {
+    comparePlans(
+      connectTestRelation.summary("count", "mean", "stddev"),
+      sparkTestRelation.summary("count", "mean", "stddev"))
+  }
+
   private def createLocalRelationProtoByQualifiedAttributes(
       attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 9eecdbb7145..64b2e54f0ef 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -376,6 +376,16 @@ class DataFrame(object):
     def where(self, condition: Expression) -> "DataFrame":
         return self.filter(condition)
 
+    def summary(self, *statistics: str) -> "DataFrame":
+        _statistics: List[str] = list(statistics)
+        for s in _statistics:
+            if not isinstance(s, str):
+                raise TypeError(f"'statistics' must be list[str], but got 
{type(s).__name__}")
+        return DataFrame.withPlan(
+            plan.StatFunction(child=self._plan, function="summary", 
statistics=_statistics),
+            session=self._session,
+        )
+
     def _get_alias(self) -> Optional[str]:
         p = self._plan
         while p is not None:
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 4b28e6cb80a..1d5c80f510e 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -16,6 +16,7 @@
 #
 
 from typing import (
+    Any,
     List,
     Optional,
     Sequence,
@@ -750,3 +751,40 @@ class Range(LogicalPlan):
             </li>
         </uL>
         """
+
+
+class StatFunction(LogicalPlan):
+    def __init__(self, child: Optional["LogicalPlan"], function: str, 
**kwargs: Any) -> None:
+        super().__init__(child)
+        assert function in ["summary"]
+        self.function = function
+        self.kwargs = kwargs
+
+    def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
+        assert self._child is not None
+
+        plan = proto.Relation()
+        plan.stat_function.input.CopyFrom(self._child.plan(session))
+
+        if self.function == "summary":
+            
plan.stat_function.summary.statistics.extend(self.kwargs.get("statistics", []))
+        else:
+            raise Exception(f"Unknown function ${self.function}.")
+
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        i = " " * indent
+        return f"""{i}<StatFunction function='{self.function}' 
augments='{self.kwargs}'>"""
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+           <li>
+              <b>StatFunction</b><br />
+              Function: {self.function} <br />
+              Augments: {self.kwargs} <br />
+              {self._child_repr_()}
+           </li>
+        </ul>
+        """
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index e43a5de583e..b11a4b0e91a 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcc\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x90\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,61 +44,65 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _READ_DATASOURCE_OPTIONSENTRY._options = None
     _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 82
-    _RELATION._serialized_end = 1054
-    _UNKNOWN._serialized_start = 1056
-    _UNKNOWN._serialized_end = 1065
-    _RELATIONCOMMON._serialized_start = 1067
-    _RELATIONCOMMON._serialized_end = 1116
-    _SQL._serialized_start = 1118
-    _SQL._serialized_end = 1145
-    _READ._serialized_start = 1148
-    _READ._serialized_end = 1558
-    _READ_NAMEDTABLE._serialized_start = 1290
-    _READ_NAMEDTABLE._serialized_end = 1351
-    _READ_DATASOURCE._serialized_start = 1354
-    _READ_DATASOURCE._serialized_end = 1545
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1487
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1545
-    _PROJECT._serialized_start = 1560
-    _PROJECT._serialized_end = 1677
-    _FILTER._serialized_start = 1679
-    _FILTER._serialized_end = 1791
-    _JOIN._serialized_start = 1794
-    _JOIN._serialized_end = 2244
-    _JOIN_JOINTYPE._serialized_start = 2057
-    _JOIN_JOINTYPE._serialized_end = 2244
-    _SETOPERATION._serialized_start = 2247
-    _SETOPERATION._serialized_end = 2610
-    _SETOPERATION_SETOPTYPE._serialized_start = 2496
-    _SETOPERATION_SETOPTYPE._serialized_end = 2610
-    _LIMIT._serialized_start = 2612
-    _LIMIT._serialized_end = 2688
-    _OFFSET._serialized_start = 2690
-    _OFFSET._serialized_end = 2769
-    _AGGREGATE._serialized_start = 2772
-    _AGGREGATE._serialized_end = 2982
-    _SORT._serialized_start = 2985
-    _SORT._serialized_end = 3516
-    _SORT_SORTFIELD._serialized_start = 3134
-    _SORT_SORTFIELD._serialized_end = 3322
-    _SORT_SORTDIRECTION._serialized_start = 3324
-    _SORT_SORTDIRECTION._serialized_end = 3432
-    _SORT_SORTNULLS._serialized_start = 3434
-    _SORT_SORTNULLS._serialized_end = 3516
-    _DEDUPLICATE._serialized_start = 3519
-    _DEDUPLICATE._serialized_end = 3661
-    _LOCALRELATION._serialized_start = 3663
-    _LOCALRELATION._serialized_end = 3756
-    _SAMPLE._serialized_start = 3759
-    _SAMPLE._serialized_end = 3999
-    _SAMPLE_SEED._serialized_start = 3973
-    _SAMPLE_SEED._serialized_end = 3999
-    _RANGE._serialized_start = 4002
-    _RANGE._serialized_end = 4200
-    _RANGE_NUMPARTITIONS._serialized_start = 4146
-    _RANGE_NUMPARTITIONS._serialized_end = 4200
-    _SUBQUERYALIAS._serialized_start = 4202
-    _SUBQUERYALIAS._serialized_end = 4316
-    _REPARTITION._serialized_start = 4318
-    _REPARTITION._serialized_end = 4443
+    _RELATION._serialized_end = 1122
+    _UNKNOWN._serialized_start = 1124
+    _UNKNOWN._serialized_end = 1133
+    _RELATIONCOMMON._serialized_start = 1135
+    _RELATIONCOMMON._serialized_end = 1184
+    _SQL._serialized_start = 1186
+    _SQL._serialized_end = 1213
+    _READ._serialized_start = 1216
+    _READ._serialized_end = 1626
+    _READ_NAMEDTABLE._serialized_start = 1358
+    _READ_NAMEDTABLE._serialized_end = 1419
+    _READ_DATASOURCE._serialized_start = 1422
+    _READ_DATASOURCE._serialized_end = 1613
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1555
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1613
+    _PROJECT._serialized_start = 1628
+    _PROJECT._serialized_end = 1745
+    _FILTER._serialized_start = 1747
+    _FILTER._serialized_end = 1859
+    _JOIN._serialized_start = 1862
+    _JOIN._serialized_end = 2312
+    _JOIN_JOINTYPE._serialized_start = 2125
+    _JOIN_JOINTYPE._serialized_end = 2312
+    _SETOPERATION._serialized_start = 2315
+    _SETOPERATION._serialized_end = 2678
+    _SETOPERATION_SETOPTYPE._serialized_start = 2564
+    _SETOPERATION_SETOPTYPE._serialized_end = 2678
+    _LIMIT._serialized_start = 2680
+    _LIMIT._serialized_end = 2756
+    _OFFSET._serialized_start = 2758
+    _OFFSET._serialized_end = 2837
+    _AGGREGATE._serialized_start = 2840
+    _AGGREGATE._serialized_end = 3050
+    _SORT._serialized_start = 3053
+    _SORT._serialized_end = 3584
+    _SORT_SORTFIELD._serialized_start = 3202
+    _SORT_SORTFIELD._serialized_end = 3390
+    _SORT_SORTDIRECTION._serialized_start = 3392
+    _SORT_SORTDIRECTION._serialized_end = 3500
+    _SORT_SORTNULLS._serialized_start = 3502
+    _SORT_SORTNULLS._serialized_end = 3584
+    _DEDUPLICATE._serialized_start = 3587
+    _DEDUPLICATE._serialized_end = 3729
+    _LOCALRELATION._serialized_start = 3731
+    _LOCALRELATION._serialized_end = 3824
+    _SAMPLE._serialized_start = 3827
+    _SAMPLE._serialized_end = 4067
+    _SAMPLE_SEED._serialized_start = 4041
+    _SAMPLE_SEED._serialized_end = 4067
+    _RANGE._serialized_start = 4070
+    _RANGE._serialized_end = 4268
+    _RANGE_NUMPARTITIONS._serialized_start = 4214
+    _RANGE_NUMPARTITIONS._serialized_end = 4268
+    _SUBQUERYALIAS._serialized_start = 4270
+    _SUBQUERYALIAS._serialized_end = 4384
+    _REPARTITION._serialized_start = 4386
+    _REPARTITION._serialized_end = 4511
+    _STATFUNCTION._serialized_start = 4514
+    _STATFUNCTION._serialized_end = 4748
+    _STATFUNCTION_SUMMARY._serialized_start = 4695
+    _STATFUNCTION_SUMMARY._serialized_end = 4736
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 30c1dddf885..6ee3c46d7c5 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -76,6 +76,7 @@ class Relation(google.protobuf.message.Message):
     RANGE_FIELD_NUMBER: builtins.int
     SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int
     REPARTITION_FIELD_NUMBER: builtins.int
+    STAT_FUNCTION_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
     @property
     def common(self) -> global___RelationCommon: ...
@@ -112,6 +113,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def repartition(self) -> global___Repartition: ...
     @property
+    def stat_function(self) -> global___StatFunction: ...
+    @property
     def unknown(self) -> global___Unknown: ...
     def __init__(
         self,
@@ -133,6 +136,7 @@ class Relation(google.protobuf.message.Message):
         range: global___Range | None = ...,
         subquery_alias: global___SubqueryAlias | None = ...,
         repartition: global___Repartition | None = ...,
+        stat_function: global___StatFunction | None = ...,
         unknown: global___Unknown | None = ...,
     ) -> None: ...
     def HasField(
@@ -172,6 +176,8 @@ class Relation(google.protobuf.message.Message):
             b"sort",
             "sql",
             b"sql",
+            "stat_function",
+            b"stat_function",
             "subquery_alias",
             b"subquery_alias",
             "unknown",
@@ -215,6 +221,8 @@ class Relation(google.protobuf.message.Message):
             b"sort",
             "sql",
             b"sql",
+            "stat_function",
+            b"stat_function",
             "subquery_alias",
             b"subquery_alias",
             "unknown",
@@ -240,6 +248,7 @@ class Relation(google.protobuf.message.Message):
         "range",
         "subquery_alias",
         "repartition",
+        "stat_function",
         "unknown",
     ] | None: ...
 
@@ -1065,3 +1074,62 @@ class Repartition(google.protobuf.message.Message):
     ) -> None: ...
 
 global___Repartition = Repartition
+
+class StatFunction(google.protobuf.message.Message):
+    """StatFunction"""
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class Summary(google.protobuf.message.Message):
+        """StatFunctions.summary"""
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        STATISTICS_FIELD_NUMBER: builtins.int
+        @property
+        def statistics(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 
...
+        def __init__(
+            self,
+            *,
+            statistics: collections.abc.Iterable[builtins.str] | None = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["statistics", 
b"statistics"]
+        ) -> None: ...
+
+    INPUT_FIELD_NUMBER: builtins.int
+    SUMMARY_FIELD_NUMBER: builtins.int
+    UNKNOWN_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """Required. The input relation."""
+    @property
+    def summary(self) -> global___StatFunction.Summary: ...
+    @property
+    def unknown(self) -> global___Unknown: ...
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        summary: global___StatFunction.Summary | None = ...,
+        unknown: global___Unknown | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "function", b"function", "input", b"input", "summary", b"summary", 
"unknown", b"unknown"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "function", b"function", "input", b"input", "summary", b"summary", 
"unknown", b"unknown"
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["function", b"function"]
+    ) -> typing_extensions.Literal["summary", "unknown"] | None: ...
+
+global___StatFunction = StatFunction
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 8a9b98e73fd..468099cb5c9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -70,6 +70,21 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.filter.condition.unresolved_function.parts, 
[">"])
         
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 
2)
 
+    def test_summary(self):
+        df = self.connect.readTable(table_name=self.tbl_name)
+        plan = df.filter(df.col_name > 
3).summary()._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.stat_function.summary.statistics, [])
+
+        plan = (
+            df.filter(df.col_name > 3)
+            .summary("count", "mean", "stddev", "min", "25%")
+            ._plan.to_proto(self.connect)
+        )
+        self.assertEqual(
+            plan.root.stat_function.summary.statistics,
+            ["count", "mean", "stddev", "min", "25%"],
+        )
+
     def test_limit(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         limit_plan = df.limit(10)._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to