This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 708ea8fc75 [spark] Spark support vector search (#6950)
708ea8fc75 is described below

commit 708ea8fc7592ac95d480cd62c7fcccb9daf7f184
Author: jerry <[email protected]>
AuthorDate: Tue Jan 6 13:07:29 2026 +0800

    [spark] Spark support vector search (#6950)
---
 .../org/apache/paimon/table/VectorSearchTable.java | 101 ++++++++++++++
 .../paimon/spark/PaimonBaseScanBuilder.scala       |   5 +-
 .../scala/org/apache/paimon/spark/PaimonScan.scala |   3 +-
 .../apache/paimon/spark/PaimonScanBuilder.scala    |   9 +-
 .../scala/org/apache/paimon/spark/PaimonScan.scala |   3 +-
 .../spark/sql/VectorSearchPushDownTest.scala       | 145 +++++++++++++++++++++
 .../paimon/spark/PaimonBaseScanBuilder.scala       |   3 +-
 .../scala/org/apache/paimon/spark/PaimonScan.scala |   3 +-
 .../apache/paimon/spark/PaimonScanBuilder.scala    |  17 ++-
 .../plans/logical/PaimonTableValuedFunctions.scala | 117 +++++++++++++++--
 .../org/apache/paimon/spark/scan/BaseScan.scala    |   7 +-
 .../spark/sql/BaseVectorSearchPushDownTest.scala   | 100 ++++++++++++++
 12 files changed, 494 insertions(+), 19 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java 
b/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java
new file mode 100644
index 0000000000..cb98e25055
--- /dev/null
+++ b/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.table;
+
+import org.apache.paimon.fs.FileIO;
+import org.apache.paimon.predicate.VectorSearch;
+import org.apache.paimon.table.source.InnerTableRead;
+import org.apache.paimon.table.source.InnerTableScan;
+import org.apache.paimon.types.RowType;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A table wrapper to hold vector search information. This is used to pass 
vector search pushdown
+ * information from logical plan optimization to physical plan execution. For 
now, it is only used
+ * by internal for Spark engine.
+ */
+public class VectorSearchTable implements ReadonlyTable {
+
+    private final InnerTable origin;
+    private final VectorSearch vectorSearch;
+
+    private VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) {
+        this.origin = origin;
+        this.vectorSearch = vectorSearch;
+    }
+
+    public static VectorSearchTable create(InnerTable origin, VectorSearch 
vectorSearch) {
+        return new VectorSearchTable(origin, vectorSearch);
+    }
+
+    public VectorSearch vectorSearch() {
+        return vectorSearch;
+    }
+
+    public InnerTable origin() {
+        return origin;
+    }
+
+    @Override
+    public String name() {
+        return origin.name();
+    }
+
+    @Override
+    public RowType rowType() {
+        return origin.rowType();
+    }
+
+    @Override
+    public List<String> primaryKeys() {
+        return origin.primaryKeys();
+    }
+
+    @Override
+    public List<String> partitionKeys() {
+        return origin.partitionKeys();
+    }
+
+    @Override
+    public Map<String, String> options() {
+        return origin.options();
+    }
+
+    @Override
+    public FileIO fileIO() {
+        return origin.fileIO();
+    }
+
+    @Override
+    public InnerTableRead newRead() {
+        return origin.newRead();
+    }
+
+    @Override
+    public InnerTableScan newScan() {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public Table copy(Map<String, String> dynamicOptions) {
+        return new VectorSearchTable((InnerTable) origin.copy(dynamicOptions), 
vectorSearch);
+    }
+}
diff --git 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
index 41a1d552f1..4f5451c95d 100644
--- 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
@@ -21,7 +21,7 @@ package org.apache.paimon.spark
 import org.apache.paimon.CoreOptions
 import org.apache.paimon.partition.PartitionPredicate
 import 
org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
-import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate}
+import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, 
TopN, VectorSearch}
 import org.apache.paimon.table.SpecialFields.rowTypeWithRowTracking
 import org.apache.paimon.table.Table
 import org.apache.paimon.types.RowType
@@ -50,6 +50,9 @@ abstract class PaimonBaseScanBuilder
 
   protected var pushedPartitionFilters: Array[PartitionPredicate] = Array.empty
   protected var pushedDataFilters: Array[Predicate] = Array.empty
+  protected var pushedLimit: Option[Int] = None
+  protected var pushedTopN: Option[TopN] = None
+  protected var pushedVectorSearch: Option[VectorSearch] = None
 
   protected var requiredSchema: StructType = 
SparkTypeUtils.fromPaimonRowType(table.rowType())
 
diff --git 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index e9eaa7d6cc..d6292ad8cf 100644
--- 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -19,7 +19,7 @@
 package org.apache.paimon.spark
 
 import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
 import org.apache.paimon.table.InnerTable
 
 import org.apache.spark.sql.PaimonUtils.fieldReference
@@ -37,6 +37,7 @@ case class PaimonScan(
     pushedDataFilters: Seq[Predicate],
     override val pushedLimit: Option[Int] = None,
     override val pushedTopN: Option[TopN] = None,
+    override val pushedVectorSearch: Option[VectorSearch] = None,
     bucketedScanDisabled: Boolean = true)
   extends PaimonBaseScan(table)
   with SupportsRuntimeFiltering {
diff --git 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index 21ab46dabc..770bd8f802 100644
--- 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -25,6 +25,13 @@ import org.apache.spark.sql.connector.read.Scan
 class PaimonScanBuilder(val table: InnerTable) extends PaimonBaseScanBuilder {
 
   override def build(): Scan = {
-    PaimonScan(table, requiredSchema, pushedPartitionFilters, 
pushedDataFilters)
+    PaimonScan(
+      table,
+      requiredSchema,
+      pushedPartitionFilters,
+      pushedDataFilters,
+      pushedLimit,
+      pushedTopN,
+      pushedVectorSearch)
   }
 }
diff --git 
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 3afca15303..8d06751f57 100644
--- 
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ 
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -19,7 +19,7 @@
 package org.apache.paimon.spark
 
 import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
 import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
 import org.apache.paimon.table.source.{DataSplit, Split}
 
@@ -39,6 +39,7 @@ case class PaimonScan(
     pushedDataFilters: Seq[Predicate],
     override val pushedLimit: Option[Int],
     override val pushedTopN: Option[TopN],
+    override val pushedVectorSearch: Option[VectorSearch],
     bucketedScanDisabled: Boolean = false)
   extends PaimonBaseScan(table)
   with SupportsRuntimeFiltering
diff --git 
a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala
 
b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala
new file mode 100644
index 0000000000..7ac3c5df0d
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonScan
+
+/** Tests for vector search table-valued function with global vector index. */
+class VectorSearchPushDownTest extends BaseVectorSearchPushDownTest {
+  test("vector search with global index") {
+    withTable("T") {
+      spark.sql("""
+                  |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+                  |TBLPROPERTIES (
+                  |  'bucket' = '-1',
+                  |  'global-index.row-count-per-shard' = '10000',
+                  |  'row-tracking.enabled' = 'true',
+                  |  'data-evolution.enabled' = 'true')
+                  |""".stripMargin)
+
+      // Insert 100 rows with predictable vectors
+      val values = (0 until 100)
+        .map(
+          i => s"($i, array(cast($i as float), cast(${i + 1} as float), 
cast(${i + 2} as float)))")
+        .mkString(",")
+      spark.sql(s"INSERT INTO T VALUES $values")
+
+      // Create vector index
+      val output = spark
+        .sql("CALL sys.create_global_index(table => 'test.T', index_column => 
'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+        .collect()
+        .head
+      assert(output.getBoolean(0))
+
+      // Test vector search with table-valued function syntax
+      val result = spark
+        .sql("""
+               |SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 
52.0f), 5)
+               |""".stripMargin)
+        .collect()
+
+      // The result should contain 5 rows
+      assert(result.length == 5)
+
+      // Vector (50, 51, 52) should be most similar to the row with id=50
+      assert(result.map(_.getInt(0)).contains(50))
+    }
+  }
+
+  test("vector search pushdown is applied in plan") {
+    withTable("T") {
+      spark.sql("""
+                  |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+                  |TBLPROPERTIES (
+                  |  'bucket' = '-1',
+                  |  'global-index.row-count-per-shard' = '10000',
+                  |  'row-tracking.enabled' = 'true',
+                  |  'data-evolution.enabled' = 'true')
+                  |""".stripMargin)
+
+      val values = (0 until 10)
+        .map(
+          i => s"($i, array(cast($i as float), cast(${i + 1} as float), 
cast(${i + 2} as float)))")
+        .mkString(",")
+      spark.sql(s"INSERT INTO T VALUES $values")
+
+      // Create vector index
+      spark
+        .sql("CALL sys.create_global_index(table => 'test.T', index_column => 
'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+        .collect()
+
+      // Check that vector search is pushed down with table function syntax
+      val df = spark.sql("""
+                           |SELECT * FROM vector_search('T', 'v', array(50.0f, 
51.0f, 52.0f), 5)
+                           |""".stripMargin)
+
+      // Get the scan from the executed plan (physical plan)
+      val executedPlan = df.queryExecution.executedPlan
+      val batchScans = executedPlan.collect {
+        case scan: org.apache.spark.sql.execution.datasources.v2.BatchScanExec 
=> scan
+      }
+
+      assert(batchScans.nonEmpty, "Should have a BatchScanExec in executed 
plan")
+      val paimonScans = batchScans.filter(_.scan.isInstanceOf[PaimonScan])
+      assert(paimonScans.nonEmpty, "Should have a PaimonScan in executed plan")
+
+      val paimonScan = paimonScans.head.scan.asInstanceOf[PaimonScan]
+      assert(paimonScan.pushedVectorSearch.isDefined, "Vector search should be 
pushed down")
+      assert(paimonScan.pushedVectorSearch.get.fieldName() == "v", "Field name 
should be 'v'")
+      assert(paimonScan.pushedVectorSearch.get.limit() == 5, "Limit should be 
5")
+    }
+  }
+
+  test("vector search topk returns correct results") {
+    withTable("T") {
+      spark.sql("""
+                  |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+                  |TBLPROPERTIES (
+                  |  'bucket' = '-1',
+                  |  'global-index.row-count-per-shard' = '10000',
+                  |  'row-tracking.enabled' = 'true',
+                  |  'data-evolution.enabled' = 'true')
+                  |""".stripMargin)
+
+      // Insert rows with distinct vectors
+      val values = (1 to 100)
+        .map {
+          i =>
+            val v = math.sqrt(3.0 * i * i)
+            val normalized = i.toFloat / v.toFloat
+            s"($i, array($normalized, $normalized, $normalized))"
+        }
+        .mkString(",")
+      spark.sql(s"INSERT INTO T VALUES $values")
+
+      // Create vector index
+      spark.sql(
+        "CALL sys.create_global_index(table => 'test.T', index_column => 'v', 
index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+
+      // Query for top 10 similar to (1, 1, 1) normalized
+      val result = spark
+        .sql("""
+               |SELECT * FROM vector_search('T', 'v', array(0.577f, 0.577f, 
0.577f), 10)
+               |""".stripMargin)
+        .collect()
+
+      assert(result.length == 10)
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
index 8179f504b3..47723171e4 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
@@ -21,7 +21,7 @@ package org.apache.paimon.spark
 import org.apache.paimon.CoreOptions
 import org.apache.paimon.partition.PartitionPredicate
 import 
org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
-import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN}
+import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, 
TopN, VectorSearch}
 import org.apache.paimon.table.{SpecialFields, Table}
 import org.apache.paimon.types.RowType
 
@@ -52,6 +52,7 @@ abstract class PaimonBaseScanBuilder
   protected var pushedDataFilters: Array[Predicate] = Array.empty
   protected var pushedLimit: Option[Int] = None
   protected var pushedTopN: Option[TopN] = None
+  protected var pushedVectorSearch: Option[VectorSearch] = None
 
   protected var requiredSchema: StructType = 
SparkTypeUtils.fromPaimonRowType(table.rowType())
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 06a97ee8b4..c9f0e9506e 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -20,7 +20,7 @@ package org.apache.paimon.spark
 
 import org.apache.paimon.CoreOptions.BucketFunctionType
 import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
 import org.apache.paimon.spark.commands.BucketExpression.quote
 import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
 import org.apache.paimon.table.source.{DataSplit, Split}
@@ -41,6 +41,7 @@ case class PaimonScan(
     pushedDataFilters: Seq[Predicate],
     override val pushedLimit: Option[Int],
     override val pushedTopN: Option[TopN],
+    override val pushedVectorSearch: Option[VectorSearch],
     bucketedScanDisabled: Boolean = false)
   extends PaimonBaseScan(table)
   with SupportsReportPartitioning
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index 6eeaaf7b93..de75bb823d 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -128,13 +128,26 @@ class PaimonScanBuilder(val table: InnerTable)
     localScan match {
       case Some(scan) => scan
       case None =>
+        val (actualTable, vectorSearch) = table match {
+          case vst: org.apache.paimon.table.VectorSearchTable =>
+            val tableVectorSearch = Option(vst.vectorSearch())
+            val vs = (tableVectorSearch, pushedVectorSearch) match {
+              case (Some(_), _) => tableVectorSearch
+              case (None, Some(_)) => pushedVectorSearch
+              case (None, None) => None
+            }
+            (vst.origin(), vs)
+          case _ => (table, pushedVectorSearch)
+        }
+
         PaimonScan(
-          table,
+          actualTable,
           requiredSchema,
           pushedPartitionFilters,
           pushedDataFilters,
           pushedLimit,
-          pushedTopN)
+          pushedTopN,
+          vectorSearch)
     }
   }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
index e4f5e7856c..6bb6004db8 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala
@@ -19,9 +19,10 @@
 package org.apache.paimon.spark.catalyst.plans.logical
 
 import org.apache.paimon.CoreOptions
+import org.apache.paimon.predicate.VectorSearch
 import org.apache.paimon.spark.SparkTable
 import 
org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions._
-import org.apache.paimon.table.DataTable
+import org.apache.paimon.table.{DataTable, InnerTable, VectorSearchTable}
 import 
org.apache.paimon.table.source.snapshot.TimeTravelUtil.InconsistentTagBucketException
 
 import org.apache.spark.sql.PaimonUtils.createDataset
@@ -29,7 +30,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
 import 
org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateArray, 
Expression, ExpressionInfo, Literal}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
 import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog}
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -42,9 +43,10 @@ object PaimonTableValuedFunctions {
   val INCREMENTAL_QUERY = "paimon_incremental_query"
   val INCREMENTAL_BETWEEN_TIMESTAMP = "paimon_incremental_between_timestamp"
   val INCREMENTAL_TO_AUTO_TAG = "paimon_incremental_to_auto_tag"
+  val VECTOR_SEARCH = "vector_search"
 
   val supportedFnNames: Seq[String] =
-    Seq(INCREMENTAL_QUERY, INCREMENTAL_BETWEEN_TIMESTAMP, 
INCREMENTAL_TO_AUTO_TAG)
+    Seq(INCREMENTAL_QUERY, INCREMENTAL_BETWEEN_TIMESTAMP, 
INCREMENTAL_TO_AUTO_TAG, VECTOR_SEARCH)
 
   private type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, 
TableFunctionBuilder)
 
@@ -56,6 +58,8 @@ object PaimonTableValuedFunctions {
         FunctionRegistryBase.build[IncrementalBetweenTimestamp](fnName, since 
= None)
       case INCREMENTAL_TO_AUTO_TAG =>
         FunctionRegistryBase.build[IncrementalToAutoTag](fnName, since = None)
+      case VECTOR_SEARCH =>
+        FunctionRegistryBase.build[VectorSearchQuery](fnName, since = None)
       case _ =>
         throw new Exception(s"Function $fnName isn't a supported table valued 
function.")
     }
@@ -85,17 +89,45 @@ object PaimonTableValuedFunctions {
     val sparkCatalog = 
catalogManager.catalog(catalogName).asInstanceOf[TableCatalog]
     val ident: Identifier = Identifier.of(Array(dbName), tableName)
     val sparkTable = sparkCatalog.loadTable(ident)
-    val options = tvf.parseArgs(args.tail)
 
-    usingSparkIncrementQuery(tvf, sparkTable, options) match {
-      case Some(snapshotIdPair: (Long, Long)) =>
-        sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident, options, 
snapshotIdPair)
+    // Handle vector_search specially
+    tvf match {
+      case vsq: VectorSearchQuery =>
+        resolveVectorSearchQuery(sparkTable, sparkCatalog, ident, vsq, 
args.tail)
       case _ =>
+        val options = tvf.parseArgs(args.tail)
+        usingSparkIncrementQuery(tvf, sparkTable, options) match {
+          case Some(snapshotIdPair: (Long, Long)) =>
+            sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident, 
options, snapshotIdPair)
+          case _ =>
+            DataSourceV2Relation.create(
+              sparkTable,
+              Some(sparkCatalog),
+              Some(ident),
+              new CaseInsensitiveStringMap(options.asJava))
+        }
+    }
+  }
+
+  private def resolveVectorSearchQuery(
+      sparkTable: Table,
+      sparkCatalog: TableCatalog,
+      ident: Identifier,
+      vsq: VectorSearchQuery,
+      argsWithoutTable: Seq[Expression]): LogicalPlan = {
+    sparkTable match {
+      case st @ SparkTable(innerTable: InnerTable) =>
+        val vectorSearch = vsq.createVectorSearch(innerTable, argsWithoutTable)
+        val vectorSearchTable = VectorSearchTable.create(innerTable, 
vectorSearch)
         DataSourceV2Relation.create(
-          sparkTable,
+          st.copy(table = vectorSearchTable),
           Some(sparkCatalog),
           Some(ident),
-          new CaseInsensitiveStringMap(options.asJava))
+          CaseInsensitiveStringMap.empty())
+      case _ =>
+        throw new RuntimeException(
+          "vector_search only supports Paimon SparkTable backed by InnerTable, 
" +
+            s"but got table implementation: ${sparkTable.getClass.getName}")
     }
   }
 
@@ -207,3 +239,70 @@ case class IncrementalToAutoTag(override val args: 
Seq[Expression])
     Map(CoreOptions.INCREMENTAL_TO_AUTO_TAG.key -> endTagName)
   }
 }
+
+/**
+ * Plan for the [[VECTOR_SEARCH]] table-valued function.
+ *
+ * Usage: vector_search(table_name, column_name, query_vector, limit)
+ *   - table_name: the Paimon table to search
+ *   - column_name: the vector column name
+ *   - query_vector: array of floats representing the query vector
+ *   - limit: the number of top results to return
+ *
+ * Example: SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 
5)
+ */
+case class VectorSearchQuery(override val args: Seq[Expression])
+  extends PaimonTableValueFunction(VECTOR_SEARCH) {
+
+  override def parseArgs(args: Seq[Expression]): Map[String, String] = {
+    // This method is not used for VectorSearchQuery as we handle it specially
+    Map.empty
+  }
+
+  def createVectorSearch(
+      innerTable: InnerTable,
+      argsWithoutTable: Seq[Expression]): VectorSearch = {
+    if (argsWithoutTable.size != 3) {
+      throw new RuntimeException(
+        s"$VECTOR_SEARCH needs three parameters after table_name: column_name, 
query_vector, limit. " +
+          s"Got ${argsWithoutTable.size} parameters after table_name."
+      )
+    }
+    val columnName = argsWithoutTable.head.eval().toString
+    if (!innerTable.rowType().containsField(columnName)) {
+      throw new RuntimeException(
+        s"Column $columnName does not exist in table ${innerTable.name()}"
+      )
+    }
+    val queryVector = extractQueryVector(argsWithoutTable(1))
+    val limit = argsWithoutTable(2).eval() match {
+      case i: Int => i
+      case l: Long => l.toInt
+      case other => throw new RuntimeException(s"Invalid limit type: 
${other.getClass.getName}")
+    }
+    if (limit <= 0) {
+      throw new IllegalArgumentException(
+        s"Limit must be a positive integer, but got: $limit"
+      )
+    }
+    new VectorSearch(queryVector, limit, columnName)
+  }
+
+  private def extractQueryVector(expr: Expression): Array[Float] = {
+    expr match {
+      case Literal(arrayData, _) if arrayData != null =>
+        val arr = 
arrayData.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData]
+        arr.toFloatArray()
+      case CreateArray(elements, _) if elements != null =>
+        elements.map {
+          case Literal(v: Float, _) => v
+          case Literal(v: Double, _) => v.toFloat
+          case Literal(v: java.lang.Float, _) if v != null => v.floatValue()
+          case Literal(v: java.lang.Double, _) if v != null => v.floatValue()
+          case other => throw new RuntimeException(s"Cannot extract float 
from: $other")
+        }.toArray
+      case _ =>
+        throw new RuntimeException(s"Cannot extract query vector from 
expression: $expr")
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
index dcd3dda67a..3e6a3f0319 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala
@@ -20,7 +20,7 @@ package org.apache.paimon.spark.scan
 
 import org.apache.paimon.CoreOptions
 import org.apache.paimon.partition.PartitionPredicate
-import org.apache.paimon.predicate.{Predicate, TopN}
+import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
 import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, 
PaimonNumSplitMetric, PaimonPartitionSizeMetric, PaimonReadBatchTimeMetric, 
PaimonResultedTableFilesMetric, PaimonResultedTableFilesTaskMetric, 
SparkTypeUtils}
 import org.apache.paimon.spark.schema.PaimonMetadataColumn
 import org.apache.paimon.spark.schema.PaimonMetadataColumn._
@@ -49,6 +49,7 @@ trait BaseScan extends Scan with SupportsReportStatistics 
with Logging {
   def pushedDataFilters: Seq[Predicate]
   def pushedLimit: Option[Int] = None
   def pushedTopN: Option[TopN] = None
+  def pushedVectorSearch: Option[VectorSearch] = None
 
   // Input splits
   def inputSplits: Array[Split]
@@ -104,6 +105,7 @@ trait BaseScan extends Scan with SupportsReportStatistics 
with Logging {
     }
     pushedLimit.foreach(_readBuilder.withLimit)
     pushedTopN.foreach(_readBuilder.withTopN)
+    pushedVectorSearch.foreach(_readBuilder.withVectorSearch)
     _readBuilder.dropStats()
   }
 
@@ -173,6 +175,7 @@ trait BaseScan extends Scan with SupportsReportStatistics 
with Logging {
       pushedPartitionFiltersStr +
       pushedDataFiltersStr +
       pushedTopN.map(topN => s", TopN: [$topN]").getOrElse("") +
-      pushedLimit.map(limit => s", Limit: [$limit]").getOrElse("")
+      pushedLimit.map(limit => s", Limit: [$limit]").getOrElse("") +
+      pushedVectorSearch.map(vs => s", VectorSearch: [$vs]").getOrElse("")
   }
 }
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala
new file mode 100644
index 0000000000..c283326cf3
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.sql.streaming.StreamTest
+
+/** Tests for vector search table-valued function. */
+class BaseVectorSearchPushDownTest extends PaimonSparkTestBase with StreamTest 
{
+
+  test("vector_search table function basic syntax") {
+    withTable("T") {
+      spark.sql("""
+                  |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+                  |TBLPROPERTIES (
+                  |  'bucket' = '-1',
+                  |  'row-tracking.enabled' = 'true',
+                  |  'data-evolution.enabled' = 'true')
+                  |""".stripMargin)
+
+      // Insert data with known vectors
+      spark.sql("""
+                  |INSERT INTO T VALUES
+                  |(1, array(1.0, 0.0, 0.0)),
+                  |(2, array(0.0, 1.0, 0.0)),
+                  |(3, array(0.0, 0.0, 1.0)),
+                  |(4, array(1.0, 1.0, 0.0)),
+                  |(5, array(1.0, 1.0, 1.0))
+                  |""".stripMargin)
+
+      // Test vector_search table function syntax
+      // Note: Without a global vector index, this will scan all rows
+      val result = spark
+        .sql("""
+               |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f), 
3)
+               |""".stripMargin)
+        .collect()
+
+      // Should return results (actual filtering depends on vector index)
+      assert(result.nonEmpty)
+
+      // Test invalid limit (negative)
+      val ex1 = intercept[Exception] {
+        spark
+          .sql("""
+                 |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 
0.0f), -3)
+                 |""".stripMargin)
+          .collect()
+      }
+      assert(ex1.getMessage.contains("Limit must be a positive integer"))
+
+      // Test invalid limit (zero)
+      val ex2 = intercept[Exception] {
+        spark
+          .sql("""
+                 |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 
0.0f), 0)
+                 |""".stripMargin)
+          .collect()
+      }
+      assert(ex2.getMessage.contains("Limit must be a positive integer"))
+
+      // Test missing parameters
+      val ex3 = intercept[Exception] {
+        spark
+          .sql("""
+                 |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 
0.0f))
+                 |""".stripMargin)
+          .collect()
+      }
+      assert(ex3.getMessage.contains("vector_search needs three parameters 
after table_name"))
+
+      // Test non-existent column
+      val ex4 = intercept[Exception] {
+        spark
+          .sql("""
+                 |SELECT * FROM vector_search('T', 'non_existent_col', 
array(1.0f, 0.0f, 0.0f), 3)
+                 |""".stripMargin)
+          .collect()
+      }
+      assert(ex4.getMessage.nonEmpty)
+    }
+  }
+}

Reply via email to