This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e1382c566b7 [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL e1382c566b7 is described below commit e1382c566b7b2ba324fec1aed6556325ebe43f7b Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Wed Nov 9 15:48:24 2022 +0800 [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL ### What changes were proposed in this pull request? Add `RenameColumns` to proto to support the implementation for `toDF(columnNames: String*)` which renames the input relation to a different set of column names. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT Closes #38475 from amaliujia/SPARK-40992. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 12 ++ .../org/apache/spark/sql/connect/dsl/package.scala | 10 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 9 ++ .../connect/planner/SparkConnectProtoSuite.scala | 4 + python/pyspark/sql/connect/proto/relations_pb2.py | 126 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 44 +++++++ 6 files changed, 143 insertions(+), 62 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index dd03bd86940..cce9f3b939e 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -47,6 +47,7 @@ message Relation { Range range = 15; SubqueryAlias subquery_alias = 16; Repartition repartition = 17; + RenameColumns rename_columns = 18; StatFunction stat_function = 100; @@ -274,3 +275,14 @@ message StatFunction { } } +// Rename columns on the input relation. +message RenameColumns { + // Required. The input relation. + Relation input = 1; + + // Required. + // + // The number of columns of the input relation must be equal to the length + // of this field. If this is not true, an exception will be returned. + repeated string column_names = 2; +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 3e68b101057..d6f7a6756c3 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -457,6 +457,16 @@ package object dsl { .build() } + def toDF(columnNames: String*): Relation = + Relation + .newBuilder() + .setRenameColumns( + RenameColumns + .newBuilder() + .setInput(logicalPlan) + .addAllColumnNames(columnNames.asJava)) + .build() + private def createSetOperation( left: Relation, right: Relation, diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 3bbdbf80276..87716c702b5 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -69,6 +69,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) case proto.Relation.RelTypeCase.STAT_FUNCTION => transformStatFunction(rel.getStatFunction) + case proto.Relation.RelTypeCase.RENAME_COLUMNS => + transformRenameColumns(rel.getRenameColumns) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -133,6 +135,13 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } + private def transformRenameColumns(rel: proto.RenameColumns): LogicalPlan = { + Dataset + .ofRows(session, transformRelation(rel.getInput)) + .toDF(rel.getColumnNamesList.asScala.toSeq: _*) + .logicalPlan + } + private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { if (!rel.hasInput) { throw InvalidPlanInput("Deduplicate needs a plan input") diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index c5b6f4fc0ee..2339c676a38 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -267,6 +267,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { sparkTestRelation.summary("count", "mean", "stddev")) } + test("Test toDF") { + comparePlans(connectTestRelation.toDF("col1", "col2"), sparkTestRelation.toDF("col1", "col2")) + } + private def createLocalRelationProtoByQualifiedAttributes( attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index b11a4b0e91a..06b59ea5f45 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x90\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd7\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,65 +44,67 @@ if _descriptor._USE_C_DESCRIPTORS == False: _READ_DATASOURCE_OPTIONSENTRY._options = None _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1122 - _UNKNOWN._serialized_start = 1124 - _UNKNOWN._serialized_end = 1133 - _RELATIONCOMMON._serialized_start = 1135 - _RELATIONCOMMON._serialized_end = 1184 - _SQL._serialized_start = 1186 - _SQL._serialized_end = 1213 - _READ._serialized_start = 1216 - _READ._serialized_end = 1626 - _READ_NAMEDTABLE._serialized_start = 1358 - _READ_NAMEDTABLE._serialized_end = 1419 - _READ_DATASOURCE._serialized_start = 1422 - _READ_DATASOURCE._serialized_end = 1613 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1555 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1613 - _PROJECT._serialized_start = 1628 - _PROJECT._serialized_end = 1745 - _FILTER._serialized_start = 1747 - _FILTER._serialized_end = 1859 - _JOIN._serialized_start = 1862 - _JOIN._serialized_end = 2312 - _JOIN_JOINTYPE._serialized_start = 2125 - _JOIN_JOINTYPE._serialized_end = 2312 - _SETOPERATION._serialized_start = 2315 - _SETOPERATION._serialized_end = 2678 - _SETOPERATION_SETOPTYPE._serialized_start = 2564 - _SETOPERATION_SETOPTYPE._serialized_end = 2678 - _LIMIT._serialized_start = 2680 - _LIMIT._serialized_end = 2756 - _OFFSET._serialized_start = 2758 - _OFFSET._serialized_end = 2837 - _AGGREGATE._serialized_start = 2840 - _AGGREGATE._serialized_end = 3050 - _SORT._serialized_start = 3053 - _SORT._serialized_end = 3584 - _SORT_SORTFIELD._serialized_start = 3202 - _SORT_SORTFIELD._serialized_end = 3390 - _SORT_SORTDIRECTION._serialized_start = 3392 - _SORT_SORTDIRECTION._serialized_end = 3500 - _SORT_SORTNULLS._serialized_start = 3502 - _SORT_SORTNULLS._serialized_end = 3584 - _DEDUPLICATE._serialized_start = 3587 - _DEDUPLICATE._serialized_end = 3729 - _LOCALRELATION._serialized_start = 3731 - _LOCALRELATION._serialized_end = 3824 - _SAMPLE._serialized_start = 3827 - _SAMPLE._serialized_end = 4067 - _SAMPLE_SEED._serialized_start = 4041 - _SAMPLE_SEED._serialized_end = 4067 - _RANGE._serialized_start = 4070 - _RANGE._serialized_end = 4268 - _RANGE_NUMPARTITIONS._serialized_start = 4214 - _RANGE_NUMPARTITIONS._serialized_end = 4268 - _SUBQUERYALIAS._serialized_start = 4270 - _SUBQUERYALIAS._serialized_end = 4384 - _REPARTITION._serialized_start = 4386 - _REPARTITION._serialized_end = 4511 - _STATFUNCTION._serialized_start = 4514 - _STATFUNCTION._serialized_end = 4748 - _STATFUNCTION_SUMMARY._serialized_start = 4695 - _STATFUNCTION_SUMMARY._serialized_end = 4736 + _RELATION._serialized_end = 1193 + _UNKNOWN._serialized_start = 1195 + _UNKNOWN._serialized_end = 1204 + _RELATIONCOMMON._serialized_start = 1206 + _RELATIONCOMMON._serialized_end = 1255 + _SQL._serialized_start = 1257 + _SQL._serialized_end = 1284 + _READ._serialized_start = 1287 + _READ._serialized_end = 1697 + _READ_NAMEDTABLE._serialized_start = 1429 + _READ_NAMEDTABLE._serialized_end = 1490 + _READ_DATASOURCE._serialized_start = 1493 + _READ_DATASOURCE._serialized_end = 1684 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1626 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1684 + _PROJECT._serialized_start = 1699 + _PROJECT._serialized_end = 1816 + _FILTER._serialized_start = 1818 + _FILTER._serialized_end = 1930 + _JOIN._serialized_start = 1933 + _JOIN._serialized_end = 2383 + _JOIN_JOINTYPE._serialized_start = 2196 + _JOIN_JOINTYPE._serialized_end = 2383 + _SETOPERATION._serialized_start = 2386 + _SETOPERATION._serialized_end = 2749 + _SETOPERATION_SETOPTYPE._serialized_start = 2635 + _SETOPERATION_SETOPTYPE._serialized_end = 2749 + _LIMIT._serialized_start = 2751 + _LIMIT._serialized_end = 2827 + _OFFSET._serialized_start = 2829 + _OFFSET._serialized_end = 2908 + _AGGREGATE._serialized_start = 2911 + _AGGREGATE._serialized_end = 3121 + _SORT._serialized_start = 3124 + _SORT._serialized_end = 3655 + _SORT_SORTFIELD._serialized_start = 3273 + _SORT_SORTFIELD._serialized_end = 3461 + _SORT_SORTDIRECTION._serialized_start = 3463 + _SORT_SORTDIRECTION._serialized_end = 3571 + _SORT_SORTNULLS._serialized_start = 3573 + _SORT_SORTNULLS._serialized_end = 3655 + _DEDUPLICATE._serialized_start = 3658 + _DEDUPLICATE._serialized_end = 3800 + _LOCALRELATION._serialized_start = 3802 + _LOCALRELATION._serialized_end = 3895 + _SAMPLE._serialized_start = 3898 + _SAMPLE._serialized_end = 4138 + _SAMPLE_SEED._serialized_start = 4112 + _SAMPLE_SEED._serialized_end = 4138 + _RANGE._serialized_start = 4141 + _RANGE._serialized_end = 4339 + _RANGE_NUMPARTITIONS._serialized_start = 4285 + _RANGE_NUMPARTITIONS._serialized_end = 4339 + _SUBQUERYALIAS._serialized_start = 4341 + _SUBQUERYALIAS._serialized_end = 4455 + _REPARTITION._serialized_start = 4457 + _REPARTITION._serialized_end = 4582 + _STATFUNCTION._serialized_start = 4585 + _STATFUNCTION._serialized_end = 4819 + _STATFUNCTION_SUMMARY._serialized_start = 4766 + _STATFUNCTION_SUMMARY._serialized_end = 4807 + _RENAMECOLUMNS._serialized_start = 4821 + _RENAMECOLUMNS._serialized_end = 4918 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 6ee3c46d7c5..bef74b03659 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -76,6 +76,7 @@ class Relation(google.protobuf.message.Message): RANGE_FIELD_NUMBER: builtins.int SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int REPARTITION_FIELD_NUMBER: builtins.int + RENAME_COLUMNS_FIELD_NUMBER: builtins.int STAT_FUNCTION_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property @@ -113,6 +114,8 @@ class Relation(google.protobuf.message.Message): @property def repartition(self) -> global___Repartition: ... @property + def rename_columns(self) -> global___RenameColumns: ... + @property def stat_function(self) -> global___StatFunction: ... @property def unknown(self) -> global___Unknown: ... @@ -136,6 +139,7 @@ class Relation(google.protobuf.message.Message): range: global___Range | None = ..., subquery_alias: global___SubqueryAlias | None = ..., repartition: global___Repartition | None = ..., + rename_columns: global___RenameColumns | None = ..., stat_function: global___StatFunction | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... @@ -166,6 +170,8 @@ class Relation(google.protobuf.message.Message): b"read", "rel_type", b"rel_type", + "rename_columns", + b"rename_columns", "repartition", b"repartition", "sample", @@ -211,6 +217,8 @@ class Relation(google.protobuf.message.Message): b"read", "rel_type", b"rel_type", + "rename_columns", + b"rename_columns", "repartition", b"repartition", "sample", @@ -248,6 +256,7 @@ class Relation(google.protobuf.message.Message): "range", "subquery_alias", "repartition", + "rename_columns", "stat_function", "unknown", ] | None: ... @@ -1133,3 +1142,38 @@ class StatFunction(google.protobuf.message.Message): ) -> typing_extensions.Literal["summary", "unknown"] | None: ... global___StatFunction = StatFunction + +class RenameColumns(google.protobuf.message.Message): + """Rename columns on the input relation.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COLUMN_NAMES_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """Required. The input relation.""" + @property + def column_names( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Required. + + The number of columns of the input relation must be equal to the length + of this field. If this is not true, an exception will be returned. + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + column_names: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["column_names", b"column_names", "input", b"input"], + ) -> None: ... + +global___RenameColumns = RenameColumns --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org