This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e1af3a992e0 [SPARK-41383][SPARK-41692][SPARK-41693] Implement 
`rollup`, `cube` and `pivot`
e1af3a992e0 is described below

commit e1af3a992e06aeb5185501db908dc272b449c62b
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Dec 23 19:51:44 2022 +0900

    [SPARK-41383][SPARK-41692][SPARK-41693] Implement `rollup`, `cube` and 
`pivot`
    
    ### What changes were proposed in this pull request?
    Implement `rollup`, `cube` and `pivot`:
    
    1. `DataFrame.rollup`
    2. `DataFrame.cube`
    3. `DataFrame.groupBy.pivot`
    
    ### Why are the changes needed?
    for API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #39191 from zhengruifeng/connect_groupby_refactor.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  34 +++++-
 .../org/apache/spark/sql/connect/dsl/package.scala |  50 +++++++-
 .../planner/LiteralValueProtoConverter.scala       |   2 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  82 +++++++++----
 .../connect/planner/SparkConnectPlannerSuite.scala |   3 +-
 .../connect/planner/SparkConnectProtoSuite.scala   |  77 ++++++++++++
 python/pyspark/sql/connect/dataframe.py            |  48 +++++++-
 python/pyspark/sql/connect/group.py                |  82 +++++++++++--
 python/pyspark/sql/connect/plan.py                 |  63 +++++++---
 python/pyspark/sql/connect/proto/relations_pb2.py  | 112 +++++++++--------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  90 ++++++++++++--
 .../sql/tests/connect/test_connect_basic.py        | 136 +++++++++++++++++++++
 12 files changed, 667 insertions(+), 112 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index c4f040c03d6..912ee1fdc63 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -235,11 +235,39 @@ message Tail {
 
 // Relation of type [[Aggregate]].
 message Aggregate {
-  // (Required) Input relation for a Aggregate.
+  // (Required) Input relation for a RelationalGroupedDataset.
   Relation input = 1;
 
-  repeated Expression grouping_expressions = 2;
-  repeated Expression result_expressions = 3;
+  // (Required) How the RelationalGroupedDataset was built.
+  GroupType group_type = 2;
+
+  // (Required) Expressions for grouping keys
+  repeated Expression grouping_expressions = 3;
+
+  // (Required) List of values that will be translated to columns in the 
output DataFrame.
+  repeated Expression aggregate_expressions = 4;
+
+  // (Optional) Pivots a column of the current `DataFrame` and performs the 
specified aggregation.
+  Pivot pivot = 5;
+
+  enum GroupType {
+    GROUP_TYPE_UNSPECIFIED = 0;
+    GROUP_TYPE_GROUPBY = 1;
+    GROUP_TYPE_ROLLUP = 2;
+    GROUP_TYPE_CUBE = 3;
+    GROUP_TYPE_PIVOT = 4;
+  }
+
+  message Pivot {
+    // (Required) The column to pivot
+    Expression col = 1;
+
+    // (Optional) List of values that will be translated to columns in the 
output DataFrame.
+    //
+    // Note that if it is empty, the server side will immediately trigger a 
job to collect
+    // the distinct values of the column.
+    repeated Expression.Literal values = 2;
+  }
 }
 
 // Relation of type [[Sort]].
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index b15e46293ab..e6d230d9eef 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -601,16 +601,64 @@ package object dsl {
       def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): 
Relation = {
         val agg = Aggregate.newBuilder()
         agg.setInput(logicalPlan)
+        agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
 
         for (groupingExpr <- groupingExprs) {
           agg.addGroupingExpressions(groupingExpr)
         }
         for (aggregateExpr <- aggregateExprs) {
-          agg.addResultExpressions(aggregateExpr)
+          agg.addAggregateExpressions(aggregateExpr)
         }
         Relation.newBuilder().setAggregate(agg.build()).build()
       }
 
+      def rollup(groupingExprs: Expression*)(aggregateExprs: Expression*): 
Relation = {
+        val agg = Aggregate.newBuilder()
+        agg.setInput(logicalPlan)
+        agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+
+        for (groupingExpr <- groupingExprs) {
+          agg.addGroupingExpressions(groupingExpr)
+        }
+        for (aggregateExpr <- aggregateExprs) {
+          agg.addAggregateExpressions(aggregateExpr)
+        }
+        Relation.newBuilder().setAggregate(agg.build()).build()
+      }
+
+      def cube(groupingExprs: Expression*)(aggregateExprs: Expression*): 
Relation = {
+        val agg = Aggregate.newBuilder()
+        agg.setInput(logicalPlan)
+        agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+
+        for (groupingExpr <- groupingExprs) {
+          agg.addGroupingExpressions(groupingExpr)
+        }
+        for (aggregateExpr <- aggregateExprs) {
+          agg.addAggregateExpressions(aggregateExpr)
+        }
+        Relation.newBuilder().setAggregate(agg.build()).build()
+      }
+
+      def pivot(groupingExprs: Expression*)(
+          pivotCol: Expression,
+          pivotValues: Seq[proto.Expression.Literal])(aggregateExprs: 
Expression*): Relation = {
+        val agg = Aggregate.newBuilder()
+        agg.setInput(logicalPlan)
+        agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
+
+        for (groupingExpr <- groupingExprs) {
+          agg.addGroupingExpressions(groupingExpr)
+        }
+        for (aggregateExpr <- aggregateExprs) {
+          agg.addAggregateExpressions(aggregateExpr)
+        }
+        agg.setPivot(
+          
Aggregate.Pivot.newBuilder().setCol(pivotCol).addAllValues(pivotValues.asJava).build())
+
+        Relation.newBuilder().setAggregate(agg.build()).build()
+      }
+
       def except(otherPlan: Relation, isAll: Boolean): Relation = {
         Relation
           .newBuilder()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index abfaaf7a1d3..82ffa4f5246 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -30,7 +30,7 @@ object LiteralValueProtoConverter {
    * @return
    *   Expression
    */
-  def toCatalystExpression(lit: proto.Expression.Literal): 
expressions.Expression = {
+  def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal 
= {
     lit.getLiteralTypeCase match {
       case proto.Expression.Literal.LiteralTypeCase.NULL =>
         expressions.Literal(null, NullType)
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4abeec0d00b..dce3a8c8e55 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -874,32 +874,72 @@ class SparkConnectPlanner(session: SparkSession) {
   }
 
   private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
-    assert(rel.hasInput)
+    if (!rel.hasInput) {
+      throw InvalidPlanInput("Aggregate needs a plan input")
+    }
+    val input = transformRelation(rel.getInput)
+
+    def toNamedExpression(expr: Expression): NamedExpression = expr match {
+      case named: NamedExpression => named
+      case expr => UnresolvedAlias(expr)
+    }
 
-    val groupingExprs =
-      rel.getGroupingExpressionsList.asScala
-        .map(transformExpression)
-        .map {
-          case ua @ UnresolvedAttribute(_) => ua
-          case a @ Alias(_, _) => a
-          case x => UnresolvedAlias(x)
+    val groupingExprs = 
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
+    val aggExprs = 
rel.getAggregateExpressionsList.asScala.toSeq.map(transformExpression)
+    val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
+
+    rel.getGroupType match {
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+        logical.Aggregate(
+          groupingExpressions = groupingExprs,
+          aggregateExpressions = aliasedAgg,
+          child = input)
+
+      case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
+        logical.Aggregate(
+          groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))),
+          aggregateExpressions = aliasedAgg,
+          child = input)
+
+      case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
+        logical.Aggregate(
+          groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))),
+          aggregateExpressions = aliasedAgg,
+          child = input)
+
+      case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
+        if (!rel.hasPivot) {
+          throw InvalidPlanInput("Aggregate with GROUP_TYPE_PIVOT requires a 
Pivot")
         }
 
-    // Retain group columns in aggregate expressions:
-    val aggExprs =
-      groupingExprs ++ 
rel.getResultExpressionsList.asScala.map(transformResultExpression)
+        val pivotExpr = transformExpression(rel.getPivot.getCol)
+
+        var valueExprs = 
rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral)
+        if (valueExprs.isEmpty) {
+          // This is to prevent unintended OOM errors when the number of 
distinct values is large
+          val maxValues = session.sessionState.conf.dataFramePivotMaxValues
+          // Get the distinct values of the column and sort them so its 
consistent
+          val pivotCol = Column(pivotExpr)
+          valueExprs = Dataset
+            .ofRows(session, input)
+            .select(pivotCol)
+            .distinct()
+            .limit(maxValues + 1)
+            .sort(pivotCol) // ensure that the output columns are in a 
consistent logical order
+            .collect()
+            .map(_.get(0))
+            .toSeq
+            .map(expressions.Literal.apply)
+        }
 
-    logical.Aggregate(
-      child = transformRelation(rel.getInput),
-      groupingExpressions = groupingExprs.toSeq,
-      aggregateExpressions = aggExprs.toSeq)
-  }
+        logical.Pivot(
+          groupByExprsOpt = Some(groupingExprs.map(toNamedExpression)),
+          pivotColumn = pivotExpr,
+          pivotValues = valueExprs,
+          aggregates = aggExprs,
+          child = input)
 
-  private def transformResultExpression(exp: proto.Expression): 
expressions.NamedExpression = {
-    if (exp.hasAlias) {
-      transformAlias(exp.getAlias)
-    } else {
-      UnresolvedAlias(transformExpression(exp))
+      case other => throw InvalidPlanInput(s"Unknown Group Type $other")
     }
   }
 
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 93cb97b4421..1142a3386f9 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -303,8 +303,9 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
 
     val agg = proto.Aggregate.newBuilder
       .setInput(readRel)
-      .addResultExpressions(sum)
+      .addAggregateExpressions(sum)
       .addGroupingExpressions(unresolvedAttribute)
+      .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
       .build()
 
     val res = transform(proto.Relation.newBuilder.setAggregate(agg).build())
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 34a30bcd4f0..66a019ef853 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession
 import org.apache.spark.sql.connect.dsl.commands._
 import org.apache.spark.sql.connect.dsl.expressions._
 import org.apache.spark.sql.connect.dsl.plans._
+import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, 
DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, 
Metadata, ShortType, StringType, StructField, StructType}
@@ -222,6 +223,82 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     comparePlans(connectPlan2, sparkPlan2)
   }
 
+  test("Rollup expressions") {
+    val connectPlan1 =
+      connectTestRelation.rollup("id".protoAttr)(proto_min("name".protoAttr))
+    val sparkPlan1 =
+      sparkTestRelation.rollup(Column("id")).agg(min(Column("name")))
+    comparePlans(connectPlan1, sparkPlan1)
+
+    val connectPlan2 =
+      
connectTestRelation.rollup("id".protoAttr)(proto_min("name".protoAttr).as("agg1"))
+    val sparkPlan2 =
+      
sparkTestRelation.rollup(Column("id")).agg(min(Column("name")).as("agg1"))
+    comparePlans(connectPlan2, sparkPlan2)
+
+    val connectPlan3 =
+      connectTestRelation.rollup("id".protoAttr, "name".protoAttr)(
+        
proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+          .as("agg1"))
+    val sparkPlan3 =
+      sparkTestRelation
+        .rollup(Column("id"), Column("name"))
+        .agg(min(lit(1)).as("agg1"))
+    comparePlans(connectPlan3, sparkPlan3)
+  }
+
+  test("Cube expressions") {
+    val connectPlan1 =
+      connectTestRelation.cube("id".protoAttr)(proto_min("name".protoAttr))
+    val sparkPlan1 =
+      sparkTestRelation.cube(Column("id")).agg(min(Column("name")))
+    comparePlans(connectPlan1, sparkPlan1)
+
+    val connectPlan2 =
+      
connectTestRelation.cube("id".protoAttr)(proto_min("name".protoAttr).as("agg1"))
+    val sparkPlan2 =
+      sparkTestRelation.cube(Column("id")).agg(min(Column("name")).as("agg1"))
+    comparePlans(connectPlan2, sparkPlan2)
+
+    val connectPlan3 =
+      connectTestRelation.cube("id".protoAttr, "name".protoAttr)(
+        
proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+          .as("agg1"))
+    val sparkPlan3 =
+      sparkTestRelation
+        .cube(Column("id"), Column("name"))
+        .agg(min(lit(1)).as("agg1"))
+    comparePlans(connectPlan3, sparkPlan3)
+  }
+
+  test("Pivot expressions") {
+    val connectPlan1 =
+      connectTestRelation.pivot("id".protoAttr)(
+        "name".protoAttr,
+        Seq("a", "b", "c").map(toConnectProtoValue))(
+        
proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+          .as("agg1"))
+    val sparkPlan1 =
+      sparkTestRelation
+        .groupBy(Column("id"))
+        .pivot(Column("name"), Seq("a", "b", "c"))
+        .agg(min(lit(1)).as("agg1"))
+    comparePlans(connectPlan1, sparkPlan1)
+
+    val connectPlan2 =
+      connectTestRelation.pivot("name".protoAttr)(
+        "id".protoAttr,
+        Seq(1, 2, 3).map(toConnectProtoValue))(
+        
proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+          .as("agg1"))
+    val sparkPlan2 =
+      sparkTestRelation
+        .groupBy(Column("name"))
+        .pivot(Column("id"), Seq(1, 2, 3))
+        .agg(min(lit(1)).as("agg1"))
+    comparePlans(connectPlan2, sparkPlan2)
+  }
+
   test("Test as(alias: String)") {
     val connectPlan = connectTestRelation.as("target_table")
     val sparkPlan = sparkTestRelation.as("target_table")
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 2200be16b17..b0b3a949d30 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -60,11 +60,10 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.session import SparkSession
 
 
-class DataFrame(object):
+class DataFrame:
     def __init__(
         self,
         session: "SparkSession",
-        data: Optional[List[Any]] = None,
         schema: Optional[StructType] = None,
     ):
         """Creates a new data frame"""
@@ -246,10 +245,53 @@ class DataFrame(object):
     first.__doc__ = PySparkDataFrame.first.__doc__
 
     def groupBy(self, *cols: "ColumnOrName") -> GroupedData:
-        return GroupedData(self, *cols)
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, Column):
+                _cols.append(c)
+            elif isinstance(c, str):
+                _cols.append(self[c])
+            else:
+                raise TypeError(
+                    f"groupBy requires all cols be Column or str, but got 
{type(c).__name__} {c}"
+                )
+
+        return GroupedData(df=self, group_type="groupby", grouping_cols=_cols)
 
     groupBy.__doc__ = PySparkDataFrame.groupBy.__doc__
 
+    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, Column):
+                _cols.append(c)
+            elif isinstance(c, str):
+                _cols.append(self[c])
+            else:
+                raise TypeError(
+                    f"rollup requires all cols be Column or str, but got 
{type(c).__name__} {c}"
+                )
+
+        return GroupedData(df=self, group_type="rollup", grouping_cols=_cols)
+
+    rollup.__doc__ = PySparkDataFrame.rollup.__doc__
+
+    def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, Column):
+                _cols.append(c)
+            elif isinstance(c, str):
+                _cols.append(self[c])
+            else:
+                raise TypeError(
+                    f"cube requires all cols be Column or str, but got 
{type(c).__name__} {c}"
+                )
+
+        return GroupedData(df=self, group_type="cube", grouping_cols=_cols)
+
+    cube.__doc__ = PySparkDataFrame.cube.__doc__
+
     @overload
     def head(self) -> Optional[Row]:
         ...
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index c275edc9a2a..004ebd50196 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -16,31 +16,55 @@
 #
 
 from typing import (
+    Any,
     Dict,
     List,
     Sequence,
     Union,
     TYPE_CHECKING,
+    Optional,
     overload,
     cast,
 )
 
+from pyspark.sql.group import GroupedData as PySparkGroupedData
+
 import pyspark.sql.connect.plan as plan
-from pyspark.sql.connect.column import (
-    Column,
-    scalar_function,
-)
+from pyspark.sql.connect.column import Column, scalar_function
 from pyspark.sql.connect.functions import col, lit
-from pyspark.sql.group import GroupedData as PySparkGroupedData
 
 if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import LiteralType
     from pyspark.sql.connect.dataframe import DataFrame
 
 
-class GroupedData(object):
-    def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> 
None:
+class GroupedData:
+    def __init__(
+        self,
+        df: "DataFrame",
+        group_type: str,
+        grouping_cols: Sequence["Column"],
+        pivot_col: Optional["Column"] = None,
+        pivot_values: Optional[Sequence["LiteralType"]] = None,
+    ) -> None:
+        from pyspark.sql.connect.dataframe import DataFrame
+
+        assert isinstance(df, DataFrame)
         self._df = df
-        self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in 
grouping_cols]
+
+        assert isinstance(group_type, str) and group_type in ["groupby", 
"rollup", "cube", "pivot"]
+        self._group_type = group_type
+
+        assert isinstance(grouping_cols, list) and all(isinstance(g, Column) 
for g in grouping_cols)
+        self._grouping_cols: List[Column] = grouping_cols
+
+        self._pivot_col: Optional["Column"] = None
+        self._pivot_values: Optional[List[Any]] = None
+        if group_type == "pivot":
+            assert pivot_col is not None and isinstance(pivot_col, Column)
+            assert pivot_values is None or isinstance(pivot_values, list)
+            self._pivot_col = pivot_col
+            self._pivot_values = pivot_values
 
     @overload
     def agg(self, *exprs: Column) -> "DataFrame":
@@ -56,17 +80,20 @@ class GroupedData(object):
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
             # Convert the dict into key value pairs
-            measures = [scalar_function(exprs[0][k], col(k)) for k in exprs[0]]
+            aggregate_cols = [scalar_function(exprs[0][k], col(k)) for k in 
exprs[0]]
         else:
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
-            measures = cast(List[Column], list(exprs))
+            aggregate_cols = cast(List[Column], list(exprs))
 
         res = DataFrame.withPlan(
             plan.Aggregate(
                 child=self._df._plan,
+                group_type=self._group_type,
                 grouping_cols=self._grouping_cols,
-                measures=measures,
+                aggregate_cols=aggregate_cols,
+                pivot_col=self._pivot_col,
+                pivot_values=self._pivot_values,
             ),
             session=self._df._session,
         )
@@ -108,5 +135,38 @@ class GroupedData(object):
 
     count.__doc__ = PySparkGroupedData.count.__doc__
 
+    def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = 
None) -> "GroupedData":
+        if self._group_type != "groupby":
+            if self._group_type == "pivot":
+                raise Exception("Repeated PIVOT operation is not supported!")
+            else:
+                raise Exception(f"PIVOT after {self._group_type.upper()} is 
not supported!")
+
+        if not isinstance(pivot_col, str):
+            raise TypeError(
+                f"pivot_col should be a str, but got 
{type(pivot_col).__name__} {pivot_col}"
+            )
+
+        if values is not None:
+            if not isinstance(values, list):
+                raise TypeError(
+                    f"values should be a list, but got {type(values).__name__} 
{values}"
+                )
+            for v in values:
+                if not isinstance(v, (bool, float, int, str)):
+                    raise TypeError(
+                        f"value should be a bool, float, int or str, but got 
{type(v).__name__} {v}"
+                    )
+
+        return GroupedData(
+            df=self._df,
+            group_type="pivot",
+            grouping_cols=self._grouping_cols,
+            pivot_col=self._df[pivot_col],
+            pivot_values=values,
+        )
+
+    pivot.__doc__ = PySparkGroupedData.pivot.__doc__
+
 
 GroupedData.__doc__ = PySparkGroupedData.__doc__
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 4e081832d01..d12256adec7 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -692,36 +692,71 @@ class Aggregate(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        grouping_cols: List[Column],
-        measures: Sequence[Column],
+        group_type: str,
+        grouping_cols: Sequence[Column],
+        aggregate_cols: Sequence[Column],
+        pivot_col: Optional[Column],
+        pivot_values: Optional[Sequence[Any]],
     ) -> None:
         super().__init__(child)
-        self.grouping_cols = grouping_cols
-        self.measures = measures
 
-    def _convert_measure(self, m: Column, session: "SparkConnectClient") -> 
proto.Expression:
-        proto_expr = proto.Expression()
-        proto_expr.CopyFrom(m.to_plan(session))
-        return proto_expr
+        assert isinstance(group_type, str) and group_type in ["groupby", 
"rollup", "cube", "pivot"]
+        self._group_type = group_type
+
+        assert isinstance(grouping_cols, list) and all(isinstance(c, Column) 
for c in grouping_cols)
+        self._grouping_cols = grouping_cols
+
+        assert isinstance(aggregate_cols, list) and all(
+            isinstance(c, Column) for c in aggregate_cols
+        )
+        self._aggregate_cols = aggregate_cols
+
+        if group_type == "pivot":
+            assert pivot_col is not None and isinstance(pivot_col, Column)
+            assert pivot_values is None or isinstance(pivot_values, list)
+        else:
+            assert pivot_col is None
+            assert pivot_values is None
+
+        self._pivot_col = pivot_col
+        self._pivot_values = pivot_values
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        from pyspark.sql.connect.functions import lit
+
         assert self._child is not None
-        groupings = [x.to_plan(session) for x in self.grouping_cols]
 
         agg = proto.Relation()
+
         agg.aggregate.input.CopyFrom(self._child.plan(session))
-        agg.aggregate.result_expressions.extend(
-            list(map(lambda x: self._convert_measure(x, session), 
self.measures))
+
+        agg.aggregate.grouping_expressions.extend([c.to_plan(session) for c in 
self._grouping_cols])
+        agg.aggregate.aggregate_expressions.extend(
+            [c.to_plan(session) for c in self._aggregate_cols]
         )
 
-        agg.aggregate.grouping_expressions.extend(groupings)
+        if self._group_type == "groupby":
+            agg.aggregate.group_type = 
proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
+        elif self._group_type == "rollup":
+            agg.aggregate.group_type = 
proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
+        elif self._group_type == "cube":
+            agg.aggregate.group_type = 
proto.Aggregate.GroupType.GROUP_TYPE_CUBE
+        elif self._group_type == "pivot":
+            agg.aggregate.group_type = 
proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
+            assert self._pivot_col is not None
+            agg.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
+            if self._pivot_values is not None and len(self._pivot_values) > 0:
+                agg.aggregate.pivot.values.extend(
+                    [lit(v).to_plan(session).literal for v in 
self._pivot_values]
+                )
+
         return agg
 
     def print(self, indent: int = 0) -> str:
         c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child 
else ""
         return (
-            f"{' ' * indent}<Sort columns={self.grouping_cols}"
-            f"measures={self.measures}>\n{c_buf}"
+            f"{' ' * indent}<Groupby={self._grouping_cols}"
+            f"Aggregate={self._aggregate_cols}>\n{c_buf}"
         )
 
     def _repr_html_(self) -> str:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index b310e2c8464..5f259e75caa 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xf7\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xf7\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...]
 )
 
 
@@ -54,6 +54,7 @@ _LIMIT = DESCRIPTOR.message_types_by_name["Limit"]
 _OFFSET = DESCRIPTOR.message_types_by_name["Offset"]
 _TAIL = DESCRIPTOR.message_types_by_name["Tail"]
 _AGGREGATE = DESCRIPTOR.message_types_by_name["Aggregate"]
+_AGGREGATE_PIVOT = _AGGREGATE.nested_types_by_name["Pivot"]
 _SORT = DESCRIPTOR.message_types_by_name["Sort"]
 _DROP = DESCRIPTOR.message_types_by_name["Drop"]
 _DEDUPLICATE = DESCRIPTOR.message_types_by_name["Deduplicate"]
@@ -81,6 +82,7 @@ _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"]
 _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
 _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
 _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"]
+_AGGREGATE_GROUPTYPE = _AGGREGATE.enum_types_by_name["GroupType"]
 Relation = _reflection.GeneratedProtocolMessageType(
     "Relation",
     (_message.Message,),
@@ -247,12 +249,22 @@ Aggregate = _reflection.GeneratedProtocolMessageType(
     "Aggregate",
     (_message.Message,),
     {
+        "Pivot": _reflection.GeneratedProtocolMessageType(
+            "Pivot",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _AGGREGATE_PIVOT,
+                "__module__": "spark.connect.relations_pb2"
+                # 
@@protoc_insertion_point(class_scope:spark.connect.Aggregate.Pivot)
+            },
+        ),
         "DESCRIPTOR": _AGGREGATE,
         "__module__": "spark.connect.relations_pb2"
         # @@protoc_insertion_point(class_scope:spark.connect.Aggregate)
     },
 )
 _sym_db.RegisterMessage(Aggregate)
+_sym_db.RegisterMessage(Aggregate.Pivot)
 
 Sort = _reflection.GeneratedProtocolMessageType(
     "Sort",
@@ -548,51 +560,55 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _TAIL._serialized_start = 3807
     _TAIL._serialized_end = 3882
     _AGGREGATE._serialized_start = 3885
-    _AGGREGATE._serialized_end = 4095
-    _SORT._serialized_start = 4098
-    _SORT._serialized_end = 4258
-    _DROP._serialized_start = 4260
-    _DROP._serialized_end = 4360
-    _DEDUPLICATE._serialized_start = 4363
-    _DEDUPLICATE._serialized_end = 4534
-    _LOCALRELATION._serialized_start = 4537
-    _LOCALRELATION._serialized_end = 4674
-    _SAMPLE._serialized_start = 4677
-    _SAMPLE._serialized_end = 4972
-    _RANGE._serialized_start = 4975
-    _RANGE._serialized_end = 5120
-    _SUBQUERYALIAS._serialized_start = 5122
-    _SUBQUERYALIAS._serialized_end = 5236
-    _REPARTITION._serialized_start = 5239
-    _REPARTITION._serialized_end = 5381
-    _SHOWSTRING._serialized_start = 5384
-    _SHOWSTRING._serialized_end = 5525
-    _STATSUMMARY._serialized_start = 5527
-    _STATSUMMARY._serialized_end = 5619
-    _STATDESCRIBE._serialized_start = 5621
-    _STATDESCRIBE._serialized_end = 5702
-    _STATCROSSTAB._serialized_start = 5704
-    _STATCROSSTAB._serialized_end = 5805
-    _NAFILL._serialized_start = 5808
-    _NAFILL._serialized_end = 5942
-    _NADROP._serialized_start = 5945
-    _NADROP._serialized_end = 6079
-    _NAREPLACE._serialized_start = 6082
-    _NAREPLACE._serialized_end = 6378
-    _NAREPLACE_REPLACEMENT._serialized_start = 6237
-    _NAREPLACE_REPLACEMENT._serialized_end = 6378
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6380
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6494
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6497
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6756
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
6689
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6756
-    _WITHCOLUMNS._serialized_start = 6759
-    _WITHCOLUMNS._serialized_end = 6890
-    _HINT._serialized_start = 6893
-    _HINT._serialized_end = 7033
-    _UNPIVOT._serialized_start = 7036
-    _UNPIVOT._serialized_end = 7282
-    _TOSCHEMA._serialized_start = 7284
-    _TOSCHEMA._serialized_end = 7390
+    _AGGREGATE._serialized_end = 4467
+    _AGGREGATE_PIVOT._serialized_start = 4224
+    _AGGREGATE_PIVOT._serialized_end = 4335
+    _AGGREGATE_GROUPTYPE._serialized_start = 4338
+    _AGGREGATE_GROUPTYPE._serialized_end = 4467
+    _SORT._serialized_start = 4470
+    _SORT._serialized_end = 4630
+    _DROP._serialized_start = 4632
+    _DROP._serialized_end = 4732
+    _DEDUPLICATE._serialized_start = 4735
+    _DEDUPLICATE._serialized_end = 4906
+    _LOCALRELATION._serialized_start = 4909
+    _LOCALRELATION._serialized_end = 5046
+    _SAMPLE._serialized_start = 5049
+    _SAMPLE._serialized_end = 5344
+    _RANGE._serialized_start = 5347
+    _RANGE._serialized_end = 5492
+    _SUBQUERYALIAS._serialized_start = 5494
+    _SUBQUERYALIAS._serialized_end = 5608
+    _REPARTITION._serialized_start = 5611
+    _REPARTITION._serialized_end = 5753
+    _SHOWSTRING._serialized_start = 5756
+    _SHOWSTRING._serialized_end = 5897
+    _STATSUMMARY._serialized_start = 5899
+    _STATSUMMARY._serialized_end = 5991
+    _STATDESCRIBE._serialized_start = 5993
+    _STATDESCRIBE._serialized_end = 6074
+    _STATCROSSTAB._serialized_start = 6076
+    _STATCROSSTAB._serialized_end = 6177
+    _NAFILL._serialized_start = 6180
+    _NAFILL._serialized_end = 6314
+    _NADROP._serialized_start = 6317
+    _NADROP._serialized_end = 6451
+    _NAREPLACE._serialized_start = 6454
+    _NAREPLACE._serialized_end = 6750
+    _NAREPLACE_REPLACEMENT._serialized_start = 6609
+    _NAREPLACE_REPLACEMENT._serialized_end = 6750
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6752
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6866
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6869
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7128
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
7061
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7128
+    _WITHCOLUMNS._serialized_start = 7131
+    _WITHCOLUMNS._serialized_end = 7262
+    _HINT._serialized_start = 7265
+    _HINT._serialized_end = 7405
+    _UNPIVOT._serialized_start = 7408
+    _UNPIVOT._serialized_end = 7654
+    _TOSCHEMA._serialized_start = 7656
+    _TOSCHEMA._serialized_end = 7762
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 62308eaaa81..f9032be6a49 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -909,49 +909,121 @@ class Aggregate(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    class _GroupType:
+        ValueType = typing.NewType("ValueType", builtins.int)
+        V: typing_extensions.TypeAlias = ValueType
+
+    class _GroupTypeEnumTypeWrapper(
+        
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Aggregate._GroupType.ValueType],
+        builtins.type,
+    ):  # noqa: F821
+        DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+        GROUP_TYPE_UNSPECIFIED: Aggregate._GroupType.ValueType  # 0
+        GROUP_TYPE_GROUPBY: Aggregate._GroupType.ValueType  # 1
+        GROUP_TYPE_ROLLUP: Aggregate._GroupType.ValueType  # 2
+        GROUP_TYPE_CUBE: Aggregate._GroupType.ValueType  # 3
+        GROUP_TYPE_PIVOT: Aggregate._GroupType.ValueType  # 4
+
+    class GroupType(_GroupType, metaclass=_GroupTypeEnumTypeWrapper): ...
+    GROUP_TYPE_UNSPECIFIED: Aggregate.GroupType.ValueType  # 0
+    GROUP_TYPE_GROUPBY: Aggregate.GroupType.ValueType  # 1
+    GROUP_TYPE_ROLLUP: Aggregate.GroupType.ValueType  # 2
+    GROUP_TYPE_CUBE: Aggregate.GroupType.ValueType  # 3
+    GROUP_TYPE_PIVOT: Aggregate.GroupType.ValueType  # 4
+
+    class Pivot(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        COL_FIELD_NUMBER: builtins.int
+        VALUES_FIELD_NUMBER: builtins.int
+        @property
+        def col(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression:
+            """(Required) The column to pivot"""
+        @property
+        def values(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
+        ]:
+            """(Optional) List of values that will be translated to columns in 
the output DataFrame.
+
+            Note that if it is empty, the server side will immediately trigger 
a job to collect
+            the distinct values of the column.
+            """
+        def __init__(
+            self,
+            *,
+            col: pyspark.sql.connect.proto.expressions_pb2.Expression | None = 
...,
+            values: collections.abc.Iterable[
+                pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
+            ]
+            | None = ...,
+        ) -> None: ...
+        def HasField(
+            self, field_name: typing_extensions.Literal["col", b"col"]
+        ) -> builtins.bool: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["col", b"col", 
"values", b"values"]
+        ) -> None: ...
+
     INPUT_FIELD_NUMBER: builtins.int
+    GROUP_TYPE_FIELD_NUMBER: builtins.int
     GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
-    RESULT_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    AGGREGATE_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    PIVOT_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> global___Relation:
-        """(Required) Input relation for a Aggregate."""
+        """(Required) Input relation for a RelationalGroupedDataset."""
+    group_type: global___Aggregate.GroupType.ValueType
+    """(Required) How the RelationalGroupedDataset was built."""
     @property
     def grouping_expressions(
         self,
     ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
         pyspark.sql.connect.proto.expressions_pb2.Expression
-    ]: ...
+    ]:
+        """(Required) Expressions for grouping keys"""
     @property
-    def result_expressions(
+    def aggregate_expressions(
         self,
     ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
         pyspark.sql.connect.proto.expressions_pb2.Expression
-    ]: ...
+    ]:
+        """(Required) List of values that will be translated to columns in the 
output DataFrame."""
+    @property
+    def pivot(self) -> global___Aggregate.Pivot:
+        """(Optional) Pivots a column of the current `DataFrame` and performs 
the specified aggregation."""
     def __init__(
         self,
         *,
         input: global___Relation | None = ...,
+        group_type: global___Aggregate.GroupType.ValueType = ...,
         grouping_expressions: collections.abc.Iterable[
             pyspark.sql.connect.proto.expressions_pb2.Expression
         ]
         | None = ...,
-        result_expressions: collections.abc.Iterable[
+        aggregate_expressions: collections.abc.Iterable[
             pyspark.sql.connect.proto.expressions_pb2.Expression
         ]
         | None = ...,
+        pivot: global___Aggregate.Pivot | None = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["input", b"input"]
+        self, field_name: typing_extensions.Literal["input", b"input", 
"pivot", b"pivot"]
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "aggregate_expressions",
+            b"aggregate_expressions",
+            "group_type",
+            b"group_type",
             "grouping_expressions",
             b"grouping_expressions",
             "input",
             b"input",
-            "result_expressions",
-            b"result_expressions",
+            "pivot",
+            b"pivot",
         ],
     ) -> None: ...
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 3e977d95541..bced3fd5e7e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1155,6 +1155,142 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             
set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()),
         )
 
+    def test_grouped_data(self):
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT * FROM VALUES
+                ('James', 'Sales', 3000, 2020),
+                ('Michael', 'Sales', 4600, 2020),
+                ('Robert', 'Sales', 4100, 2020),
+                ('Maria', 'Finance', 3000, 2020),
+                ('James', 'Sales', 3000, 2019),
+                ('Scott', 'Finance', 3300, 2020),
+                ('Jen', 'Finance', 3900, 2020),
+                ('Jeff', 'Marketing', 3000, 2020),
+                ('Kumar', 'Marketing', 2000, 2020),
+                ('Saif', 'Sales', 4100, 2020)
+            AS T(name, department, salary, year)
+            """
+
+        # +-------+----------+------+----+
+        # |   name|department|salary|year|
+        # +-------+----------+------+----+
+        # |  James|     Sales|  3000|2020|
+        # |Michael|     Sales|  4600|2020|
+        # | Robert|     Sales|  4100|2020|
+        # |  Maria|   Finance|  3000|2020|
+        # |  James|     Sales|  3000|2019|
+        # |  Scott|   Finance|  3300|2020|
+        # |    Jen|   Finance|  3900|2020|
+        # |   Jeff| Marketing|  3000|2020|
+        # |  Kumar| Marketing|  2000|2020|
+        # |   Saif|     Sales|  4100|2020|
+        # +-------+----------+------+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        # test groupby
+        self.assert_eq(
+            cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(),
+            sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(),
+        )
+        self.assert_eq(
+            cdf.groupBy("name", cdf.department).agg(CF.max("year"), 
CF.min(cdf.salary)).toPandas(),
+            sdf.groupBy("name", sdf.department).agg(SF.max("year"), 
SF.min(sdf.salary)).toPandas(),
+        )
+
+        # test rollup
+        self.assert_eq(
+            cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(),
+            sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(),
+        )
+        self.assert_eq(
+            cdf.rollup("name", cdf.department).agg(CF.max("year"), 
CF.min(cdf.salary)).toPandas(),
+            sdf.rollup("name", sdf.department).agg(SF.max("year"), 
SF.min(sdf.salary)).toPandas(),
+        )
+
+        # test cube
+        self.assert_eq(
+            cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(),
+            sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(),
+        )
+        self.assert_eq(
+            cdf.cube("name", cdf.department).agg(CF.max("year"), 
CF.min(cdf.salary)).toPandas(),
+            sdf.cube("name", sdf.department).agg(SF.max("year"), 
SF.min(sdf.salary)).toPandas(),
+        )
+
+        # test pivot
+        # pivot with values
+        self.assert_eq(
+            cdf.groupBy("name")
+            .pivot("department", ["Sales", "Marketing"])
+            .agg(CF.sum(cdf.salary))
+            .toPandas(),
+            sdf.groupBy("name")
+            .pivot("department", ["Sales", "Marketing"])
+            .agg(SF.sum(sdf.salary))
+            .toPandas(),
+        )
+        self.assert_eq(
+            cdf.groupBy(cdf.name)
+            .pivot("department", ["Sales", "Finance", "Marketing"])
+            .agg(CF.sum(cdf.salary))
+            .toPandas(),
+            sdf.groupBy(sdf.name)
+            .pivot("department", ["Sales", "Finance", "Marketing"])
+            .agg(SF.sum(sdf.salary))
+            .toPandas(),
+        )
+        self.assert_eq(
+            cdf.groupBy(cdf.name)
+            .pivot("department", ["Sales", "Finance", "Unknown"])
+            .agg(CF.sum(cdf.salary))
+            .toPandas(),
+            sdf.groupBy(sdf.name)
+            .pivot("department", ["Sales", "Finance", "Unknown"])
+            .agg(SF.sum(sdf.salary))
+            .toPandas(),
+        )
+
+        # pivot without values
+        self.assert_eq(
+            
cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(),
+            
sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(),
+        )
+
+        self.assert_eq(
+            
cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(),
+            
sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(),
+        )
+
+        # check error
+        with self.assertRaisesRegex(
+            Exception,
+            "PIVOT after ROLLUP is not supported",
+        ):
+            cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary))
+
+        with self.assertRaisesRegex(
+            Exception,
+            "PIVOT after CUBE is not supported",
+        ):
+            cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary))
+
+        with self.assertRaisesRegex(
+            Exception,
+            "Repeated PIVOT operation is not supported",
+        ):
+            
cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary))
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "value should be a bool, float, int or str, but got bytes",
+        ):
+            cdf.groupBy("name").pivot("department", ["Sales", 
b"Marketing"]).agg(CF.sum(cdf.salary))
+
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class ChannelBuilderTests(ReusedPySparkTestCase):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to