This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 708ea8fc75 [spark] Spark support vector search (#6950)
708ea8fc75 is described below
commit 708ea8fc7592ac95d480cd62c7fcccb9daf7f184
Author: jerry <[email protected]>
AuthorDate: Tue Jan 6 13:07:29 2026 +0800
[spark] Spark support vector search (#6950)
---
.../org/apache/paimon/table/VectorSearchTable.java | 101 ++++++++++++++
.../paimon/spark/PaimonBaseScanBuilder.scala | 5 +-
.../scala/org/apache/paimon/spark/PaimonScan.scala | 3 +-
.../apache/paimon/spark/PaimonScanBuilder.scala | 9 +-
.../scala/org/apache/paimon/spark/PaimonScan.scala | 3 +-
.../spark/sql/VectorSearchPushDownTest.scala | 145 +++++++++++++++++++++
.../paimon/spark/PaimonBaseScanBuilder.scala | 3 +-
.../scala/org/apache/paimon/spark/PaimonScan.scala | 3 +-
.../apache/paimon/spark/PaimonScanBuilder.scala | 17 ++-
.../plans/logical/PaimonTableValuedFunctions.scala | 117 +++++++++++++++--
.../org/apache/paimon/spark/scan/BaseScan.scala | 7 +-
.../spark/sql/BaseVectorSearchPushDownTest.scala | 100 ++++++++++++++
12 files changed, 494 insertions(+), 19 deletions(-)
diff --git
a/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java
b/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java
new file mode 100644
index 0000000000..cb98e25055
--- /dev/null
+++ b/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.table;
+
+import org.apache.paimon.fs.FileIO;
+import org.apache.paimon.predicate.VectorSearch;
+import org.apache.paimon.table.source.InnerTableRead;
+import org.apache.paimon.table.source.InnerTableScan;
+import org.apache.paimon.types.RowType;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A table wrapper to hold vector search information. This is used to pass
vector search pushdown
+ * information from logical plan optimization to physical plan execution. For
now, it is only used
+ * by internal for Spark engine.
+ */
+public class VectorSearchTable implements ReadonlyTable {
+
+ private final InnerTable origin;
+ private final VectorSearch vectorSearch;
+
+ private VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) {
+ this.origin = origin;
+ this.vectorSearch = vectorSearch;
+ }
+
+ public static VectorSearchTable create(InnerTable origin, VectorSearch
vectorSearch) {
+ return new VectorSearchTable(origin, vectorSearch);
+ }
+
+ public VectorSearch vectorSearch() {
+ return vectorSearch;
+ }
+
+ public InnerTable origin() {
+ return origin;
+ }
+
+ @Override
+ public String name() {
+ return origin.name();
+ }
+
+ @Override
+ public RowType rowType() {
+ return origin.rowType();
+ }
+
+ @Override
+ public List<String> primaryKeys() {
+ return origin.primaryKeys();
+ }
+
+ @Override
+ public List<String> partitionKeys() {
+ return origin.partitionKeys();
+ }
+
+ @Override
+ public Map<String, String> options() {
+ return origin.options();
+ }
+
+ @Override
+ public FileIO fileIO() {
+ return origin.fileIO();
+ }
+
+ @Override
+ public InnerTableRead newRead() {
+ return origin.newRead();
+ }
+
+ @Override
+ public InnerTableScan newScan() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Table copy(Map<String, String> dynamicOptions) {
+ return new VectorSearchTable((InnerTable) origin.copy(dynamicOptions),
vectorSearch);
+ }
+}
diff --git
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
index 41a1d552f1..4f5451c95d 100644
---
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
+++
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
@@ -21,7 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.CoreOptions
import org.apache.paimon.partition.PartitionPredicate
import
org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
-import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate}
+import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate,
TopN, VectorSearch}
import org.apache.paimon.table.SpecialFields.rowTypeWithRowTracking
import org.apache.paimon.table.Table
import org.apache.paimon.types.RowType
@@ -50,6 +50,9 @@ abstract class PaimonBaseScanBuilder
protected var pushedPartitionFilters: Array[PartitionPredicate] = Array.empty
protected var pushedDataFilters: Array[Predicate] = Array.empty
+ protected var pushedLimit: Option[Int] = None
+ protected var pushedTopN: Option[TopN] = None
+ protected var pushedVectorSearch: Option[VectorSearch] = None
protected var requiredSchema: StructType =
SparkTypeUtils.fromPaimonRowType(table.rowType())
diff --git
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index e9eaa7d6cc..d6292ad8cf 100644
---
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -19,7 +19,7 @@
package org.apache.paimon.spark
import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.table.InnerTable
import org.apache.spark.sql.PaimonUtils.fieldReference
@@ -37,6 +37,7 @@ case class PaimonScan(
pushedDataFilters: Seq[Predicate],
override val pushedLimit: Option[Int] = None,
override val pushedTopN: Option[TopN] = None,
+ override val pushedVectorSearch: Option[VectorSearch] = None,
bucketedScanDisabled: Boolean = true)
extends PaimonBaseScan(table)
with SupportsRuntimeFiltering {
diff --git
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index 21ab46dabc..770bd8f802 100644
---
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -25,6 +25,13 @@ import org.apache.spark.sql.connector.read.Scan
class PaimonScanBuilder(val table: InnerTable) extends PaimonBaseScanBuilder {
override def build(): Scan = {
- PaimonScan(table, requiredSchema, pushedPartitionFilters,
pushedDataFilters)
+ PaimonScan(
+ table,
+ requiredSchema,
+ pushedPartitionFilters,
+ pushedDataFilters,
+ pushedLimit,
+ pushedTopN,
+ pushedVectorSearch)
}
}
diff --git
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 3afca15303..8d06751f57 100644
---
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -19,7 +19,7 @@
package org.apache.paimon.spark
import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
import org.apache.paimon.table.source.{DataSplit, Split}
@@ -39,6 +39,7 @@ case class PaimonScan(
pushedDataFilters: Seq[Predicate],
override val pushedLimit: Option[Int],
override val pushedTopN: Option[TopN],
+ override val pushedVectorSearch: Option[VectorSearch],
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table)
with SupportsRuntimeFiltering
diff --git
a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala
b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala
new file mode 100644
index 0000000000..7ac3c5df0d
--- /dev/null
+++
b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonScan
+
+/** Tests for vector search table-valued function with global vector index. */
+class VectorSearchPushDownTest extends BaseVectorSearchPushDownTest {
+ test("vector search with global index") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'global-index.row-count-per-shard' = '10000',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ // Insert 100 rows with predictable vectors
+ val values = (0 until 100)
+ .map(
+ i => s"($i, array(cast($i as float), cast(${i + 1} as float),
cast(${i + 2} as float)))")
+ .mkString(",")
+ spark.sql(s"INSERT INTO T VALUES $values")
+
+ // Create vector index
+ val output = spark
+ .sql("CALL sys.create_global_index(table => 'test.T', index_column =>
'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+ .collect()
+ .head
+ assert(output.getBoolean(0))
+
+ // Test vector search with table-valued function syntax
+ val result = spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f,
52.0f), 5)
+ |""".stripMargin)
+ .collect()
+
+ // The result should contain 5 rows
+ assert(result.length == 5)
+
+ // Vector (50, 51, 52) should be most similar to the row with id=50
+ assert(result.map(_.getInt(0)).contains(50))
+ }
+ }
+
+ test("vector search pushdown is applied in plan") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'global-index.row-count-per-shard' = '10000',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ val values = (0 until 10)
+ .map(
+ i => s"($i, array(cast($i as float), cast(${i + 1} as float),
cast(${i + 2} as float)))")
+ .mkString(",")
+ spark.sql(s"INSERT INTO T VALUES $values")
+
+ // Create vector index
+ spark
+ .sql("CALL sys.create_global_index(table => 'test.T', index_column =>
'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+ .collect()
+
+ // Check that vector search is pushed down with table function syntax
+ val df = spark.sql("""
+ |SELECT * FROM vector_search('T', 'v', array(50.0f,
51.0f, 52.0f), 5)
+ |""".stripMargin)
+
+ // Get the scan from the executed plan (physical plan)
+ val executedPlan = df.queryExecution.executedPlan
+ val batchScans = executedPlan.collect {
+ case scan: org.apache.spark.sql.execution.datasources.v2.BatchScanExec
=> scan
+ }
+
+ assert(batchScans.nonEmpty, "Should have a BatchScanExec in executed
plan")
+ val paimonScans = batchScans.filter(_.scan.isInstanceOf[PaimonScan])
+ assert(paimonScans.nonEmpty, "Should have a PaimonScan in executed plan")
+
+ val paimonScan = paimonScans.head.scan.asInstanceOf[PaimonScan]
+ assert(paimonScan.pushedVectorSearch.isDefined, "Vector search should be
pushed down")
+ assert(paimonScan.pushedVectorSearch.get.fieldName() == "v", "Field name
should be 'v'")
+ assert(paimonScan.pushedVectorSearch.get.limit() == 5, "Limit should be
5")
+ }
+ }
+
+ test("vector search topk returns correct results") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'global-index.row-count-per-shard' = '10000',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ // Insert rows with distinct vectors
+ val values = (1 to 100)
+ .map {
+ i =>
+ val v = math.sqrt(3.0 * i * i)
+ val normalized = i.toFloat / v.toFloat
+ s"($i, array($normalized, $normalized, $normalized))"
+ }
+ .mkString(",")
+ spark.sql(s"INSERT INTO T VALUES $values")
+
+ // Create vector index
+ spark.sql(
+ "CALL sys.create_global_index(table => 'test.T', index_column => 'v',
index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+
+ // Query for top 10 similar to (1, 1, 1) normalized
+ val result = spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(0.577f, 0.577f,
0.577f), 10)
+ |""".stripMargin)
+ .collect()
+
+ assert(result.length == 10)
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
index 8179f504b3..47723171e4 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
@@ -21,7 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.CoreOptions
import org.apache.paimon.partition.PartitionPredicate
import
org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
-import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN}
+import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate,
TopN, VectorSearch}
import org.apache.paimon.table.{SpecialFields, Table}
import org.apache.paimon.types.RowType
@@ -52,6 +52,7 @@ abstract class PaimonBaseScanBuilder
protected var pushedDataFilters: Array[Predicate] = Array.empty
protected var pushedLimit: Option[Int] = None
protected var pushedTopN: Option[TopN] = None
+ protected var pushedVectorSearch: Option[VectorSearch] = None
protected var requiredSchema: StructType =
SparkTypeUtils.fromPaimonRowType(table.rowType())
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 06a97ee8b4..c9f0e9506e 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -20,7 +20,7 @@ package org.apache.paimon.spark
import org.apache.paimon.CoreOptions.BucketFunctionType
import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.spark.commands.BucketExpression.quote
import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
import org.apache.paimon.table.source.{DataSplit, Split}
@@ -41,6 +41,7 @@ case class PaimonScan(
pushedDataFilters: Seq[Predicate],
override val pushedLimit: Option[Int],
override val pushedTopN: Option[TopN],
+ override val pushedVectorSearch: Option[VectorSearch],
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table)
with SupportsReportPartitioning
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index 6eeaaf7b93..de75bb823d 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -128,13 +128,26 @@ class PaimonScanBuilder(val table: InnerTable)
localScan match {
case Some(scan) => scan
case None =>
+ val (actualTable, vectorSearch) = table match {
+ case vst: org.apache.paimon.table.VectorSearchTable =>
+ val tableVectorSearch = Option(vst.vectorSearch())
+ val vs = (tableVectorSearch, pushedVectorSearch) match {
+ case (Some(_), _) => tableVectorSearch
+ case (None, Some(_)) => pushedVectorSearch
+ case (None, None) => None
+ }
+ (vst.origin(), vs)
+ case _ => (table, pushedVectorSearch)
+ }
+
PaimonScan(
- table,
+ actualTable,
requiredSchema,
pushedPartitionFilters,
pushedDataFilters,
pushedLimit,
- pushedTopN)
+ pushedTopN,
+ vectorSearch)
}
}
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
index e4f5e7856c..6bb6004db8 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
@@ -19,9 +19,10 @@
package org.apache.paimon.spark.catalyst.plans.logical
import org.apache.paimon.CoreOptions
+import org.apache.paimon.predicate.VectorSearch
import org.apache.paimon.spark.SparkTable
import
org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions._
-import org.apache.paimon.table.DataTable
+import org.apache.paimon.table.{DataTable, InnerTable, VectorSearchTable}
import
org.apache.paimon.table.source.snapshot.TimeTravelUtil.InconsistentTagBucketException
import org.apache.spark.sql.PaimonUtils.createDataset
@@ -29,7 +30,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
import
org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateArray,
Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -42,9 +43,10 @@ object PaimonTableValuedFunctions {
val INCREMENTAL_QUERY = "paimon_incremental_query"
val INCREMENTAL_BETWEEN_TIMESTAMP = "paimon_incremental_between_timestamp"
val INCREMENTAL_TO_AUTO_TAG = "paimon_incremental_to_auto_tag"
+ val VECTOR_SEARCH = "vector_search"
val supportedFnNames: Seq[String] =
- Seq(INCREMENTAL_QUERY, INCREMENTAL_BETWEEN_TIMESTAMP,
INCREMENTAL_TO_AUTO_TAG)
+ Seq(INCREMENTAL_QUERY, INCREMENTAL_BETWEEN_TIMESTAMP,
INCREMENTAL_TO_AUTO_TAG, VECTOR_SEARCH)
private type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo,
TableFunctionBuilder)
@@ -56,6 +58,8 @@ object PaimonTableValuedFunctions {
FunctionRegistryBase.build[IncrementalBetweenTimestamp](fnName, since
= None)
case INCREMENTAL_TO_AUTO_TAG =>
FunctionRegistryBase.build[IncrementalToAutoTag](fnName, since = None)
+ case VECTOR_SEARCH =>
+ FunctionRegistryBase.build[VectorSearchQuery](fnName, since = None)
case _ =>
throw new Exception(s"Function $fnName isn't a supported table valued
function.")
}
@@ -85,17 +89,45 @@ object PaimonTableValuedFunctions {
val sparkCatalog =
catalogManager.catalog(catalogName).asInstanceOf[TableCatalog]
val ident: Identifier = Identifier.of(Array(dbName), tableName)
val sparkTable = sparkCatalog.loadTable(ident)
- val options = tvf.parseArgs(args.tail)
- usingSparkIncrementQuery(tvf, sparkTable, options) match {
- case Some(snapshotIdPair: (Long, Long)) =>
- sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident, options,
snapshotIdPair)
+ // Handle vector_search specially
+ tvf match {
+ case vsq: VectorSearchQuery =>
+ resolveVectorSearchQuery(sparkTable, sparkCatalog, ident, vsq,
args.tail)
case _ =>
+ val options = tvf.parseArgs(args.tail)
+ usingSparkIncrementQuery(tvf, sparkTable, options) match {
+ case Some(snapshotIdPair: (Long, Long)) =>
+ sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident,
options, snapshotIdPair)
+ case _ =>
+ DataSourceV2Relation.create(
+ sparkTable,
+ Some(sparkCatalog),
+ Some(ident),
+ new CaseInsensitiveStringMap(options.asJava))
+ }
+ }
+ }
+
+ private def resolveVectorSearchQuery(
+ sparkTable: Table,
+ sparkCatalog: TableCatalog,
+ ident: Identifier,
+ vsq: VectorSearchQuery,
+ argsWithoutTable: Seq[Expression]): LogicalPlan = {
+ sparkTable match {
+ case st @ SparkTable(innerTable: InnerTable) =>
+ val vectorSearch = vsq.createVectorSearch(innerTable, argsWithoutTable)
+ val vectorSearchTable = VectorSearchTable.create(innerTable,
vectorSearch)
DataSourceV2Relation.create(
- sparkTable,
+ st.copy(table = vectorSearchTable),
Some(sparkCatalog),
Some(ident),
- new CaseInsensitiveStringMap(options.asJava))
+ CaseInsensitiveStringMap.empty())
+ case _ =>
+ throw new RuntimeException(
+ "vector_search only supports Paimon SparkTable backed by InnerTable,
" +
+ s"but got table implementation: ${sparkTable.getClass.getName}")
}
}
@@ -207,3 +239,70 @@ case class IncrementalToAutoTag(override val args:
Seq[Expression])
Map(CoreOptions.INCREMENTAL_TO_AUTO_TAG.key -> endTagName)
}
}
+
+/**
+ * Plan for the [[VECTOR_SEARCH]] table-valued function.
+ *
+ * Usage: vector_search(table_name, column_name, query_vector, limit)
+ * - table_name: the Paimon table to search
+ * - column_name: the vector column name
+ * - query_vector: array of floats representing the query vector
+ * - limit: the number of top results to return
+ *
+ * Example: SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f),
5)
+ */
+case class VectorSearchQuery(override val args: Seq[Expression])
+ extends PaimonTableValueFunction(VECTOR_SEARCH) {
+
+ override def parseArgs(args: Seq[Expression]): Map[String, String] = {
+ // This method is not used for VectorSearchQuery as we handle it specially
+ Map.empty
+ }
+
+ def createVectorSearch(
+ innerTable: InnerTable,
+ argsWithoutTable: Seq[Expression]): VectorSearch = {
+ if (argsWithoutTable.size != 3) {
+ throw new RuntimeException(
+ s"$VECTOR_SEARCH needs three parameters after table_name: column_name,
query_vector, limit. " +
+ s"Got ${argsWithoutTable.size} parameters after table_name."
+ )
+ }
+ val columnName = argsWithoutTable.head.eval().toString
+ if (!innerTable.rowType().containsField(columnName)) {
+ throw new RuntimeException(
+ s"Column $columnName does not exist in table ${innerTable.name()}"
+ )
+ }
+ val queryVector = extractQueryVector(argsWithoutTable(1))
+ val limit = argsWithoutTable(2).eval() match {
+ case i: Int => i
+ case l: Long => l.toInt
+ case other => throw new RuntimeException(s"Invalid limit type:
${other.getClass.getName}")
+ }
+ if (limit <= 0) {
+ throw new IllegalArgumentException(
+ s"Limit must be a positive integer, but got: $limit"
+ )
+ }
+ new VectorSearch(queryVector, limit, columnName)
+ }
+
+ private def extractQueryVector(expr: Expression): Array[Float] = {
+ expr match {
+ case Literal(arrayData, _) if arrayData != null =>
+ val arr =
arrayData.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData]
+ arr.toFloatArray()
+ case CreateArray(elements, _) if elements != null =>
+ elements.map {
+ case Literal(v: Float, _) => v
+ case Literal(v: Double, _) => v.toFloat
+ case Literal(v: java.lang.Float, _) if v != null => v.floatValue()
+ case Literal(v: java.lang.Double, _) if v != null => v.floatValue()
+ case other => throw new RuntimeException(s"Cannot extract float
from: $other")
+ }.toArray
+ case _ =>
+ throw new RuntimeException(s"Cannot extract query vector from
expression: $expr")
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
index dcd3dda67a..3e6a3f0319 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
@@ -20,7 +20,7 @@ package org.apache.paimon.spark.scan
import org.apache.paimon.CoreOptions
import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition,
PaimonNumSplitMetric, PaimonPartitionSizeMetric, PaimonReadBatchTimeMetric,
PaimonResultedTableFilesMetric, PaimonResultedTableFilesTaskMetric,
SparkTypeUtils}
import org.apache.paimon.spark.schema.PaimonMetadataColumn
import org.apache.paimon.spark.schema.PaimonMetadataColumn._
@@ -49,6 +49,7 @@ trait BaseScan extends Scan with SupportsReportStatistics
with Logging {
def pushedDataFilters: Seq[Predicate]
def pushedLimit: Option[Int] = None
def pushedTopN: Option[TopN] = None
+ def pushedVectorSearch: Option[VectorSearch] = None
// Input splits
def inputSplits: Array[Split]
@@ -104,6 +105,7 @@ trait BaseScan extends Scan with SupportsReportStatistics
with Logging {
}
pushedLimit.foreach(_readBuilder.withLimit)
pushedTopN.foreach(_readBuilder.withTopN)
+ pushedVectorSearch.foreach(_readBuilder.withVectorSearch)
_readBuilder.dropStats()
}
@@ -173,6 +175,7 @@ trait BaseScan extends Scan with SupportsReportStatistics
with Logging {
pushedPartitionFiltersStr +
pushedDataFiltersStr +
pushedTopN.map(topN => s", TopN: [$topN]").getOrElse("") +
- pushedLimit.map(limit => s", Limit: [$limit]").getOrElse("")
+ pushedLimit.map(limit => s", Limit: [$limit]").getOrElse("") +
+ pushedVectorSearch.map(vs => s", VectorSearch: [$vs]").getOrElse("")
}
}
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala
new file mode 100644
index 0000000000..c283326cf3
--- /dev/null
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.sql.streaming.StreamTest
+
+/** Tests for vector search table-valued function. */
+class BaseVectorSearchPushDownTest extends PaimonSparkTestBase with StreamTest
{
+
+ test("vector_search table function basic syntax") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ // Insert data with known vectors
+ spark.sql("""
+ |INSERT INTO T VALUES
+ |(1, array(1.0, 0.0, 0.0)),
+ |(2, array(0.0, 1.0, 0.0)),
+ |(3, array(0.0, 0.0, 1.0)),
+ |(4, array(1.0, 1.0, 0.0)),
+ |(5, array(1.0, 1.0, 1.0))
+ |""".stripMargin)
+
+ // Test vector_search table function syntax
+ // Note: Without a global vector index, this will scan all rows
+ val result = spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f),
3)
+ |""".stripMargin)
+ .collect()
+
+ // Should return results (actual filtering depends on vector index)
+ assert(result.nonEmpty)
+
+ // Test invalid limit (negative)
+ val ex1 = intercept[Exception] {
+ spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f,
0.0f), -3)
+ |""".stripMargin)
+ .collect()
+ }
+ assert(ex1.getMessage.contains("Limit must be a positive integer"))
+
+ // Test invalid limit (zero)
+ val ex2 = intercept[Exception] {
+ spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f,
0.0f), 0)
+ |""".stripMargin)
+ .collect()
+ }
+ assert(ex2.getMessage.contains("Limit must be a positive integer"))
+
+ // Test missing parameters
+ val ex3 = intercept[Exception] {
+ spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f,
0.0f))
+ |""".stripMargin)
+ .collect()
+ }
+ assert(ex3.getMessage.contains("vector_search needs three parameters
after table_name"))
+
+ // Test non-existent column
+ val ex4 = intercept[Exception] {
+ spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'non_existent_col',
array(1.0f, 0.0f, 0.0f), 3)
+ |""".stripMargin)
+ .collect()
+ }
+ assert(ex4.getMessage.nonEmpty)
+ }
+ }
+}