This is an automated email from the ASF dual-hosted git repository. hossman pushed a commit to branch jira/SOLR-17975 in repository https://gitbox.apache.org/repos/asf/solr.git
commit a51f2a3d3e8f88dec04390d9ff24f70bfe72ba9b Author: Chris Hostetter <[email protected]> AuthorDate: Wed Jan 14 11:10:40 2026 -0700 flesh out some nocommits --- .../solr/schema/LateInteractionVectorField.java | 74 +++++++++++++++------- .../org/apache/solr/search/ValueSourceParser.java | 5 +- .../schema/TestLateInteractionVectorFieldInit.java | 4 +- .../solr/search/TestLateInteractionVectors.java | 38 +++++++---- 4 files changed, 82 insertions(+), 39 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/LateInteractionVectorField.java b/solr/core/src/java/org/apache/solr/schema/LateInteractionVectorField.java index 10ea9f1939a..d0aa0ee98e3 100644 --- a/solr/core/src/java/org/apache/solr/schema/LateInteractionVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/LateInteractionVectorField.java @@ -32,6 +32,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.LateInteractionFloatValuesSource; +import org.apache.lucene.search.LateInteractionFloatValuesSource.ScoreFunction; import org.apache.lucene.search.Query; import org.apache.lucene.search.SortField; import org.apache.lucene.util.BytesRef; @@ -52,6 +53,8 @@ public class LateInteractionVectorField extends FieldType { public static final String SIMILARITY_FUNCTION = "similarityFunction"; public static final VectorSimilarityFunction DEFAULT_SIMILARITY = VectorSimilarityFunction.EUCLIDEAN; + public static final String SCORE_FUNCTION = "scoreFunction"; + public static final ScoreFunction DEFAULT_SCORE_FUNCTION = ScoreFunction.SUM_MAX_SIM; private static final int MUST_BE_TRUE = DOC_VALUES; private static final int MUST_BE_FALSE = MULTIVALUED | TOKENIZED | INDEXED | UNINVERTIBLE; @@ -63,10 +66,7 @@ public class LateInteractionVectorField extends FieldType { private int dimension; private VectorSimilarityFunction similarityFunction; - - // nocommit: pre-emptively add ScoreFunction opt? - // nocommit: if we don't add it now, write a test to fail if/when new options added to - // ScoreFunction enum + private ScoreFunction scoreFunction; public LateInteractionVectorField() { super(); @@ -74,31 +74,32 @@ public class LateInteractionVectorField extends FieldType { @Override public void init(IndexSchema schema, Map<String, String> args) { + this.dimension = - ofNullable(args.get(VECTOR_DIMENSION)) + ofNullable(args.remove(VECTOR_DIMENSION)) .map(Integer::parseInt) .orElseThrow( () -> new SolrException( SolrException.ErrorCode.SERVER_ERROR, VECTOR_DIMENSION + " is a mandatory parameter")); - args.remove(VECTOR_DIMENSION); - try { - this.similarityFunction = - ofNullable(args.get(SIMILARITY_FUNCTION)) - .map(value -> VectorSimilarityFunction.valueOf(value.toUpperCase(Locale.ROOT))) - .orElse(DEFAULT_SIMILARITY); - } catch (IllegalArgumentException e) { - throw new SolrException( - SolrException.ErrorCode.SERVER_ERROR, - SIMILARITY_FUNCTION + " not recognized: " + args.get(SIMILARITY_FUNCTION)); - } - args.remove(SIMILARITY_FUNCTION); + this.similarityFunction = + optionalEnumArg( + SIMILARITY_FUNCTION, + args.remove(SIMILARITY_FUNCTION), + VectorSimilarityFunction.class, + DEFAULT_SIMILARITY); + this.scoreFunction = + optionalEnumArg( + SCORE_FUNCTION, + args.remove(SCORE_FUNCTION), + ScoreFunction.class, + DEFAULT_SCORE_FUNCTION); // By the time this method is called, FieldType.setArgs has already set "typical" defaults, // and parsed the users explicit options. - // We need to override those defaults, and error if the user asked for nonesense + // We need to override those defaults, and error if the user asked for nonsense this.properties |= MUST_BE_TRUE; this.properties &= ~MUST_BE_FALSE; @@ -122,11 +123,17 @@ public class LateInteractionVectorField extends FieldType { return similarityFunction; } + public ScoreFunction getScoreFunction() { + return scoreFunction; + } + public DoubleValuesSource getMultiVecSimilarityValueSource( final SchemaField f, final String vecStr) throws SyntaxError { - // nocommit: use ScoreFunction here if we add it return new LateInteractionFloatValuesSource( - f.getName(), stringToMultiFloatVector(dimension, vecStr), getSimilarityFunction()); + f.getName(), + stringToMultiFloatVector(dimension, vecStr), + getSimilarityFunction(), + getScoreFunction()); } @Override @@ -290,7 +297,7 @@ public class LateInteractionVectorField extends FieldType { public void write(TextResponseWriter writer, String name, IndexableField f) throws IOException { writer.writeStr(name, toExternal(f), false); } - + @Override public Object toObject(SchemaField sf, BytesRef term) { return multiFloatVectorToString(LateInteractionField.decode(term)); @@ -309,7 +316,8 @@ public class LateInteractionVectorField extends FieldType { public ValueSource getValueSource(SchemaField field, QParser parser) { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, - getClass().getSimpleName() + " not supported for function queries, use lateVector() function."); + getClass().getSimpleName() + + " not supported for function queries, use lateVector() function."); } /** Not supported */ @@ -317,7 +325,8 @@ public class LateInteractionVectorField extends FieldType { public Query getFieldQuery(QParser parser, SchemaField field, String externalVal) { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, - getClass().getSimpleName() + " not supported for field queries, use lateVector() function."); + getClass().getSimpleName() + + " not supported for field queries, use lateVector() function."); } /** Not Supported */ @@ -349,4 +358,23 @@ public class LateInteractionVectorField extends FieldType { SolrException.ErrorCode.BAD_REQUEST, getClass().getSimpleName() + " not supported for sorting."); } + + /** + * @param key Config option name, used in exception messages + * @param value Value specified in configuration, may be <code>null</code> + * @param clazz Enum class specifying the return type + * @param defaultValue default to use if value is <code>null</code> + */ + private static final <E extends Enum<E>> E optionalEnumArg( + final String key, final String value, final Class<E> clazz, final E defaultValue) + throws SolrException { + try { + return ofNullable(value) + .map(v -> Enum.valueOf(clazz, v.toUpperCase(Locale.ROOT))) + .orElse(defaultValue); + } catch (IllegalArgumentException e) { + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, key + " not recognized: " + value); + } + } } diff --git a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java index fa03a6b2931..bbd9387e0b0 100644 --- a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java +++ b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java @@ -1377,10 +1377,9 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin { "Invalid number of arguments. Please provide both a field name, and a (String) multi-vector."); } final SchemaField sf = fp.getReq().getSchema().getField(fieldName); - if (sf.getType() instanceof LateInteractionVectorField) { + if (sf.getType() instanceof LateInteractionVectorField lif) { return ValueSource.fromDoubleValuesSource( - ((LateInteractionVectorField) sf.getType()) - .getMultiVecSimilarityValueSource(sf, vecStr)); + lif.getMultiVecSimilarityValueSource(sf, vecStr)); } throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, diff --git a/solr/core/src/test/org/apache/solr/schema/TestLateInteractionVectorFieldInit.java b/solr/core/src/test/org/apache/solr/schema/TestLateInteractionVectorFieldInit.java index cda86969dc0..4f154ef5b79 100644 --- a/solr/core/src/test/org/apache/solr/schema/TestLateInteractionVectorFieldInit.java +++ b/solr/core/src/test/org/apache/solr/schema/TestLateInteractionVectorFieldInit.java @@ -48,7 +48,9 @@ public class TestLateInteractionVectorFieldInit extends AbstractBadConfigTestBas assertConfigs( "solrconfig-basic.xml", "bad-schema-late-vec-field-indexed.xml", "indexed: bad_field"); assertConfigs( - "solrconfig-basic.xml", "bad-schema-late-vec-field-multivalued.xml", "multiValued: bad_field"); + "solrconfig-basic.xml", + "bad-schema-late-vec-field-multivalued.xml", + "multiValued: bad_field"); } public void test_SchemaFields() throws Exception { diff --git a/solr/core/src/test/org/apache/solr/search/TestLateInteractionVectors.java b/solr/core/src/test/org/apache/solr/search/TestLateInteractionVectors.java index e9a2dc4a958..5d1e3b7c308 100644 --- a/solr/core/src/test/org/apache/solr/search/TestLateInteractionVectors.java +++ b/solr/core/src/test/org/apache/solr/search/TestLateInteractionVectors.java @@ -22,16 +22,16 @@ import static org.apache.solr.schema.LateInteractionVectorField.stringToMultiFlo import static org.hamcrest.Matchers.startsWith; import java.util.Arrays; +import java.util.EnumSet; import java.util.List; import java.util.Map; import org.apache.lucene.document.LateInteractionField; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.LateInteractionFloatValuesSource.ScoreFunction; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.schema.FieldType; -import org.apache.solr.schema.LateInteractionVectorField; import org.apache.solr.schema.SchemaField; import org.junit.After; import org.junit.Before; @@ -50,6 +50,18 @@ public class TestLateInteractionVectors extends SolrTestCaseJ4 { deleteCore(); } + public void testFutureProofAgainstNewScoreFunctions() throws Exception { + // if this assert fails, it means there are new value(s) in the ScoreFunction enum, + // and we need to add fieldType declarations using those new ScoreFunctions to our test + // configs, and confirm the correct score function is used in various tests + // + // then remove this test method + assertEquals( + "The ScoreFunction enum in Lucene now has more then value, test needs updated", + EnumSet.of(ScoreFunction.SUM_MAX_SIM), + EnumSet.allOf(ScoreFunction.class)); + } + public void testStringEncodingAndDecoding() throws Exception { final int DIMENSIONS = 4; @@ -143,10 +155,12 @@ public class TestLateInteractionVectors extends SolrTestCaseJ4 { /** Low level test of createFields */ public void createFields() throws Exception { - final Map<String,float[][]> data = Map.of("[[1,2,3,4]]", - new float[][] { { 1F, 2F, 3F, 4F }}, - "[[1,2,3,4],[5,6,7,8]]", - new float[][] { { 1F, 2F, 3F, 4F }, { 5F, 6F, 7F, 8F }}); + final Map<String, float[][]> data = + Map.of( + "[[1,2,3,4]]", + new float[][] {{1F, 2F, 3F, 4F}}, + "[[1,2,3,4],[5,6,7,8]]", + new float[][] {{1F, 2F, 3F, 4F}, {5F, 6F, 7F, 8F}}); try (SolrQueryRequest r = req()) { // defaults with stored + doc values @@ -155,7 +169,7 @@ public class TestLateInteractionVectors extends SolrTestCaseJ4 { final float[][] expected = data.get(input); final List<IndexableField> actual = f.getType().createFields(f, input); assertEquals(2, actual.size()); - + if (actual.get(0) instanceof LateInteractionField lif) { assertEquals(expected, lif.getValue()); } else { @@ -167,14 +181,14 @@ public class TestLateInteractionVectors extends SolrTestCaseJ4 { fail("second Field isn't stored: " + actual.get(1).getClass()); } } - + // stored=false, only doc values for (String input : data.keySet()) { final SchemaField f = r.getSchema().getField("lv_4_nostored"); final float[][] expected = data.get(input); final List<IndexableField> actual = f.getType().createFields(f, input); assertEquals(1, actual.size()); - + if (actual.get(0) instanceof LateInteractionField lif) { assertEquals(expected, lif.getValue()); } else { @@ -183,7 +197,7 @@ public class TestLateInteractionVectors extends SolrTestCaseJ4 { } } } - + public void testSimpleIndexAndRetrieval() throws Exception { // for simplicity, use a single doc, with identical values in several fields @@ -244,8 +258,8 @@ public class TestLateInteractionVectors extends SolrTestCaseJ4 { "//str[@name='lv_4_cosine'][.='" + d4s + "']", // dv only non-stored fields - "//str[@name='lv_3_nostored'][.='"+d3s+"']", - "//str[@name='lv_4_nostored'][.='"+d4s+"']", + "//str[@name='lv_3_nostored'][.='" + d3s + "']", + "//str[@name='lv_4_nostored'][.='" + d4s + "']", // function computations "//float[@name='euclid_3_def'][.=" + euclid3 + "]",
