This is an automated email from the ASF dual-hosted git repository.
shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new c95979dec32 [FLINK-38428][table] Support to run vector search with
constant value input (#27130)
c95979dec32 is described below
commit c95979dec32ab6a80a90d10f7006796339d34e68
Author: Shengkai <[email protected]>
AuthorDate: Thu Oct 23 15:44:47 2025 +0800
[FLINK-38428][table] Support to run vector search with constant value input
(#27130)
---
.../ConstantVectorSearchCallToCorrelateRule.java | 114 +++++++++++++++++++++
.../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +
.../stream/sql/VectorSearchTableFunctionTest.java | 14 +--
.../stream/table/AsyncVectorSearchITCase.java | 14 +++
.../runtime/stream/table/VectorSearchITCase.java | 14 +++
.../stream/sql/VectorSearchTableFunctionTest.xml | 42 ++++++++
6 files changed, 188 insertions(+), 12 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConstantVectorSearchCallToCorrelateRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConstantVectorSearchCallToCorrelateRule.java
new file mode 100644
index 00000000000..13fbe0e3fd3
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConstantVectorSearchCallToCorrelateRule.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import
org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
+
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.logical.LogicalValues;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.tools.RelBuilder;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+
+/** Rule to convert VECTOR_SEARCH call with literal value to a correlated
VECTOR_SEARCH call. */
+public class ConstantVectorSearchCallToCorrelateRule
+ extends RelRule<
+ ConstantVectorSearchCallToCorrelateRule
+ .ConstantVectorSearchCallToCorrelateRuleConfig> {
+
+ public static final ConstantVectorSearchCallToCorrelateRule INSTANCE =
+ ConstantVectorSearchCallToCorrelateRuleConfig.DEFAULT.toRule();
+
+ private ConstantVectorSearchCallToCorrelateRule(
+ ConstantVectorSearchCallToCorrelateRuleConfig config) {
+ super(config);
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ LogicalTableFunctionScan scan = call.rel(0);
+ RexNode rexNode = scan.getCall();
+ if (!(rexNode instanceof RexCall)) {
+ return false;
+ }
+ RexCall rexCall = (RexCall) rexNode;
+ return rexCall.getOperator() instanceof SqlVectorSearchTableFunction
+ && RexUtil.isConstant(rexCall.getOperands().get(2));
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ LogicalTableFunctionScan scan = call.rel(0);
+ RexCall functionCall = (RexCall) scan.getCall();
+ RexNode constantCall = functionCall.getOperands().get(2);
+ RelOptCluster cluster = scan.getCluster();
+ RelBuilder builder = call.builder();
+
+ // left side
+ LogicalValues values = LogicalValues.createOneRow(cluster);
+ builder.push(values);
+ builder.project(constantCall);
+
+ // right side
+ CorrelationId correlId = cluster.createCorrel();
+ RexNode correlRex =
+
cluster.getRexBuilder().makeCorrel(builder.peek().getRowType(), correlId);
+ RexNode correlatedConstant =
cluster.getRexBuilder().makeFieldAccess(correlRex, 0);
+ builder.push(scan.getInput(0));
+ ArrayList<RexNode> operands = new ArrayList<>(functionCall.operands);
+ operands.set(2, correlatedConstant);
+ builder.functionScan(functionCall.getOperator(), 1, operands);
+
+ // add correlate node
+ builder.join(
+ JoinRelType.INNER,
+ cluster.getRexBuilder().makeLiteral(true),
+ Collections.singleton(correlId));
+
+ // prune useless value input
+ builder.projectExcept(builder.field(0));
+ call.transformTo(builder.build());
+ }
+
+ @Value.Immutable
+ public interface ConstantVectorSearchCallToCorrelateRuleConfig extends
RelRule.Config {
+
+ ConstantVectorSearchCallToCorrelateRuleConfig DEFAULT =
+
ImmutableConstantVectorSearchCallToCorrelateRuleConfig.builder()
+ .build()
+ .withOperandSupplier(
+ b0 ->
b0.operand(LogicalTableFunctionScan.class).anyInputs())
+
.withDescription("ConstantVectorSearchCallToCorrelateRule");
+
+ @Override
+ default ConstantVectorSearchCallToCorrelateRule toRule() {
+ return new ConstantVectorSearchCallToCorrelateRule(this);
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index 26b98c32a7f..bc9e3885856 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -128,6 +128,8 @@ object FlinkStreamRuleSets {
// unnest rule
LogicalUnnestRule.INSTANCE,
UncollectToTableFunctionScanRule.INSTANCE,
+ // vector search rule.
+ ConstantVectorSearchCallToCorrelateRule.INSTANCE,
// rewrite constant table function scan to correlate
JoinTableFunctionScanToCorrelateRule.INSTANCE,
// Wrap arguments for JSON aggregate functions
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
index ef949677394..0021d2e920d 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
@@ -115,24 +115,14 @@ public class VectorSearchTableFunctionTest extends
TableTestBase {
void testLiteralValue() {
String sql =
"SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable,
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
- assertThatThrownBy(() -> util.verifyRelPlan(sql))
- .satisfies(
- FlinkAssertions.anyCauseMatches(
- TableException.class,
-
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)],
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
- + "+-
FlinkLogicalTableSourceScan(table=[[default_catalog, default_database,
VectorTable]], fields=[e, f, g])"));
+ util.verifyRelPlan(sql);
}
@Test
void testLiteralValueWithoutLateralKeyword() {
String sql =
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable,
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
- assertThatThrownBy(() -> util.verifyRelPlan(sql))
- .satisfies(
- FlinkAssertions.anyCauseMatches(
- TableException.class,
-
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)],
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
- + "+-
FlinkLogicalTableSourceScan(table=[[default_catalog, default_database,
VectorTable]], fields=[e, f, g])"));
+ util.verifyRelPlan(sql);
}
@Test
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
index 1b67730805e..408c26685ba 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java
@@ -37,6 +37,7 @@ import java.util.Collection;
import java.util.List;
import java.util.concurrent.TimeoutException;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
@@ -153,6 +154,19 @@ public class AsyncVectorSearchITCase extends
StreamingWithStateTestBase {
TimeoutException.class, "Async function call
has timed out."));
}
+ @TestTemplate
+ void testConstantValue() {
+ List<Row> actual =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT * FROM
TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
+ .collect());
+ assertThat(actual)
+ .containsExactlyInAnyOrder(
+ Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
+ Row.of(3L, new Float[] {8f, 15f, 17f},
0.9977375565610862));
+ }
+
@TestTemplate
void testVectorSearchWithCalc() {
assertThatThrownBy(
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
index 14bd8c39f2a..18ce3e8f999 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java
@@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
@@ -123,6 +124,19 @@ public class VectorSearchITCase extends StreamingTestBase {
Row.of(4L, null, null, null, null));
}
+ @Test
+ void testConstantValue() {
+ List<Row> actual =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT * FROM
TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
+ .collect());
+ assertThat(actual)
+ .containsExactlyInAnyOrder(
+ Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
+ Row.of(3L, new Float[] {8f, 15f, 17f},
0.9977375565610862));
+ }
+
@Test
void testVectorSearchWithCalc() {
assertThatThrownBy(
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
index 0933534e171..2e2c21785bc 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
@@ -16,6 +16,48 @@ See the License for the specific language governing
permissions and
limitations under the License.
-->
<Root>
+ <TestCase name="testLiteralValue">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable,
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
++- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)],
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
+ +- LogicalProject(e=[$0], f=[$1], g=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
VectorTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[e, f, g, score])
++-
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable],
joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10],
select=[$f0, e, f, g, score])
+ +- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
+ +- Values(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testLiteralValueWithoutLateralKeyword">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable,
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
++- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)],
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
+ +- LogicalProject(e=[$0], f=[$1], g=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
VectorTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[e, f, g, score])
++-
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable],
joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10],
select=[$f0, e, f, g, score])
+ +- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
+ +- Values(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testNameConflicts">
<Resource name="sql">
<![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(