This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e1af3a992e0 [SPARK-41383][SPARK-41692][SPARK-41693] Implement `rollup`, `cube` and `pivot` e1af3a992e0 is described below commit e1af3a992e06aeb5185501db908dc272b449c62b Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Dec 23 19:51:44 2022 +0900 [SPARK-41383][SPARK-41692][SPARK-41693] Implement `rollup`, `cube` and `pivot` ### What changes were proposed in this pull request? Implement `rollup`, `cube` and `pivot`: 1. `DataFrame.rollup` 2. `DataFrame.cube` 3. `DataFrame.groupBy.pivot` ### Why are the changes needed? for API coverage ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT Closes #39191 from zhengruifeng/connect_groupby_refactor. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 34 +++++- .../org/apache/spark/sql/connect/dsl/package.scala | 50 +++++++- .../planner/LiteralValueProtoConverter.scala | 2 +- .../sql/connect/planner/SparkConnectPlanner.scala | 82 +++++++++---- .../connect/planner/SparkConnectPlannerSuite.scala | 3 +- .../connect/planner/SparkConnectProtoSuite.scala | 77 ++++++++++++ python/pyspark/sql/connect/dataframe.py | 48 +++++++- python/pyspark/sql/connect/group.py | 82 +++++++++++-- python/pyspark/sql/connect/plan.py | 63 +++++++--- python/pyspark/sql/connect/proto/relations_pb2.py | 112 +++++++++-------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 90 ++++++++++++-- .../sql/tests/connect/test_connect_basic.py | 136 +++++++++++++++++++++ 12 files changed, 667 insertions(+), 112 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index c4f040c03d6..912ee1fdc63 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -235,11 +235,39 @@ message Tail { // Relation of type [[Aggregate]]. message Aggregate { - // (Required) Input relation for a Aggregate. + // (Required) Input relation for a RelationalGroupedDataset. Relation input = 1; - repeated Expression grouping_expressions = 2; - repeated Expression result_expressions = 3; + // (Required) How the RelationalGroupedDataset was built. + GroupType group_type = 2; + + // (Required) Expressions for grouping keys + repeated Expression grouping_expressions = 3; + + // (Required) List of values that will be translated to columns in the output DataFrame. + repeated Expression aggregate_expressions = 4; + + // (Optional) Pivots a column of the current `DataFrame` and performs the specified aggregation. + Pivot pivot = 5; + + enum GroupType { + GROUP_TYPE_UNSPECIFIED = 0; + GROUP_TYPE_GROUPBY = 1; + GROUP_TYPE_ROLLUP = 2; + GROUP_TYPE_CUBE = 3; + GROUP_TYPE_PIVOT = 4; + } + + message Pivot { + // (Required) The column to pivot + Expression col = 1; + + // (Optional) List of values that will be translated to columns in the output DataFrame. + // + // Note that if it is empty, the server side will immediately trigger a job to collect + // the distinct values of the column. + repeated Expression.Literal values = 2; + } } // Relation of type [[Sort]]. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index b15e46293ab..e6d230d9eef 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -601,16 +601,64 @@ package object dsl { def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = { val agg = Aggregate.newBuilder() agg.setInput(logicalPlan) + agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) for (groupingExpr <- groupingExprs) { agg.addGroupingExpressions(groupingExpr) } for (aggregateExpr <- aggregateExprs) { - agg.addResultExpressions(aggregateExpr) + agg.addAggregateExpressions(aggregateExpr) } Relation.newBuilder().setAggregate(agg.build()).build() } + def rollup(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = { + val agg = Aggregate.newBuilder() + agg.setInput(logicalPlan) + agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + + for (groupingExpr <- groupingExprs) { + agg.addGroupingExpressions(groupingExpr) + } + for (aggregateExpr <- aggregateExprs) { + agg.addAggregateExpressions(aggregateExpr) + } + Relation.newBuilder().setAggregate(agg.build()).build() + } + + def cube(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = { + val agg = Aggregate.newBuilder() + agg.setInput(logicalPlan) + agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + + for (groupingExpr <- groupingExprs) { + agg.addGroupingExpressions(groupingExpr) + } + for (aggregateExpr <- aggregateExprs) { + agg.addAggregateExpressions(aggregateExpr) + } + Relation.newBuilder().setAggregate(agg.build()).build() + } + + def pivot(groupingExprs: Expression*)( + pivotCol: Expression, + pivotValues: Seq[proto.Expression.Literal])(aggregateExprs: Expression*): Relation = { + val agg = Aggregate.newBuilder() + agg.setInput(logicalPlan) + agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT) + + for (groupingExpr <- groupingExprs) { + agg.addGroupingExpressions(groupingExpr) + } + for (aggregateExpr <- aggregateExprs) { + agg.addAggregateExpressions(aggregateExpr) + } + agg.setPivot( + Aggregate.Pivot.newBuilder().setCol(pivotCol).addAllValues(pivotValues.asJava).build()) + + Relation.newBuilder().setAggregate(agg.build()).build() + } + def except(otherPlan: Relation, isAll: Boolean): Relation = { Relation .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala index abfaaf7a1d3..82ffa4f5246 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala @@ -30,7 +30,7 @@ object LiteralValueProtoConverter { * @return * Expression */ - def toCatalystExpression(lit: proto.Expression.Literal): expressions.Expression = { + def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal = { lit.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.NULL => expressions.Literal(null, NullType) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4abeec0d00b..dce3a8c8e55 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -874,32 +874,72 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { - assert(rel.hasInput) + if (!rel.hasInput) { + throw InvalidPlanInput("Aggregate needs a plan input") + } + val input = transformRelation(rel.getInput) + + def toNamedExpression(expr: Expression): NamedExpression = expr match { + case named: NamedExpression => named + case expr => UnresolvedAlias(expr) + } - val groupingExprs = - rel.getGroupingExpressionsList.asScala - .map(transformExpression) - .map { - case ua @ UnresolvedAttribute(_) => ua - case a @ Alias(_, _) => a - case x => UnresolvedAlias(x) + val groupingExprs = rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression) + val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq.map(transformExpression) + val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression) + + rel.getGroupType match { + case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => + logical.Aggregate( + groupingExpressions = groupingExprs, + aggregateExpressions = aliasedAgg, + child = input) + + case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => + logical.Aggregate( + groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))), + aggregateExpressions = aliasedAgg, + child = input) + + case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => + logical.Aggregate( + groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))), + aggregateExpressions = aliasedAgg, + child = input) + + case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT => + if (!rel.hasPivot) { + throw InvalidPlanInput("Aggregate with GROUP_TYPE_PIVOT requires a Pivot") } - // Retain group columns in aggregate expressions: - val aggExprs = - groupingExprs ++ rel.getResultExpressionsList.asScala.map(transformResultExpression) + val pivotExpr = transformExpression(rel.getPivot.getCol) + + var valueExprs = rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral) + if (valueExprs.isEmpty) { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = session.sessionState.conf.dataFramePivotMaxValues + // Get the distinct values of the column and sort them so its consistent + val pivotCol = Column(pivotExpr) + valueExprs = Dataset + .ofRows(session, input) + .select(pivotCol) + .distinct() + .limit(maxValues + 1) + .sort(pivotCol) // ensure that the output columns are in a consistent logical order + .collect() + .map(_.get(0)) + .toSeq + .map(expressions.Literal.apply) + } - logical.Aggregate( - child = transformRelation(rel.getInput), - groupingExpressions = groupingExprs.toSeq, - aggregateExpressions = aggExprs.toSeq) - } + logical.Pivot( + groupByExprsOpt = Some(groupingExprs.map(toNamedExpression)), + pivotColumn = pivotExpr, + pivotValues = valueExprs, + aggregates = aggExprs, + child = input) - private def transformResultExpression(exp: proto.Expression): expressions.NamedExpression = { - if (exp.hasAlias) { - transformAlias(exp.getAlias) - } else { - UnresolvedAlias(transformExpression(exp)) + case other => throw InvalidPlanInput(s"Unknown Group Type $other") } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 93cb97b4421..1142a3386f9 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -303,8 +303,9 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val agg = proto.Aggregate.newBuilder .setInput(readRel) - .addResultExpressions(sum) + .addAggregateExpressions(sum) .addGroupingExpressions(unresolvedAttribute) + .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) .build() val res = transform(proto.Relation.newBuilder.setAggregate(agg).build()) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 34a30bcd4f0..66a019ef853 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StringType, StructField, StructType} @@ -222,6 +223,82 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan2, sparkPlan2) } + test("Rollup expressions") { + val connectPlan1 = + connectTestRelation.rollup("id".protoAttr)(proto_min("name".protoAttr)) + val sparkPlan1 = + sparkTestRelation.rollup(Column("id")).agg(min(Column("name"))) + comparePlans(connectPlan1, sparkPlan1) + + val connectPlan2 = + connectTestRelation.rollup("id".protoAttr)(proto_min("name".protoAttr).as("agg1")) + val sparkPlan2 = + sparkTestRelation.rollup(Column("id")).agg(min(Column("name")).as("agg1")) + comparePlans(connectPlan2, sparkPlan2) + + val connectPlan3 = + connectTestRelation.rollup("id".protoAttr, "name".protoAttr)( + proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + .as("agg1")) + val sparkPlan3 = + sparkTestRelation + .rollup(Column("id"), Column("name")) + .agg(min(lit(1)).as("agg1")) + comparePlans(connectPlan3, sparkPlan3) + } + + test("Cube expressions") { + val connectPlan1 = + connectTestRelation.cube("id".protoAttr)(proto_min("name".protoAttr)) + val sparkPlan1 = + sparkTestRelation.cube(Column("id")).agg(min(Column("name"))) + comparePlans(connectPlan1, sparkPlan1) + + val connectPlan2 = + connectTestRelation.cube("id".protoAttr)(proto_min("name".protoAttr).as("agg1")) + val sparkPlan2 = + sparkTestRelation.cube(Column("id")).agg(min(Column("name")).as("agg1")) + comparePlans(connectPlan2, sparkPlan2) + + val connectPlan3 = + connectTestRelation.cube("id".protoAttr, "name".protoAttr)( + proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + .as("agg1")) + val sparkPlan3 = + sparkTestRelation + .cube(Column("id"), Column("name")) + .agg(min(lit(1)).as("agg1")) + comparePlans(connectPlan3, sparkPlan3) + } + + test("Pivot expressions") { + val connectPlan1 = + connectTestRelation.pivot("id".protoAttr)( + "name".protoAttr, + Seq("a", "b", "c").map(toConnectProtoValue))( + proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + .as("agg1")) + val sparkPlan1 = + sparkTestRelation + .groupBy(Column("id")) + .pivot(Column("name"), Seq("a", "b", "c")) + .agg(min(lit(1)).as("agg1")) + comparePlans(connectPlan1, sparkPlan1) + + val connectPlan2 = + connectTestRelation.pivot("name".protoAttr)( + "id".protoAttr, + Seq(1, 2, 3).map(toConnectProtoValue))( + proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + .as("agg1")) + val sparkPlan2 = + sparkTestRelation + .groupBy(Column("name")) + .pivot(Column("id"), Seq(1, 2, 3)) + .agg(min(lit(1)).as("agg1")) + comparePlans(connectPlan2, sparkPlan2) + } + test("Test as(alias: String)") { val connectPlan = connectTestRelation.as("target_table") val sparkPlan = sparkTestRelation.as("target_table") diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 2200be16b17..b0b3a949d30 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -60,11 +60,10 @@ if TYPE_CHECKING: from pyspark.sql.connect.session import SparkSession -class DataFrame(object): +class DataFrame: def __init__( self, session: "SparkSession", - data: Optional[List[Any]] = None, schema: Optional[StructType] = None, ): """Creates a new data frame""" @@ -246,10 +245,53 @@ class DataFrame(object): first.__doc__ = PySparkDataFrame.first.__doc__ def groupBy(self, *cols: "ColumnOrName") -> GroupedData: - return GroupedData(self, *cols) + _cols: List[Column] = [] + for c in cols: + if isinstance(c, Column): + _cols.append(c) + elif isinstance(c, str): + _cols.append(self[c]) + else: + raise TypeError( + f"groupBy requires all cols be Column or str, but got {type(c).__name__} {c}" + ) + + return GroupedData(df=self, group_type="groupby", grouping_cols=_cols) groupBy.__doc__ = PySparkDataFrame.groupBy.__doc__ + def rollup(self, *cols: "ColumnOrName") -> "GroupedData": + _cols: List[Column] = [] + for c in cols: + if isinstance(c, Column): + _cols.append(c) + elif isinstance(c, str): + _cols.append(self[c]) + else: + raise TypeError( + f"rollup requires all cols be Column or str, but got {type(c).__name__} {c}" + ) + + return GroupedData(df=self, group_type="rollup", grouping_cols=_cols) + + rollup.__doc__ = PySparkDataFrame.rollup.__doc__ + + def cube(self, *cols: "ColumnOrName") -> "GroupedData": + _cols: List[Column] = [] + for c in cols: + if isinstance(c, Column): + _cols.append(c) + elif isinstance(c, str): + _cols.append(self[c]) + else: + raise TypeError( + f"cube requires all cols be Column or str, but got {type(c).__name__} {c}" + ) + + return GroupedData(df=self, group_type="cube", grouping_cols=_cols) + + cube.__doc__ = PySparkDataFrame.cube.__doc__ + @overload def head(self) -> Optional[Row]: ... diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index c275edc9a2a..004ebd50196 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -16,31 +16,55 @@ # from typing import ( + Any, Dict, List, Sequence, Union, TYPE_CHECKING, + Optional, overload, cast, ) +from pyspark.sql.group import GroupedData as PySparkGroupedData + import pyspark.sql.connect.plan as plan -from pyspark.sql.connect.column import ( - Column, - scalar_function, -) +from pyspark.sql.connect.column import Column, scalar_function from pyspark.sql.connect.functions import col, lit -from pyspark.sql.group import GroupedData as PySparkGroupedData if TYPE_CHECKING: + from pyspark.sql.connect._typing import LiteralType from pyspark.sql.connect.dataframe import DataFrame -class GroupedData(object): - def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> None: +class GroupedData: + def __init__( + self, + df: "DataFrame", + group_type: str, + grouping_cols: Sequence["Column"], + pivot_col: Optional["Column"] = None, + pivot_values: Optional[Sequence["LiteralType"]] = None, + ) -> None: + from pyspark.sql.connect.dataframe import DataFrame + + assert isinstance(df, DataFrame) self._df = df - self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in grouping_cols] + + assert isinstance(group_type, str) and group_type in ["groupby", "rollup", "cube", "pivot"] + self._group_type = group_type + + assert isinstance(grouping_cols, list) and all(isinstance(g, Column) for g in grouping_cols) + self._grouping_cols: List[Column] = grouping_cols + + self._pivot_col: Optional["Column"] = None + self._pivot_values: Optional[List[Any]] = None + if group_type == "pivot": + assert pivot_col is not None and isinstance(pivot_col, Column) + assert pivot_values is None or isinstance(pivot_values, list) + self._pivot_col = pivot_col + self._pivot_values = pivot_values @overload def agg(self, *exprs: Column) -> "DataFrame": @@ -56,17 +80,20 @@ class GroupedData(object): assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): # Convert the dict into key value pairs - measures = [scalar_function(exprs[0][k], col(k)) for k in exprs[0]] + aggregate_cols = [scalar_function(exprs[0][k], col(k)) for k in exprs[0]] else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - measures = cast(List[Column], list(exprs)) + aggregate_cols = cast(List[Column], list(exprs)) res = DataFrame.withPlan( plan.Aggregate( child=self._df._plan, + group_type=self._group_type, grouping_cols=self._grouping_cols, - measures=measures, + aggregate_cols=aggregate_cols, + pivot_col=self._pivot_col, + pivot_values=self._pivot_values, ), session=self._df._session, ) @@ -108,5 +135,38 @@ class GroupedData(object): count.__doc__ = PySparkGroupedData.count.__doc__ + def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) -> "GroupedData": + if self._group_type != "groupby": + if self._group_type == "pivot": + raise Exception("Repeated PIVOT operation is not supported!") + else: + raise Exception(f"PIVOT after {self._group_type.upper()} is not supported!") + + if not isinstance(pivot_col, str): + raise TypeError( + f"pivot_col should be a str, but got {type(pivot_col).__name__} {pivot_col}" + ) + + if values is not None: + if not isinstance(values, list): + raise TypeError( + f"values should be a list, but got {type(values).__name__} {values}" + ) + for v in values: + if not isinstance(v, (bool, float, int, str)): + raise TypeError( + f"value should be a bool, float, int or str, but got {type(v).__name__} {v}" + ) + + return GroupedData( + df=self._df, + group_type="pivot", + grouping_cols=self._grouping_cols, + pivot_col=self._df[pivot_col], + pivot_values=values, + ) + + pivot.__doc__ = PySparkGroupedData.pivot.__doc__ + GroupedData.__doc__ = PySparkGroupedData.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 4e081832d01..d12256adec7 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -692,36 +692,71 @@ class Aggregate(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], - grouping_cols: List[Column], - measures: Sequence[Column], + group_type: str, + grouping_cols: Sequence[Column], + aggregate_cols: Sequence[Column], + pivot_col: Optional[Column], + pivot_values: Optional[Sequence[Any]], ) -> None: super().__init__(child) - self.grouping_cols = grouping_cols - self.measures = measures - def _convert_measure(self, m: Column, session: "SparkConnectClient") -> proto.Expression: - proto_expr = proto.Expression() - proto_expr.CopyFrom(m.to_plan(session)) - return proto_expr + assert isinstance(group_type, str) and group_type in ["groupby", "rollup", "cube", "pivot"] + self._group_type = group_type + + assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) + self._grouping_cols = grouping_cols + + assert isinstance(aggregate_cols, list) and all( + isinstance(c, Column) for c in aggregate_cols + ) + self._aggregate_cols = aggregate_cols + + if group_type == "pivot": + assert pivot_col is not None and isinstance(pivot_col, Column) + assert pivot_values is None or isinstance(pivot_values, list) + else: + assert pivot_col is None + assert pivot_values is None + + self._pivot_col = pivot_col + self._pivot_values = pivot_values def plan(self, session: "SparkConnectClient") -> proto.Relation: + from pyspark.sql.connect.functions import lit + assert self._child is not None - groupings = [x.to_plan(session) for x in self.grouping_cols] agg = proto.Relation() + agg.aggregate.input.CopyFrom(self._child.plan(session)) - agg.aggregate.result_expressions.extend( - list(map(lambda x: self._convert_measure(x, session), self.measures)) + + agg.aggregate.grouping_expressions.extend([c.to_plan(session) for c in self._grouping_cols]) + agg.aggregate.aggregate_expressions.extend( + [c.to_plan(session) for c in self._aggregate_cols] ) - agg.aggregate.grouping_expressions.extend(groupings) + if self._group_type == "groupby": + agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY + elif self._group_type == "rollup": + agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP + elif self._group_type == "cube": + agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE + elif self._group_type == "pivot": + agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT + assert self._pivot_col is not None + agg.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session)) + if self._pivot_values is not None and len(self._pivot_values) > 0: + agg.aggregate.pivot.values.extend( + [lit(v).to_plan(session).literal for v in self._pivot_values] + ) + return agg def print(self, indent: int = 0) -> str: c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" return ( - f"{' ' * indent}<Sort columns={self.grouping_cols}" - f"measures={self.measures}>\n{c_buf}" + f"{' ' * indent}<Groupby={self._grouping_cols}" + f"Aggregate={self._aggregate_cols}>\n{c_buf}" ) def _repr_html_(self) -> str: diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index b310e2c8464..5f259e75caa 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xf7\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xf7\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilte [...] ) @@ -54,6 +54,7 @@ _LIMIT = DESCRIPTOR.message_types_by_name["Limit"] _OFFSET = DESCRIPTOR.message_types_by_name["Offset"] _TAIL = DESCRIPTOR.message_types_by_name["Tail"] _AGGREGATE = DESCRIPTOR.message_types_by_name["Aggregate"] +_AGGREGATE_PIVOT = _AGGREGATE.nested_types_by_name["Pivot"] _SORT = DESCRIPTOR.message_types_by_name["Sort"] _DROP = DESCRIPTOR.message_types_by_name["Drop"] _DEDUPLICATE = DESCRIPTOR.message_types_by_name["Deduplicate"] @@ -81,6 +82,7 @@ _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"] _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"] _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"] _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"] +_AGGREGATE_GROUPTYPE = _AGGREGATE.enum_types_by_name["GroupType"] Relation = _reflection.GeneratedProtocolMessageType( "Relation", (_message.Message,), @@ -247,12 +249,22 @@ Aggregate = _reflection.GeneratedProtocolMessageType( "Aggregate", (_message.Message,), { + "Pivot": _reflection.GeneratedProtocolMessageType( + "Pivot", + (_message.Message,), + { + "DESCRIPTOR": _AGGREGATE_PIVOT, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Aggregate.Pivot) + }, + ), "DESCRIPTOR": _AGGREGATE, "__module__": "spark.connect.relations_pb2" # @@protoc_insertion_point(class_scope:spark.connect.Aggregate) }, ) _sym_db.RegisterMessage(Aggregate) +_sym_db.RegisterMessage(Aggregate.Pivot) Sort = _reflection.GeneratedProtocolMessageType( "Sort", @@ -548,51 +560,55 @@ if _descriptor._USE_C_DESCRIPTORS == False: _TAIL._serialized_start = 3807 _TAIL._serialized_end = 3882 _AGGREGATE._serialized_start = 3885 - _AGGREGATE._serialized_end = 4095 - _SORT._serialized_start = 4098 - _SORT._serialized_end = 4258 - _DROP._serialized_start = 4260 - _DROP._serialized_end = 4360 - _DEDUPLICATE._serialized_start = 4363 - _DEDUPLICATE._serialized_end = 4534 - _LOCALRELATION._serialized_start = 4537 - _LOCALRELATION._serialized_end = 4674 - _SAMPLE._serialized_start = 4677 - _SAMPLE._serialized_end = 4972 - _RANGE._serialized_start = 4975 - _RANGE._serialized_end = 5120 - _SUBQUERYALIAS._serialized_start = 5122 - _SUBQUERYALIAS._serialized_end = 5236 - _REPARTITION._serialized_start = 5239 - _REPARTITION._serialized_end = 5381 - _SHOWSTRING._serialized_start = 5384 - _SHOWSTRING._serialized_end = 5525 - _STATSUMMARY._serialized_start = 5527 - _STATSUMMARY._serialized_end = 5619 - _STATDESCRIBE._serialized_start = 5621 - _STATDESCRIBE._serialized_end = 5702 - _STATCROSSTAB._serialized_start = 5704 - _STATCROSSTAB._serialized_end = 5805 - _NAFILL._serialized_start = 5808 - _NAFILL._serialized_end = 5942 - _NADROP._serialized_start = 5945 - _NADROP._serialized_end = 6079 - _NAREPLACE._serialized_start = 6082 - _NAREPLACE._serialized_end = 6378 - _NAREPLACE_REPLACEMENT._serialized_start = 6237 - _NAREPLACE_REPLACEMENT._serialized_end = 6378 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6380 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6494 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6497 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6756 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6689 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6756 - _WITHCOLUMNS._serialized_start = 6759 - _WITHCOLUMNS._serialized_end = 6890 - _HINT._serialized_start = 6893 - _HINT._serialized_end = 7033 - _UNPIVOT._serialized_start = 7036 - _UNPIVOT._serialized_end = 7282 - _TOSCHEMA._serialized_start = 7284 - _TOSCHEMA._serialized_end = 7390 + _AGGREGATE._serialized_end = 4467 + _AGGREGATE_PIVOT._serialized_start = 4224 + _AGGREGATE_PIVOT._serialized_end = 4335 + _AGGREGATE_GROUPTYPE._serialized_start = 4338 + _AGGREGATE_GROUPTYPE._serialized_end = 4467 + _SORT._serialized_start = 4470 + _SORT._serialized_end = 4630 + _DROP._serialized_start = 4632 + _DROP._serialized_end = 4732 + _DEDUPLICATE._serialized_start = 4735 + _DEDUPLICATE._serialized_end = 4906 + _LOCALRELATION._serialized_start = 4909 + _LOCALRELATION._serialized_end = 5046 + _SAMPLE._serialized_start = 5049 + _SAMPLE._serialized_end = 5344 + _RANGE._serialized_start = 5347 + _RANGE._serialized_end = 5492 + _SUBQUERYALIAS._serialized_start = 5494 + _SUBQUERYALIAS._serialized_end = 5608 + _REPARTITION._serialized_start = 5611 + _REPARTITION._serialized_end = 5753 + _SHOWSTRING._serialized_start = 5756 + _SHOWSTRING._serialized_end = 5897 + _STATSUMMARY._serialized_start = 5899 + _STATSUMMARY._serialized_end = 5991 + _STATDESCRIBE._serialized_start = 5993 + _STATDESCRIBE._serialized_end = 6074 + _STATCROSSTAB._serialized_start = 6076 + _STATCROSSTAB._serialized_end = 6177 + _NAFILL._serialized_start = 6180 + _NAFILL._serialized_end = 6314 + _NADROP._serialized_start = 6317 + _NADROP._serialized_end = 6451 + _NAREPLACE._serialized_start = 6454 + _NAREPLACE._serialized_end = 6750 + _NAREPLACE_REPLACEMENT._serialized_start = 6609 + _NAREPLACE_REPLACEMENT._serialized_end = 6750 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6752 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6866 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6869 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7128 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 7061 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7128 + _WITHCOLUMNS._serialized_start = 7131 + _WITHCOLUMNS._serialized_end = 7262 + _HINT._serialized_start = 7265 + _HINT._serialized_end = 7405 + _UNPIVOT._serialized_start = 7408 + _UNPIVOT._serialized_end = 7654 + _TOSCHEMA._serialized_start = 7656 + _TOSCHEMA._serialized_end = 7762 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 62308eaaa81..f9032be6a49 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -909,49 +909,121 @@ class Aggregate(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class _GroupType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _GroupTypeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Aggregate._GroupType.ValueType], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + GROUP_TYPE_UNSPECIFIED: Aggregate._GroupType.ValueType # 0 + GROUP_TYPE_GROUPBY: Aggregate._GroupType.ValueType # 1 + GROUP_TYPE_ROLLUP: Aggregate._GroupType.ValueType # 2 + GROUP_TYPE_CUBE: Aggregate._GroupType.ValueType # 3 + GROUP_TYPE_PIVOT: Aggregate._GroupType.ValueType # 4 + + class GroupType(_GroupType, metaclass=_GroupTypeEnumTypeWrapper): ... + GROUP_TYPE_UNSPECIFIED: Aggregate.GroupType.ValueType # 0 + GROUP_TYPE_GROUPBY: Aggregate.GroupType.ValueType # 1 + GROUP_TYPE_ROLLUP: Aggregate.GroupType.ValueType # 2 + GROUP_TYPE_CUBE: Aggregate.GroupType.ValueType # 3 + GROUP_TYPE_PIVOT: Aggregate.GroupType.ValueType # 4 + + class Pivot(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + COL_FIELD_NUMBER: builtins.int + VALUES_FIELD_NUMBER: builtins.int + @property + def col(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression: + """(Required) The column to pivot""" + @property + def values( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression.Literal + ]: + """(Optional) List of values that will be translated to columns in the output DataFrame. + + Note that if it is empty, the server side will immediately trigger a job to collect + the distinct values of the column. + """ + def __init__( + self, + *, + col: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ..., + values: collections.abc.Iterable[ + pyspark.sql.connect.proto.expressions_pb2.Expression.Literal + ] + | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["col", b"col"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["col", b"col", "values", b"values"] + ) -> None: ... + INPUT_FIELD_NUMBER: builtins.int + GROUP_TYPE_FIELD_NUMBER: builtins.int GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int - RESULT_EXPRESSIONS_FIELD_NUMBER: builtins.int + AGGREGATE_EXPRESSIONS_FIELD_NUMBER: builtins.int + PIVOT_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: - """(Required) Input relation for a Aggregate.""" + """(Required) Input relation for a RelationalGroupedDataset.""" + group_type: global___Aggregate.GroupType.ValueType + """(Required) How the RelationalGroupedDataset was built.""" @property def grouping_expressions( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ pyspark.sql.connect.proto.expressions_pb2.Expression - ]: ... + ]: + """(Required) Expressions for grouping keys""" @property - def result_expressions( + def aggregate_expressions( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ pyspark.sql.connect.proto.expressions_pb2.Expression - ]: ... + ]: + """(Required) List of values that will be translated to columns in the output DataFrame.""" + @property + def pivot(self) -> global___Aggregate.Pivot: + """(Optional) Pivots a column of the current `DataFrame` and performs the specified aggregation.""" def __init__( self, *, input: global___Relation | None = ..., + group_type: global___Aggregate.GroupType.ValueType = ..., grouping_expressions: collections.abc.Iterable[ pyspark.sql.connect.proto.expressions_pb2.Expression ] | None = ..., - result_expressions: collections.abc.Iterable[ + aggregate_expressions: collections.abc.Iterable[ pyspark.sql.connect.proto.expressions_pb2.Expression ] | None = ..., + pivot: global___Aggregate.Pivot | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["input", b"input"] + self, field_name: typing_extensions.Literal["input", b"input", "pivot", b"pivot"] ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "aggregate_expressions", + b"aggregate_expressions", + "group_type", + b"group_type", "grouping_expressions", b"grouping_expressions", "input", b"input", - "result_expressions", - b"result_expressions", + "pivot", + b"pivot", ], ) -> None: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 3e977d95541..bced3fd5e7e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1155,6 +1155,142 @@ class SparkConnectTests(SparkConnectSQLTestCase): set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()), ) + def test_grouped_data(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + ('James', 'Sales', 3000, 2020), + ('Michael', 'Sales', 4600, 2020), + ('Robert', 'Sales', 4100, 2020), + ('Maria', 'Finance', 3000, 2020), + ('James', 'Sales', 3000, 2019), + ('Scott', 'Finance', 3300, 2020), + ('Jen', 'Finance', 3900, 2020), + ('Jeff', 'Marketing', 3000, 2020), + ('Kumar', 'Marketing', 2000, 2020), + ('Saif', 'Sales', 4100, 2020) + AS T(name, department, salary, year) + """ + + # +-------+----------+------+----+ + # | name|department|salary|year| + # +-------+----------+------+----+ + # | James| Sales| 3000|2020| + # |Michael| Sales| 4600|2020| + # | Robert| Sales| 4100|2020| + # | Maria| Finance| 3000|2020| + # | James| Sales| 3000|2019| + # | Scott| Finance| 3300|2020| + # | Jen| Finance| 3900|2020| + # | Jeff| Marketing| 3000|2020| + # | Kumar| Marketing| 2000|2020| + # | Saif| Sales| 4100|2020| + # +-------+----------+------+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test groupby + self.assert_eq( + cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(), + sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(), + ) + self.assert_eq( + cdf.groupBy("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), + sdf.groupBy("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), + ) + + # test rollup + self.assert_eq( + cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(), + sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(), + ) + self.assert_eq( + cdf.rollup("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), + sdf.rollup("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), + ) + + # test cube + self.assert_eq( + cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(), + sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(), + ) + self.assert_eq( + cdf.cube("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(), + sdf.cube("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(), + ) + + # test pivot + # pivot with values + self.assert_eq( + cdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .agg(CF.sum(cdf.salary)) + .toPandas(), + sdf.groupBy("name") + .pivot("department", ["Sales", "Marketing"]) + .agg(SF.sum(sdf.salary)) + .toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name) + .pivot("department", ["Sales", "Finance", "Marketing"]) + .agg(CF.sum(cdf.salary)) + .toPandas(), + sdf.groupBy(sdf.name) + .pivot("department", ["Sales", "Finance", "Marketing"]) + .agg(SF.sum(sdf.salary)) + .toPandas(), + ) + self.assert_eq( + cdf.groupBy(cdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .agg(CF.sum(cdf.salary)) + .toPandas(), + sdf.groupBy(sdf.name) + .pivot("department", ["Sales", "Finance", "Unknown"]) + .agg(SF.sum(sdf.salary)) + .toPandas(), + ) + + # pivot without values + self.assert_eq( + cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(), + sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(), + ) + + self.assert_eq( + cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(), + sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(), + ) + + # check error + with self.assertRaisesRegex( + Exception, + "PIVOT after ROLLUP is not supported", + ): + cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary)) + + with self.assertRaisesRegex( + Exception, + "PIVOT after CUBE is not supported", + ): + cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary)) + + with self.assertRaisesRegex( + Exception, + "Repeated PIVOT operation is not supported", + ): + cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary)) + + with self.assertRaisesRegex( + TypeError, + "value should be a bool, float, int or str, but got bytes", + ): + cdf.groupBy("name").pivot("department", ["Sales", b"Marketing"]).agg(CF.sum(cdf.salary)) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ChannelBuilderTests(ReusedPySparkTestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org