This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 80118e2c688 [SPARK-44807][CONNECT] Add Dataset.metadataColumn to Scala Client 80118e2c688 is described below commit 80118e2c688cdeebf49925385dfec376079b003b Author: Herman van Hovell <her...@databricks.com> AuthorDate: Wed Aug 16 18:14:33 2023 +0200 [SPARK-44807][CONNECT] Add Dataset.metadataColumn to Scala Client ### What changes were proposed in this pull request? This PR adds Dataset.metadataColumn to the Spark Connect Scala Client. ### Why are the changes needed? We want the scala client to be as compatible as possible with the API provided by sql/core. ### Does this PR introduce _any_ user-facing change? Yes, it adds a new method to the Spark Connect Scala Client. ### How was this patch tested? I added a test to `ClientE2ETestSuite`. Closes #42492 from hvanhovell/SPARK-44807. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 16 +++++ .../org/apache/spark/sql/ClientE2ETestSuite.scala | 22 +++++++ .../CheckConnectJvmClientCompatibility.scala | 1 - .../main/protobuf/spark/connect/expressions.proto | 3 + .../sql/connect/planner/SparkConnectPlanner.scala | 3 + .../pyspark/sql/connect/proto/expressions_pb2.py | 68 +++++++++++----------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 25 +++++++- .../catalyst/analysis/ColumnResolutionHelper.scala | 9 ++- .../sql/catalyst/plans/logical/LogicalPlan.scala | 1 + 9 files changed, 111 insertions(+), 37 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index cb7d2c84df5..3c89e649020 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1043,6 +1043,22 @@ class Dataset[T] private[sql] ( Column.apply(colName, getPlanId) } + /** + * Selects a metadata column based on its logical column name, and returns it as a [[Column]]. + * + * A metadata column can be accessed this way even if the underlying data source defines a data + * column with a conflicting name. + * + * @group untypedrel + * @since 3.5.0 + */ + def metadataColumn(colName: String): Column = Column { builder => + val attributeBuilder = builder.getUnresolvedAttributeBuilder + .setUnparsedIdentifier(colName) + .setIsMetadataColumn(true) + getPlanId.foreach(attributeBuilder.setPlanId) + } + /** * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 074cf170dd3..7b9b5f43e80 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -1203,6 +1203,28 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM .dropDuplicatesWithinWatermark("newcol") testAndVerify(result2) } + + test("Dataset.metadataColumn") { + val session: SparkSession = spark + import session.implicits._ + withTempPath { file => + val path = file.getAbsoluteFile.toURI.toString + spark + .range(0, 100, 1, 1) + .withColumn("_metadata", concat(lit("lol_"), col("id"))) + .write + .parquet(file.toPath.toAbsolutePath.toString) + + val df = spark.read.parquet(path) + val (filepath, rc) = df + .groupBy(df.metadataColumn("_metadata").getField("file_path")) + .count() + .as[(String, Long)] + .head() + assert(filepath.startsWith(path)) + assert(rc == 100) + } + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 8f226eb2f7e..3d7a80b1fb6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -179,7 +179,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"), diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index b222f663cd0..4aac2bcc612 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -223,6 +223,9 @@ message Expression { // (Optional) The id of corresponding connect plan. optional int64 plan_id = 2; + + // (Optional) The requested column is a metadata column. + optional bool is_metadata_column = 3; } // An unresolved function is not explicitly bound to one explicit function, but the function diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 0a72d1f70c6..a7e0d9aa0c7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1404,6 +1404,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { if (attr.hasPlanId) { expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId) } + if (attr.hasIsMetadataColumn) { + expr.setTagValue(LogicalPlan.IS_METADATA_COL, attr.getIsMetadataColumn) + } expr } diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 51ad47bb1c8..eb125fab39c 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbf,\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x8a-\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -45,7 +45,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 5800 + _EXPRESSION._serialized_end = 5875 _EXPRESSION_WINDOW._serialized_start = 1645 _EXPRESSION_WINDOW._serialized_end = 2428 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935 @@ -74,36 +74,36 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_LITERAL_MAP._serialized_end = 4422 _EXPRESSION_LITERAL_STRUCT._serialized_start = 4425 _EXPRESSION_LITERAL_STRUCT._serialized_end = 4554 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4572 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4684 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4687 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4891 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4893 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4943 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4945 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5027 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5029 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5115 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5118 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5250 - _EXPRESSION_UPDATEFIELDS._serialized_start = 5253 - _EXPRESSION_UPDATEFIELDS._serialized_end = 5440 - _EXPRESSION_ALIAS._serialized_start = 5442 - _EXPRESSION_ALIAS._serialized_end = 5562 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5565 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5723 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5725 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5787 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5803 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6167 - _PYTHONUDF._serialized_start = 6170 - _PYTHONUDF._serialized_end = 6325 - _SCALARSCALAUDF._serialized_start = 6328 - _SCALARSCALAUDF._serialized_end = 6512 - _JAVAUDF._serialized_start = 6515 - _JAVAUDF._serialized_end = 6664 - _CALLFUNCTION._serialized_start = 6666 - _CALLFUNCTION._serialized_end = 6774 - _NAMEDARGUMENTEXPRESSION._serialized_start = 6776 - _NAMEDARGUMENTEXPRESSION._serialized_end = 6868 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4573 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4759 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4762 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4966 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4968 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5018 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5020 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5102 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5104 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5190 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5193 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5325 + _EXPRESSION_UPDATEFIELDS._serialized_start = 5328 + _EXPRESSION_UPDATEFIELDS._serialized_end = 5515 + _EXPRESSION_ALIAS._serialized_start = 5517 + _EXPRESSION_ALIAS._serialized_end = 5637 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5640 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5798 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5800 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5862 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5878 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6242 + _PYTHONUDF._serialized_start = 6245 + _PYTHONUDF._serialized_end = 6400 + _SCALARSCALAUDF._serialized_start = 6403 + _SCALARSCALAUDF._serialized_end = 6587 + _JAVAUDF._serialized_start = 6590 + _JAVAUDF._serialized_end = 6739 + _CALLFUNCTION._serialized_start = 6741 + _CALLFUNCTION._serialized_end = 6849 + _NAMEDARGUMENTEXPRESSION._serialized_start = 6851 + _NAMEDARGUMENTEXPRESSION._serialized_end = 6943 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 2b418ef23f6..b590d22da2c 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -750,33 +750,56 @@ class Expression(google.protobuf.message.Message): UNPARSED_IDENTIFIER_FIELD_NUMBER: builtins.int PLAN_ID_FIELD_NUMBER: builtins.int + IS_METADATA_COLUMN_FIELD_NUMBER: builtins.int unparsed_identifier: builtins.str """(Required) An identifier that will be parsed by Catalyst parser. This should follow the Spark SQL identifier syntax. """ plan_id: builtins.int """(Optional) The id of corresponding connect plan.""" + is_metadata_column: builtins.bool + """(Optional) The requested column is a metadata column.""" def __init__( self, *, unparsed_identifier: builtins.str = ..., plan_id: builtins.int | None = ..., + is_metadata_column: builtins.bool | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"], + field_name: typing_extensions.Literal[ + "_is_metadata_column", + b"_is_metadata_column", + "_plan_id", + b"_plan_id", + "is_metadata_column", + b"is_metadata_column", + "plan_id", + b"plan_id", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_is_metadata_column", + b"_is_metadata_column", "_plan_id", b"_plan_id", + "is_metadata_column", + b"is_metadata_column", "plan_id", b"plan_id", "unparsed_identifier", b"unparsed_identifier", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["_is_metadata_column", b"_is_metadata_column"], + ) -> typing_extensions.Literal["is_metadata_column"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"] ) -> typing_extensions.Literal["plan_id"] | None: ... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index b631d1fd8b6..56d1f5f3a10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -511,8 +511,15 @@ trait ColumnResolutionHelper extends Logging { } val plan = planOpt.get + val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).contains(true) try { - plan.resolve(u.nameParts, conf.resolver) + if (!isMetadataAccess) { + plan.resolve(u.nameParts, conf.resolver) + } else if (u.nameParts.size == 1) { + plan.getMetadataAttributeByNameOpt(u.nameParts.head) + } else { + None + } } catch { case e: AnalysisException => logDebug(s"Fail to resolve $u with $plan due to $e") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 374eb070db1..a57c9d7a162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -196,6 +196,7 @@ object LogicalPlan { // 3, resolve this expression with the matching node. If any error occurs, analyzer fallbacks // to the old code path. private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id") + private[spark] val IS_METADATA_COL = TreeNodeTag[Boolean]("is_metadata_col") } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org