This is an automated email from the ASF dual-hosted git repository.

voonhous pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 853cbef630b3 feat(vector_search): Implement functionality for 
pre-filters and maxD… (#18797)
853cbef630b3 is described below

commit 853cbef630b3ae93fdf385acd57a9f6437aa1851
Author: Rahil C <[email protected]>
AuthorDate: Wed May 27 22:31:30 2026 -0700

    feat(vector_search): Implement functionality for pre-filters and maxD… 
(#18797)
    
    * feat(vector_search): Implement functionality for pre-filters and 
maxDistance threshold parameters
    
    * feat(vector_search): address review feedback on prefilter + maxDistance
    
    - Narrow applyFilter to catch only ParseException + AnalysisException so
      unrelated runtime errors are not misreported as filter problems.
    - Reject non-numeric literals (including strings whose contents parse as
      numbers) for max_distance via explicit NumericType guard in
      parseOptionalDouble; the previous toString.toDouble path silently
      accepted '0.5' as a string.
    - Document on both parseArgs Javadocs that NULL, empty string, and
      whitespace-only strings all mean "no filter," matching what the
      parseOptionalString helper already implements.
    - Add failure-mode tests for invalid filter syntax, unknown filter
      column, non-string filter literal, non-numeric max_distance literal,
      integer max_distance widening, negative max_distance, empty/whitespace
      filter equivalence with no-filter, and batch-mode error wrapping.
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
    
    * minor fixes
    
    * feat(vector_search): address voonhous review comments
    
    - Fix stray asterisk in BruteForceSearchAlgorithm Javadoc
    - Drop redundant default parameter values on BruteForceSearchAlgorithm
      overrides (inherited from the trait)
    - Clarify in buildBatchQueryPlan Javadoc that the filter applies to the
      corpus only, not the query side
    - Add test exercising parseOptionalDouble with non-Double numeric
      literals (Int and Decimal) to lock in the widening contract
    
    * test(vector_search): fix testMaxDistanceAcceptsNonDoubleNumericLiterals 
for inclusive threshold
    
    The Int-literal sub-case expected 3 rows with max_distance=1, on the
    incorrect assumption that the threshold filter is strict (>). The actual
    implementation uses <= (consistent with 
testSingleQueryMaxDistanceExcludesAll
    which depends on doc_1 at distance 0.0 being kept when max_distance=0.0).
    
    With <= and max_distance=1.0, doc_2 and doc_3 (both at distance 1.0) are
    included, returning all 5 corpus rows. Updated the expected count and
    flipped the negative doc_2 assertion to positive doc_2 + doc_3 checks
    with messages naming the inclusive-boundary contract, so the test now
    also acts as a regression guard for that semantics.
    
    The Decimal sub-case was unaffected.
    
    Co-Authored-By: Claude Opus 4.7 <[email protected]>
    
    * fix issue
    
    * refactor(vector_search): use direct pattern match for numeric eval result
    
    Replace the toString.toDouble round-trip in parseOptionalDouble with a
    direct pattern match on Number / Decimal, avoiding the unnecessary string
    conversion (especially wasteful for Decimal literals).
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
    
    ---------
    
    Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]>
---
 .../HoodieVectorSearchTableValuedFunction.scala    |  90 +++-
 .../hudi/analysis/HoodieSparkBaseAnalysis.scala    |   6 +-
 .../analysis/HoodieVectorSearchPlanBuilder.scala   |  74 ++-
 .../TestHoodieVectorSearchFunction.scala           | 520 ++++++++++++++++++++-
 4 files changed, 657 insertions(+), 33 deletions(-)

diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
index 39102cddce34..30603853351c 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
Literal}
 import org.apache.spark.sql.hudi.command.exception.HoodieAnalysisException
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{Decimal, NullType, NumericType, StringType}
 
 object HoodieVectorSearchTableValuedFunction {
 
@@ -53,21 +53,28 @@ object HoodieVectorSearchTableValuedFunction {
     queryVectorExpr: Expression,
     k: Int,
     metric: DistanceMetric.Value,
-    algorithm: SearchAlgorithm.Value
+    algorithm: SearchAlgorithm.Value,
+    filter: Option[String],
+    maxDistance: Option[Double]
   )
 
   /**
    * Parse arguments for the hudi_vector_search TVF (single-query mode).
    *
-   * Signature (4–6 args):
-   *   hudi_vector_search('table', 'embedding_col', ARRAY(1.0, 2.0, ...), k [, 
'metric'] [, 'algorithm'])
+   * Signature (4–8 args):
+   *   hudi_vector_search('table', 'embedding_col', ARRAY(1.0, 2.0, ...), k
+   *     [, 'metric'] [, 'algorithm'] [, 'filter_predicate' | NULL] [, 
max_distance | NULL])
    *   metric defaults to 'cosine'; algorithm defaults to 'brute_force'.
+   *   filter is a SQL predicate applied to the corpus before distance 
computation;
+   *     NULL, the empty string, and whitespace-only strings all mean "no 
filter."
+   *   max_distance excludes results whose distance exceeds the given 
threshold;
+   *     NULL means "no threshold." Must be a numeric literal when specified.
    */
   def parseArgs(exprs: Seq[Expression]): ParsedArgs = {
-    if (exprs.size < 4 || exprs.size > 6) {
+    if (exprs.size < 4 || exprs.size > 8) {
       throw new HoodieAnalysisException(
-        s"Function '$FUNC_NAME' expects 4-6 arguments: " +
-          "(table, embedding_col, query_vector, k [, metric] [, algorithm]).")
+        s"Function '$FUNC_NAME' expects 4-8 arguments: " +
+          "(table, embedding_col, query_vector, k [, metric] [, algorithm] [, 
filter] [, max_distance]).")
     }
 
     def requireStringLiteral(expr: Expression, argName: String): String = expr 
match {
@@ -84,7 +91,9 @@ object HoodieVectorSearchTableValuedFunction {
     else DistanceMetric.COSINE
     val algorithm = if (exprs.size >= 6) 
SearchAlgorithm.fromString(requireStringLiteral(exprs(5), "algorithm"))
     else SearchAlgorithm.BRUTE_FORCE
-    ParsedArgs(table, embeddingCol, queryVectorExpr, k, metric, algorithm)
+    val filter = if (exprs.size >= 7) parseOptionalString(FUNC_NAME, exprs(6), 
"filter") else None
+    val maxDistance = if (exprs.size >= 8) parseOptionalDouble(FUNC_NAME, 
exprs(7), "max_distance") else None
+    ParsedArgs(table, embeddingCol, queryVectorExpr, k, metric, algorithm, 
filter, maxDistance)
   }
 
   private[logical] def parseK(funcName: String, expr: Expression): Int = {
@@ -106,6 +115,43 @@ object HoodieVectorSearchTableValuedFunction {
     }
     kValue
   }
+
+  /** Parses a string argument that may be NULL (meaning "not specified"). */
+  private[logical] def parseOptionalString(
+      funcName: String, expr: Expression, argName: String): Option[String] = 
expr match {
+    case Literal(null, _) => None
+    case Literal(v, StringType) if v != null =>
+      val s = v.toString.trim
+      if (s.isEmpty) None else Some(s)
+    case _ => throw new HoodieAnalysisException(
+      s"Function '$funcName': argument '$argName' must be a string literal or 
NULL, got: ${expr.sql}")
+  }
+
+  /**
+   * Parses a numeric argument that may be NULL (meaning "not specified"). 
Accepts
+   * any foldable expression of [[NumericType]] (or [[NullType]] for an untyped
+   * NULL keyword) — including a bare literal or {@code CAST(literal AS 
numeric)} —
+   * and widens to Double. String literals are rejected even when their 
contents
+   * happen to parse as a number, so the type contract surfaces at parse time.
+   */
+  private[logical] def parseOptionalDouble(
+      funcName: String, expr: Expression, argName: String): Option[Double] = {
+    val numericOrNull = expr.dataType match {
+      case _: NumericType | NullType => true
+      case _ => false
+    }
+    if (!expr.foldable || !numericOrNull) {
+      throw new HoodieAnalysisException(
+        s"Function '$funcName': argument '$argName' must be a numeric literal 
or NULL, got: ${expr.sql}")
+    }
+    Option(expr.eval()).map {
+      case d: Decimal => d.toDouble
+      case n: Number => n.doubleValue()
+      case other => throw new HoodieAnalysisException(
+        s"Function '$funcName': argument '$argName' has unexpected runtime 
type: " +
+          s"${other.getClass.getName}")
+    }
+  }
 }
 
 /**
@@ -142,21 +188,28 @@ object HoodieVectorSearchBatchTableValuedFunction {
     queryEmbeddingCol: String,
     k: Int,
     metric: HoodieVectorSearchTableValuedFunction.DistanceMetric.Value,
-    algorithm: HoodieVectorSearchTableValuedFunction.SearchAlgorithm.Value
+    algorithm: HoodieVectorSearchTableValuedFunction.SearchAlgorithm.Value,
+    filter: Option[String],
+    maxDistance: Option[Double]
   )
 
   /**
    * Parse arguments for the hudi_vector_search_batch TVF (batch-query mode).
    *
-   * Signature (5–7 args):
-   *   hudi_vector_search_batch('corpus_table', 'corpus_col', 'query_table', 
'query_col', k [, 'metric'] [, 'algorithm'])
+   * Signature (5–9 args):
+   *   hudi_vector_search_batch('corpus_table', 'corpus_col', 'query_table', 
'query_col', k
+   *     [, 'metric'] [, 'algorithm'] [, 'filter_predicate' | NULL] [, 
max_distance | NULL])
    *   metric defaults to 'cosine'; algorithm defaults to 'brute_force'.
+   *   filter is a SQL predicate applied to the corpus before distance 
computation;
+   *     NULL, the empty string, and whitespace-only strings all mean "no 
filter."
+   *   max_distance excludes results whose distance exceeds the given 
threshold;
+   *     NULL means "no threshold." Must be a numeric literal when specified.
    */
   def parseArgs(exprs: Seq[Expression]): ParsedArgs = {
-    if (exprs.size < 5 || exprs.size > 7) {
+    if (exprs.size < 5 || exprs.size > 9) {
       throw new HoodieAnalysisException(
-        s"Function '$FUNC_NAME' expects 5-7 arguments: " +
-          "(corpus_table, corpus_col, query_table, query_col, k [, metric] [, 
algorithm]).")
+        s"Function '$FUNC_NAME' expects 5-9 arguments: " +
+          "(corpus_table, corpus_col, query_table, query_col, k [, metric] [, 
algorithm] [, filter] [, max_distance]).")
     }
 
     def requireStringLiteral(expr: Expression, argName: String): String = expr 
match {
@@ -176,7 +229,14 @@ object HoodieVectorSearchBatchTableValuedFunction {
     val algorithm = if (exprs.size >= 7)
       
HoodieVectorSearchTableValuedFunction.SearchAlgorithm.fromString(requireStringLiteral(exprs(6),
 "algorithm"))
     else HoodieVectorSearchTableValuedFunction.SearchAlgorithm.BRUTE_FORCE
-    ParsedArgs(corpusTable, corpusEmbeddingCol, queryTable, queryEmbeddingCol, 
k, metric, algorithm)
+    val filter = if (exprs.size >= 8)
+      HoodieVectorSearchTableValuedFunction.parseOptionalString(FUNC_NAME, 
exprs(7), "filter")
+    else None
+    val maxDistance = if (exprs.size >= 9)
+      HoodieVectorSearchTableValuedFunction.parseOptionalDouble(FUNC_NAME, 
exprs(8), "max_distance")
+    else None
+    ParsedArgs(corpusTable, corpusEmbeddingCol, queryTable, queryEmbeddingCol,
+      k, metric, algorithm, filter, maxDistance)
   }
 }
 
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
index 88ad684c1e2f..82ac87b7bcbc 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
@@ -156,7 +156,8 @@ case class ResolveReferences(spark: SparkSession) extends 
Rule[LogicalPlan]
       val searchAlgorithm = 
HoodieVectorSearchPlanBuilder.resolveAlgorithm(a.algorithm)
       val corpusDf = resolveTableToDf(a.table, 
HoodieVectorSearchTableValuedFunction.FUNC_NAME)
       val queryVector = evaluateQueryVector(a.queryVectorExpr)
-      searchAlgorithm.buildSingleQueryPlan(spark, corpusDf, a.embeddingCol, 
queryVector, a.k, a.metric)
+      searchAlgorithm.buildSingleQueryPlan(
+        spark, corpusDf, a.embeddingCol, queryVector, a.k, a.metric, a.filter, 
a.maxDistance)
 
     case HoodieVectorSearchBatchTableValuedFunction(args) =>
       val a = HoodieVectorSearchBatchTableValuedFunction.parseArgs(args)
@@ -164,7 +165,8 @@ case class ResolveReferences(spark: SparkSession) extends 
Rule[LogicalPlan]
       val corpusDf = resolveTableToDf(a.corpusTable, 
HoodieVectorSearchBatchTableValuedFunction.FUNC_NAME)
       val queryDf = resolveTableToDf(a.queryTable, 
HoodieVectorSearchBatchTableValuedFunction.FUNC_NAME)
       searchAlgorithm.buildBatchQueryPlan(
-        spark, corpusDf, a.corpusEmbeddingCol, queryDf, a.queryEmbeddingCol, 
a.k, a.metric)
+        spark, corpusDf, a.corpusEmbeddingCol, queryDf, a.queryEmbeddingCol,
+        a.k, a.metric, a.filter, a.maxDistance)
 
     case mO@MatchMergeIntoTable(targetTableO, sourceTableO, _)
       // START: custom Hudi change: don't want to go to the spark mit 
resolution so we resolve the source and target
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
index e99079b4966f..9459ff0c0624 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.hudi.analysis
 
 import org.apache.hudi.common.schema.HoodieSchema
 
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
+import org.apache.spark.sql.catalyst.parser.ParseException
 import 
org.apache.spark.sql.catalyst.plans.logical.HoodieVectorSearchTableValuedFunction.{DistanceMetric,
 SearchAlgorithm}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.expressions.Window
@@ -64,6 +65,12 @@ trait VectorSearchAlgorithm {
    * @param queryVector    the query vector, normalized to Array[Double]
    * @param k              number of nearest neighbors to return
    * @param metric         distance metric (COSINE, L2, DOT_PRODUCT)
+   * @param filter         optional SQL predicate applied to the corpus before 
distance computation,
+   *                       String literals inside the predicate must use 
double quotes (e.g. {@code label = "z-axis"})
+   *                       because the filter is passed as a single-quoted TVF 
argument and
+   *                       Spark's lexer does not support {@code ''} as an 
escape sequence
+   *                       inside string literals.
+   * @param maxDistance    optional distance threshold; results exceeding this 
value are excluded
    * @return an analyzed LogicalPlan whose output matches the single-query 
schema contract
    */
   def buildSingleQueryPlan(
@@ -72,7 +79,9 @@ trait VectorSearchAlgorithm {
       embeddingCol: String,
       queryVector: Array[Double],
       k: Int,
-      metric: DistanceMetric.Value): LogicalPlan
+      metric: DistanceMetric.Value,
+      filter: Option[String] = None,
+      maxDistance: Option[Double] = None): LogicalPlan
 
   /**
    * Build a plan that finds the k nearest corpus rows for each row in the 
query table.
@@ -84,6 +93,13 @@ trait VectorSearchAlgorithm {
    * @param queryEmbeddingCol  name of the embedding column in queryDf
    * @param k                  number of nearest neighbors per query
    * @param metric             distance metric (COSINE, L2, DOT_PRODUCT)
+   * @param filter             optional SQL predicate applied to the corpus 
before distance
+   *                           computation, and shrinking cross-join 
cardinality. Applied to the
+   *                           corpus only — to restrict the query side, apply 
a projection or
+   *                           filter on the query table before invoking the 
TVF. See
+   *                           [[buildSingleQueryPlan]] for quoting 
requirements.
+   * @param maxDistance        optional distance threshold; results exceeding 
this value are
+   *                           excluded before per-query top-K selection, 
reducing shuffle volume.
    * @return an analyzed LogicalPlan whose output matches the batch-query 
schema contract
    * @note Batch mode broadcasts the query table to all executors via a 
cross-join.
    *       This is designed for small-to-medium query sets (tens to low 
hundreds of rows).
@@ -96,7 +112,9 @@ trait VectorSearchAlgorithm {
       queryDf: DataFrame,
       queryEmbeddingCol: String,
       k: Int,
-      metric: DistanceMetric.Value): LogicalPlan
+      metric: DistanceMetric.Value,
+      filter: Option[String] = None,
+      maxDistance: Option[Double] = None): LogicalPlan
 }
 
 /**
@@ -232,6 +250,10 @@ object HoodieVectorSearchPlanBuilder {
  * and select top-K per query. The cross-join produces O(|corpus| * |queries|)
  * intermediate rows, so this is suitable for small-to-medium query sets
  * (tens to low hundreds of queries) against moderate corpora.
+ *
+ * <p>Both modes support an optional {@code filter} predicate (applied to the 
corpus before
+ * distance computation), and an optional {@code maxDistance} threshold 
(results beyond this
+ * distance are excluded before top-K selection, reducing shuffle and sort 
volume).
  */
 object BruteForceSearchAlgorithm extends VectorSearchAlgorithm {
 
@@ -239,18 +261,39 @@ object BruteForceSearchAlgorithm extends 
VectorSearchAlgorithm {
 
   override val name: String = "brute_force"
 
+  /**
+   * Applies a user-supplied SQL predicate to the corpus DataFrame, wrapping
+   * [[ParseException]] (predicate syntax error) and [[AnalysisException]]
+   * (unknown column, type mismatch, etc.) in a [[HoodieAnalysisException]] 
that
+   * echoes the offending expression. Other exception types are allowed to
+   * propagate untouched so they aren't misreported as a filter problem.
+   */
+  private def applyFilter(df: DataFrame, filterExpr: String): DataFrame = {
+    try {
+      df.filter(filterExpr)
+    } catch {
+      case e @ (_: ParseException | _: AnalysisException) =>
+        throw new HoodieAnalysisException(
+          s"Invalid pre-filter expression '$filterExpr': ${e.getMessage}")
+    }
+  }
+
   override def buildSingleQueryPlan(
       spark: SparkSession,
       corpusDf: DataFrame,
       embeddingCol: String,
       queryVector: Array[Double],
       k: Int,
-      metric: DistanceMetric.Value): LogicalPlan = {
+      metric: DistanceMetric.Value,
+      filter: Option[String],
+      maxDistance: Option[Double]): LogicalPlan = {
     validateEmbeddingColumn(corpusDf, embeddingCol)
     validateQueryVectorDimension(corpusDf, embeddingCol, queryVector.length)
 
     val elemType = getElementType(corpusDf, embeddingCol)
-    val filteredDf = corpusDf.filter(col(embeddingCol).isNotNull)
+    // Apply pre-filter before distance computation to enable reducing the 
number of rows that need distance computation.
+    var filteredDf = corpusDf.filter(col(embeddingCol).isNotNull)
+    filter.foreach(f => filteredDf = applyFilter(filteredDf, f))
 
     // Validate byte corpus query vector values before creating the UDF.
     if (elemType == ByteType) {
@@ -269,11 +312,12 @@ object BruteForceSearchAlgorithm extends 
VectorSearchAlgorithm {
     // so only the corpus column is passed per row.
     val distanceUdf = VectorDistanceUtils.createSingleQueryDistanceUdf(metric, 
elemType, queryVector)
 
-    val result = filteredDf
+    var scored = filteredDf
       .withColumn(DISTANCE_COL, distanceUdf(col(embeddingCol)))
       .drop(embeddingCol)
-      .orderBy(col(DISTANCE_COL).asc)
-      .limit(k)
+    // Apply max-distance threshold before ordering to shrink the sort input.
+    maxDistance.foreach(d => scored = scored.filter(col(DISTANCE_COL) <= d))
+    val result = scored.orderBy(col(DISTANCE_COL).asc).limit(k)
 
     result.queryExecution.analyzed
   }
@@ -285,7 +329,9 @@ object BruteForceSearchAlgorithm extends 
VectorSearchAlgorithm {
       queryDf: DataFrame,
       queryEmbeddingCol: String,
       k: Int,
-      metric: DistanceMetric.Value): LogicalPlan = {
+      metric: DistanceMetric.Value,
+      filter: Option[String],
+      maxDistance: Option[Double]): LogicalPlan = {
     validateEmbeddingColumn(corpusDf, corpusEmbeddingCol)
     validateEmbeddingColumn(queryDf, queryEmbeddingCol)
     validateElementTypeCompatibility(corpusDf, corpusEmbeddingCol, queryDf, 
queryEmbeddingCol)
@@ -293,7 +339,10 @@ object BruteForceSearchAlgorithm extends 
VectorSearchAlgorithm {
 
     val corpusElemType = getElementType(corpusDf, corpusEmbeddingCol)
     val distanceUdf = VectorDistanceUtils.createDistanceUdf(metric, 
corpusElemType)
-    val filteredCorpus = corpusDf.filter(col(corpusEmbeddingCol).isNotNull)
+    // Apply pre-filter before the cross-join to enable Hudi partition pruning 
and
+    // data skipping, reducing the cross-join cardinality.
+    var filteredCorpus = corpusDf.filter(col(corpusEmbeddingCol).isNotNull)
+    filter.foreach(f => filteredCorpus = applyFilter(filteredCorpus, f))
 
     // Prefix clashing query columns with "_hudi_query_" to avoid cross-join 
ambiguity when
     // corpus and query share column names (e.g. both have "id" or 
"embedding").
@@ -310,12 +359,15 @@ object BruteForceSearchAlgorithm extends 
VectorSearchAlgorithm {
     }: _*)
 
     // Cross join corpus with broadcast queries, compute distance, then rank
-    val scored = filteredCorpus.crossJoin(broadcast(renamedQuery))
+    var scored = filteredCorpus.crossJoin(broadcast(renamedQuery))
       .withColumn(DISTANCE_COL,
         distanceUdf(col(corpusEmbeddingCol), col(QUERY_EMB_ALIAS)))
       .drop(corpusEmbeddingCol)
       .drop(QUERY_EMB_ALIAS)
 
+    // Apply max-distance threshold before windowing to reduce shuffle volume.
+    maxDistance.foreach(d => scored = scored.filter(col(DISTANCE_COL) <= d))
+
     val window = 
Window.partitionBy(QUERY_ID_COL).orderBy(col(DISTANCE_COL).asc)
     val result = scored
       .withColumn(RANK_COL, row_number().over(window))
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
index 817b890c8d0e..8ae49360f2b9 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
@@ -500,19 +500,20 @@ class TestHoodieVectorSearchFunction extends 
HoodieSparkClientTestBase {
         s"""SELECT * FROM hudi_vector_search('$corpusViewName', 
'embedding')""".stripMargin
       ).collect()
     })
-    assertTrue(exFew.getMessage.contains("expects 4-6 arguments") ||
-      exFew.getCause.getMessage.contains("expects 4-6 arguments"))
+    assertTrue(exFew.getMessage.contains("expects 4-8 arguments") ||
+      exFew.getCause.getMessage.contains("expects 4-8 arguments"))
 
-    // Too many arguments
+    // Too many arguments — 9 is one over the new max of 8.
     val exMany = assertThrows(classOf[Exception], () => {
       spark.sql(
         s"""SELECT * FROM hudi_vector_search(
-           |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3, 
'cosine', 'brute_force', 'extra_arg'
+           |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3,
+           |  'cosine', 'brute_force', NULL, 0.5, 'extra_arg'
            |)""".stripMargin
       ).collect()
     })
     val msg = if (exMany.getCause != null) exMany.getCause.getMessage else 
exMany.getMessage
-    assertTrue(msg.contains("4-6 arguments"), s"Expected arg-count error, got: 
$msg")
+    assertTrue(msg.contains("4-8 arguments"), s"Expected arg-count error, got: 
$msg")
   }
 
   @Test
@@ -1356,4 +1357,513 @@ class TestHoodieVectorSearchFunction extends 
HoodieSparkClientTestBase {
     spark.catalog.dropTempView("zero_q_corpus")
   }
 
+  // ─── Pre-filter tests (single-query mode) ───
+
+  @Test
+  def testSingleQueryWithPreFilter(): Unit = {
+    // Corpus labels: x-axis, y-axis, z-axis, xy-diagonal, xyz-diagonal.
+    // Filter to {x-axis, xy-diagonal} then search for [1,0,0].
+    val result = spark.sql(
+      s"""
+         |SELECT id, label, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  'label IN ("x-axis", "xy-diagonal")'
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(2, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+    assertEquals("doc_4", result(1).getAs[String]("id"))
+    assertEquals(1.0 - 0.70710678, result(1).getAs[Double]("_hudi_distance"), 
1e-4)
+  }
+
+  @Test
+  def testSingleQueryWithPreFilterNarrowResult(): Unit = {
+    // Filter to only z-axis then search for [1,0,0] (orthogonal → cosine 
distance 1.0).
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  'label = "z-axis"'
+         |)
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(1, result.length)
+    assertEquals("doc_3", result(0).getAs[String]("id"))
+    assertEquals(1.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+  }
+
+  @Test
+  def testSingleQueryWithNullFilter(): Unit = {
+    // NULL filter behaves as no filter — all rows are considered.
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  3,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(3, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals("doc_4", result(1).getAs[String]("id"))
+    assertEquals("doc_5", result(2).getAs[String]("id"))
+  }
+
+  // ─── Max-distance tests (single-query mode) ───
+
+  @Test
+  def testSingleQueryWithMaxDistance(): Unit = {
+    // For cosine distance with query [1,0,0]:
+    //   doc_1 → 0.0 (included)
+    //   doc_4 → ~0.293 (included)
+    //   doc_5 → ~0.423 (excluded by 0.3)
+    //   doc_2/doc_3 → 1.0 (excluded)
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL,
+         |  0.3
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(2, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+    assertEquals("doc_4", result(1).getAs[String]("id"))
+    assertEquals(1.0 - 0.70710678, result(1).getAs[Double]("_hudi_distance"), 
1e-4)
+  }
+
+  @Test
+  def testSingleQueryMaxDistanceExcludesAll(): Unit = {
+    // max_distance = 0 keeps only exact matches.
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL,
+         |  0.0
+         |)
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(1, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+  }
+
+  @Test
+  def testSingleQueryMaxDistanceWithNullThreshold(): Unit = {
+    // NULL max_distance behaves as no threshold.
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  3,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL,
+         |  NULL
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(3, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+  }
+
+  // ─── Combined filter + max-distance tests (single-query mode) ───
+
+  @Test
+  def testSingleQueryFilterAndMaxDistanceCombined(): Unit = {
+    // Filter excludes z-axis, max_distance = 0.5 excludes anything farther.
+    // Remaining: doc_1 (0.0), doc_4 (~0.293), doc_5 (~0.423) — doc_2 (1.0) is 
excluded.
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  'label != "z-axis"',
+         |  0.5
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(3, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals("doc_4", result(1).getAs[String]("id"))
+    assertEquals("doc_5", result(2).getAs[String]("id"))
+    val ids = result.map(_.getAs[String]("id")).toSet
+    assertFalse(ids.contains("doc_2"), "doc_2 should be excluded by 
max_distance")
+    assertFalse(ids.contains("doc_3"), "doc_3 should be excluded by filter")
+  }
+
+  @Test
+  def testSingleQueryFilterWithL2AndMaxDistance(): Unit = {
+    // L2 distance with filter and max_distance.
+    // label LIKE '%axis%' matches doc_1, doc_2, doc_3.
+    // L2 to [1,0,0]: doc_1=0.0, doc_2=sqrt(2)~=1.414, doc_3=sqrt(2)~=1.414.
+    // max_distance = 1.0 keeps only doc_1.
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'l2',
+         |  'brute_force',
+         |  'label LIKE "%axis%"',
+         |  1.0
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(1, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+  }
+
+  // ─── Batch-mode filter + max-distance tests ───
+
+  @Test
+  def testBatchQueryWithPreFilter(): Unit = {
+    createFloatQueryView("batch_filter_queries", "qid", "qvec", Seq(
+      ("q_x", Seq(1.0f, 0.0f, 0.0f)),
+      ("q_y", Seq(0.0f, 1.0f, 0.0f))
+    ))
+
+    val result = spark.sql(
+      s"""
+         |SELECT *
+         |FROM hudi_vector_search_batch(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  'batch_filter_queries',
+         |  'qvec',
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  'label IN ("x-axis", "y-axis")'
+         |)
+         |""".stripMargin
+    ).collect()
+
+    // Filter narrows corpus to {doc_1, doc_2}, 2 queries → at most 2 results 
per query → 4 rows.
+    assertEquals(4, result.length)
+    val ids = result.map(_.getAs[String]("id")).toSet
+    assertEquals(Set("doc_1", "doc_2"), ids)
+
+    spark.catalog.dropTempView("batch_filter_queries")
+  }
+
+  @Test
+  def testBatchQueryWithMaxDistance(): Unit = {
+    createFloatQueryView("batch_maxd_queries", "qid", "qvec", Seq(
+      ("q_x", Seq(1.0f, 0.0f, 0.0f)),
+      ("q_z", Seq(0.0f, 0.0f, 1.0f))
+    ))
+
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance, _hudi_query_index
+         |FROM hudi_vector_search_batch(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  'batch_maxd_queries',
+         |  'qvec',
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL,
+         |  0.5
+         |)
+         |ORDER BY _hudi_query_index, _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    // q_x (cosine to [1,0,0]) with max 0.5 keeps doc_1(0.0), doc_4(~0.293), 
doc_5(~0.423) → 3.
+    // q_z (cosine to [0,0,1]) with max 0.5 keeps doc_3(0.0), doc_5(~0.423) → 
2.
+    assertEquals(5, result.length)
+    result.foreach { row =>
+      val d = row.getAs[Double]("_hudi_distance")
+      assertTrue(d <= 0.5 + 1e-5,
+        s"Distance $d exceeds max_distance 0.5")
+    }
+
+    spark.catalog.dropTempView("batch_maxd_queries")
+  }
+
+  @Test
+  def testBatchQueryFilterAndMaxDistanceCombined(): Unit = {
+    createFloatQueryView("batch_combo_queries", "qid", "qvec", Seq(
+      ("q_x", Seq(1.0f, 0.0f, 0.0f))
+    ))
+
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search_batch(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  'batch_combo_queries',
+         |  'qvec',
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  'label != "z-axis"',
+         |  0.3
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    // Filter (no z-axis) → {doc_1, doc_2, doc_4, doc_5}; cosine to [1,0,0]:
+    //   doc_1=0.0, doc_4~=0.293, doc_5~=0.423, doc_2=1.0.
+    // max 0.3 keeps doc_1 and doc_4.
+    assertEquals(2, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals("doc_4", result(1).getAs[String]("id"))
+
+    spark.catalog.dropTempView("batch_combo_queries")
+  }
+
+  // ─── Failure-mode tests for filter + max_distance ───
+
+  @Test
+  def testInvalidFilterSyntaxIsWrapped(): Unit = {
+    // Garbage predicate text should surface as a HoodieAnalysisException whose
+    // message includes the offending expression.
+    val ex = assertThrows(classOf[Exception], () => {
+      spark.sql(
+        s"""
+           |SELECT * FROM hudi_vector_search(
+           |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+           |  'cosine', 'brute_force',
+           |  'this is not valid SQL !!'
+           |)
+           |""".stripMargin
+      ).collect()
+    })
+    val msg = if (ex.getCause != null) ex.getCause.getMessage else 
ex.getMessage
+    assertTrue(msg.contains("Invalid pre-filter expression"),
+      s"Expected wrapped pre-filter error, got: $msg")
+    assertTrue(msg.contains("this is not valid SQL"),
+      s"Expected error message to echo the offending expression, got: $msg")
+  }
+
+  @Test
+  def testFilterReferencingUnknownColumnIsWrapped(): Unit = {
+    // Unknown column in the filter should also be surfaced as a wrapped error.
+    val ex = assertThrows(classOf[Exception], () => {
+      spark.sql(
+        s"""
+           |SELECT * FROM hudi_vector_search(
+           |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+           |  'cosine', 'brute_force',
+           |  'no_such_column = "x"'
+           |)
+           |""".stripMargin
+      ).collect()
+    })
+    val msg = if (ex.getCause != null) ex.getCause.getMessage else 
ex.getMessage
+    assertTrue(msg.contains("Invalid pre-filter expression"),
+      s"Expected wrapped pre-filter error, got: $msg")
+    assertTrue(msg.contains("no_such_column"),
+      s"Expected error message to echo the offending expression, got: $msg")
+  }
+
+  @Test
+  def testFilterMustBeStringOrNull(): Unit = {
+    // Passing an integer literal where a string filter is expected must throw.
+    val ex = assertThrows(classOf[Exception], () => {
+      spark.sql(
+        s"""
+           |SELECT * FROM hudi_vector_search(
+           |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+           |  'cosine', 'brute_force',
+           |  42
+           |)
+           |""".stripMargin
+      ).collect()
+    })
+    val msg = if (ex.getCause != null) ex.getCause.getMessage else 
ex.getMessage
+    assertTrue(msg.contains("must be a string literal or NULL"),
+      s"Expected filter-type error, got: $msg")
+  }
+
+  @Test
+  def testMaxDistanceMustBeNumericOrNull(): Unit = {
+    // Passing a string literal where a numeric max_distance is expected must 
throw.
+    val ex = assertThrows(classOf[Exception], () => {
+      spark.sql(
+        s"""
+           |SELECT * FROM hudi_vector_search(
+           |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+           |  'cosine', 'brute_force',
+           |  NULL,
+           |  'not_a_number'
+           |)
+           |""".stripMargin
+      ).collect()
+    })
+    val msg = if (ex.getCause != null) ex.getCause.getMessage else 
ex.getMessage
+    assertTrue(msg.contains("must be a numeric literal or NULL"),
+      s"Expected max_distance-type error, got: $msg")
+  }
+
+  @Test
+  def testMaxDistanceAcceptsNonDoubleNumericLiterals(): Unit = {
+    // parseOptionalDouble widens any NumericType literal (Int, Long, Short, 
Byte, Decimal, Float)
+    // to Double. Exercise that contract with an Int literal and a Decimal 
literal.
+    //
+    // Int literal: max_distance = 1 → widens to 1.0. The threshold is 
inclusive (<=) so all five
+    // corpus rows are kept: doc_1 (0.0), doc_4 (~0.293), doc_5 (~0.423), 
doc_2 (1.0), doc_3 (1.0).
+    // The Int-vs-Double distinction is what's under test, not the boundary 
itself; using a value
+    // exactly at the corpus boundary makes the widening result observable.
+    val intResult = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL,
+         |  1
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(5, intResult.length)
+    val intIds = intResult.map(_.getAs[String]("id")).toSet
+    assertTrue(intIds.contains("doc_1"))
+    assertTrue(intIds.contains("doc_4"))
+    assertTrue(intIds.contains("doc_5"))
+    assertTrue(intIds.contains("doc_2"),
+      "doc_2 at distance 1.0 should be included by inclusive (<=) threshold")
+    assertTrue(intIds.contains("doc_3"),
+      "doc_3 at distance 1.0 should be included by inclusive (<=) threshold")
+
+    // Decimal literal: CAST(0.3 AS DECIMAL(10,2)) → keeps doc_1 (0.0) and 
doc_4 (~0.293).
+    val decResult = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine',
+         |  'brute_force',
+         |  NULL,
+         |  CAST(0.3 AS DECIMAL(10,2))
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(2, decResult.length)
+    assertEquals("doc_1", decResult(0).getAs[String]("id"))
+    assertEquals("doc_4", decResult(1).getAs[String]("id"))
+  }
+
+  @Test
+  def testNegativeMaxDistanceExcludesAll(): Unit = {
+    // Distances are always >= 0 for cosine/L2, so a negative threshold means
+    // no row can satisfy <= d. The user gets an empty result, not an error.
+    val result = spark.sql(
+      s"""
+         |SELECT id FROM hudi_vector_search(
+         |  '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+         |  'cosine', 'brute_force',
+         |  NULL,
+         |  -0.5
+         |)
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(0, result.length)
+  }
+
+  @Test
+  def testBatchQueryInvalidFilterIsWrapped(): Unit = {
+    // Batch mode should also wrap predicate parse / analysis errors.
+    createFloatQueryView("batch_invalid_filter_queries", "qid", "qvec", Seq(
+      ("q_x", Seq(1.0f, 0.0f, 0.0f))
+    ))
+
+    val ex = assertThrows(classOf[Exception], () => {
+      spark.sql(
+        s"""
+           |SELECT * FROM hudi_vector_search_batch(
+           |  '$corpusViewName', 'embedding',
+           |  'batch_invalid_filter_queries', 'qvec',
+           |  5, 'cosine', 'brute_force',
+           |  'no_such_column = "x"'
+           |)
+           |""".stripMargin
+      ).collect()
+    })
+    val msg = if (ex.getCause != null) ex.getCause.getMessage else 
ex.getMessage
+    assertTrue(msg.contains("Invalid pre-filter expression"),
+      s"Expected wrapped pre-filter error, got: $msg")
+
+    spark.catalog.dropTempView("batch_invalid_filter_queries")
+  }
+
 }


Reply via email to