This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3212fa96016 [SPARK-41439][CONNECT][PYTHON] Implement `DataFrame.melt` and `DataFrame.unpivot` 3212fa96016 is described below commit 3212fa960169b1f1c29d63185aa96d535798fcc4 Author: Jiaan Geng <belie...@163.com> AuthorDate: Fri Dec 9 16:27:01 2022 +0800 [SPARK-41439][CONNECT][PYTHON] Implement `DataFrame.melt` and `DataFrame.unpivot` ### What changes were proposed in this pull request? Implement `DataFrame.melt` and `DataFrame.unpivot` with a proto message 1. Implement `DataFrame.melt` and `DataFrame.unpivot` for scala API 2. Implement `DataFrame.melt` and `DataFrame.unpivot` for python API ### Why are the changes needed? for Connect API coverage ### Does this PR introduce _any_ user-facing change? 'No'. New API ### How was this patch tested? New test cases. Closes #38973 from beliefer/SPARK-41439. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 19 +++ .../org/apache/spark/sql/connect/dsl/package.scala | 47 ++++++ .../sql/connect/planner/SparkConnectPlanner.scala | 33 +++- .../connect/planner/SparkConnectProtoSuite.scala | 28 ++++ python/pyspark/sql/connect/dataframe.py | 35 ++++ python/pyspark/sql/connect/plan.py | 59 +++++++ python/pyspark/sql/connect/proto/relations_pb2.py | 182 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 72 ++++++++ .../sql/tests/connect/test_connect_plan_only.py | 58 +++++++ 9 files changed, 448 insertions(+), 85 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index ece8767c06c..30468501236 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -54,6 +54,7 @@ message Relation { Tail tail = 22; WithColumns with_columns = 23; Hint hint = 24; + Unpivot unpivot = 25; // NA functions NAFill fill_na = 90; @@ -570,3 +571,21 @@ message Hint { // (Optional) Hint parameters. repeated Expression.Literal parameters = 3; } + +// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. +message Unpivot { + // (Required) The input relation. + Relation input = 1; + + // (Required) Id columns. + repeated Expression ids = 2; + + // (Optional) Value columns to unpivot. + repeated Expression values = 3; + + // (Required) Name of the variable column. + string variable_column_name = 4; + + // (Required) Name of the value column. + string value_column_name = 5; +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index fb79243ba37..545c2aaaf04 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -719,6 +719,53 @@ package object dsl { .build() } + def unpivot( + ids: Seq[Expression], + values: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = { + Relation + .newBuilder() + .setUnpivot( + Unpivot + .newBuilder() + .setInput(logicalPlan) + .addAllIds(ids.asJava) + .addAllValues(values.asJava) + .setVariableColumnName(variableColumnName) + .setValueColumnName(valueColumnName)) + .build() + } + + def unpivot( + ids: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = { + Relation + .newBuilder() + .setUnpivot( + Unpivot + .newBuilder() + .setInput(logicalPlan) + .addAllIds(ids.asJava) + .setVariableColumnName(variableColumnName) + .setValueColumnName(valueColumnName)) + .build() + } + + def melt( + ids: Seq[Expression], + values: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = + unpivot(ids, values, variableColumnName, valueColumnName) + + def melt( + ids: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = + unpivot(ids, variableColumnName, valueColumnName) + private def createSetOperation( left: Relation, right: Relation, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 55283ca96b1..0ea8cc6c634 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -95,6 +95,7 @@ class SparkConnectPlanner(session: SparkSession) { transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap) case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) + case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -309,6 +310,36 @@ class SparkConnectPlanner(session: SparkSession) { UnresolvedHint(rel.getName, params, transformRelation(rel.getInput)) } + private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = { + val ids = rel.getIdsList.asScala.toArray.map { expr => + Column(transformExpression(expr)) + } + + if (rel.getValuesList.isEmpty) { + Unpivot( + Some(ids.map(_.named)), + None, + None, + rel.getVariableColumnName, + Seq(rel.getValueColumnName), + transformRelation(rel.getInput) + ) + } else { + val values = rel.getValuesList.asScala.toArray.map { expr => + Column(transformExpression(expr)) + } + + Unpivot( + Some(ids.map(_.named)), + Some(values.map(v => Seq(v.named))), + None, + rel.getVariableColumnName, + Seq(rel.getValueColumnName), + transformRelation(rel.getInput) + ) + } + } + private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { if (!rel.hasInput) { throw InvalidPlanInput("Deduplicate needs a plan input") diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index c04d7bde746..8611ba45f75 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -560,6 +560,34 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectTestRelation.hint("COALESCE", 3), sparkTestRelation.hint("COALESCE", 3)) } + test("Test Unpivot") { + val connectPlan0 = + connectTestRelation.unpivot(Seq("id".protoAttr), Seq("name".protoAttr), "variable", "value") + val sparkPlan0 = + sparkTestRelation.unpivot(Array(Column("id")), Array(Column("name")), "variable", "value") + comparePlans(connectPlan0, sparkPlan0) + + val connectPlan1 = + connectTestRelation.unpivot(Seq("id".protoAttr), "variable", "value") + val sparkPlan1 = + sparkTestRelation.unpivot(Array(Column("id")), "variable", "value") + comparePlans(connectPlan1, sparkPlan1) + } + + test("Test Melt") { + val connectPlan0 = + connectTestRelation.melt(Seq("id".protoAttr), Seq("name".protoAttr), "variable", "value") + val sparkPlan0 = + sparkTestRelation.melt(Array(Column("id")), Array(Column("name")), "variable", "value") + comparePlans(connectPlan0, sparkPlan0) + + val connectPlan1 = + connectTestRelation.melt(Seq("id".protoAttr), "variable", "value") + val sparkPlan1 = + sparkTestRelation.melt(Array(Column("id")), "variable", "value") + comparePlans(connectPlan1, sparkPlan1) + } + private def createLocalRelationProtoByAttributeReferences( attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index f268dc431b8..4c1956cc577 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -824,6 +824,41 @@ class DataFrame(object): session=self._session, ) + def unpivot( + self, + ids: List["ColumnOrName"], + values: List["ColumnOrName"], + variableColumnName: str, + valueColumnName: str, + ) -> "DataFrame": + """ + Returns a new :class:`DataFrame` by unpivot a DataFrame from wide format to long format, + optionally leaving identifier columns set. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + ids : list + Id columns. + values : list, optional + Value columns to unpivot. + variableColumnName : str + Name of the variable column. + valueColumnName : str + Name of the value column. + + Returns + ------- + :class:`DataFrame` + """ + return DataFrame.withPlan( + plan.Unpivot(self._plan, ids, values, variableColumnName, valueColumnName), + self._session, + ) + + melt = unpivot + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: """ Prints the first ``n`` rows to the console. diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 2de0dbb40c3..748a353e0c3 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -984,6 +984,65 @@ class RenameColumnsNameByName(LogicalPlan): """ +class Unpivot(LogicalPlan): + """Logical plan object for a unpivot operation.""" + + def __init__( + self, + child: Optional["LogicalPlan"], + ids: List["ColumnOrName"], + values: List["ColumnOrName"], + variable_column_name: str, + value_column_name: str, + ) -> None: + super().__init__(child) + self.ids = ids + self.values = values + self.variable_column_name = variable_column_name + self.value_column_name = value_column_name + + def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression: + if isinstance(col, Column): + return col.to_plan(session) + else: + return self.unresolved_attr(col) + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + + plan = proto.Relation() + plan.unpivot.input.CopyFrom(self._child.plan(session)) + plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids]) + plan.unpivot.values.extend([self.col_to_expr(x, session) for x in self.values]) + plan.unpivot.variable_column_name = self.variable_column_name + plan.unpivot.value_column_name = self.value_column_name + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}" + f"<Unpivot ids={self.ids}, values={self.values}, " + f"variable_column_name={self.variable_column_name}, " + f"value_column_name={self.value_column_name}>" + f"\n{c_buf}" + ) + + def _repr_html_(self) -> str: + return f""" + <ul> + <li> + <b>Unpivot</b><br /> + ids: {self.ids} + values: {self.values} + variable_column_name: {self.variable_column_name} + value_column_name: {self.value_column_name} + {self._child._repr_html_() if self._child is not None else ""} + </li> + </uL> + """ + + class NAFill(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any] diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 06cf18417d2..68e4c423cc4 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8b\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) @@ -77,6 +77,7 @@ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = ( ) _WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"] _HINT = DESCRIPTOR.message_types_by_name["Hint"] +_UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"] _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"] _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"] _SORT_SORTDIRECTION = _SORT.enum_types_by_name["SortDirection"] @@ -493,6 +494,17 @@ Hint = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(Hint) +Unpivot = _reflection.GeneratedProtocolMessageType( + "Unpivot", + (_message.Message,), + { + "DESCRIPTOR": _UNPIVOT, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Unpivot) + }, +) +_sym_db.RegisterMessage(Unpivot) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -502,87 +514,89 @@ if _descriptor._USE_C_DESCRIPTORS == False: _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1885 - _UNKNOWN._serialized_start = 1887 - _UNKNOWN._serialized_end = 1896 - _RELATIONCOMMON._serialized_start = 1898 - _RELATIONCOMMON._serialized_end = 1947 - _SQL._serialized_start = 1949 - _SQL._serialized_end = 1976 - _READ._serialized_start = 1979 - _READ._serialized_end = 2405 - _READ_NAMEDTABLE._serialized_start = 2121 - _READ_NAMEDTABLE._serialized_end = 2182 - _READ_DATASOURCE._serialized_start = 2185 - _READ_DATASOURCE._serialized_end = 2392 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2323 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2381 - _PROJECT._serialized_start = 2407 - _PROJECT._serialized_end = 2524 - _FILTER._serialized_start = 2526 - _FILTER._serialized_end = 2638 - _JOIN._serialized_start = 2641 - _JOIN._serialized_end = 3112 - _JOIN_JOINTYPE._serialized_start = 2904 - _JOIN_JOINTYPE._serialized_end = 3112 - _SETOPERATION._serialized_start = 3115 - _SETOPERATION._serialized_end = 3511 - _SETOPERATION_SETOPTYPE._serialized_start = 3374 - _SETOPERATION_SETOPTYPE._serialized_end = 3488 - _LIMIT._serialized_start = 3513 - _LIMIT._serialized_end = 3589 - _OFFSET._serialized_start = 3591 - _OFFSET._serialized_end = 3670 - _TAIL._serialized_start = 3672 - _TAIL._serialized_end = 3747 - _AGGREGATE._serialized_start = 3750 - _AGGREGATE._serialized_end = 3960 - _SORT._serialized_start = 3963 - _SORT._serialized_end = 4513 - _SORT_SORTFIELD._serialized_start = 4117 - _SORT_SORTFIELD._serialized_end = 4305 - _SORT_SORTDIRECTION._serialized_start = 4307 - _SORT_SORTDIRECTION._serialized_end = 4415 - _SORT_SORTNULLS._serialized_start = 4417 - _SORT_SORTNULLS._serialized_end = 4499 - _DROP._serialized_start = 4515 - _DROP._serialized_end = 4615 - _DEDUPLICATE._serialized_start = 4618 - _DEDUPLICATE._serialized_end = 4789 - _LOCALRELATION._serialized_start = 4791 - _LOCALRELATION._serialized_end = 4826 - _SAMPLE._serialized_start = 4829 - _SAMPLE._serialized_end = 5053 - _RANGE._serialized_start = 5056 - _RANGE._serialized_end = 5201 - _SUBQUERYALIAS._serialized_start = 5203 - _SUBQUERYALIAS._serialized_end = 5317 - _REPARTITION._serialized_start = 5320 - _REPARTITION._serialized_end = 5462 - _SHOWSTRING._serialized_start = 5465 - _SHOWSTRING._serialized_end = 5606 - _STATSUMMARY._serialized_start = 5608 - _STATSUMMARY._serialized_end = 5700 - _STATDESCRIBE._serialized_start = 5702 - _STATDESCRIBE._serialized_end = 5783 - _STATCROSSTAB._serialized_start = 5785 - _STATCROSSTAB._serialized_end = 5886 - _NAFILL._serialized_start = 5889 - _NAFILL._serialized_end = 6023 - _NADROP._serialized_start = 6026 - _NADROP._serialized_end = 6160 - _NAREPLACE._serialized_start = 6163 - _NAREPLACE._serialized_end = 6459 - _NAREPLACE_REPLACEMENT._serialized_start = 6318 - _NAREPLACE_REPLACEMENT._serialized_end = 6459 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6461 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6575 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6578 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6837 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6770 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6837 - _WITHCOLUMNS._serialized_start = 6840 - _WITHCOLUMNS._serialized_end = 6971 - _HINT._serialized_start = 6974 - _HINT._serialized_end = 7114 + _RELATION._serialized_end = 1937 + _UNKNOWN._serialized_start = 1939 + _UNKNOWN._serialized_end = 1948 + _RELATIONCOMMON._serialized_start = 1950 + _RELATIONCOMMON._serialized_end = 1999 + _SQL._serialized_start = 2001 + _SQL._serialized_end = 2028 + _READ._serialized_start = 2031 + _READ._serialized_end = 2457 + _READ_NAMEDTABLE._serialized_start = 2173 + _READ_NAMEDTABLE._serialized_end = 2234 + _READ_DATASOURCE._serialized_start = 2237 + _READ_DATASOURCE._serialized_end = 2444 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2375 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2433 + _PROJECT._serialized_start = 2459 + _PROJECT._serialized_end = 2576 + _FILTER._serialized_start = 2578 + _FILTER._serialized_end = 2690 + _JOIN._serialized_start = 2693 + _JOIN._serialized_end = 3164 + _JOIN_JOINTYPE._serialized_start = 2956 + _JOIN_JOINTYPE._serialized_end = 3164 + _SETOPERATION._serialized_start = 3167 + _SETOPERATION._serialized_end = 3563 + _SETOPERATION_SETOPTYPE._serialized_start = 3426 + _SETOPERATION_SETOPTYPE._serialized_end = 3540 + _LIMIT._serialized_start = 3565 + _LIMIT._serialized_end = 3641 + _OFFSET._serialized_start = 3643 + _OFFSET._serialized_end = 3722 + _TAIL._serialized_start = 3724 + _TAIL._serialized_end = 3799 + _AGGREGATE._serialized_start = 3802 + _AGGREGATE._serialized_end = 4012 + _SORT._serialized_start = 4015 + _SORT._serialized_end = 4565 + _SORT_SORTFIELD._serialized_start = 4169 + _SORT_SORTFIELD._serialized_end = 4357 + _SORT_SORTDIRECTION._serialized_start = 4359 + _SORT_SORTDIRECTION._serialized_end = 4467 + _SORT_SORTNULLS._serialized_start = 4469 + _SORT_SORTNULLS._serialized_end = 4551 + _DROP._serialized_start = 4567 + _DROP._serialized_end = 4667 + _DEDUPLICATE._serialized_start = 4670 + _DEDUPLICATE._serialized_end = 4841 + _LOCALRELATION._serialized_start = 4843 + _LOCALRELATION._serialized_end = 4878 + _SAMPLE._serialized_start = 4881 + _SAMPLE._serialized_end = 5105 + _RANGE._serialized_start = 5108 + _RANGE._serialized_end = 5253 + _SUBQUERYALIAS._serialized_start = 5255 + _SUBQUERYALIAS._serialized_end = 5369 + _REPARTITION._serialized_start = 5372 + _REPARTITION._serialized_end = 5514 + _SHOWSTRING._serialized_start = 5517 + _SHOWSTRING._serialized_end = 5658 + _STATSUMMARY._serialized_start = 5660 + _STATSUMMARY._serialized_end = 5752 + _STATDESCRIBE._serialized_start = 5754 + _STATDESCRIBE._serialized_end = 5835 + _STATCROSSTAB._serialized_start = 5837 + _STATCROSSTAB._serialized_end = 5938 + _NAFILL._serialized_start = 5941 + _NAFILL._serialized_end = 6075 + _NADROP._serialized_start = 6078 + _NADROP._serialized_end = 6212 + _NAREPLACE._serialized_start = 6215 + _NAREPLACE._serialized_end = 6511 + _NAREPLACE_REPLACEMENT._serialized_start = 6370 + _NAREPLACE_REPLACEMENT._serialized_end = 6511 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6513 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6627 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6630 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6889 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6822 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6889 + _WITHCOLUMNS._serialized_start = 6892 + _WITHCOLUMNS._serialized_end = 7023 + _HINT._serialized_start = 7026 + _HINT._serialized_end = 7166 + _UNPIVOT._serialized_start = 7169 + _UNPIVOT._serialized_end = 7415 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index f1336613687..db872092002 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -83,6 +83,7 @@ class Relation(google.protobuf.message.Message): TAIL_FIELD_NUMBER: builtins.int WITH_COLUMNS_FIELD_NUMBER: builtins.int HINT_FIELD_NUMBER: builtins.int + UNPIVOT_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -139,6 +140,8 @@ class Relation(google.protobuf.message.Message): @property def hint(self) -> global___Hint: ... @property + def unpivot(self) -> global___Unpivot: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -181,6 +184,7 @@ class Relation(google.protobuf.message.Message): tail: global___Tail | None = ..., with_columns: global___WithColumns | None = ..., hint: global___Hint | None = ..., + unpivot: global___Unpivot | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -254,6 +258,8 @@ class Relation(google.protobuf.message.Message): b"tail", "unknown", b"unknown", + "unpivot", + b"unpivot", "with_columns", b"with_columns", ], @@ -323,6 +329,8 @@ class Relation(google.protobuf.message.Message): b"tail", "unknown", b"unknown", + "unpivot", + b"unpivot", "with_columns", b"with_columns", ], @@ -353,6 +361,7 @@ class Relation(google.protobuf.message.Message): "tail", "with_columns", "hint", + "unpivot", "fill_na", "drop_na", "replace", @@ -1963,3 +1972,66 @@ class Hint(google.protobuf.message.Message): ) -> None: ... global___Hint = Hint + +class Unpivot(google.protobuf.message.Message): + """Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + IDS_FIELD_NUMBER: builtins.int + VALUES_FIELD_NUMBER: builtins.int + VARIABLE_COLUMN_NAME_FIELD_NUMBER: builtins.int + VALUE_COLUMN_NAME_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + @property + def ids( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Required) Id columns.""" + @property + def values( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Optional) Value columns to unpivot.""" + variable_column_name: builtins.str + """(Required) Name of the variable column.""" + value_column_name: builtins.str + """(Required) Name of the value column.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + ids: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + | None = ..., + values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + | None = ..., + variable_column_name: builtins.str = ..., + value_column_name: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "ids", + b"ids", + "input", + b"input", + "value_column_name", + b"value_column_name", + "values", + b"values", + "variable_column_name", + b"variable_column_name", + ], + ) -> None: ... + +global___Unpivot = Unpivot diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index b9695eea785..7cd97f02d32 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -168,6 +168,64 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.replace.replacements[1].old_value.string, "Bob") self.assertEqual(plan.root.replace.replacements[1].new_value.string, "B") + def test_unpivot(self): + df = self.connect.readTable(table_name=self.tbl_name) + + plan = ( + df.filter(df.col_name > 3) + .unpivot(["id"], ["name"], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values)) + self.assertEqual( + plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name" + ) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + + plan = ( + df.filter(df.col_name > 3) + .unpivot(["id"], [], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(len(plan.root.unpivot.ids) == 1) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(len(plan.root.unpivot.values) == 0) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + + def test_melt(self): + df = self.connect.readTable(table_name=self.tbl_name) + + plan = ( + df.filter(df.col_name > 3) + .melt(["id"], ["name"], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values)) + self.assertEqual( + plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name" + ) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + + plan = ( + df.filter(df.col_name > 3) + .melt(["id"], [], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(len(plan.root.unpivot.ids) == 1) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(len(plan.root.unpivot.values) == 0) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + def test_summary(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org