This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c5d27603f29 [SPARK-41064][CONNECT][PYTHON] Implement 
`DataFrame.crosstab` and `DataFrame.stat.crosstab`
c5d27603f29 is described below

commit c5d27603f29437f1686cac70727594c19410a273
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Nov 10 18:15:54 2022 +0800

    [SPARK-41064][CONNECT][PYTHON] Implement `DataFrame.crosstab` and 
`DataFrame.stat.crosstab`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.crosstab` and `DataFrame.stat.crosstab`
    
    ### Why are the changes needed?
    for api coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new api
    
    ### How was this patch tested?
    added ut
    
    Closes #38578 from zhengruifeng/connect_df_crosstab.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  19 +++
 .../org/apache/spark/sql/connect/dsl/package.scala |  17 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  10 ++
 .../connect/planner/SparkConnectProtoSuite.scala   |   6 +
 python/pyspark/sql/connect/dataframe.py            |  62 ++++++++++
 python/pyspark/sql/connect/plan.py                 |  32 +++++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 134 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  50 ++++++++
 .../sql/tests/connect/test_connect_plan_only.py    |  10 ++
 9 files changed, 274 insertions(+), 66 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index b3613fc908d..639d1bafce5 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -52,6 +52,7 @@ message Relation {
 
     // stat functions
     StatSummary summary = 100;
+    StatCrosstab crosstab = 101;
 
     Unknown unknown = 999;
   }
@@ -284,6 +285,24 @@ message StatSummary {
   repeated string statistics = 2;
 }
 
+// Computes a pair-wise frequency table of the given columns. Also known as a 
contingency table.
+// It will invoke 'Dataset.stat.crosstab' (same as 
'StatFunctions.crossTabulate')
+// to compute the results.
+message StatCrosstab {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Required) The name of the first column.
+  //
+  // Distinct items will make the first item of each row.
+  string col1 = 2;
+
+  // (Required) The name of the second column.
+  //
+  // Distinct items will make the column names of the DataFrame.
+  string col2 = 3;
+}
+
 // Rename columns on the input relation by the same length of names.
 message RenameColumnsBySameLengthNames {
   // Required. The input relation.
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 381cbf7a9a8..5e7a94da347 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -227,6 +227,21 @@ package object dsl {
       }
     }
 
+    implicit class DslStatFunctions(val logicalPlan: Relation) {
+      def crosstab(col1: String, col2: String): Relation = {
+        Relation
+          .newBuilder()
+          .setCrosstab(
+            proto.StatCrosstab
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setCol1(col1)
+              .setCol2(col2)
+              .build())
+          .build()
+      }
+    }
+
     implicit class DslLogicalPlan(val logicalPlan: Relation) {
       def select(exprs: Expression*): Relation = {
         Relation
@@ -463,6 +478,8 @@ package object dsl {
             
Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
           .build()
 
+      def stat: DslStatFunctions = new DslStatFunctions(logicalPlan)
+
       def summary(statistics: String*): Relation = {
         Relation
           .newBuilder()
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 148f5569683..04ce880a925 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -67,6 +67,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
         transformSubqueryAlias(rel.getSubqueryAlias)
       case proto.Relation.RelTypeCase.REPARTITION => 
transformRepartition(rel.getRepartition)
       case proto.Relation.RelTypeCase.SUMMARY => 
transformStatSummary(rel.getSummary)
+      case proto.Relation.RelTypeCase.CROSSTAB =>
+        transformStatCrosstab(rel.getCrosstab)
       case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES =>
         
transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames)
       case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP =>
@@ -129,6 +131,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
       .logicalPlan
   }
 
+  private def transformStatCrosstab(rel: proto.StatCrosstab): LogicalPlan = {
+    Dataset
+      .ofRows(session, transformRelation(rel.getInput))
+      .stat
+      .crosstab(rel.getCol1, rel.getCol2)
+      .logicalPlan
+  }
+
   private def transformRenameColumnsBySamelenghtNames(
       rel: proto.RenameColumnsBySameLengthNames): LogicalPlan = {
     Dataset
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 3612c5e0d0a..5052b451047 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -273,6 +273,12 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
       sparkTestRelation.summary("count", "mean", "stddev"))
   }
 
+  test("Test crosstab") {
+    comparePlans(
+      connectTestRelation.stat.crosstab("id", "name"),
+      sparkTestRelation.stat.crosstab("id", "name"))
+  }
+
   test("Test toDF") {
     comparePlans(connectTestRelation.toDF("col1", "col2"), 
sparkTestRelation.toDF("col1", "col2"))
   }
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 6bf3ce0dcc9..e3116ea1250 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -501,6 +501,18 @@ class DataFrame(object):
     def where(self, condition: Expression) -> "DataFrame":
         return self.filter(condition)
 
+    @property
+    def stat(self) -> "DataFrameStatFunctions":
+        """Returns a :class:`DataFrameStatFunctions` for statistic functions.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        :class:`DataFrameStatFunctions`
+        """
+        return DataFrameStatFunctions(self)
+
     def summary(self, *statistics: str) -> "DataFrame":
         _statistics: List[str] = list(statistics)
         for s in _statistics:
@@ -511,6 +523,41 @@ class DataFrame(object):
             session=self._session,
         )
 
+    def crosstab(self, col1: str, col2: str) -> "DataFrame":
+        """
+        Computes a pair-wise frequency table of the given columns. Also known 
as a contingency
+        table. The number of distinct values for each column should be less 
than 1e4. At most 1e6
+        non-zero pair frequencies will be returned.
+        The first column of each row will be the distinct values of `col1` and 
the column names
+        will be the distinct values of `col2`. The name of the first column 
will be `$col1_$col2`.
+        Pairs that have no occurrences will have zero as their counts.
+        :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` 
are aliases.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        col1 : str
+            The name of the first column. Distinct items will make the first 
item of
+            each row.
+        col2 : str
+            The name of the second column. Distinct items will make the column 
names
+            of the :class:`DataFrame`.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            Frequency matrix of two columns.
+        """
+        if not isinstance(col1, str):
+            raise TypeError(f"'col1' must be str, but got 
{type(col1).__name__}")
+        if not isinstance(col2, str):
+            raise TypeError(f"'col2' must be str, but got 
{type(col2).__name__}")
+        return DataFrame.withPlan(
+            plan.StatCrosstab(child=self._plan, col1=col1, col2=col2),
+            session=self._session,
+        )
+
     def _get_alias(self) -> Optional[str]:
         p = self._plan
         while p is not None:
@@ -579,3 +626,18 @@ class DataFrame(object):
             return self._session.explain_string(query)
         else:
             return ""
+
+
+class DataFrameStatFunctions:
+    """Functionality for statistic functions with :class:`DataFrame`.
+
+    .. versionadded:: 3.4.0
+    """
+
+    def __init__(self, df: DataFrame):
+        self.df = df
+
+    def crosstab(self, col1: str, col2: str) -> DataFrame:
+        return self.df.crosstab(col1, col2)
+
+    crosstab.__doc__ = DataFrame.crosstab.__doc__
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 047c9f2ce0f..926119c5457 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -830,3 +830,35 @@ class StatSummary(LogicalPlan):
            </li>
         </ul>
         """
+
+
+class StatCrosstab(LogicalPlan):
+    def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) 
-> None:
+        super().__init__(child)
+        self.col1 = col1
+        self.col2 = col2
+
+    def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+        assert self._child is not None
+
+        plan = proto.Relation()
+        plan.crosstab.input.CopyFrom(self._child.plan(session))
+        plan.crosstab.col1 = self.col1
+        plan.crosstab.col2 = self.col2
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        i = " " * indent
+        return f"""{i}<Crosstab col1='{self.col1}' col2='{self.col2}'>"""
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+           <li>
+              <b>Crosstab</b><br />
+              Col1: {self.col1} <br />
+              Col2: {self.col2} <br />
+              {self._child_repr_()}
+           </li>
+        </ul>
+        """
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index d8c85596727..323eb8e7690 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xfb\t\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0b\ [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0b\ [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -46,69 +46,71 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = 
b"8\001"
     _RELATION._serialized_start = 82
-    _RELATION._serialized_end = 1357
-    _UNKNOWN._serialized_start = 1359
-    _UNKNOWN._serialized_end = 1368
-    _RELATIONCOMMON._serialized_start = 1370
-    _RELATIONCOMMON._serialized_end = 1419
-    _SQL._serialized_start = 1421
-    _SQL._serialized_end = 1448
-    _READ._serialized_start = 1451
-    _READ._serialized_end = 1861
-    _READ_NAMEDTABLE._serialized_start = 1593
-    _READ_NAMEDTABLE._serialized_end = 1654
-    _READ_DATASOURCE._serialized_start = 1657
-    _READ_DATASOURCE._serialized_end = 1848
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1790
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1848
-    _PROJECT._serialized_start = 1863
-    _PROJECT._serialized_end = 1980
-    _FILTER._serialized_start = 1982
-    _FILTER._serialized_end = 2094
-    _JOIN._serialized_start = 2097
-    _JOIN._serialized_end = 2547
-    _JOIN_JOINTYPE._serialized_start = 2360
-    _JOIN_JOINTYPE._serialized_end = 2547
-    _SETOPERATION._serialized_start = 2550
-    _SETOPERATION._serialized_end = 2913
-    _SETOPERATION_SETOPTYPE._serialized_start = 2799
-    _SETOPERATION_SETOPTYPE._serialized_end = 2913
-    _LIMIT._serialized_start = 2915
-    _LIMIT._serialized_end = 2991
-    _OFFSET._serialized_start = 2993
-    _OFFSET._serialized_end = 3072
-    _AGGREGATE._serialized_start = 3075
-    _AGGREGATE._serialized_end = 3285
-    _SORT._serialized_start = 3288
-    _SORT._serialized_end = 3819
-    _SORT_SORTFIELD._serialized_start = 3437
-    _SORT_SORTFIELD._serialized_end = 3625
-    _SORT_SORTDIRECTION._serialized_start = 3627
-    _SORT_SORTDIRECTION._serialized_end = 3735
-    _SORT_SORTNULLS._serialized_start = 3737
-    _SORT_SORTNULLS._serialized_end = 3819
-    _DEDUPLICATE._serialized_start = 3822
-    _DEDUPLICATE._serialized_end = 3964
-    _LOCALRELATION._serialized_start = 3966
-    _LOCALRELATION._serialized_end = 4059
-    _SAMPLE._serialized_start = 4062
-    _SAMPLE._serialized_end = 4302
-    _SAMPLE_SEED._serialized_start = 4276
-    _SAMPLE_SEED._serialized_end = 4302
-    _RANGE._serialized_start = 4305
-    _RANGE._serialized_end = 4503
-    _RANGE_NUMPARTITIONS._serialized_start = 4449
-    _RANGE_NUMPARTITIONS._serialized_end = 4503
-    _SUBQUERYALIAS._serialized_start = 4505
-    _SUBQUERYALIAS._serialized_end = 4619
-    _REPARTITION._serialized_start = 4621
-    _REPARTITION._serialized_end = 4746
-    _STATSUMMARY._serialized_start = 4748
-    _STATSUMMARY._serialized_end = 4840
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 4842
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 4956
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 4959
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5218
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
5151
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5218
+    _RELATION._serialized_end = 1416
+    _UNKNOWN._serialized_start = 1418
+    _UNKNOWN._serialized_end = 1427
+    _RELATIONCOMMON._serialized_start = 1429
+    _RELATIONCOMMON._serialized_end = 1478
+    _SQL._serialized_start = 1480
+    _SQL._serialized_end = 1507
+    _READ._serialized_start = 1510
+    _READ._serialized_end = 1920
+    _READ_NAMEDTABLE._serialized_start = 1652
+    _READ_NAMEDTABLE._serialized_end = 1713
+    _READ_DATASOURCE._serialized_start = 1716
+    _READ_DATASOURCE._serialized_end = 1907
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1849
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1907
+    _PROJECT._serialized_start = 1922
+    _PROJECT._serialized_end = 2039
+    _FILTER._serialized_start = 2041
+    _FILTER._serialized_end = 2153
+    _JOIN._serialized_start = 2156
+    _JOIN._serialized_end = 2606
+    _JOIN_JOINTYPE._serialized_start = 2419
+    _JOIN_JOINTYPE._serialized_end = 2606
+    _SETOPERATION._serialized_start = 2609
+    _SETOPERATION._serialized_end = 2972
+    _SETOPERATION_SETOPTYPE._serialized_start = 2858
+    _SETOPERATION_SETOPTYPE._serialized_end = 2972
+    _LIMIT._serialized_start = 2974
+    _LIMIT._serialized_end = 3050
+    _OFFSET._serialized_start = 3052
+    _OFFSET._serialized_end = 3131
+    _AGGREGATE._serialized_start = 3134
+    _AGGREGATE._serialized_end = 3344
+    _SORT._serialized_start = 3347
+    _SORT._serialized_end = 3878
+    _SORT_SORTFIELD._serialized_start = 3496
+    _SORT_SORTFIELD._serialized_end = 3684
+    _SORT_SORTDIRECTION._serialized_start = 3686
+    _SORT_SORTDIRECTION._serialized_end = 3794
+    _SORT_SORTNULLS._serialized_start = 3796
+    _SORT_SORTNULLS._serialized_end = 3878
+    _DEDUPLICATE._serialized_start = 3881
+    _DEDUPLICATE._serialized_end = 4023
+    _LOCALRELATION._serialized_start = 4025
+    _LOCALRELATION._serialized_end = 4118
+    _SAMPLE._serialized_start = 4121
+    _SAMPLE._serialized_end = 4361
+    _SAMPLE_SEED._serialized_start = 4335
+    _SAMPLE_SEED._serialized_end = 4361
+    _RANGE._serialized_start = 4364
+    _RANGE._serialized_end = 4562
+    _RANGE_NUMPARTITIONS._serialized_start = 4508
+    _RANGE_NUMPARTITIONS._serialized_end = 4562
+    _SUBQUERYALIAS._serialized_start = 4564
+    _SUBQUERYALIAS._serialized_end = 4678
+    _REPARTITION._serialized_start = 4680
+    _REPARTITION._serialized_end = 4805
+    _STATSUMMARY._serialized_start = 4807
+    _STATSUMMARY._serialized_end = 4899
+    _STATCROSSTAB._serialized_start = 4901
+    _STATCROSSTAB._serialized_end = 5002
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5004
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5118
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5121
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5380
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
5313
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5380
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 5569e4db4ef..53f75b7520f 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -79,6 +79,7 @@ class Relation(google.protobuf.message.Message):
     RENAME_COLUMNS_BY_SAME_LENGTH_NAMES_FIELD_NUMBER: builtins.int
     RENAME_COLUMNS_BY_NAME_TO_NAME_MAP_FIELD_NUMBER: builtins.int
     SUMMARY_FIELD_NUMBER: builtins.int
+    CROSSTAB_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
     @property
     def common(self) -> global___RelationCommon: ...
@@ -122,6 +123,8 @@ class Relation(google.protobuf.message.Message):
     def summary(self) -> global___StatSummary:
         """stat functions"""
     @property
+    def crosstab(self) -> global___StatCrosstab: ...
+    @property
     def unknown(self) -> global___Unknown: ...
     def __init__(
         self,
@@ -146,6 +149,7 @@ class Relation(google.protobuf.message.Message):
         rename_columns_by_same_length_names: 
global___RenameColumnsBySameLengthNames | None = ...,
         rename_columns_by_name_to_name_map: 
global___RenameColumnsByNameToNameMap | None = ...,
         summary: global___StatSummary | None = ...,
+        crosstab: global___StatCrosstab | None = ...,
         unknown: global___Unknown | None = ...,
     ) -> None: ...
     def HasField(
@@ -155,6 +159,8 @@ class Relation(google.protobuf.message.Message):
             b"aggregate",
             "common",
             b"common",
+            "crosstab",
+            b"crosstab",
             "deduplicate",
             b"deduplicate",
             "filter",
@@ -204,6 +210,8 @@ class Relation(google.protobuf.message.Message):
             b"aggregate",
             "common",
             b"common",
+            "crosstab",
+            b"crosstab",
             "deduplicate",
             b"deduplicate",
             "filter",
@@ -268,6 +276,7 @@ class Relation(google.protobuf.message.Message):
         "rename_columns_by_same_length_names",
         "rename_columns_by_name_to_name_map",
         "summary",
+        "crosstab",
         "unknown",
     ] | None: ...
 
@@ -1141,6 +1150,47 @@ class StatSummary(google.protobuf.message.Message):
 
 global___StatSummary = StatSummary
 
+class StatCrosstab(google.protobuf.message.Message):
+    """Computes a pair-wise frequency table of the given columns. Also known 
as a contingency table.
+    It will invoke 'Dataset.stat.crosstab' (same as 
'StatFunctions.crossTabulate')
+    to compute the results.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COL1_FIELD_NUMBER: builtins.int
+    COL2_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    col1: builtins.str
+    """(Required) The name of the first column.
+
+    Distinct items will make the first item of each row.
+    """
+    col2: builtins.str
+    """(Required) The name of the second column.
+
+    Distinct items will make the column names of the DataFrame.
+    """
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        col1: builtins.str = ...,
+        col2: builtins.str = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["input", b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal["col1", b"col1", "col2", 
b"col2", "input", b"input"],
+    ) -> None: ...
+
+global___StatCrosstab = StatCrosstab
+
 class RenameColumnsBySameLengthNames(google.protobuf.message.Message):
     """Rename columns on the input relation by the same length of names."""
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 0164fec11ff..c46d4d10624 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -85,6 +85,16 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
             ["count", "mean", "stddev", "min", "25%"],
         )
 
+    def test_crosstab(self):
+        df = self.connect.readTable(table_name=self.tbl_name)
+        plan = df.filter(df.col_name > 3).crosstab("col_a", 
"col_b")._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.crosstab.col1, "col_a")
+        self.assertEqual(plan.root.crosstab.col2, "col_b")
+
+        plan = df.stat.crosstab("col_a", "col_b")._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.crosstab.col1, "col_a")
+        self.assertEqual(plan.root.crosstab.col2, "col_b")
+
     def test_limit(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         limit_plan = df.limit(10)._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to