This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 80118e2c688 [SPARK-44807][CONNECT] Add Dataset.metadataColumn to Scala 
Client
80118e2c688 is described below

commit 80118e2c688cdeebf49925385dfec376079b003b
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Wed Aug 16 18:14:33 2023 +0200

    [SPARK-44807][CONNECT] Add Dataset.metadataColumn to Scala Client
    
    ### What changes were proposed in this pull request?
    This PR adds Dataset.metadataColumn to the Spark Connect Scala Client.
    
    ### Why are the changes needed?
    We want the scala client to be as compatible as possible with the API 
provided by sql/core.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it adds a new method to the Spark Connect Scala Client.
    
    ### How was this patch tested?
    I added a test to `ClientE2ETestSuite`.
    
    Closes #42492 from hvanhovell/SPARK-44807.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 16 +++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  | 22 +++++++
 .../CheckConnectJvmClientCompatibility.scala       |  1 -
 .../main/protobuf/spark/connect/expressions.proto  |  3 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  3 +
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 68 +++++++++++-----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 25 +++++++-
 .../catalyst/analysis/ColumnResolutionHelper.scala |  9 ++-
 .../sql/catalyst/plans/logical/LogicalPlan.scala   |  1 +
 9 files changed, 111 insertions(+), 37 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index cb7d2c84df5..3c89e649020 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1043,6 +1043,22 @@ class Dataset[T] private[sql] (
     Column.apply(colName, getPlanId)
   }
 
+  /**
+   * Selects a metadata column based on its logical column name, and returns 
it as a [[Column]].
+   *
+   * A metadata column can be accessed this way even if the underlying data 
source defines a data
+   * column with a conflicting name.
+   *
+   * @group untypedrel
+   * @since 3.5.0
+   */
+  def metadataColumn(colName: String): Column = Column { builder =>
+    val attributeBuilder = builder.getUnresolvedAttributeBuilder
+      .setUnparsedIdentifier(colName)
+      .setIsMetadataColumn(true)
+    getPlanId.foreach(attributeBuilder.setPlanId)
+  }
+
   /**
    * Selects column based on the column name specified as a regex and returns 
it as [[Column]].
    * @group untypedrel
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 074cf170dd3..7b9b5f43e80 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -1203,6 +1203,28 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
       .dropDuplicatesWithinWatermark("newcol")
     testAndVerify(result2)
   }
+
+  test("Dataset.metadataColumn") {
+    val session: SparkSession = spark
+    import session.implicits._
+    withTempPath { file =>
+      val path = file.getAbsoluteFile.toURI.toString
+      spark
+        .range(0, 100, 1, 1)
+        .withColumn("_metadata", concat(lit("lol_"), col("id")))
+        .write
+        .parquet(file.toPath.toAbsolutePath.toString)
+
+      val df = spark.read.parquet(path)
+      val (filepath, rc) = df
+        .groupBy(df.metadataColumn("_metadata").getField("file_path"))
+        .count()
+        .as[(String, Long)]
+        .head()
+      assert(filepath.startsWith(path))
+      assert(rc == 100)
+    }
+  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 8f226eb2f7e..3d7a80b1fb6 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -179,7 +179,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), 
// protected
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index b222f663cd0..4aac2bcc612 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -223,6 +223,9 @@ message Expression {
 
     // (Optional) The id of corresponding connect plan.
     optional int64 plan_id = 2;
+
+    // (Optional) The requested column is a metadata column.
+    optional bool is_metadata_column = 3;
   }
 
   // An unresolved function is not explicitly bound to one explicit function, 
but the function
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0a72d1f70c6..a7e0d9aa0c7 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1404,6 +1404,9 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     if (attr.hasPlanId) {
       expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId)
     }
+    if (attr.hasIsMetadataColumn) {
+      expr.setTagValue(LogicalPlan.IS_METADATA_COL, attr.getIsMetadataColumn)
+    }
     expr
   }
 
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 51ad47bb1c8..eb125fab39c 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbf,\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x8a-\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,7 +45,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
         b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     )
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 5800
+    _EXPRESSION._serialized_end = 5875
     _EXPRESSION_WINDOW._serialized_start = 1645
     _EXPRESSION_WINDOW._serialized_end = 2428
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935
@@ -74,36 +74,36 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_LITERAL_MAP._serialized_end = 4422
     _EXPRESSION_LITERAL_STRUCT._serialized_start = 4425
     _EXPRESSION_LITERAL_STRUCT._serialized_end = 4554
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4572
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4684
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4687
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4891
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4893
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4943
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4945
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5027
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5029
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5115
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5118
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5250
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 5253
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 5440
-    _EXPRESSION_ALIAS._serialized_start = 5442
-    _EXPRESSION_ALIAS._serialized_end = 5562
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5565
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5723
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5725
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5787
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5803
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6167
-    _PYTHONUDF._serialized_start = 6170
-    _PYTHONUDF._serialized_end = 6325
-    _SCALARSCALAUDF._serialized_start = 6328
-    _SCALARSCALAUDF._serialized_end = 6512
-    _JAVAUDF._serialized_start = 6515
-    _JAVAUDF._serialized_end = 6664
-    _CALLFUNCTION._serialized_start = 6666
-    _CALLFUNCTION._serialized_end = 6774
-    _NAMEDARGUMENTEXPRESSION._serialized_start = 6776
-    _NAMEDARGUMENTEXPRESSION._serialized_end = 6868
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4573
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4759
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4762
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4966
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4968
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5018
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5020
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5102
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5104
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5190
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5193
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5325
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 5328
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 5515
+    _EXPRESSION_ALIAS._serialized_start = 5517
+    _EXPRESSION_ALIAS._serialized_end = 5637
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5640
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5798
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5800
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5862
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5878
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6242
+    _PYTHONUDF._serialized_start = 6245
+    _PYTHONUDF._serialized_end = 6400
+    _SCALARSCALAUDF._serialized_start = 6403
+    _SCALARSCALAUDF._serialized_end = 6587
+    _JAVAUDF._serialized_start = 6590
+    _JAVAUDF._serialized_end = 6739
+    _CALLFUNCTION._serialized_start = 6741
+    _CALLFUNCTION._serialized_end = 6849
+    _NAMEDARGUMENTEXPRESSION._serialized_start = 6851
+    _NAMEDARGUMENTEXPRESSION._serialized_end = 6943
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 2b418ef23f6..b590d22da2c 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -750,33 +750,56 @@ class Expression(google.protobuf.message.Message):
 
         UNPARSED_IDENTIFIER_FIELD_NUMBER: builtins.int
         PLAN_ID_FIELD_NUMBER: builtins.int
+        IS_METADATA_COLUMN_FIELD_NUMBER: builtins.int
         unparsed_identifier: builtins.str
         """(Required) An identifier that will be parsed by Catalyst parser. 
This should follow the
         Spark SQL identifier syntax.
         """
         plan_id: builtins.int
         """(Optional) The id of corresponding connect plan."""
+        is_metadata_column: builtins.bool
+        """(Optional) The requested column is a metadata column."""
         def __init__(
             self,
             *,
             unparsed_identifier: builtins.str = ...,
             plan_id: builtins.int | None = ...,
+            is_metadata_column: builtins.bool | None = ...,
         ) -> None: ...
         def HasField(
             self,
-            field_name: typing_extensions.Literal["_plan_id", b"_plan_id", 
"plan_id", b"plan_id"],
+            field_name: typing_extensions.Literal[
+                "_is_metadata_column",
+                b"_is_metadata_column",
+                "_plan_id",
+                b"_plan_id",
+                "is_metadata_column",
+                b"is_metadata_column",
+                "plan_id",
+                b"plan_id",
+            ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
+                "_is_metadata_column",
+                b"_is_metadata_column",
                 "_plan_id",
                 b"_plan_id",
+                "is_metadata_column",
+                b"is_metadata_column",
                 "plan_id",
                 b"plan_id",
                 "unparsed_identifier",
                 b"unparsed_identifier",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self,
+            oneof_group: typing_extensions.Literal["_is_metadata_column", 
b"_is_metadata_column"],
+        ) -> typing_extensions.Literal["is_metadata_column"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["_plan_id", 
b"_plan_id"]
         ) -> typing_extensions.Literal["plan_id"] | None: ...
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index b631d1fd8b6..56d1f5f3a10 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -511,8 +511,15 @@ trait ColumnResolutionHelper extends Logging {
     }
     val plan = planOpt.get
 
+    val isMetadataAccess = 
u.getTagValue(LogicalPlan.IS_METADATA_COL).contains(true)
     try {
-      plan.resolve(u.nameParts, conf.resolver)
+      if (!isMetadataAccess) {
+        plan.resolve(u.nameParts, conf.resolver)
+      } else if (u.nameParts.size == 1) {
+        plan.getMetadataAttributeByNameOpt(u.nameParts.head)
+      } else {
+        None
+      }
     } catch {
       case e: AnalysisException =>
         logDebug(s"Fail to resolve $u with $plan due to $e")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 374eb070db1..a57c9d7a162 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -196,6 +196,7 @@ object LogicalPlan {
   //    3, resolve this expression with the matching node. If any error 
occurs, analyzer fallbacks
   //    to the old code path.
   private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
+  private[spark] val IS_METADATA_COL = TreeNodeTag[Boolean]("is_metadata_col")
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to