This is an automated email from the ASF dual-hosted git repository.
yihua pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new dca76ca279fc fix: Support data pruning using nested partition columns
(#18126)
dca76ca279fc is described below
commit dca76ca279fc98c18bb6900e8de510b5d159cfc1
Author: Lin Liu <[email protected]>
AuthorDate: Fri May 15 14:22:50 2026 -0700
fix: Support data pruning using nested partition columns (#18126)
Co-authored-by: Vinish Reddy <[email protected]>
Co-authored-by: Y Ethan Guo <[email protected]>
---
.../scala/org/apache/hudi/HoodieFileIndex.scala | 122 +++++++--
.../apache/hudi/SparkHoodieTableFileIndex.scala | 165 +++++++++---
.../org/apache/hudi/TestHoodieFileIndex.scala | 94 ++++++-
.../apache/hudi/functional/TestCOWDataSource.scala | 296 ++++++++++++++++++++-
.../apache/hudi/functional/TestMORDataSource.scala | 5 +
.../Spark3HoodiePruneFileSourcePartitions.scala | 18 +-
.../Spark33HoodiePruneFileSourcePartitions.scala | 2 +-
.../Spark4HoodiePruneFileSourcePartitions.scala | 16 +-
.../hudi/utilities/HiveIncrementalPuller.java | 22 +-
9 files changed, 671 insertions(+), 69 deletions(-)
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
index 660592cf791e..2187e5ab6903 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
@@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
-import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, GetStructField, Literal}
import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusCache,
NoopCache, PartitionDirectory}
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils
import org.apache.spark.sql.internal.SQLConf
@@ -105,6 +105,10 @@ case class HoodieFileIndex(spark: SparkSession,
@transient protected var hasPushedDownPartitionPredicates: Boolean = false
+ /** True when any partition column is a nested field path (e.g.
"nested_record.level"). */
+ private val hasNestedPartitionColumns: Boolean =
+ getPartitionColumns.exists(_.contains("."))
+
/**
* NOTE: [[indicesSupport]] is a transient state, since it's only relevant
while logical plan
* is handled by the Spark's driver
@@ -167,19 +171,44 @@ case class HoodieFileIndex(spark: SparkSession,
/**
* Invoked by Spark to fetch list of latest base files per partition.
*
- * @param partitionFilters partition column filters
- * @param dataFilters data columns filters
- * @return list of PartitionDirectory containing partition to base files
mapping
+ * For regular partition columns, Spark passes correct `partitionFilters`
directly.
+ *
+ * For nested partition columns (e.g. `nested_record.level`), Spark cannot
match
+ * [[GetStructField]] expressions against the flat dot-path partition schema
and passes
+ * `partitionFilters = []`. The nested predicates land in `dataFilters`
instead.
+ * We re-extract them via [[extractNestedPartitionFilters]].
+ *
+ * Example: `SELECT * FROM t WHERE nested_record.level = 'INFO' AND
int_field > 0`
+ * - Spark passes: `partitionFilters = []`, `dataFilters =
[nested_record.level = 'INFO', int_field > 0]`
+ * - We extract: `effectivePartitionFilters = [nested_record.level =
'INFO']`
+ *
+ * This is stateless — safe under AQE re-planning, subqueries, and FileIndex
reuse.
+ *
+ * Known limitation: for mixed flat+nested partitions (e.g. `["country",
"nested_record.level"]`),
+ * if Spark passes `partitionFilters = [country = 'US']`, we skip extraction
and the nested
+ * filter is not used for partition pruning. A future fix could merge
extracted nested filters
+ * with the provided `partitionFilters`.
*/
override def listFiles(partitionFilters: Seq[Expression], dataFilters:
Seq[Expression]): Seq[PartitionDirectory] = {
- val slices = filterFileSlices(dataFilters, partitionFilters).flatMap(
+ val effectivePartitionFilters = if (partitionFilters.isEmpty &&
hasNestedPartitionColumns) {
+ extractNestedPartitionFilters(dataFilters)
+ } else {
+ partitionFilters
+ }
+
+ val slices = filterFileSlices(dataFilters,
effectivePartitionFilters).flatMap(
{ case (partitionOpt, fileSlices) =>
- fileSlices.filter(!_.isEmpty).map(fs => (
InternalRow.fromSeq(partitionOpt.get.getValues), fs))
+ fileSlices.filter(!_.isEmpty).map(fs =>
(InternalRow.fromSeq(partitionOpt.get.getValues), fs))
}
)
prepareFileSlices(slices)
}
+ /** Delegates to companion object with this table's partition columns. */
+ private def extractNestedPartitionFilters(dataFilters: Seq[Expression]):
Seq[Expression] = {
+ HoodieFileIndex.extractNestedPartitionFilters(dataFilters,
getPartitionColumns.toSet)
+ }
+
protected def prepareFileSlices(slices: Seq[(InternalRow, FileSlice)]):
Seq[PartitionDirectory] = {
hasPushedDownPartitionPredicates = true
@@ -212,25 +241,25 @@ case class HoodieFileIndex(spark: SparkSession,
}
/**
- * The functions prunes the partition paths based on the input partition
filters. For every partition path, the file
- * slices are further filtered after querying metadata table based on the
data filters.
+ * Prunes partitions by `partitionFilters`, then optionally applies data
skipping via metadata
+ * table indices (column stats, record-level index, etc.) to filter file
slices.
*
- * @param dataFilters data columns filters
- * @param partitionFilters partition column filters
- * @param partitionPrune for HoodiePruneFileSourcePartitions rule only prune
partitions
- * @return A sequence of pruned partitions and corresponding filtered file
slices
+ * @param dataFilters data column filters (used for data skipping)
+ * @param partitionFilters partition column filters (used for partition
pruning)
+ * @param isPartitionPruneOnly when true, skip data skipping. Used by
[[HoodiePruneFileSourcePartitions]]
+ * during planning (data skipping runs later in
[[listFiles]]).
*/
- def filterFileSlices(dataFilters: Seq[Expression], partitionFilters:
Seq[Expression], isPartitionPruned: Boolean = false)
+ def filterFileSlices(dataFilters: Seq[Expression], partitionFilters:
Seq[Expression],
+ isPartitionPruneOnly: Boolean = false)
: Seq[(Option[BaseHoodieTableFileIndex.PartitionPath], Seq[FileSlice])] = {
val (isPruned, prunedPartitionsAndFileSlices) =
prunePartitionsAndGetFileSlices(dataFilters, partitionFilters)
hasPushedDownPartitionPredicates = true
- // If there are no data filters, return all the file slices.
- // If isPartitionPurge is true, this fun is trigger by
HoodiePruneFileSourcePartitions, don't look up candidate files
- // If there are no file slices, return empty list.
- if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty ||
isPartitionPruned ) {
+ // Skip data skipping when: no file slices, no data filters, or
partition-prune-only mode
+ // (planning phase — data skipping runs later during execution).
+ if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty ||
isPartitionPruneOnly) {
prunedPartitionsAndFileSlices
} else {
// Look up candidate files names in the col-stats or record level index,
if all of the following conditions are true
@@ -502,6 +531,65 @@ object HoodieFileIndex extends Logging {
val Strict: Val = Val("strict")
}
+ /**
+ * Extracts filters from `dataFilters` that reference nested partition
columns by walking
+ * [[GetStructField]] chains to reconstruct the full dot-path and matching
against partition
+ * column names. We cannot match on the struct root alone because sibling
fields share it
+ * (e.g. `nested_record.level` and `nested_record.nested_int` both reference
`nested_record`).
+ *
+ * Given partition column `nested_record.level` and:
+ * {{{
+ * dataFilters = [nested_record.level = 'INFO', nested_record.nested_int >
0, int_field = 5]
+ * }}}
+ * Returns: `[nested_record.level = 'INFO']`
+ *
+ * Known limitations vs regular partition columns:
+ * - `(nested_record.level = 'INFO' AND d = 2) OR (nested_record.level =
'ERROR')` is excluded
+ * entirely (references both partition and data columns). A weaker
predicate like
+ * `nested_record.level IN ('INFO', 'ERROR')` could be extracted but is
not implemented.
+ * Spark has the same OR limitation for regular partition columns.
+ *
+ * @param dataFilters filters to scan for nested partition
predicates
+ * @param partitionColumnNames partition column dot-paths, e.g.
`Set("nested_record.level")`
+ * @return only the filters whose every column reference is a partition
column
+ */
+ private[hudi] def extractNestedPartitionFilters(dataFilters: Seq[Expression],
+ partitionColumnNames:
Set[String]): Seq[Expression] = {
+ val partitionColumnRoots = partitionColumnNames.map(_.split("\\.", 2)(0))
+ dataFilters.filter { expr =>
+ // Resolve all outermost GetStructField chains to their full dot-paths.
+ val structFieldPaths = collectOutermostStructFieldPaths(expr)
+ // The expression is a partition filter only when:
+ // 1. It contains at least one GetStructField that resolves to a
partition column path, AND
+ // 2. ALL resolved paths are partition columns (no non-partition nested
fields), AND
+ // 3. ALL attribute references are roots of partition columns
+ // (guards against mixed expressions like "nested_record.level =
'INFO' AND int_field > 0")
+ structFieldPaths.nonEmpty &&
+ structFieldPaths.forall(partitionColumnNames.contains) &&
+ expr.references.map(_.name).forall(partitionColumnRoots.contains)
+ }
+ }
+
+ /**
+ * Collects full dot-paths of outermost [[GetStructField]] chains in an
expression.
+ * `EqualTo(a.b.c, 1)` → `Seq("a.b.c")` (not intermediate `"a.b"`).
+ */
+ private[hudi] def collectOutermostStructFieldPaths(expr: Expression):
Seq[String] = {
+ expr match {
+ case g: GetStructField => resolveGetStructFieldPath(g).toSeq
+ case _ => expr.children.flatMap(collectOutermostStructFieldPaths)
+ }
+ }
+
+ /** Resolves a [[GetStructField]] chain to its full dot-path:
`attr("a").b.c` → `"a.b.c"`. */
+ private[hudi] def resolveGetStructFieldPath(expr: Expression):
Option[String] = expr match {
+ case GetStructField(child: AttributeReference, _, Some(fieldName)) =>
+ Some(child.name + "." + fieldName)
+ case GetStructField(child: GetStructField, _, Some(fieldName)) =>
+ resolveGetStructFieldPath(child).map(_ + "." + fieldName)
+ case _ => None
+ }
+
def collectReferencedColumns(spark: SparkSession, queryFilters:
Seq[Expression], schema: StructType): Seq[String] = {
val resolver = spark.sessionState.analyzer.resolver
val refs = queryFilters.flatMap(_.references)
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
index 1ba9628af3b7..b80eb204823a 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
@@ -46,11 +46,11 @@ import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
BasePredicate, BoundReference, EmptyRow, EqualTo, Expression,
InterpretedPredicate, Literal}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
BasePredicate, BoundReference, EmptyRow, EqualTo, Expression, GetStructField,
Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{FileStatusCache, NoopCache}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType,
ShortType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{ByteType, DataType, DateType, IntegerType,
LongType, ShortType, StringType, StructField, StructType}
import org.slf4j.LoggerFactory
import javax.annotation.concurrent.NotThreadSafe
@@ -59,6 +59,7 @@ import java.lang.reflect.{Array => JArray}
import java.util.Collections
import scala.collection.JavaConverters._
+import scala.collection.mutable.LinkedHashMap
import scala.language.implicitConversions
import scala.util.{Success, Try}
@@ -201,6 +202,25 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
}
}
+ /**
+ * Spark-facing partition schema that preserves nested structure for nested
partition columns.
+ *
+ * NOTE: Hudi's [[partitionSchema]] intentionally returns a *flat* schema
where field names use full
+ * dot-paths (for example, "a.b.c") to avoid collisions with top-level data
columns. Some Spark
+ * planner/analyzer paths, however, reason about nested columns as nested
[[StructType]]s and
+ * require a nested schema shape to properly resolve [[GetStructField]]
chains.
+ *
+ * This method reconstructs a nested [[StructType]] from the flat partition
schema, using the same
+ * leaf data-types, and preserving deterministic field ordering based on the
original flat schema.
+ */
+ def partitionSchemaForSpark: StructType = {
+ if (!shouldReadAsPartitionedTable) {
+ new StructType()
+ } else {
+
SparkHoodieTableFileIndex.buildNestedPartitionSchema(_partitionSchemaFromProperties)
+ }
+ }
+
/**
* Fetch list of latest base files w/ corresponding log files, after
performing
* partition pruning
@@ -238,13 +258,32 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
def listMatchingPartitionPaths(predicates: Seq[Expression]):
Seq[PartitionPath] = {
val resolve = spark.sessionState.analyzer.resolver
val partitionColumnNames = getPartitionColumns
- val partitionPruningPredicates = predicates.filter {
- _.references.map(_.name).forall { ref =>
- // NOTE: We're leveraging Spark's resolver here to appropriately
handle case-sensitivity
- partitionColumnNames.exists(partCol => resolve(ref, partCol))
- }
+
+ // Resolves GetStructField chain to full dot-path:
GetStructField(attr("a"), _, "b") → "a.b"
+ def getFieldPath(expr: Expression): Option[String] = expr match {
+ case a: AttributeReference => Some(a.name)
+ case GetStructField(child, _, Some(fieldName)) =>
+ getFieldPath(child).map(_ + "." + fieldName)
+ case _ => None
}
+ // True if every column reference in expr resolves to a partition column.
+ // For nested columns, walks GetStructField chains to match the full
dot-path.
+ // Example: partition = "nested_record.level"
+ // nested_record.level = 'INFO' → GetStructField path
"nested_record.level" → true
+ // nested_record.nested_int = 10 → GetStructField path
"nested_record.nested_int" → false
+ // IsNotNull(nested_record) → AttributeReference "nested_record"
not in partitionColumnNames → false
+ def referencesOnlyPartitionColumns(expr: Expression): Boolean = expr match
{
+ case g: GetStructField =>
+ getFieldPath(g).exists(path => partitionColumnNames.exists(pc =>
resolve(path, pc)))
+ case a: AttributeReference =>
+ partitionColumnNames.exists(pc => resolve(a.name, pc))
+ case _ =>
+ expr.children.forall(referencesOnlyPartitionColumns)
+ }
+
+ val partitionPruningPredicates =
predicates.filter(referencesOnlyPartitionColumns)
+
if (partitionPruningPredicates.isEmpty) {
val queryPartitionPaths = getAllQueryPartitionPaths.asScala.toSeq
logInfo(s"No partition predicates provided, listing full table
(${queryPartitionPaths.size} partitions)")
@@ -269,10 +308,18 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
// the whole table
if (haveProperPartitionValues(partitionPaths.toSeq) &&
partitionSchema.nonEmpty) {
val predicate = partitionPruningPredicates.reduce(expressions.And)
+ val partitionFieldNames = partitionSchema.fieldNames
val transformedPredicate = predicate.transform {
+ case g @ GetStructField(_, _, Some(_)) =>
+ getFieldPath(g).flatMap { path =>
+ val idx = partitionFieldNames.indexWhere(name => resolve(path,
name))
+ if (idx >= 0) Some(BoundReference(idx,
partitionSchema(idx).dataType, nullable = true))
+ else None
+ }.getOrElse(g)
case a: AttributeReference =>
- val index = partitionSchema.indexWhere(a.name == _.name)
- BoundReference(index, partitionSchema(index).dataType, nullable =
true)
+ val index = partitionSchema.indexWhere(sf => resolve(a.name,
sf.name))
+ if (index >= 0) BoundReference(index,
partitionSchema(index).dataType, nullable = true)
+ else a
}
val boundPredicate: BasePredicate = try {
// Try using 1-arg constructor via reflection
@@ -488,6 +535,76 @@ object SparkHoodieTableFileIndex extends
SparkAdapterSupport {
private val LOG = LoggerFactory.getLogger(classOf[SparkHoodieTableFileIndex])
private val PUT_LEAF_FILES_METHOD_NAME = "putLeafFiles"
+ private case class NestedFieldNode(
+ leafType: Option[DataType],
+ children: LinkedHashMap[String, NestedFieldNode]
+ )
+
+ /**
+ * Reconstruct nested partition schema from a flat partition schema
containing dot-path field names.
+ *
+ * For example, flat fields ["a.b": int, "a.c": string, "d": long] becomes:
+ *
+ * StructType(
+ * StructField("a", StructType(StructField("b", int), StructField("c",
string))),
+ * StructField("d", long)
+ * )
+ */
+ private[hudi] def buildNestedPartitionSchema(flatPartitionSchema:
StructType): StructType = {
+ if (flatPartitionSchema.isEmpty) {
+ new StructType()
+ } else {
+ val root = NestedFieldNode(None, LinkedHashMap.empty)
+
+ def getOrCreateChild(parent: NestedFieldNode, name: String):
NestedFieldNode = {
+ parent.children.getOrElseUpdate(name, NestedFieldNode(None,
LinkedHashMap.empty))
+ }
+
+ flatPartitionSchema.fields.foreach { field =>
+ val parts = field.name.split("\\.", -1)
+ checkState(parts.forall(p => p.nonEmpty),
+ s"Invalid partition field path '${field.name}' in partition schema")
+
+ var node = root
+ var i = 0
+ while (i < parts.length) {
+ val part = parts(i)
+ val isLeaf = i == parts.length - 1
+
+ if (isLeaf) {
+ val child = getOrCreateChild(node, part)
+ checkState(child.children.isEmpty,
+ s"Conflicting partition schema: '${field.name}' collides with
nested fields under '${parts.take(i + 1).mkString(".")}'")
+ checkState(child.leafType.isEmpty ||
child.leafType.contains(field.dataType),
+ s"Conflicting partition schema: '${field.name}' has inconsistent
types (${child.leafType.orNull} vs ${field.dataType})")
+ node.children.update(part, child.copy(leafType =
Some(field.dataType)))
+ } else {
+ val child = getOrCreateChild(node, part)
+ checkState(child.leafType.isEmpty,
+ s"Conflicting partition schema: '${field.name}' requires struct
at '${parts.take(i + 1).mkString(".")}', but a leaf is defined")
+ node = child
+ }
+
+ i += 1
+ }
+ }
+
+ def toStructType(node: NestedFieldNode): StructType = {
+ val fields = node.children.map { case (name, child) =>
+ child.leafType match {
+ case Some(dt) if child.children.isEmpty =>
+ StructField(name, dt, nullable = true)
+ case _ =>
+ StructField(name, toStructType(child), nullable = true)
+ }
+ }.toArray
+ StructType(fields)
+ }
+
+ toStructType(root)
+ }
+ }
+
private def haveProperPartitionValues(partitionPaths: Seq[PartitionPath]) = {
partitionPaths.forall(_.getValues.length > 0)
}
@@ -520,27 +637,10 @@ object SparkHoodieTableFileIndex extends
SparkAdapterSupport {
}
/**
- * This method unravels [[StructType]] into a [[Map]] of pairs of dot-path
notation with corresponding
- * [[StructField]] object for every field of the provided [[StructType]],
recursively.
- *
- * For example, following struct
- * <pre>
- * StructType(
- * StructField("a",
- * StructType(
- * StructField("b", StringType),
- * StructField("c", IntType)
- * )
- * )
- * )
- * </pre>
- *
- * will be converted into following mapping:
- *
- * <pre>
- * "a.b" -> StructField("b", StringType),
- * "a.c" -> StructField("c", IntType),
- * </pre>
+ * Maps every leaf field in `structType` to its dot-path name.
+ * Both the key and [[StructField.name]] use the full path.
+ * E.g. `StructType(StructField("a", StructType(StructField("b",
IntegerType))))`
+ * → `Map("a.b" -> StructField("a.b", IntegerType))`.
*/
private def generateFieldMap(structType: StructType) : Map[String,
StructField] = {
def traverse(structField: Either[StructField, StructType]) : Map[String,
StructField] = {
@@ -548,7 +648,10 @@ object SparkHoodieTableFileIndex extends
SparkAdapterSupport {
case Right(struct) => struct.fields.flatMap(f =>
traverse(Left(f))).toMap
case Left(field) => field.dataType match {
case struct: StructType => traverse(Right(struct)).map {
- case (key, structField) => (s"${field.name}.$key", structField)
+ case (key, structField) => {
+ val fullPath = s"${field.name}.$key"
+ (fullPath, structField.copy(name = fullPath))
+ }
}
case _ => Map(field.name -> field)
}
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
index d5e0c6a927ac..8d06e257d178 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
@@ -44,14 +44,14 @@ import org.apache.hudi.testutils.HoodieSparkClientTestBase
import org.apache.hudi.util.JFunction
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference,
EqualTo, GreaterThanOrEqual, LessThan, Literal}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference,
EqualTo, Expression, GetStructField, GreaterThanOrEqual, LessThan, Literal, Or}
import org.apache.spark.sql.execution.datasources.{NoopCache,
PartitionDirectory}
import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.hudi.HoodieSparkSessionExtension
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.junit.jupiter.api.{BeforeEach, Test}
-import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
+import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows,
assertTrue}
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.{Arguments, CsvSource, MethodSource,
ValueSource}
@@ -858,6 +858,35 @@ class TestHoodieFileIndex extends
HoodieSparkClientTestBase with ScalaAssertionS
partitionValues.mkString(StoragePath.SEPARATOR)
}
}
+
+ // ---- buildNestedPartitionSchema tests ----
+
+ @ParameterizedTest
+ @MethodSource(Array("buildNestedPartitionSchemaCases"))
+ def testBuildNestedPartitionSchema(name: String, flat: StructType, expected:
StructType): Unit = {
+ assertEquals(expected,
SparkHoodieTableFileIndex.buildNestedPartitionSchema(flat))
+ }
+
+ @Test
+ def testBuildNestedPartitionSchemaConflictThrows(): Unit = {
+ // "a" as leaf and "a.b" as nested — conflict
+ val flat = StructType(Seq(StructField("a", StringType), StructField("a.b",
IntegerType)))
+ assertThrows(classOf[IllegalStateException]) {
+ SparkHoodieTableFileIndex.buildNestedPartitionSchema(flat)
+ }
+ }
+
+ // ---- extractNestedPartitionFilters tests ----
+
+ @ParameterizedTest
+ @MethodSource(Array("extractNestedPartitionFiltersCases"))
+ def testExtractNestedPartitionFilters(name: String,
+ filters: Seq[Expression],
+ partitionColumns: Set[String],
+ expected: Seq[Expression]): Unit = {
+ assertEquals(expected,
HoodieFileIndex.extractNestedPartitionFilters(filters, partitionColumns))
+ }
+
}
object TestHoodieFileIndex {
@@ -870,4 +899,65 @@ object TestHoodieFileIndex {
Arguments.arguments("org.apache.hudi.keygen.TimestampBasedKeyGenerator")
)
}
+
+ def buildNestedPartitionSchemaCases(): java.util.stream.Stream[Arguments] = {
+ val nested = StructType(Seq(
+ StructField("nested_record", StructType(Seq(StructField("level",
StringType, nullable = true))), nullable = true)))
+ val twoLevel = StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("b", StructType(Seq(StructField("c", IntegerType, nullable
= true))), nullable = true))), nullable = true)))
+ val siblings = StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("b", StringType, nullable = true),
+ StructField("c", IntegerType, nullable = true))), nullable = true)))
+ val mixed = StructType(Seq(
+ StructField("country", StringType, nullable = true),
+ StructField("nested_record", StructType(Seq(StructField("level",
StringType, nullable = true))), nullable = true)))
+ java.util.stream.Stream.of(
+ Arguments.of("empty",
+ new StructType(),
+ new StructType()),
+ Arguments.of("flat",
+ StructType(Seq(StructField("country", StringType))),
+ StructType(Seq(StructField("country", StringType, nullable = true)))),
+ Arguments.of("singleNested",
+ StructType(Seq(StructField("nested_record.level", StringType))),
+ nested),
+ Arguments.of("twoLevelNesting",
+ StructType(Seq(StructField("a.b.c", IntegerType))),
+ twoLevel),
+ Arguments.of("siblingFields",
+ StructType(Seq(StructField("a.b", StringType), StructField("a.c",
IntegerType))),
+ siblings),
+ Arguments.of("mixedFlatAndNested",
+ StructType(Seq(StructField("country", StringType),
StructField("nested_record.level", StringType))),
+ mixed)
+ )
+ }
+
+ def extractNestedPartitionFiltersCases(): java.util.stream.Stream[Arguments]
= {
+ val levelStruct = StructType(Seq(StructField("level", StringType)))
+ val multiFieldStruct = StructType(Seq(
+ StructField("nested_int", IntegerType), StructField("level",
StringType)))
+
+ val partFilter = EqualTo(
+ GetStructField(AttributeReference("nested_record", levelStruct)(), 0,
Some("level")),
+ Literal("INFO"))
+ val dataFilter = EqualTo(AttributeReference("int_field", IntegerType)(),
Literal(5))
+ val siblingFilter = EqualTo(
+ GetStructField(AttributeReference("nested_record", multiFieldStruct)(),
0, Some("nested_int")),
+ Literal(10))
+ val orFilter = Or(
+ EqualTo(GetStructField(AttributeReference("nested_record",
levelStruct)(), 0, Some("level")), Literal("INFO")),
+ EqualTo(GetStructField(AttributeReference("nested_record",
levelStruct)(), 0, Some("level")), Literal("ERROR")))
+
+ java.util.stream.Stream.of(
+ Arguments.of("partitionFilterExtractedDataFilterDropped",
+ Seq(partFilter, dataFilter), Set("nested_record.level"),
Seq(partFilter)),
+ Arguments.of("siblingFieldExcluded",
+ Seq(siblingFilter), Set("nested_record.level"), Seq.empty[Expression]),
+ Arguments.of("orWithOnlyPartitionColumnsExtracted",
+ Seq(orFilter), Set("nested_record.level"), Seq(orFilter))
+ )
+ }
}
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
index db5ca1ee7800..106c1f09a7ae 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
@@ -17,7 +17,7 @@
package org.apache.hudi.functional
-import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions,
DataSourceWriteOptions, HoodieDataSourceHelpers, HoodieSchemaConversionUtils,
HoodieSparkUtils, QuickstartUtils, ScalaAssertionSupport}
+import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions,
DataSourceWriteOptions, HoodieBaseRelation, HoodieDataSourceHelpers,
HoodieFileIndex, HoodieSchemaConversionUtils, HoodieSparkUtils,
QuickstartUtils, ScalaAssertionSupport}
import org.apache.hudi.DataSourceWriteOptions.{INLINE_CLUSTERING_ENABLE,
KEYGENERATOR_CLASS_NAME}
import org.apache.hudi.HoodieConversionUtils.toJavaOption
import org.apache.hudi.QuickstartUtils.{convertToStringList,
getQuickstartWriteConfigs}
@@ -42,13 +42,14 @@ import org.apache.hudi.hive.HiveSyncConfigHolder
import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator,
GlobalDeleteKeyGenerator, NonpartitionedKeyGenerator, SimpleKeyGenerator,
TimestampBasedKeyGenerator}
import org.apache.hudi.keygen.constant.{KeyGeneratorOptions, KeyGeneratorType}
import org.apache.hudi.metrics.{Metrics, MetricsReporterType}
-import org.apache.hudi.storage.{StoragePath, StoragePathFilter}
+import org.apache.hudi.storage.{HoodieStorage, StoragePath, StoragePathFilter}
import org.apache.hudi.table.HoodieSparkTable
import org.apache.hudi.testutils.{DataSourceTestUtils,
HoodieSparkClientTestBase}
import org.apache.hudi.util.JFunction
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Dataset, Encoders,
Row, SaveMode, SparkSession, SparkSessionExtensions}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
LogicalRelation}
import org.apache.spark.sql.functions.{col, concat, lit, udf, when}
import org.apache.spark.sql.hudi.HoodieSparkSessionExtension
import org.apache.spark.sql.types.{ArrayType, DataTypes, DateType,
IntegerType, LongType, MapType, StringType, StructField, StructType,
TimestampType}
@@ -2616,9 +2617,300 @@ class TestCOWDataSource extends
HoodieSparkClientTestBase with ScalaAssertionSup
assertEquals("row3", results(2).getAs[String]("_row_key"))
assertEquals("value3", results(2).getAs[String]("data"))
}
+
+ @Test
+ def testNestedFieldPartition(): Unit = {
+ TestCOWDataSource.runNestedFieldPartitionTest(spark, basePath, storage,
"COW")
+ }
}
object TestCOWDataSource {
+
+ /**
+ * Shared test logic for nested field partition (COW and MOR).
+ * Used by TestCOWDataSource.testNestedFieldPartition and
TestMORDataSource.testNestedFieldPartition.
+ */
+ def runNestedFieldPartitionTest(spark: SparkSession, basePath: String,
storage: HoodieStorage, tableType: String): Unit = {
+ // Define schema with nested_record containing level field
+ val nestedSchema = StructType(Seq(
+ StructField("nested_int", IntegerType, nullable = false),
+ StructField("level", StringType, nullable = false)
+ ))
+
+ val schema = StructType(Seq(
+ StructField("key", StringType, nullable = false),
+ StructField("ts", LongType, nullable = false),
+ StructField("level", StringType, nullable = false),
+ StructField("int_field", IntegerType, nullable = false),
+ StructField("string_field", StringType, nullable = true),
+ StructField("nested_record", nestedSchema, nullable = true)
+ ))
+
+ // Create test data where top-level 'level' and 'nested_record.level' have
DIFFERENT values
+ // This helps verify we're correctly partitioning/filtering on the nested
field
+ val recordsCommit1 = Seq(
+ Row("key1", 1L, "L1", 1, "str1", Row(10, "INFO")),
+ Row("key2", 2L, "L2", 2, "str2", Row(20, "ERROR")),
+ Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")),
+ Row("key4", 4L, "L4", 4, "str4", Row(40, "DEBUG")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO"))
+ )
+
+ val tableTypeOptVal = if (tableType == "MOR") {
+ DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL
+ } else {
+ DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL
+ }
+
+ val baseWriteOpts = Map(
+ "hoodie.insert.shuffle.parallelism" -> "4",
+ "hoodie.upsert.shuffle.parallelism" -> "4",
+ DataSourceWriteOptions.RECORDKEY_FIELD.key -> "key",
+ DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "nested_record.level",
+ HoodieTableConfig.ORDERING_FIELDS.key -> "ts",
+ HoodieWriteConfig.TBL_NAME.key -> "test_nested_partition",
+ DataSourceWriteOptions.TABLE_TYPE.key -> tableTypeOptVal
+ )
+ val writeOpts = if (tableType == "MOR") {
+ baseWriteOpts + ("hoodie.compact.inline" -> "false")
+ } else {
+ baseWriteOpts
+ }
+
+ // Commit 1 - Initial insert
+ val inputDF1 = spark.createDataFrame(
+ spark.sparkContext.parallelize(recordsCommit1),
+ schema
+ )
+ inputDF1.write.format("hudi")
+ .options(writeOpts)
+ .option(DataSourceWriteOptions.OPERATION.key,
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
+ .mode(SaveMode.Overwrite)
+ .save(basePath)
+ val commit1 = DataSourceTestUtils.latestCommitCompletionTime(storage,
basePath)
+
+ // Commit 2 - Upsert: update key1 (int_field 1->100), insert key6 (INFO)
+ val recordsCommit2 = Seq(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO"))
+ )
+ val inputDF2 = spark.createDataFrame(
+ spark.sparkContext.parallelize(recordsCommit2),
+ schema
+ )
+ inputDF2.write.format("hudi")
+ .options(writeOpts)
+ .option(DataSourceWriteOptions.OPERATION.key,
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ .mode(SaveMode.Append)
+ .save(basePath)
+ val commit2 = DataSourceTestUtils.latestCommitCompletionTime(storage,
basePath)
+
+ // Commit 3 - Upsert: update key3 (int_field 3->300), insert key7 (INFO)
+ val recordsCommit3 = Seq(
+ Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")),
+ Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO"))
+ )
+ val inputDF3 = spark.createDataFrame(
+ spark.sparkContext.parallelize(recordsCommit3),
+ schema
+ )
+ inputDF3.write.format("hudi")
+ .options(writeOpts)
+ .option(DataSourceWriteOptions.OPERATION.key,
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ .mode(SaveMode.Append)
+ .save(basePath)
+ val commit3 = DataSourceTestUtils.latestCommitCompletionTime(storage,
basePath)
+
+ // Verify partition structure - we should have 3 partitions: INFO, ERROR,
DEBUG
+ val allPartitions = storage.listDirectEntries(new StoragePath(basePath))
+ .asScala.filter(_.isDirectory)
+ .map(_.getPath.getName)
+ .filterNot(_.startsWith(".")) // Filter out .hoodie and other hidden
directories
+ .sorted
+ assertEquals(3, allPartitions.size, s"Expected 3 partitions for
$tableType, but got: ${allPartitions.mkString(", ")}")
+ assertTrue(allPartitions.contains("INFO"), s"Missing INFO partition for
$tableType")
+ assertTrue(allPartitions.contains("ERROR"), s"Missing ERROR partition for
$tableType")
+ assertTrue(allPartitions.contains("DEBUG"), s"Missing DEBUG partition for
$tableType")
+
+ // Snapshot read - filter on nested_record.level = 'INFO' (latest state: 5
records)
+ val snapshotDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field",
"nested_record")
+ .orderBy("key")
+
+ // VERIFICATION 1: Check partition schema contains the nested field
+ val snapshotRelation =
snapshotDF.queryExecution.optimizedPlan.collectFirst {
+ case lr: LogicalRelation => lr
+ }
+ assertTrue(snapshotRelation.isDefined, s"LogicalRelation should exist for
$tableType")
+ val fileIndex = snapshotRelation.get.relation match {
+ case fsRelation: HadoopFsRelation =>
+ fsRelation.location.asInstanceOf[HoodieFileIndex]
+ case baseRelation: HoodieBaseRelation =>
+ baseRelation.fileIndex
+ case _ => null
+ }
+ assertTrue(fileIndex != null, s"FileIndex should be available for
$tableType")
+ assertEquals(1, fileIndex.partitionSchema.fields.length,
+ s"Partition schema should have 1 field for $tableType")
+ assertEquals("nested_record.level",
fileIndex.partitionSchema.fields(0).name,
+ s"Partition field should be 'nested_record.level' for $tableType")
+
+ // VERIFICATION 2: Check that predicates were pushed down to FileIndex
+ assertTrue(fileIndex.hasPredicatesPushedDown,
+ s"Partition predicates should be pushed down to FileIndex for
$tableType")
+
+ // VERIFICATION 3: Verify partition pruning by checking the physical plan
+ // The physical plan should show that only specific files are being scanned
+ val physicalPlan = snapshotDF.queryExecution.executedPlan.toString()
+ assertTrue(physicalPlan.contains("Scan") ||
physicalPlan.contains("FileScan"),
+ s"Physical plan should contain scan operation for $tableType")
+
+ // Collect results to execute the query
+ val snapshotResults = snapshotDF.collect()
+ val expectedSnapshot = Array(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO")),
+ Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO"))
+ )
+ assertEquals(expectedSnapshot.length, snapshotResults.length,
+ s"Snapshot (INFO) count mismatch for $tableType")
+ expectedSnapshot.zip(snapshotResults).foreach { case (expected, actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Time travel - as of commit1 (only initial 5 records; INFO = key1, key3,
key5)
+ val timeTravelDF1 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, commit1)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field",
"nested_record")
+ .orderBy("key")
+
+ // VERIFICATION 4: Verify partition pruning works for time travel queries
+ // Check that the time travel query with partition filter returns correct
results
+ val timeTravelCommit1 = timeTravelDF1.collect()
+ val expectedAfterCommit1 = Array(
+ Row("key1", 1L, "L1", 1, "str1", Row(10, "INFO")),
+ Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO"))
+ )
+ assertEquals(expectedAfterCommit1.length, timeTravelCommit1.length,
+ s"Time travel to commit1 (INFO) count mismatch for $tableType")
+ expectedAfterCommit1.zip(timeTravelCommit1).foreach { case (expected,
actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Time travel - as of commit2 (after 2nd commit; INFO = key1 updated,
key3, key5, key6)
+ val timeTravelCommit2 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, commit2)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field",
"nested_record")
+ .orderBy("key")
+ .collect()
+
+ val expectedAfterCommit2 = Array(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO"))
+ )
+ assertEquals(expectedAfterCommit2.length, timeTravelCommit2.length,
+ s"Time travel to commit2 (INFO) count mismatch for $tableType")
+ expectedAfterCommit2.zip(timeTravelCommit2).foreach { case (expected,
actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Incremental query - from commit1 to commit2 (only key1 update and key6
insert; both INFO)
+ val incrementalDF1To2 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.QUERY_TYPE.key,
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)
+ .option(DataSourceReadOptions.START_COMMIT.key, commit1)
+ .option(DataSourceReadOptions.END_COMMIT.key, commit2)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field",
"nested_record")
+ .orderBy("key")
+
+ // VERIFICATION 6: Verify partition filtering works for incremental queries
+ // For incremental queries, the filter on nested_record.level should still
limit scanned data
+ val incrementalPlan1To2 =
incrementalDF1To2.queryExecution.executedPlan.toString()
+ // The plan should show filtering is happening
+ assertTrue(incrementalPlan1To2.contains("Filter") ||
incrementalPlan1To2.contains("Scan"),
+ s"Incremental query plan should show filtering for $tableType")
+
+ val incrementalCommit1To2 = incrementalDF1To2.collect()
+ val expectedInc1To2 = Array(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO"))
+ )
+ assertEquals(expectedInc1To2.length, incrementalCommit1To2.length,
+ s"Incremental (commit1->commit2, INFO) count mismatch for $tableType")
+ expectedInc1To2.zip(incrementalCommit1To2).foreach { case (expected,
actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Incremental query - from commit2 to commit3 (only key3 update and key7
insert; both INFO)
+ val incrementalCommit2To3 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.QUERY_TYPE.key,
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)
+ .option(DataSourceReadOptions.START_COMMIT.key, commit2)
+ .option(DataSourceReadOptions.END_COMMIT.key, commit3)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field",
"nested_record")
+ .orderBy("key")
+ .collect()
+
+ val expectedInc2To3 = Array(
+ Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")),
+ Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO"))
+ )
+ assertEquals(expectedInc2To3.length, incrementalCommit2To3.length,
+ s"Incremental (commit2->commit3, INFO) count mismatch for $tableType")
+ expectedInc2To3.zip(incrementalCommit2To3).foreach { case (expected,
actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // VERIFICATION 4: Test with different partition values to ensure
filtering is working correctly
+ // Query for ERROR partition (should only return key2)
+ val errorPartitionDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("nested_record.level = 'ERROR'")
+ .select("key", "nested_record")
+
+ val errorResults = errorPartitionDF.collect()
+ assertEquals(1, errorResults.length, s"ERROR partition should have 1
record for $tableType")
+ assertEquals("key2", errorResults(0).getString(0),
+ s"ERROR partition should contain key2 for $tableType")
+
+ // VERIFICATION 5: Test with DEBUG partition
+ val debugPartitionDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("nested_record.level = 'DEBUG'")
+ .select("key", "nested_record")
+
+ val debugResults = debugPartitionDF.collect()
+ assertEquals(1, debugResults.length, s"DEBUG partition should have 1
record for $tableType")
+ assertEquals("key4", debugResults(0).getString(0),
+ s"DEBUG partition should contain key4 for $tableType")
+
+ // VERIFICATION 6: Verify that filtering on top-level 'level' field
returns correct results
+ // This ensures we're correctly distinguishing between nested_record.level
(partition) and level (data column)
+ val topLevelFilterDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("level = 'L1'") // Filter on top-level 'level', not
nested_record.level
+ .select("key", "level", "nested_record")
+
+ val topLevelResults = topLevelFilterDF.collect()
+ // Should return key1 which has level='L1' and is in INFO partition
+ assertEquals(1, topLevelResults.length, s"Top-level level='L1' should
return 1 record for $tableType")
+ assertEquals("key1", topLevelResults(0).getString(0),
+ s"Top-level level='L1' should return key1 for $tableType")
+ }
+
def convertColumnsToNullable(df: DataFrame, cols: String*): DataFrame = {
cols.foldLeft(df) { (df, c) =>
// NOTE: This is the trick to make Spark convert a non-null column "c"
into a nullable
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
index ab7bbcd097d2..81d049d43243 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
@@ -2347,6 +2347,11 @@ class TestMORDataSource extends
HoodieSparkClientTestBase with SparkDatasetMixin
assertEquals("row4", results(3).getAs[String]("_row_key"))
assertEquals("value4", results(3).getAs[String]("data"))
}
+
+ @Test
+ def testNestedFieldPartition(): Unit = {
+ TestCOWDataSource.runNestedFieldPartitionTest(spark, basePath, storage,
"MOR")
+ }
}
object TestMORDataSource {
diff --git
a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
index 589ee9774d3b..3cb9e8250da9 100644
---
a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
+++
b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType
* Prune the partitions of Hudi table based relations by the means of pushing
down the
* partition filters
*
- * NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's
[[PruneFileSourcePartitions]]
+ * NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's
[[org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions]]
*/
case class Spark3HoodiePruneFileSourcePartitions(spark: SparkSession) extends
Rule[LogicalPlan] {
@@ -48,11 +48,11 @@ case class Spark3HoodiePruneFileSourcePartitions(spark:
SparkSession) extends Ru
val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters,
lr.output)
val (partitionPruningFilters, dataFilters) =
- getPartitionFiltersAndDataFilters(fileIndex.partitionSchema,
normalizedFilters)
+ getPartitionFiltersAndDataFilters(fileIndex.partitionSchemaForSpark,
normalizedFilters)
// [[HudiFileIndex]] is a caching one, therefore we don't need to
reconstruct new relation,
// instead we simply just refresh the index and update the stats
- fileIndex.filterFileSlices(dataFilters, partitionPruningFilters,
isPartitionPruned = true)
+ fileIndex.filterFileSlices(dataFilters, partitionPruningFilters,
isPartitionPruneOnly = true)
if (partitionPruningFilters.nonEmpty) {
// Change table stats based on the sizeInBytes of pruned files
@@ -105,11 +105,21 @@ private object Spark3HoodiePruneFileSourcePartitions
extends PredicateHelper {
Project(projects, withFilter)
}
+ /**
+ * Returns true if the given attribute references a partition column. For
nested partition columns
+ * (e.g. `nested_record.level`), `partitionSchema` is the nested
[[StructType]] from
+ * `partitionSchemaForSpark`, so the top-level name is the struct root (e.g.
`nested_record`),
+ * which matches `attr.name` directly via `contains`.
+ */
+ private def isPartitionColumnReference(attr: AttributeReference,
partitionSchema: StructType): Boolean = {
+ partitionSchema.names.contains(attr.name)
+ }
+
def getPartitionFiltersAndDataFilters(partitionSchema: StructType,
normalizedFilters: Seq[Expression]):
(Seq[Expression], Seq[Expression]) = {
val partitionColumns = normalizedFilters.flatMap { expr =>
expr.collect {
- case attr: AttributeReference if
partitionSchema.names.contains(attr.name) =>
+ case attr: AttributeReference if isPartitionColumnReference(attr,
partitionSchema) =>
attr
}
}
diff --git
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
index 7d7240231cd0..add1b7aaebf1 100644
---
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
+++
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
@@ -60,7 +60,7 @@ case class Spark33HoodiePruneFileSourcePartitions(spark:
SparkSession) extends R
// [[HudiFileIndex]] is a caching one, therefore we don't need to
reconstruct new relation,
// instead we simply just refresh the index and update the stats
- fileIndex.filterFileSlices(dataFilters, partitionPruningFilters,
isPartitionPruned = true)
+ fileIndex.filterFileSlices(dataFilters, partitionPruningFilters,
isPartitionPruneOnly = true)
if (partitionPruningFilters.nonEmpty) {
// Change table stats based on the sizeInBytes of pruned files
diff --git
a/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
b/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
index 8412018c22db..0f6cf87da86f 100644
---
a/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
+++
b/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
@@ -48,11 +48,11 @@ case class Spark4HoodiePruneFileSourcePartitions(spark:
SparkSession) extends Ru
val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters,
lr.output)
val (partitionPruningFilters, dataFilters) =
- getPartitionFiltersAndDataFilters(fileIndex.partitionSchema,
normalizedFilters)
+ getPartitionFiltersAndDataFilters(fileIndex.partitionSchemaForSpark,
normalizedFilters)
// [[HudiFileIndex]] is a caching one, therefore we don't need to
reconstruct new relation,
// instead we simply just refresh the index and update the stats
- fileIndex.filterFileSlices(dataFilters, partitionPruningFilters,
isPartitionPruned = true)
+ fileIndex.filterFileSlices(dataFilters, partitionPruningFilters,
isPartitionPruneOnly = true)
if (partitionPruningFilters.nonEmpty) {
// Change table stats based on the sizeInBytes of pruned files
@@ -105,11 +105,21 @@ private object Spark4HoodiePruneFileSourcePartitions
extends PredicateHelper {
Project(projects, withFilter)
}
+ /**
+ * Returns true if the given attribute references a partition column. For
nested partition columns
+ * (e.g. `nested_record.level`), `partitionSchema` is the nested
[[StructType]] from
+ * `partitionSchemaForSpark`, so the top-level name is the struct root (e.g.
`nested_record`),
+ * which matches `attr.name` directly via `contains`.
+ */
+ private def isPartitionColumnReference(attr: AttributeReference,
partitionSchema: StructType): Boolean = {
+ partitionSchema.names.contains(attr.name)
+ }
+
def getPartitionFiltersAndDataFilters(partitionSchema: StructType,
normalizedFilters: Seq[Expression]):
(Seq[Expression], Seq[Expression]) = {
val partitionColumns = normalizedFilters.flatMap { expr =>
expr.collect {
- case attr: AttributeReference if
partitionSchema.names.contains(attr.name) =>
+ case attr: AttributeReference if isPartitionColumnReference(attr,
partitionSchema) =>
attr
}
}
diff --git
a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
index fede1b8fba03..2510edce72a8 100644
---
a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
+++
b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
@@ -20,9 +20,9 @@ package org.apache.hudi.utilities;
import org.apache.hudi.common.table.HoodieTableMetaClient;
import org.apache.hudi.common.table.timeline.HoodieInstant;
-import org.apache.hudi.common.util.FileIOUtils;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.exception.HoodieException;
+import org.apache.hudi.hadoop.fs.HadoopFSUtils;
import org.apache.hudi.utilities.exception.HoodieIncrementalPullException;
import org.apache.hudi.utilities.exception.HoodieIncrementalPullSQLException;
@@ -50,6 +50,8 @@ import java.util.List;
import java.util.Scanner;
import java.util.stream.Collectors;
+import static org.apache.hudi.io.util.FileIOUtils.readAsUTFString;
+
/**
* Utility to pull data after a given commit, based on the supplied HiveQL and
save the delta as another hive temporary
* table. This temporary table can be further read using {@link
org.apache.hudi.utilities.sources.HiveIncrPullSource} and the changes can
@@ -115,7 +117,7 @@ public class HiveIncrementalPuller {
this.config = config;
validateConfig(config);
String templateContent =
-
FileIOUtils.readAsUTFString(this.getClass().getResourceAsStream("/IncrementalPull.sqltemplate"));
+
readAsUTFString(this.getClass().getResourceAsStream("/IncrementalPull.sqltemplate"));
incrementalPullSQLTemplate = new ST(templateContent);
}
@@ -298,12 +300,13 @@ public class HiveIncrementalPuller {
if (!fs.exists(new Path(targetDataPath)) || !fs.exists(new
Path(targetDataPath + "/.hoodie"))) {
return "0";
}
- HoodieTableMetaClient metadata =
HoodieTableMetaClient.builder().setConf(fs.getConf()).setBasePath(targetDataPath).build();
+ HoodieTableMetaClient metadata = HoodieTableMetaClient.builder()
+
.setConf(HadoopFSUtils.getStorageConfWithCopy(fs.getConf())).setBasePath(targetDataPath).build();
Option<HoodieInstant> lastCommit =
metadata.getActiveTimeline().getCommitsTimeline().filterCompletedInstants().lastInstant();
if (lastCommit.isPresent()) {
- return lastCommit.get().getTimestamp();
+ return lastCommit.get().requestedTime();
}
return "0";
}
@@ -331,14 +334,15 @@ public class HiveIncrementalPuller {
}
private String getLastCommitTimePulled(FileSystem fs, String
sourceTableLocation) {
- HoodieTableMetaClient metadata =
HoodieTableMetaClient.builder().setConf(fs.getConf()).setBasePath(sourceTableLocation).build();
+ HoodieTableMetaClient metadata = HoodieTableMetaClient.builder()
+ .setConf(HadoopFSUtils.getStorageConfWithCopy(fs.getConf()))
+ .setBasePath(sourceTableLocation).build();
List<String> commitsToSync =
metadata.getActiveTimeline().getCommitsTimeline().filterCompletedInstants()
- .findInstantsAfter(config.fromCommitTime,
config.maxCommits).getInstantsAsStream().map(HoodieInstant::getTimestamp)
+ .findInstantsAfter(config.fromCommitTime,
config.maxCommits).getInstantsAsStream().map(HoodieInstant::requestedTime)
.collect(Collectors.toList());
if (commitsToSync.isEmpty()) {
- LOG.info("Nothing to sync. All commits in {} are {} and from commit time
is {}", config.sourceTable,
-
metadata.getActiveTimeline().getCommitsTimeline().filterCompletedInstants().getInstants(),
- config.fromCommitTime);
+ LOG.info("Nothing to sync. All commits in {} are {} and from commit time
is {}", config.sourceTable, metadata.getActiveTimeline().getCommitsTimeline()
+ .filterCompletedInstants().getInstants(), config.fromCommitTime);
return null;
}
LOG.info("Syncing commits {}", commitsToSync);