This is an automated email from the ASF dual-hosted git repository.

shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new c95979dec32 [FLINK-38428][table] Support to run vector search with 
constant value input (#27130)
c95979dec32 is described below

commit c95979dec32ab6a80a90d10f7006796339d34e68
Author: Shengkai <[email protected]>
AuthorDate: Thu Oct 23 15:44:47 2025 +0800

    [FLINK-38428][table] Support to run vector search with constant value input 
(#27130)
---
 .../ConstantVectorSearchCallToCorrelateRule.java   | 114 +++++++++++++++++++++
 .../planner/plan/rules/FlinkStreamRuleSets.scala   |   2 +
 .../stream/sql/VectorSearchTableFunctionTest.java  |  14 +--
 .../stream/table/AsyncVectorSearchITCase.java      |  14 +++
 .../runtime/stream/table/VectorSearchITCase.java   |  14 +++
 .../stream/sql/VectorSearchTableFunctionTest.xml   |  42 ++++++++
 6 files changed, 188 insertions(+), 12 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConstantVectorSearchCallToCorrelateRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConstantVectorSearchCallToCorrelateRule.java
new file mode 100644
index 00000000000..13fbe0e3fd3
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConstantVectorSearchCallToCorrelateRule.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import 
org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
+
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.logical.LogicalValues;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.tools.RelBuilder;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+
+/** Rule to convert VECTOR_SEARCH call with literal value to a correlated 
VECTOR_SEARCH call. */
+public class ConstantVectorSearchCallToCorrelateRule
+        extends RelRule<
+                ConstantVectorSearchCallToCorrelateRule
+                        .ConstantVectorSearchCallToCorrelateRuleConfig> {
+
+    public static final ConstantVectorSearchCallToCorrelateRule INSTANCE =
+            ConstantVectorSearchCallToCorrelateRuleConfig.DEFAULT.toRule();
+
+    private ConstantVectorSearchCallToCorrelateRule(
+            ConstantVectorSearchCallToCorrelateRuleConfig config) {
+        super(config);
+    }
+
+    @Override
+    public boolean matches(RelOptRuleCall call) {
+        LogicalTableFunctionScan scan = call.rel(0);
+        RexNode rexNode = scan.getCall();
+        if (!(rexNode instanceof RexCall)) {
+            return false;
+        }
+        RexCall rexCall = (RexCall) rexNode;
+        return rexCall.getOperator() instanceof SqlVectorSearchTableFunction
+                && RexUtil.isConstant(rexCall.getOperands().get(2));
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+        LogicalTableFunctionScan scan = call.rel(0);
+        RexCall functionCall = (RexCall) scan.getCall();
+        RexNode constantCall = functionCall.getOperands().get(2);
+        RelOptCluster cluster = scan.getCluster();
+        RelBuilder builder = call.builder();
+
+        // left side
+        LogicalValues values = LogicalValues.createOneRow(cluster);
+        builder.push(values);
+        builder.project(constantCall);
+
+        // right side
+        CorrelationId correlId = cluster.createCorrel();
+        RexNode correlRex =
+                
cluster.getRexBuilder().makeCorrel(builder.peek().getRowType(), correlId);
+        RexNode correlatedConstant = 
cluster.getRexBuilder().makeFieldAccess(correlRex, 0);
+        builder.push(scan.getInput(0));
+        ArrayList<RexNode> operands = new ArrayList<>(functionCall.operands);
+        operands.set(2, correlatedConstant);
+        builder.functionScan(functionCall.getOperator(), 1, operands);
+
+        // add correlate node
+        builder.join(
+                JoinRelType.INNER,
+                cluster.getRexBuilder().makeLiteral(true),
+                Collections.singleton(correlId));
+
+        // prune useless value input
+        builder.projectExcept(builder.field(0));
+        call.transformTo(builder.build());
+    }
+
+    @Value.Immutable
+    public interface ConstantVectorSearchCallToCorrelateRuleConfig extends 
RelRule.Config {
+
+        ConstantVectorSearchCallToCorrelateRuleConfig DEFAULT =
+                
ImmutableConstantVectorSearchCallToCorrelateRuleConfig.builder()
+                        .build()
+                        .withOperandSupplier(
+                                b0 -> 
b0.operand(LogicalTableFunctionScan.class).anyInputs())
+                        
.withDescription("ConstantVectorSearchCallToCorrelateRule");
+
+        @Override
+        default ConstantVectorSearchCallToCorrelateRule toRule() {
+            return new ConstantVectorSearchCallToCorrelateRule(this);
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index 26b98c32a7f..bc9e3885856 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -128,6 +128,8 @@ object FlinkStreamRuleSets {
           // unnest rule
           LogicalUnnestRule.INSTANCE,
           UncollectToTableFunctionScanRule.INSTANCE,
+          // vector search rule.
+          ConstantVectorSearchCallToCorrelateRule.INSTANCE,
           // rewrite constant table function scan to correlate
           JoinTableFunctionScanToCorrelateRule.INSTANCE,
           // Wrap arguments for JSON aggregate functions
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
index ef949677394..0021d2e920d 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
@@ -115,24 +115,14 @@ public class VectorSearchTableFunctionTest extends 
TableTestBase {
     void testLiteralValue() {
         String sql =
                 "SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, 
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
-        assertThatThrownBy(() -> util.verifyRelPlan(sql))
-                .satisfies(
-                        FlinkAssertions.anyCauseMatches(
-                                TableException.class,
-                                
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], 
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
-                                        + "+- 
FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, 
VectorTable]], fields=[e, f, g])"));
+        util.verifyRelPlan(sql);
     }
 
     @Test
     void testLiteralValueWithoutLateralKeyword() {
         String sql =
                 "SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, 
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
-        assertThatThrownBy(() -> util.verifyRelPlan(sql))
-                .satisfies(
-                        FlinkAssertions.anyCauseMatches(
-                                TableException.class,
-                                
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], 
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
-                                        + "+- 
FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, 
VectorTable]], fields=[e, f, g])"));
+        util.verifyRelPlan(sql);
     }
 
     @Test
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
index 1b67730805e..408c26685ba 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
@@ -37,6 +37,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.TimeoutException;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
 
@@ -153,6 +154,19 @@ public class AsyncVectorSearchITCase extends 
StreamingWithStateTestBase {
                                 TimeoutException.class, "Async function call 
has timed out."));
     }
 
+    @TestTemplate
+    void testConstantValue() {
+        List<Row> actual =
+                CollectionUtil.iteratorToList(
+                        tEnv().executeSql(
+                                        "SELECT * FROM 
TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
+                                .collect());
+        assertThat(actual)
+                .containsExactlyInAnyOrder(
+                        Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
+                        Row.of(3L, new Float[] {8f, 15f, 17f}, 
0.9977375565610862));
+    }
+
     @TestTemplate
     void testVectorSearchWithCalc() {
         assertThatThrownBy(
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
index 14bd8c39f2a..18ce3e8f999 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
@@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
 
@@ -123,6 +124,19 @@ public class VectorSearchITCase extends StreamingTestBase {
                         Row.of(4L, null, null, null, null));
     }
 
+    @Test
+    void testConstantValue() {
+        List<Row> actual =
+                CollectionUtil.iteratorToList(
+                        tEnv().executeSql(
+                                        "SELECT * FROM 
TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
+                                .collect());
+        assertThat(actual)
+                .containsExactlyInAnyOrder(
+                        Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
+                        Row.of(3L, new Float[] {8f, 15f, 17f}, 
0.9977375565610862));
+    }
+
     @Test
     void testVectorSearchWithCalc() {
         assertThatThrownBy(
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
index 0933534e171..2e2c21785bc 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
@@ -16,6 +16,48 @@ See the License for the specific language governing 
permissions and
 limitations under the License.
 -->
 <Root>
+  <TestCase name="testLiteralValue">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, 
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
++- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], 
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
+   +- LogicalProject(e=[$0], f=[$1], g=[$2])
+      +- LogicalTableScan(table=[[default_catalog, default_database, 
VectorTable]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Calc(select=[e, f, g, score])
++- 
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], 
joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10], 
select=[$f0, e, f, g, score])
+   +- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
+      +- Values(tuples=[[{ 0 }]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testLiteralValueWithoutLateralKeyword">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, 
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
++- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], 
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
+   +- LogicalProject(e=[$0], f=[$1], g=[$2])
+      +- LogicalTableScan(table=[[default_catalog, default_database, 
VectorTable]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Calc(select=[e, f, g, score])
++- 
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], 
joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10], 
select=[$f0, e, f, g, score])
+   +- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
+      +- Values(tuples=[[{ 0 }]])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testNameConflicts">
     <Resource name="sql">
       <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(

Reply via email to