This is an automated email from the ASF dual-hosted git repository.
voonhous pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new 853cbef630b3 feat(vector_search): Implement functionality for
pre-filters and maxD… (#18797)
853cbef630b3 is described below
commit 853cbef630b3ae93fdf385acd57a9f6437aa1851
Author: Rahil C <[email protected]>
AuthorDate: Wed May 27 22:31:30 2026 -0700
feat(vector_search): Implement functionality for pre-filters and maxD…
(#18797)
* feat(vector_search): Implement functionality for pre-filters and
maxDistance threshold parameters
* feat(vector_search): address review feedback on prefilter + maxDistance
- Narrow applyFilter to catch only ParseException + AnalysisException so
unrelated runtime errors are not misreported as filter problems.
- Reject non-numeric literals (including strings whose contents parse as
numbers) for max_distance via explicit NumericType guard in
parseOptionalDouble; the previous toString.toDouble path silently
accepted '0.5' as a string.
- Document on both parseArgs Javadocs that NULL, empty string, and
whitespace-only strings all mean "no filter," matching what the
parseOptionalString helper already implements.
- Add failure-mode tests for invalid filter syntax, unknown filter
column, non-string filter literal, non-numeric max_distance literal,
integer max_distance widening, negative max_distance, empty/whitespace
filter equivalence with no-filter, and batch-mode error wrapping.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
* minor fixes
* feat(vector_search): address voonhous review comments
- Fix stray asterisk in BruteForceSearchAlgorithm Javadoc
- Drop redundant default parameter values on BruteForceSearchAlgorithm
overrides (inherited from the trait)
- Clarify in buildBatchQueryPlan Javadoc that the filter applies to the
corpus only, not the query side
- Add test exercising parseOptionalDouble with non-Double numeric
literals (Int and Decimal) to lock in the widening contract
* test(vector_search): fix testMaxDistanceAcceptsNonDoubleNumericLiterals
for inclusive threshold
The Int-literal sub-case expected 3 rows with max_distance=1, on the
incorrect assumption that the threshold filter is strict (>). The actual
implementation uses <= (consistent with
testSingleQueryMaxDistanceExcludesAll
which depends on doc_1 at distance 0.0 being kept when max_distance=0.0).
With <= and max_distance=1.0, doc_2 and doc_3 (both at distance 1.0) are
included, returning all 5 corpus rows. Updated the expected count and
flipped the negative doc_2 assertion to positive doc_2 + doc_3 checks
with messages naming the inclusive-boundary contract, so the test now
also acts as a regression guard for that semantics.
The Decimal sub-case was unaffected.
Co-Authored-By: Claude Opus 4.7 <[email protected]>
* fix issue
* refactor(vector_search): use direct pattern match for numeric eval result
Replace the toString.toDouble round-trip in parseOptionalDouble with a
direct pattern match on Number / Decimal, avoiding the unnecessary string
conversion (especially wasteful for Decimal literals).
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
---------
Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]>
---
.../HoodieVectorSearchTableValuedFunction.scala | 90 +++-
.../hudi/analysis/HoodieSparkBaseAnalysis.scala | 6 +-
.../analysis/HoodieVectorSearchPlanBuilder.scala | 74 ++-
.../TestHoodieVectorSearchFunction.scala | 520 ++++++++++++++++++++-
4 files changed, 657 insertions(+), 33 deletions(-)
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
index 39102cddce34..30603853351c 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
Literal}
import org.apache.spark.sql.hudi.command.exception.HoodieAnalysisException
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{Decimal, NullType, NumericType, StringType}
object HoodieVectorSearchTableValuedFunction {
@@ -53,21 +53,28 @@ object HoodieVectorSearchTableValuedFunction {
queryVectorExpr: Expression,
k: Int,
metric: DistanceMetric.Value,
- algorithm: SearchAlgorithm.Value
+ algorithm: SearchAlgorithm.Value,
+ filter: Option[String],
+ maxDistance: Option[Double]
)
/**
* Parse arguments for the hudi_vector_search TVF (single-query mode).
*
- * Signature (4–6 args):
- * hudi_vector_search('table', 'embedding_col', ARRAY(1.0, 2.0, ...), k [,
'metric'] [, 'algorithm'])
+ * Signature (4–8 args):
+ * hudi_vector_search('table', 'embedding_col', ARRAY(1.0, 2.0, ...), k
+ * [, 'metric'] [, 'algorithm'] [, 'filter_predicate' | NULL] [,
max_distance | NULL])
* metric defaults to 'cosine'; algorithm defaults to 'brute_force'.
+ * filter is a SQL predicate applied to the corpus before distance
computation;
+ * NULL, the empty string, and whitespace-only strings all mean "no
filter."
+ * max_distance excludes results whose distance exceeds the given
threshold;
+ * NULL means "no threshold." Must be a numeric literal when specified.
*/
def parseArgs(exprs: Seq[Expression]): ParsedArgs = {
- if (exprs.size < 4 || exprs.size > 6) {
+ if (exprs.size < 4 || exprs.size > 8) {
throw new HoodieAnalysisException(
- s"Function '$FUNC_NAME' expects 4-6 arguments: " +
- "(table, embedding_col, query_vector, k [, metric] [, algorithm]).")
+ s"Function '$FUNC_NAME' expects 4-8 arguments: " +
+ "(table, embedding_col, query_vector, k [, metric] [, algorithm] [,
filter] [, max_distance]).")
}
def requireStringLiteral(expr: Expression, argName: String): String = expr
match {
@@ -84,7 +91,9 @@ object HoodieVectorSearchTableValuedFunction {
else DistanceMetric.COSINE
val algorithm = if (exprs.size >= 6)
SearchAlgorithm.fromString(requireStringLiteral(exprs(5), "algorithm"))
else SearchAlgorithm.BRUTE_FORCE
- ParsedArgs(table, embeddingCol, queryVectorExpr, k, metric, algorithm)
+ val filter = if (exprs.size >= 7) parseOptionalString(FUNC_NAME, exprs(6),
"filter") else None
+ val maxDistance = if (exprs.size >= 8) parseOptionalDouble(FUNC_NAME,
exprs(7), "max_distance") else None
+ ParsedArgs(table, embeddingCol, queryVectorExpr, k, metric, algorithm,
filter, maxDistance)
}
private[logical] def parseK(funcName: String, expr: Expression): Int = {
@@ -106,6 +115,43 @@ object HoodieVectorSearchTableValuedFunction {
}
kValue
}
+
+ /** Parses a string argument that may be NULL (meaning "not specified"). */
+ private[logical] def parseOptionalString(
+ funcName: String, expr: Expression, argName: String): Option[String] =
expr match {
+ case Literal(null, _) => None
+ case Literal(v, StringType) if v != null =>
+ val s = v.toString.trim
+ if (s.isEmpty) None else Some(s)
+ case _ => throw new HoodieAnalysisException(
+ s"Function '$funcName': argument '$argName' must be a string literal or
NULL, got: ${expr.sql}")
+ }
+
+ /**
+ * Parses a numeric argument that may be NULL (meaning "not specified").
Accepts
+ * any foldable expression of [[NumericType]] (or [[NullType]] for an untyped
+ * NULL keyword) — including a bare literal or {@code CAST(literal AS
numeric)} —
+ * and widens to Double. String literals are rejected even when their
contents
+ * happen to parse as a number, so the type contract surfaces at parse time.
+ */
+ private[logical] def parseOptionalDouble(
+ funcName: String, expr: Expression, argName: String): Option[Double] = {
+ val numericOrNull = expr.dataType match {
+ case _: NumericType | NullType => true
+ case _ => false
+ }
+ if (!expr.foldable || !numericOrNull) {
+ throw new HoodieAnalysisException(
+ s"Function '$funcName': argument '$argName' must be a numeric literal
or NULL, got: ${expr.sql}")
+ }
+ Option(expr.eval()).map {
+ case d: Decimal => d.toDouble
+ case n: Number => n.doubleValue()
+ case other => throw new HoodieAnalysisException(
+ s"Function '$funcName': argument '$argName' has unexpected runtime
type: " +
+ s"${other.getClass.getName}")
+ }
+ }
}
/**
@@ -142,21 +188,28 @@ object HoodieVectorSearchBatchTableValuedFunction {
queryEmbeddingCol: String,
k: Int,
metric: HoodieVectorSearchTableValuedFunction.DistanceMetric.Value,
- algorithm: HoodieVectorSearchTableValuedFunction.SearchAlgorithm.Value
+ algorithm: HoodieVectorSearchTableValuedFunction.SearchAlgorithm.Value,
+ filter: Option[String],
+ maxDistance: Option[Double]
)
/**
* Parse arguments for the hudi_vector_search_batch TVF (batch-query mode).
*
- * Signature (5–7 args):
- * hudi_vector_search_batch('corpus_table', 'corpus_col', 'query_table',
'query_col', k [, 'metric'] [, 'algorithm'])
+ * Signature (5–9 args):
+ * hudi_vector_search_batch('corpus_table', 'corpus_col', 'query_table',
'query_col', k
+ * [, 'metric'] [, 'algorithm'] [, 'filter_predicate' | NULL] [,
max_distance | NULL])
* metric defaults to 'cosine'; algorithm defaults to 'brute_force'.
+ * filter is a SQL predicate applied to the corpus before distance
computation;
+ * NULL, the empty string, and whitespace-only strings all mean "no
filter."
+ * max_distance excludes results whose distance exceeds the given
threshold;
+ * NULL means "no threshold." Must be a numeric literal when specified.
*/
def parseArgs(exprs: Seq[Expression]): ParsedArgs = {
- if (exprs.size < 5 || exprs.size > 7) {
+ if (exprs.size < 5 || exprs.size > 9) {
throw new HoodieAnalysisException(
- s"Function '$FUNC_NAME' expects 5-7 arguments: " +
- "(corpus_table, corpus_col, query_table, query_col, k [, metric] [,
algorithm]).")
+ s"Function '$FUNC_NAME' expects 5-9 arguments: " +
+ "(corpus_table, corpus_col, query_table, query_col, k [, metric] [,
algorithm] [, filter] [, max_distance]).")
}
def requireStringLiteral(expr: Expression, argName: String): String = expr
match {
@@ -176,7 +229,14 @@ object HoodieVectorSearchBatchTableValuedFunction {
val algorithm = if (exprs.size >= 7)
HoodieVectorSearchTableValuedFunction.SearchAlgorithm.fromString(requireStringLiteral(exprs(6),
"algorithm"))
else HoodieVectorSearchTableValuedFunction.SearchAlgorithm.BRUTE_FORCE
- ParsedArgs(corpusTable, corpusEmbeddingCol, queryTable, queryEmbeddingCol,
k, metric, algorithm)
+ val filter = if (exprs.size >= 8)
+ HoodieVectorSearchTableValuedFunction.parseOptionalString(FUNC_NAME,
exprs(7), "filter")
+ else None
+ val maxDistance = if (exprs.size >= 9)
+ HoodieVectorSearchTableValuedFunction.parseOptionalDouble(FUNC_NAME,
exprs(8), "max_distance")
+ else None
+ ParsedArgs(corpusTable, corpusEmbeddingCol, queryTable, queryEmbeddingCol,
+ k, metric, algorithm, filter, maxDistance)
}
}
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
index 88ad684c1e2f..82ac87b7bcbc 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala
@@ -156,7 +156,8 @@ case class ResolveReferences(spark: SparkSession) extends
Rule[LogicalPlan]
val searchAlgorithm =
HoodieVectorSearchPlanBuilder.resolveAlgorithm(a.algorithm)
val corpusDf = resolveTableToDf(a.table,
HoodieVectorSearchTableValuedFunction.FUNC_NAME)
val queryVector = evaluateQueryVector(a.queryVectorExpr)
- searchAlgorithm.buildSingleQueryPlan(spark, corpusDf, a.embeddingCol,
queryVector, a.k, a.metric)
+ searchAlgorithm.buildSingleQueryPlan(
+ spark, corpusDf, a.embeddingCol, queryVector, a.k, a.metric, a.filter,
a.maxDistance)
case HoodieVectorSearchBatchTableValuedFunction(args) =>
val a = HoodieVectorSearchBatchTableValuedFunction.parseArgs(args)
@@ -164,7 +165,8 @@ case class ResolveReferences(spark: SparkSession) extends
Rule[LogicalPlan]
val corpusDf = resolveTableToDf(a.corpusTable,
HoodieVectorSearchBatchTableValuedFunction.FUNC_NAME)
val queryDf = resolveTableToDf(a.queryTable,
HoodieVectorSearchBatchTableValuedFunction.FUNC_NAME)
searchAlgorithm.buildBatchQueryPlan(
- spark, corpusDf, a.corpusEmbeddingCol, queryDf, a.queryEmbeddingCol,
a.k, a.metric)
+ spark, corpusDf, a.corpusEmbeddingCol, queryDf, a.queryEmbeddingCol,
+ a.k, a.metric, a.filter, a.maxDistance)
case mO@MatchMergeIntoTable(targetTableO, sourceTableO, _)
// START: custom Hudi change: don't want to go to the spark mit
resolution so we resolve the source and target
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
index e99079b4966f..9459ff0c0624 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.hudi.analysis
import org.apache.hudi.common.schema.HoodieSchema
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
+import org.apache.spark.sql.catalyst.parser.ParseException
import
org.apache.spark.sql.catalyst.plans.logical.HoodieVectorSearchTableValuedFunction.{DistanceMetric,
SearchAlgorithm}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.expressions.Window
@@ -64,6 +65,12 @@ trait VectorSearchAlgorithm {
* @param queryVector the query vector, normalized to Array[Double]
* @param k number of nearest neighbors to return
* @param metric distance metric (COSINE, L2, DOT_PRODUCT)
+ * @param filter optional SQL predicate applied to the corpus before
distance computation,
+ * String literals inside the predicate must use
double quotes (e.g. {@code label = "z-axis"})
+ * because the filter is passed as a single-quoted TVF
argument and
+ * Spark's lexer does not support {@code ''} as an
escape sequence
+ * inside string literals.
+ * @param maxDistance optional distance threshold; results exceeding this
value are excluded
* @return an analyzed LogicalPlan whose output matches the single-query
schema contract
*/
def buildSingleQueryPlan(
@@ -72,7 +79,9 @@ trait VectorSearchAlgorithm {
embeddingCol: String,
queryVector: Array[Double],
k: Int,
- metric: DistanceMetric.Value): LogicalPlan
+ metric: DistanceMetric.Value,
+ filter: Option[String] = None,
+ maxDistance: Option[Double] = None): LogicalPlan
/**
* Build a plan that finds the k nearest corpus rows for each row in the
query table.
@@ -84,6 +93,13 @@ trait VectorSearchAlgorithm {
* @param queryEmbeddingCol name of the embedding column in queryDf
* @param k number of nearest neighbors per query
* @param metric distance metric (COSINE, L2, DOT_PRODUCT)
+ * @param filter optional SQL predicate applied to the corpus
before distance
+ * computation, and shrinking cross-join
cardinality. Applied to the
+ * corpus only — to restrict the query side, apply
a projection or
+ * filter on the query table before invoking the
TVF. See
+ * [[buildSingleQueryPlan]] for quoting
requirements.
+ * @param maxDistance optional distance threshold; results exceeding
this value are
+ * excluded before per-query top-K selection,
reducing shuffle volume.
* @return an analyzed LogicalPlan whose output matches the batch-query
schema contract
* @note Batch mode broadcasts the query table to all executors via a
cross-join.
* This is designed for small-to-medium query sets (tens to low
hundreds of rows).
@@ -96,7 +112,9 @@ trait VectorSearchAlgorithm {
queryDf: DataFrame,
queryEmbeddingCol: String,
k: Int,
- metric: DistanceMetric.Value): LogicalPlan
+ metric: DistanceMetric.Value,
+ filter: Option[String] = None,
+ maxDistance: Option[Double] = None): LogicalPlan
}
/**
@@ -232,6 +250,10 @@ object HoodieVectorSearchPlanBuilder {
* and select top-K per query. The cross-join produces O(|corpus| * |queries|)
* intermediate rows, so this is suitable for small-to-medium query sets
* (tens to low hundreds of queries) against moderate corpora.
+ *
+ * <p>Both modes support an optional {@code filter} predicate (applied to the
corpus before
+ * distance computation), and an optional {@code maxDistance} threshold
(results beyond this
+ * distance are excluded before top-K selection, reducing shuffle and sort
volume).
*/
object BruteForceSearchAlgorithm extends VectorSearchAlgorithm {
@@ -239,18 +261,39 @@ object BruteForceSearchAlgorithm extends
VectorSearchAlgorithm {
override val name: String = "brute_force"
+ /**
+ * Applies a user-supplied SQL predicate to the corpus DataFrame, wrapping
+ * [[ParseException]] (predicate syntax error) and [[AnalysisException]]
+ * (unknown column, type mismatch, etc.) in a [[HoodieAnalysisException]]
that
+ * echoes the offending expression. Other exception types are allowed to
+ * propagate untouched so they aren't misreported as a filter problem.
+ */
+ private def applyFilter(df: DataFrame, filterExpr: String): DataFrame = {
+ try {
+ df.filter(filterExpr)
+ } catch {
+ case e @ (_: ParseException | _: AnalysisException) =>
+ throw new HoodieAnalysisException(
+ s"Invalid pre-filter expression '$filterExpr': ${e.getMessage}")
+ }
+ }
+
override def buildSingleQueryPlan(
spark: SparkSession,
corpusDf: DataFrame,
embeddingCol: String,
queryVector: Array[Double],
k: Int,
- metric: DistanceMetric.Value): LogicalPlan = {
+ metric: DistanceMetric.Value,
+ filter: Option[String],
+ maxDistance: Option[Double]): LogicalPlan = {
validateEmbeddingColumn(corpusDf, embeddingCol)
validateQueryVectorDimension(corpusDf, embeddingCol, queryVector.length)
val elemType = getElementType(corpusDf, embeddingCol)
- val filteredDf = corpusDf.filter(col(embeddingCol).isNotNull)
+ // Apply pre-filter before distance computation to enable reducing the
number of rows that need distance computation.
+ var filteredDf = corpusDf.filter(col(embeddingCol).isNotNull)
+ filter.foreach(f => filteredDf = applyFilter(filteredDf, f))
// Validate byte corpus query vector values before creating the UDF.
if (elemType == ByteType) {
@@ -269,11 +312,12 @@ object BruteForceSearchAlgorithm extends
VectorSearchAlgorithm {
// so only the corpus column is passed per row.
val distanceUdf = VectorDistanceUtils.createSingleQueryDistanceUdf(metric,
elemType, queryVector)
- val result = filteredDf
+ var scored = filteredDf
.withColumn(DISTANCE_COL, distanceUdf(col(embeddingCol)))
.drop(embeddingCol)
- .orderBy(col(DISTANCE_COL).asc)
- .limit(k)
+ // Apply max-distance threshold before ordering to shrink the sort input.
+ maxDistance.foreach(d => scored = scored.filter(col(DISTANCE_COL) <= d))
+ val result = scored.orderBy(col(DISTANCE_COL).asc).limit(k)
result.queryExecution.analyzed
}
@@ -285,7 +329,9 @@ object BruteForceSearchAlgorithm extends
VectorSearchAlgorithm {
queryDf: DataFrame,
queryEmbeddingCol: String,
k: Int,
- metric: DistanceMetric.Value): LogicalPlan = {
+ metric: DistanceMetric.Value,
+ filter: Option[String],
+ maxDistance: Option[Double]): LogicalPlan = {
validateEmbeddingColumn(corpusDf, corpusEmbeddingCol)
validateEmbeddingColumn(queryDf, queryEmbeddingCol)
validateElementTypeCompatibility(corpusDf, corpusEmbeddingCol, queryDf,
queryEmbeddingCol)
@@ -293,7 +339,10 @@ object BruteForceSearchAlgorithm extends
VectorSearchAlgorithm {
val corpusElemType = getElementType(corpusDf, corpusEmbeddingCol)
val distanceUdf = VectorDistanceUtils.createDistanceUdf(metric,
corpusElemType)
- val filteredCorpus = corpusDf.filter(col(corpusEmbeddingCol).isNotNull)
+ // Apply pre-filter before the cross-join to enable Hudi partition pruning
and
+ // data skipping, reducing the cross-join cardinality.
+ var filteredCorpus = corpusDf.filter(col(corpusEmbeddingCol).isNotNull)
+ filter.foreach(f => filteredCorpus = applyFilter(filteredCorpus, f))
// Prefix clashing query columns with "_hudi_query_" to avoid cross-join
ambiguity when
// corpus and query share column names (e.g. both have "id" or
"embedding").
@@ -310,12 +359,15 @@ object BruteForceSearchAlgorithm extends
VectorSearchAlgorithm {
}: _*)
// Cross join corpus with broadcast queries, compute distance, then rank
- val scored = filteredCorpus.crossJoin(broadcast(renamedQuery))
+ var scored = filteredCorpus.crossJoin(broadcast(renamedQuery))
.withColumn(DISTANCE_COL,
distanceUdf(col(corpusEmbeddingCol), col(QUERY_EMB_ALIAS)))
.drop(corpusEmbeddingCol)
.drop(QUERY_EMB_ALIAS)
+ // Apply max-distance threshold before windowing to reduce shuffle volume.
+ maxDistance.foreach(d => scored = scored.filter(col(DISTANCE_COL) <= d))
+
val window =
Window.partitionBy(QUERY_ID_COL).orderBy(col(DISTANCE_COL).asc)
val result = scored
.withColumn(RANK_COL, row_number().over(window))
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
index 817b890c8d0e..8ae49360f2b9 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala
@@ -500,19 +500,20 @@ class TestHoodieVectorSearchFunction extends
HoodieSparkClientTestBase {
s"""SELECT * FROM hudi_vector_search('$corpusViewName',
'embedding')""".stripMargin
).collect()
})
- assertTrue(exFew.getMessage.contains("expects 4-6 arguments") ||
- exFew.getCause.getMessage.contains("expects 4-6 arguments"))
+ assertTrue(exFew.getMessage.contains("expects 4-8 arguments") ||
+ exFew.getCause.getMessage.contains("expects 4-8 arguments"))
- // Too many arguments
+ // Too many arguments — 9 is one over the new max of 8.
val exMany = assertThrows(classOf[Exception], () => {
spark.sql(
s"""SELECT * FROM hudi_vector_search(
- | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3,
'cosine', 'brute_force', 'extra_arg'
+ | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3,
+ | 'cosine', 'brute_force', NULL, 0.5, 'extra_arg'
|)""".stripMargin
).collect()
})
val msg = if (exMany.getCause != null) exMany.getCause.getMessage else
exMany.getMessage
- assertTrue(msg.contains("4-6 arguments"), s"Expected arg-count error, got:
$msg")
+ assertTrue(msg.contains("4-8 arguments"), s"Expected arg-count error, got:
$msg")
}
@Test
@@ -1356,4 +1357,513 @@ class TestHoodieVectorSearchFunction extends
HoodieSparkClientTestBase {
spark.catalog.dropTempView("zero_q_corpus")
}
+ // ─── Pre-filter tests (single-query mode) ───
+
+ @Test
+ def testSingleQueryWithPreFilter(): Unit = {
+ // Corpus labels: x-axis, y-axis, z-axis, xy-diagonal, xyz-diagonal.
+ // Filter to {x-axis, xy-diagonal} then search for [1,0,0].
+ val result = spark.sql(
+ s"""
+ |SELECT id, label, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | 'label IN ("x-axis", "xy-diagonal")'
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(2, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+ assertEquals("doc_4", result(1).getAs[String]("id"))
+ assertEquals(1.0 - 0.70710678, result(1).getAs[Double]("_hudi_distance"),
1e-4)
+ }
+
+ @Test
+ def testSingleQueryWithPreFilterNarrowResult(): Unit = {
+ // Filter to only z-axis then search for [1,0,0] (orthogonal → cosine
distance 1.0).
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | 'label = "z-axis"'
+ |)
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(1, result.length)
+ assertEquals("doc_3", result(0).getAs[String]("id"))
+ assertEquals(1.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+ }
+
+ @Test
+ def testSingleQueryWithNullFilter(): Unit = {
+ // NULL filter behaves as no filter — all rows are considered.
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 3,
+ | 'cosine',
+ | 'brute_force',
+ | NULL
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(3, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals("doc_4", result(1).getAs[String]("id"))
+ assertEquals("doc_5", result(2).getAs[String]("id"))
+ }
+
+ // ─── Max-distance tests (single-query mode) ───
+
+ @Test
+ def testSingleQueryWithMaxDistance(): Unit = {
+ // For cosine distance with query [1,0,0]:
+ // doc_1 → 0.0 (included)
+ // doc_4 → ~0.293 (included)
+ // doc_5 → ~0.423 (excluded by 0.3)
+ // doc_2/doc_3 → 1.0 (excluded)
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | NULL,
+ | 0.3
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(2, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+ assertEquals("doc_4", result(1).getAs[String]("id"))
+ assertEquals(1.0 - 0.70710678, result(1).getAs[Double]("_hudi_distance"),
1e-4)
+ }
+
+ @Test
+ def testSingleQueryMaxDistanceExcludesAll(): Unit = {
+ // max_distance = 0 keeps only exact matches.
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | NULL,
+ | 0.0
+ |)
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(1, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+ }
+
+ @Test
+ def testSingleQueryMaxDistanceWithNullThreshold(): Unit = {
+ // NULL max_distance behaves as no threshold.
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 3,
+ | 'cosine',
+ | 'brute_force',
+ | NULL,
+ | NULL
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(3, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ }
+
+ // ─── Combined filter + max-distance tests (single-query mode) ───
+
+ @Test
+ def testSingleQueryFilterAndMaxDistanceCombined(): Unit = {
+ // Filter excludes z-axis, max_distance = 0.5 excludes anything farther.
+ // Remaining: doc_1 (0.0), doc_4 (~0.293), doc_5 (~0.423) — doc_2 (1.0) is
excluded.
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | 'label != "z-axis"',
+ | 0.5
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(3, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals("doc_4", result(1).getAs[String]("id"))
+ assertEquals("doc_5", result(2).getAs[String]("id"))
+ val ids = result.map(_.getAs[String]("id")).toSet
+ assertFalse(ids.contains("doc_2"), "doc_2 should be excluded by
max_distance")
+ assertFalse(ids.contains("doc_3"), "doc_3 should be excluded by filter")
+ }
+
+ @Test
+ def testSingleQueryFilterWithL2AndMaxDistance(): Unit = {
+ // L2 distance with filter and max_distance.
+ // label LIKE '%axis%' matches doc_1, doc_2, doc_3.
+ // L2 to [1,0,0]: doc_1=0.0, doc_2=sqrt(2)~=1.414, doc_3=sqrt(2)~=1.414.
+ // max_distance = 1.0 keeps only doc_1.
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'l2',
+ | 'brute_force',
+ | 'label LIKE "%axis%"',
+ | 1.0
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(1, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+ }
+
+ // ─── Batch-mode filter + max-distance tests ───
+
+ @Test
+ def testBatchQueryWithPreFilter(): Unit = {
+ createFloatQueryView("batch_filter_queries", "qid", "qvec", Seq(
+ ("q_x", Seq(1.0f, 0.0f, 0.0f)),
+ ("q_y", Seq(0.0f, 1.0f, 0.0f))
+ ))
+
+ val result = spark.sql(
+ s"""
+ |SELECT *
+ |FROM hudi_vector_search_batch(
+ | '$corpusViewName',
+ | 'embedding',
+ | 'batch_filter_queries',
+ | 'qvec',
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | 'label IN ("x-axis", "y-axis")'
+ |)
+ |""".stripMargin
+ ).collect()
+
+ // Filter narrows corpus to {doc_1, doc_2}, 2 queries → at most 2 results
per query → 4 rows.
+ assertEquals(4, result.length)
+ val ids = result.map(_.getAs[String]("id")).toSet
+ assertEquals(Set("doc_1", "doc_2"), ids)
+
+ spark.catalog.dropTempView("batch_filter_queries")
+ }
+
+ @Test
+ def testBatchQueryWithMaxDistance(): Unit = {
+ createFloatQueryView("batch_maxd_queries", "qid", "qvec", Seq(
+ ("q_x", Seq(1.0f, 0.0f, 0.0f)),
+ ("q_z", Seq(0.0f, 0.0f, 1.0f))
+ ))
+
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance, _hudi_query_index
+ |FROM hudi_vector_search_batch(
+ | '$corpusViewName',
+ | 'embedding',
+ | 'batch_maxd_queries',
+ | 'qvec',
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | NULL,
+ | 0.5
+ |)
+ |ORDER BY _hudi_query_index, _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ // q_x (cosine to [1,0,0]) with max 0.5 keeps doc_1(0.0), doc_4(~0.293),
doc_5(~0.423) → 3.
+ // q_z (cosine to [0,0,1]) with max 0.5 keeps doc_3(0.0), doc_5(~0.423) →
2.
+ assertEquals(5, result.length)
+ result.foreach { row =>
+ val d = row.getAs[Double]("_hudi_distance")
+ assertTrue(d <= 0.5 + 1e-5,
+ s"Distance $d exceeds max_distance 0.5")
+ }
+
+ spark.catalog.dropTempView("batch_maxd_queries")
+ }
+
+ @Test
+ def testBatchQueryFilterAndMaxDistanceCombined(): Unit = {
+ createFloatQueryView("batch_combo_queries", "qid", "qvec", Seq(
+ ("q_x", Seq(1.0f, 0.0f, 0.0f))
+ ))
+
+ val result = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search_batch(
+ | '$corpusViewName',
+ | 'embedding',
+ | 'batch_combo_queries',
+ | 'qvec',
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | 'label != "z-axis"',
+ | 0.3
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ // Filter (no z-axis) → {doc_1, doc_2, doc_4, doc_5}; cosine to [1,0,0]:
+ // doc_1=0.0, doc_4~=0.293, doc_5~=0.423, doc_2=1.0.
+ // max 0.3 keeps doc_1 and doc_4.
+ assertEquals(2, result.length)
+ assertEquals("doc_1", result(0).getAs[String]("id"))
+ assertEquals("doc_4", result(1).getAs[String]("id"))
+
+ spark.catalog.dropTempView("batch_combo_queries")
+ }
+
+ // ─── Failure-mode tests for filter + max_distance ───
+
+ @Test
+ def testInvalidFilterSyntaxIsWrapped(): Unit = {
+ // Garbage predicate text should surface as a HoodieAnalysisException whose
+ // message includes the offending expression.
+ val ex = assertThrows(classOf[Exception], () => {
+ spark.sql(
+ s"""
+ |SELECT * FROM hudi_vector_search(
+ | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+ | 'cosine', 'brute_force',
+ | 'this is not valid SQL !!'
+ |)
+ |""".stripMargin
+ ).collect()
+ })
+ val msg = if (ex.getCause != null) ex.getCause.getMessage else
ex.getMessage
+ assertTrue(msg.contains("Invalid pre-filter expression"),
+ s"Expected wrapped pre-filter error, got: $msg")
+ assertTrue(msg.contains("this is not valid SQL"),
+ s"Expected error message to echo the offending expression, got: $msg")
+ }
+
+ @Test
+ def testFilterReferencingUnknownColumnIsWrapped(): Unit = {
+ // Unknown column in the filter should also be surfaced as a wrapped error.
+ val ex = assertThrows(classOf[Exception], () => {
+ spark.sql(
+ s"""
+ |SELECT * FROM hudi_vector_search(
+ | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+ | 'cosine', 'brute_force',
+ | 'no_such_column = "x"'
+ |)
+ |""".stripMargin
+ ).collect()
+ })
+ val msg = if (ex.getCause != null) ex.getCause.getMessage else
ex.getMessage
+ assertTrue(msg.contains("Invalid pre-filter expression"),
+ s"Expected wrapped pre-filter error, got: $msg")
+ assertTrue(msg.contains("no_such_column"),
+ s"Expected error message to echo the offending expression, got: $msg")
+ }
+
+ @Test
+ def testFilterMustBeStringOrNull(): Unit = {
+ // Passing an integer literal where a string filter is expected must throw.
+ val ex = assertThrows(classOf[Exception], () => {
+ spark.sql(
+ s"""
+ |SELECT * FROM hudi_vector_search(
+ | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+ | 'cosine', 'brute_force',
+ | 42
+ |)
+ |""".stripMargin
+ ).collect()
+ })
+ val msg = if (ex.getCause != null) ex.getCause.getMessage else
ex.getMessage
+ assertTrue(msg.contains("must be a string literal or NULL"),
+ s"Expected filter-type error, got: $msg")
+ }
+
+ @Test
+ def testMaxDistanceMustBeNumericOrNull(): Unit = {
+ // Passing a string literal where a numeric max_distance is expected must
throw.
+ val ex = assertThrows(classOf[Exception], () => {
+ spark.sql(
+ s"""
+ |SELECT * FROM hudi_vector_search(
+ | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+ | 'cosine', 'brute_force',
+ | NULL,
+ | 'not_a_number'
+ |)
+ |""".stripMargin
+ ).collect()
+ })
+ val msg = if (ex.getCause != null) ex.getCause.getMessage else
ex.getMessage
+ assertTrue(msg.contains("must be a numeric literal or NULL"),
+ s"Expected max_distance-type error, got: $msg")
+ }
+
+ @Test
+ def testMaxDistanceAcceptsNonDoubleNumericLiterals(): Unit = {
+ // parseOptionalDouble widens any NumericType literal (Int, Long, Short,
Byte, Decimal, Float)
+ // to Double. Exercise that contract with an Int literal and a Decimal
literal.
+ //
+ // Int literal: max_distance = 1 → widens to 1.0. The threshold is
inclusive (<=) so all five
+ // corpus rows are kept: doc_1 (0.0), doc_4 (~0.293), doc_5 (~0.423),
doc_2 (1.0), doc_3 (1.0).
+ // The Int-vs-Double distinction is what's under test, not the boundary
itself; using a value
+ // exactly at the corpus boundary makes the widening result observable.
+ val intResult = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | NULL,
+ | 1
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(5, intResult.length)
+ val intIds = intResult.map(_.getAs[String]("id")).toSet
+ assertTrue(intIds.contains("doc_1"))
+ assertTrue(intIds.contains("doc_4"))
+ assertTrue(intIds.contains("doc_5"))
+ assertTrue(intIds.contains("doc_2"),
+ "doc_2 at distance 1.0 should be included by inclusive (<=) threshold")
+ assertTrue(intIds.contains("doc_3"),
+ "doc_3 at distance 1.0 should be included by inclusive (<=) threshold")
+
+ // Decimal literal: CAST(0.3 AS DECIMAL(10,2)) → keeps doc_1 (0.0) and
doc_4 (~0.293).
+ val decResult = spark.sql(
+ s"""
+ |SELECT id, _hudi_distance
+ |FROM hudi_vector_search(
+ | '$corpusViewName',
+ | 'embedding',
+ | ARRAY(1.0, 0.0, 0.0),
+ | 5,
+ | 'cosine',
+ | 'brute_force',
+ | NULL,
+ | CAST(0.3 AS DECIMAL(10,2))
+ |)
+ |ORDER BY _hudi_distance
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(2, decResult.length)
+ assertEquals("doc_1", decResult(0).getAs[String]("id"))
+ assertEquals("doc_4", decResult(1).getAs[String]("id"))
+ }
+
+ @Test
+ def testNegativeMaxDistanceExcludesAll(): Unit = {
+ // Distances are always >= 0 for cosine/L2, so a negative threshold means
+ // no row can satisfy <= d. The user gets an empty result, not an error.
+ val result = spark.sql(
+ s"""
+ |SELECT id FROM hudi_vector_search(
+ | '$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 5,
+ | 'cosine', 'brute_force',
+ | NULL,
+ | -0.5
+ |)
+ |""".stripMargin
+ ).collect()
+
+ assertEquals(0, result.length)
+ }
+
+ @Test
+ def testBatchQueryInvalidFilterIsWrapped(): Unit = {
+ // Batch mode should also wrap predicate parse / analysis errors.
+ createFloatQueryView("batch_invalid_filter_queries", "qid", "qvec", Seq(
+ ("q_x", Seq(1.0f, 0.0f, 0.0f))
+ ))
+
+ val ex = assertThrows(classOf[Exception], () => {
+ spark.sql(
+ s"""
+ |SELECT * FROM hudi_vector_search_batch(
+ | '$corpusViewName', 'embedding',
+ | 'batch_invalid_filter_queries', 'qvec',
+ | 5, 'cosine', 'brute_force',
+ | 'no_such_column = "x"'
+ |)
+ |""".stripMargin
+ ).collect()
+ })
+ val msg = if (ex.getCause != null) ex.getCause.getMessage else
ex.getMessage
+ assertTrue(msg.contains("Invalid pre-filter expression"),
+ s"Expected wrapped pre-filter error, got: $msg")
+
+ spark.catalog.dropTempView("batch_invalid_filter_queries")
+ }
+
}