This is an automated email from the ASF dual-hosted git repository. abenedetti pushed a commit to branch branch_10_0 in repository https://gitbox.apache.org/repos/asf/solr.git
commit ed54b21db3efe5820075f1f8d36a44c13bd8c1a9 Author: Anna Ruggero <[email protected]> AuthorDate: Tue Oct 21 17:03:25 2025 +0200 SOLR-17815: Add parameter to regulate for ACORN-based filtering in vector search (#3680) (cherry picked from commit 512e02ae637b2831aa489902338207072c4c9684) --- solr/CHANGES.txt | 2 + .../org/apache/solr/schema/DenseVectorField.java | 30 +- .../org/apache/solr/search/neural/KnnQParser.java | 5 +- .../apache/solr/schema/DenseVectorFieldTest.java | 287 ++++++++++ .../apache/solr/search/neural/KnnQParserTest.java | 605 +++++++++++++++++++++ .../query-guide/pages/dense-vector-search.adoc | 22 + 6 files changed, 945 insertions(+), 6 deletions(-) diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 1a2f38bc84f..9a2ca472e99 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -35,6 +35,8 @@ New Features * SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti) +* SOLR-17815: Add parameter to regulate for ACORN-based filtering in vector search. (Anna Ruggero, Alessandro Benedetti) + Improvements --------------------- diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 771d11c5635..0500d7fcbbe 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -43,6 +43,7 @@ import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SeededKnnVectorQuery; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; @@ -379,17 +380,36 @@ public class DenseVectorField extends FloatPointField { int topK, Query filterQuery, Query seedQuery, - EarlyTerminationParams earlyTermination) { + EarlyTerminationParams earlyTermination, + Integer filteredSearchThreshold) { DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); final Query knnQuery = switch (vectorEncoding) { - case FLOAT32 -> new KnnFloatVectorQuery( - fieldName, vectorBuilder.getFloatVector(), topK, filterQuery); - case BYTE -> new KnnByteVectorQuery( - fieldName, vectorBuilder.getByteVector(), topK, filterQuery); + case FLOAT32 -> { + if (filteredSearchThreshold != null) { + KnnSearchStrategy knnSearchStrategy = + new KnnSearchStrategy.Hnsw(filteredSearchThreshold); + yield new KnnFloatVectorQuery( + fieldName, vectorBuilder.getFloatVector(), topK, filterQuery, knnSearchStrategy); + } else { + yield new KnnFloatVectorQuery( + fieldName, vectorBuilder.getFloatVector(), topK, filterQuery); + } + } + case BYTE -> { + if (filteredSearchThreshold != null) { + KnnSearchStrategy knnSearchStrategy = + new KnnSearchStrategy.Hnsw(filteredSearchThreshold); + yield new KnnByteVectorQuery( + fieldName, vectorBuilder.getByteVector(), topK, filterQuery, knnSearchStrategy); + } else { + yield new KnnByteVectorQuery( + fieldName, vectorBuilder.getByteVector(), topK, filterQuery); + } + } }; final boolean seedEnabled = (seedQuery != null); diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index 664e6f341c9..db355e0b84e 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -32,6 +32,7 @@ public class KnnQParser extends AbstractVectorQParserBase { protected static final String TOP_K = "topK"; protected static final int DEFAULT_TOP_K = 10; protected static final String SEED_QUERY = "seedQuery"; + protected static final String FILTERED_SEARCH_THRESHOLD = "filteredSearchThreshold"; // parameters for PatienceKnnVectorQuery, a version of knn vector query that exits early when HNSW // queue saturates over a {@code #saturationThreshold} for more than {@code #patience} times. @@ -107,6 +108,7 @@ public class KnnQParser extends AbstractVectorQParserBase { final DenseVectorField denseVectorType = getCheckedFieldType(schemaField); final String vectorToSearch = getVectorToSearch(); final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); + final Integer filteredSearchThreshold = localParams.getInt(FILTERED_SEARCH_THRESHOLD); return denseVectorType.getKnnVectorQuery( schemaField.getName(), @@ -114,6 +116,7 @@ public class KnnQParser extends AbstractVectorQParserBase { topK, getFilterQuery(), getSeedQuery(), - getEarlyTerminationParams()); + getEarlyTerminationParams(), + filteredSearchThreshold); } } diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 6b3c63ca331..d3790dbbefd 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -25,6 +25,13 @@ import java.util.List; import java.util.Map; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.PatienceKnnVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.SeededKnnVectorQuery; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.solr.client.solrj.request.JavaBinUpdateRequestCodec; import org.apache.solr.client.solrj.request.UpdateRequest; import org.apache.solr.common.SolrException; @@ -35,6 +42,7 @@ import org.apache.solr.core.AbstractBadConfigTestBase; import org.apache.solr.handler.loader.JavabinLoader; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.search.neural.KnnQParser; import org.apache.solr.update.CommitUpdateCommand; import org.apache.solr.update.processor.UpdateRequestProcessor; import org.apache.solr.update.processor.UpdateRequestProcessorChain; @@ -838,4 +846,283 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase { deleteCore(); } } + + @Test + public void testFilteredSearchThreshold_floatNoThresholdInInput_shouldSetDefaultThreshold() + throws Exception { + try { + Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + KnnFloatVectorQuery vectorQuery = + (KnnFloatVectorQuery) + type.getKnnVectorQuery("vector", "[2, 1, 3, 4]", 3, null, null, null, null); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void testFilteredSearchThreshold_floatThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + KnnFloatVectorQuery vectorQuery = + (KnnFloatVectorQuery) + type.getKnnVectorQuery( + "vector", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void testFilteredSearchThreshold_seededFloatThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + Query seedQuery = new BooleanQuery.Builder().build(); + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + SeededKnnVectorQuery vectorQuery = + (SeededKnnVectorQuery) + type.getKnnVectorQuery( + "vector", "[2, 1, 3, 4]", 3, null, seedQuery, null, expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void + testFilteredSearchThreshold_earlyTerminationFloatThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + KnnQParser.EarlyTerminationParams earlyTermination = + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + PatienceKnnVectorQuery vectorQuery = + (PatienceKnnVectorQuery) + type.getKnnVectorQuery( + "vector", "[2, 1, 3, 4]", 3, null, null, earlyTermination, expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void + testFilteredSearchThreshold_seededAndEarlyTerminationFloatThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + Query seedQuery = new BooleanQuery.Builder().build(); + KnnQParser.EarlyTerminationParams earlyTermination = + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + PatienceKnnVectorQuery vectorQuery = + (PatienceKnnVectorQuery) + type.getKnnVectorQuery( + "vector", + "[2, 1, 3, 4]", + 3, + null, + seedQuery, + earlyTermination, + expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void testFilteredSearchThreshold_byteNoThresholdInInput_shouldSetDefaultThreshold() + throws Exception { + try { + Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector_byte_encoding"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + KnnByteVectorQuery vectorQuery = + (KnnByteVectorQuery) + type.getKnnVectorQuery( + "vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, null); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void testFilteredSearchThreshold_byteThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector_byte_encoding"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + KnnByteVectorQuery vectorQuery = + (KnnByteVectorQuery) + type.getKnnVectorQuery( + "vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void testFilteredSearchThreshold_seededByteThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + Query seedQuery = new BooleanQuery.Builder().build(); + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector_byte_encoding"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + SeededKnnVectorQuery vectorQuery = + (SeededKnnVectorQuery) + type.getKnnVectorQuery( + "vector_byte_encoding", + "[2, 1, 3, 4]", + 3, + null, + seedQuery, + null, + expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void + testFilteredSearchThreshold_earlyTerminationByteThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + KnnQParser.EarlyTerminationParams earlyTermination = + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector_byte_encoding"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + PatienceKnnVectorQuery vectorQuery = + (PatienceKnnVectorQuery) + type.getKnnVectorQuery( + "vector_byte_encoding", + "[2, 1, 3, 4]", + 3, + null, + null, + earlyTermination, + expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } + + @Test + public void + testFilteredSearchThreshold_seededAndEarlyTerminationByteThresholdInInput_shouldSetCustomThreshold() + throws Exception { + try { + Query seedQuery = new BooleanQuery.Builder().build(); + KnnQParser.EarlyTerminationParams earlyTermination = + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); + Integer expectedThreshold = 30; + + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + IndexSchema schema = h.getCore().getLatestSchema(); + SchemaField vectorField = schema.getField("vector_byte_encoding"); + assertNotNull(vectorField); + DenseVectorField type = (DenseVectorField) vectorField.getType(); + PatienceKnnVectorQuery vectorQuery = + (PatienceKnnVectorQuery) + type.getKnnVectorQuery( + "vector_byte_encoding", + "[2, 1, 3, 4]", + 3, + null, + seedQuery, + earlyTermination, + expectedThreshold); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + deleteCore(); + } + } } diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserTest.java index cfa5d91da69..398b9215044 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserTest.java @@ -23,6 +23,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.PatienceKnnVectorQuery; +import org.apache.lucene.search.SeededKnnVectorQuery; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; @@ -1314,4 +1319,604 @@ public class KnnQParserTest extends SolrTestCaseJ4 { // value "//str[@name='parsedquery'][contains(.,'delegate=KnnFloatVectorQuery:vector[1.0,...][4]')]"); } + + @Test + public void testFilteredSearchThreshold_parsingFloatEncoding_shouldSetDefaultThreshold() + throws Exception { + Integer expectedThreshold = 0; + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorField, + "topK", + topK, + "v", + vectorToSearch, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + KnnFloatVectorQuery vectorQuery = (KnnFloatVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void testFilteredSearchThreshold_parsingFloatEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorField, + "topK", + topK, + "v", + vectorToSearch, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + KnnFloatVectorQuery vectorQuery = (KnnFloatVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void testFilteredSearchThreshold_seededParsingFloatEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + String seedQuery = "id:(1 4 7 8 9)"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorField, + "topK", + topK, + "v", + vectorToSearch, + "seedQuery", + seedQuery, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + SeededKnnVectorQuery vectorQuery = (SeededKnnVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void + testFilteredSearchThreshold_earlyTerminationParsingFloatEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + String earlyTermination = "true"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorField, + "topK", + topK, + "v", + vectorToSearch, + "earlyTermination", + earlyTermination, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + PatienceKnnVectorQuery vectorQuery = (PatienceKnnVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void + testFilteredSearchThreshold_seededAndEarlyTerminationParsingFloatEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + String seedQuery = "id:(1 4 7 8 9)"; + String earlyTermination = "true"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorField, + "topK", + topK, + "v", + vectorToSearch, + "seedQuery", + seedQuery, + "earlyTermination", + earlyTermination, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorField + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + PatienceKnnVectorQuery vectorQuery = (PatienceKnnVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void testFilteredSearchThreshold_parsingByteEncoding_shouldSetDefaultThreshold() + throws Exception { + Integer expectedThreshold = 0; + String vectorToSearch = "[1, 2, 3, 4]"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorFieldByteEncoding, + "topK", + topK, + "v", + vectorToSearch, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + KnnByteVectorQuery vectorQuery = (KnnByteVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void testFilteredSearchThreshold_parsingByteEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1, 2, 3, 4]"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorFieldByteEncoding, + "topK", + topK, + "v", + vectorToSearch, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + KnnByteVectorQuery vectorQuery = (KnnByteVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void testFilteredSearchThreshold_seededParsingByteEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1, 2, 3, 4]"; + String seedQuery = "id:(1 4 7 8 9)"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorFieldByteEncoding, + "topK", + topK, + "v", + vectorToSearch, + "seedQuery", + seedQuery, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + SeededKnnVectorQuery vectorQuery = (SeededKnnVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void + testFilteredSearchThreshold_earlyTerminationParsingByteEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1, 2, 3, 4]"; + String earlyTermination = "true"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorFieldByteEncoding, + "topK", + topK, + "v", + vectorToSearch, + "earlyTermination", + earlyTermination, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + PatienceKnnVectorQuery vectorQuery = (PatienceKnnVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } + + @Test + public void + testFilteredSearchThreshold_seededAndEarlyTerminationParsingByteEncoding_shouldSetCustomThreshold() + throws Exception { + Integer expectedThreshold = 30; + String vectorToSearch = "[1, 2, 3, 4]"; + String seedQuery = "id:(1 4 7 8 9)"; + String earlyTermination = "true"; + String topK = "4"; + + SolrParams params = + params( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + SolrParams localParams = + params( + "type", + "knn", + "f", + vectorFieldByteEncoding, + "topK", + topK, + "v", + vectorToSearch, + "seedQuery", + seedQuery, + "earlyTermination", + earlyTermination, + "filteredSearchThreshold", + expectedThreshold.toString()); + SolrQueryRequest req = + req( + CommonParams.Q, + "{!knn f=" + + vectorFieldByteEncoding + + " topK=" + + topK + + " seedQuery=" + + seedQuery + + " earlyTermination=" + + earlyTermination + + " filteredSearchThreshold=" + + expectedThreshold + + "}" + + vectorToSearch); + + KnnQParser qparser = new KnnQParser(vectorToSearch, localParams, params, req); + try { + PatienceKnnVectorQuery vectorQuery = (PatienceKnnVectorQuery) qparser.parse(); + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); + Integer threshold = strategy.filteredSearchThreshold(); + + assertEquals(expectedThreshold, threshold); + } finally { + req.close(); + } + } } diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 82ae6648a05..aa1c316f51d 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -477,6 +477,28 @@ Here is an example of a `knn` search using a `seedQuery`: The search results retrieved are the k=10 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`. Documents matching the query `id:(1 4 10)` are used as entry points for the ANN search. If no documents match the seed, Solr falls back to a regular knn search without seeding, starting instead from random entry points. +`filteredSearchThreshold`:: ++ +[%autowidth,frame=none] +|=== +|Optional |Default: {lucene-javadocs}/core/constant-values.html#org.apache.lucene.search.knn.KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD[Lucene default] |An integer value from 0 to 100 +|=== ++ +ACORN is an algorithm designed to make hybrid searches consisting of a filter and a vector search more efficient. +This approach tackles both the performance limitations of pre- and post- filtering. +It modifies the construction of the HNSW graph and the search on it. Based on https://arxiv.org/abs/2403.04871[ACORN: Performant and Predicate-Agnostic Search Over Vector Embeddings and Structured Data (2024)]. ++ +Solr relies on Lucene's implementation of the `filteredSearchThreshold` in the {lucene-javadocs}/core/org/apache/lucene/search/knn/KnnSearchStrategy.html[KnnSearchStrategy]. ++ +A suggested value is 60 based on a benchmark you can read more about in this Github https://github.com/apache/lucene/pull/14160#issue-2805145799[comment]. ++ +The `filteredSearchThreshold` regulates this behavior. If the percentage of documents that satisfies the filter is less than the threshold ACORN will be used. + +Here is an example of a `knn` search using a `filteredSearchThreshold`: + +[source,text] +?q={!knn f=vector topK=10 filteredSearchThreshold=60}[1.0, 2.0, 3.0, 4.0] + === knn_text_to_vector Query Parser The `knn_text_to_vector` query parser encode a textual query to a vector using a dedicated Large Language Model(fine tuned for the task of encoding text to vector for sentence similarity) and matches k-nearest neighbours documents to such query vector.
