This is an automated email from the ASF dual-hosted git repository. abenedetti pushed a commit to branch branch_10_0 in repository https://gitbox.apache.org/repos/asf/solr.git
commit 91a7015433802ee28734139429890c7765690e6b Author: Anna Ruggero <[email protected]> AuthorDate: Tue Oct 21 10:36:57 2025 +0200 SOLR-16667: LTR Add feature vector caching for ranking (#3433) by Anna and Alessandro (cherry picked from commit aeb9063585869f2bce12e7a189cad9f57f1fd86a) --- solr/CHANGES.txt | 4 +- .../src/java/org/apache/solr/core/SolrConfig.java | 11 +- .../org/apache/solr/search/SolrIndexSearcher.java | 11 + solr/modules/ltr/build.gradle | 2 + solr/modules/ltr/gradle.lockfile | 2 +- .../java/org/apache/solr/ltr/CSVFeatureLogger.java | 11 +- .../ltr/src/java/org/apache/solr/ltr/DocInfo.java | 10 + .../java/org/apache/solr/ltr/FeatureLogger.java | 63 +--- .../src/java/org/apache/solr/ltr/LTRRescorer.java | 118 ++----- .../java/org/apache/solr/ltr/LTRScoringQuery.java | 377 +++------------------ .../ltr/feature/extraction/FeatureExtractor.java | 129 +++++++ .../feature/extraction/MultiFeaturesExtractor.java | 65 ++++ .../feature/extraction/SingleFeatureExtractor.java | 66 ++++ .../extraction/package-info.java} | 26 +- .../ltr/interleaving/LTRInterleavingRescorer.java | 36 +- .../LTRFeatureLoggerTransformerFactory.java | 65 ++-- .../solr/ltr/scoring/FeatureTraversalScorer.java | 71 ++++ .../solr/ltr/scoring/MultiFeaturesScorer.java | 211 ++++++++++++ .../solr/ltr/scoring/SingleFeatureScorer.java | 143 ++++++++ .../{DocInfo.java => scoring/package-info.java} | 26 +- .../featurevectorcache_features.json | 57 ++++ .../featurevectorcache_linear_model.json | 26 ++ .../conf/solrconfig-ltr-featurevectorcache.xml | 76 +++++ .../solr/collection1/conf/solrconfig-ltr.xml | 7 +- .../collection1/conf/solrconfig-ltr_Th10_10.xml | 3 - .../solr/collection1/conf/solrconfig-multiseg.xml | 3 - .../apache/solr/ltr/TestFeatureVectorCache.java | 366 ++++++++++++++++++++ .../org/apache/solr/ltr/TestLTRScoringQuery.java | 2 +- .../test/org/apache/solr/ltr/TestRerankBase.java | 6 + .../solr/ltr/TestSelectiveWeightCreation.java | 9 +- .../TestFeatureExtractionFromMultipleSegments.java | 3 +- .../conf/solrconfig.xml | 8 +- .../query-guide/pages/learning-to-rank.adoc | 14 +- 33 files changed, 1437 insertions(+), 590 deletions(-) diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index bb9d4448590..1a2f38bc84f 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -33,6 +33,8 @@ New Features * SOLR-17813: Add support for SeededKnnVectorQuery (Ilaria Petreti via Alessandro Benedetti) +* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti) + Improvements --------------------- @@ -752,7 +754,7 @@ Bug Fixes * SOLR-17726: MoreLikeThis to support copy-fields (Ilaria Petreti via Alessandro Benedetti) -* SOLR-16667: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti) +* SOLR-17760: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti) * SOLR-17800: Security Manager should handle symlink on /tmp (Kevin Risden) diff --git a/solr/core/src/java/org/apache/solr/core/SolrConfig.java b/solr/core/src/java/org/apache/solr/core/SolrConfig.java index c9482838d4a..4be0350efd3 100644 --- a/solr/core/src/java/org/apache/solr/core/SolrConfig.java +++ b/solr/core/src/java/org/apache/solr/core/SolrConfig.java @@ -301,6 +301,9 @@ public class SolrConfig implements MapSerializable { queryResultCacheConfig = CacheConfig.getConfig( this, get("query").get("queryResultCache"), "query/queryResultCache"); + featureVectorCacheConfig = + CacheConfig.getConfig( + this, get("query").get("featureVectorCache"), "query/featureVectorCache"); documentCacheConfig = CacheConfig.getConfig(this, get("query").get("documentCache"), "query/documentCache"); CacheConfig conf = @@ -662,6 +665,7 @@ public class SolrConfig implements MapSerializable { public final CacheConfig queryResultCacheConfig; public final CacheConfig documentCacheConfig; public final CacheConfig fieldValueCacheConfig; + public final CacheConfig featureVectorCacheConfig; public final Map<String, CacheConfig> userCacheConfigs; // SolrIndexSearcher - more... public final boolean useFilterForSortedQuery; @@ -998,7 +1002,12 @@ public class SolrConfig implements MapSerializable { } addCacheConfig( - m, filterCacheConfig, queryResultCacheConfig, documentCacheConfig, fieldValueCacheConfig); + m, + filterCacheConfig, + queryResultCacheConfig, + documentCacheConfig, + fieldValueCacheConfig, + featureVectorCacheConfig); m = new LinkedHashMap<>(); result.put("requestDispatcher", m); if (httpCachingConfig != null) m.put("httpCaching", httpCachingConfig); diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index 590cca8916e..15cf7ecc431 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -163,6 +163,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI private final SolrCache<Query, DocSet> filterCache; private final SolrCache<QueryResultKey, DocList> queryResultCache; private final SolrCache<String, UnInvertedField> fieldValueCache; + private final SolrCache<Integer, float[]> featureVectorCache; private final LongAdder fullSortCount = new LongAdder(); private final LongAdder skipSortCount = new LongAdder(); private final LongAdder liveDocsNaiveCacheHitCount = new LongAdder(); @@ -448,6 +449,11 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI ? null : solrConfig.queryResultCacheConfig.newInstance(); if (queryResultCache != null) clist.add(queryResultCache); + featureVectorCache = + solrConfig.featureVectorCacheConfig == null + ? null + : solrConfig.featureVectorCacheConfig.newInstance(); + if (featureVectorCache != null) clist.add(featureVectorCache); SolrCache<Integer, Document> documentCache = docFetcher.getDocumentCache(); if (documentCache != null) clist.add(documentCache); @@ -469,6 +475,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI this.filterCache = null; this.queryResultCache = null; this.fieldValueCache = null; + this.featureVectorCache = null; this.cacheMap = NO_GENERIC_CACHES; this.cacheList = NO_CACHES; } @@ -689,6 +696,10 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI return filterCache; } + public SolrCache<Integer, float[]> getFeatureVectorCache() { + return featureVectorCache; + } + // // Set default regenerators on filter and query caches if they don't have any // diff --git a/solr/modules/ltr/build.gradle b/solr/modules/ltr/build.gradle index 61e02bb645d..20582b1d06d 100644 --- a/solr/modules/ltr/build.gradle +++ b/solr/modules/ltr/build.gradle @@ -55,6 +55,8 @@ dependencies { testImplementation libs.junit.junit testImplementation libs.hamcrest.hamcrest + testImplementation libs.prometheus.metrics.model + testImplementation libs.commonsio.commonsio } diff --git a/solr/modules/ltr/gradle.lockfile b/solr/modules/ltr/gradle.lockfile index de00f483d4d..d4364c51501 100644 --- a/solr/modules/ltr/gradle.lockfile +++ b/solr/modules/ltr/gradle.lockfile @@ -65,7 +65,7 @@ io.opentelemetry:opentelemetry-sdk-metrics:1.53.0=jarValidation,runtimeClasspath io.opentelemetry:opentelemetry-sdk-trace:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.opentelemetry:opentelemetry-sdk:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.prometheus:prometheus-metrics-exposition-formats:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath -io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath +io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath io.sgr:s2-geometry-library-java:1.0.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.swagger.core.v3:swagger-annotations-jakarta:2.2.22=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath jakarta.annotation:jakarta.annotation-api:2.1.1=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java index 22ddcb8724a..57a86a10e1c 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java @@ -23,21 +23,20 @@ public class CSVFeatureLogger extends FeatureLogger { private final char keyValueSep; private final char featureSep; - public CSVFeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll) { - super(fvCacheName, f, logAll); + public CSVFeatureLogger(FeatureFormat f, Boolean logAll) { + super(f, logAll); this.keyValueSep = DEFAULT_KEY_VALUE_SEPARATOR; this.featureSep = DEFAULT_FEATURE_SEPARATOR; } - public CSVFeatureLogger( - String fvCacheName, FeatureFormat f, Boolean logAll, char keyValueSep, char featureSep) { - super(fvCacheName, f, logAll); + public CSVFeatureLogger(FeatureFormat f, Boolean logAll, char keyValueSep, char featureSep) { + super(f, logAll); this.keyValueSep = keyValueSep; this.featureSep = featureSep; } @Override - public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) { + public String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) { // Allocate the buffer to a size based on the number of features instead of the // default 16. You need space for the name, value, and two separators per feature, // but not all the features are expected to fire, so this is just a naive estimate. diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java index e454d90acc2..ee82bb41df7 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java @@ -22,6 +22,8 @@ public class DocInfo extends HashMap<String, Object> { // Name of key used to store the original score of a doc private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE"; + // Name of key used to store the original id of a doc + private static final String ORIGINAL_DOC_ID = "ORIGINAL_DOC_ID"; public DocInfo() { super(); @@ -38,4 +40,12 @@ public class DocInfo extends HashMap<String, Object> { public boolean hasOriginalDocScore() { return containsKey(ORIGINAL_DOC_SCORE); } + + public void setOriginalDocId(int docId) { + put(ORIGINAL_DOC_ID, docId); + } + + public int getOriginalDocId() { + return (int) get(ORIGINAL_DOC_ID); + } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java index 9be531c1ef3..54d308b665e 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java @@ -16,16 +16,10 @@ */ package org.apache.solr.ltr; -import org.apache.solr.search.SolrIndexSearcher; - /** * FeatureLogger can be registered in a model and provide a strategy for logging the feature values. */ public abstract class FeatureLogger { - - /** the name of the cache using for storing the feature value */ - private final String fvCacheName; - public enum FeatureFormat { DENSE, SPARSE @@ -35,54 +29,15 @@ public abstract class FeatureLogger { protected Boolean logAll; - protected FeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll) { - this.fvCacheName = fvCacheName; + protected boolean logFeatures; + + protected FeatureLogger(FeatureFormat f, Boolean logAll) { this.featureFormat = f; this.logAll = logAll; + this.logFeatures = false; } - /** - * Log will be called every time that the model generates the feature values for a document and a - * query. - * - * @param docid Solr document id whose features we are saving - * @param featuresInfo List of all the {@link LTRScoringQuery.FeatureInfo} objects which contain - * name and value for all the features triggered by the result set - * @return true if the logger successfully logged the features, false otherwise. - */ - public boolean log( - int docid, - LTRScoringQuery scoringQuery, - SolrIndexSearcher searcher, - LTRScoringQuery.FeatureInfo[] featuresInfo) { - final String featureVector = makeFeatureVector(featuresInfo); - if (featureVector == null) { - return false; - } - - if (null == searcher.cacheInsert(fvCacheName, fvCacheKey(scoringQuery, docid), featureVector)) { - return false; - } - - return true; - } - - public abstract String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo); - - private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) { - return scoringQuery.hashCode() + (31 * docid); - } - - /** - * populate the document with its feature vector - * - * @param docid Solr document id - * @return String representation of the list of features calculated for docid - */ - public String getFeatureVector( - int docid, LTRScoringQuery scoringQuery, SolrIndexSearcher searcher) { - return (String) searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery, docid)); - } + public abstract String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo); public Boolean isLoggingAll() { return logAll; @@ -91,4 +46,12 @@ public abstract class FeatureLogger { public void setLogAll(Boolean logAll) { this.logAll = logAll; } + + public void setLogFeatures(boolean logFeatures) { + this.logFeatures = logFeatures; + } + + public boolean isLogFeatures() { + return logFeatures; + } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index a21c107438c..0cd0258eb52 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -33,7 +33,6 @@ import org.apache.lucene.search.Weight; import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; import org.apache.solr.search.IncompleteRerankingException; import org.apache.solr.search.QueryLimits; -import org.apache.solr.search.SolrIndexSearcher; /** * Implements the rescoring logic. The top documents returned by solr with their original scores, @@ -114,31 +113,31 @@ public class LTRRescorer extends Rescorer { * * @param searcher current IndexSearcher * @param firstPassTopDocs documents to rerank; - * @param topN documents to return; + * @param docsToRerank documents to return; */ @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int docsToRerank) throws IOException { - if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { + if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { return firstPassTopDocs; } final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs); - topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value())); + docsToRerank = Math.toIntExact(Math.min(docsToRerank, firstPassTopDocs.totalHits.value())); - final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults); + final ScoreDoc[] reranked = rerank(searcher, docsToRerank, firstPassResults); return new TopDocs(firstPassTopDocs.totalHits, reranked); } - private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) + private ScoreDoc[] rerank(IndexSearcher searcher, int docsToRerank, ScoreDoc[] firstPassResults) throws IOException { - final ScoreDoc[] reranked = new ScoreDoc[topN]; + final ScoreDoc[] reranked = new ScoreDoc[docsToRerank]; final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves(); final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); - scoreFeatures(searcher, topN, modelWeight, firstPassResults, leaves, reranked); + scoreFeatures(docsToRerank, modelWeight, firstPassResults, leaves, reranked); // Must sort all documents that we reranked, and then select the top Arrays.sort(reranked, scoreComparator); return reranked; @@ -153,8 +152,7 @@ public class LTRRescorer extends Rescorer { } public void scoreFeatures( - IndexSearcher indexSearcher, - int topN, + int docsToRerank, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves, @@ -166,13 +164,12 @@ public class LTRRescorer extends Rescorer { int docBase = 0; LTRScoringQuery.ModelWeight.ModelScorer scorer = null; - int hitUpto = 0; + int hitPosition = 0; - while (hitUpto < hits.length) { - final ScoreDoc hit = hits[hitUpto]; - final int docID = hit.doc; + while (hitPosition < hits.length) { + final ScoreDoc hit = hits[hitPosition]; LeafReaderContext readerContext = null; - while (docID >= endDoc) { + while (hit.doc >= endDoc) { readerUpto++; readerContext = leaves.get(readerUpto); endDoc = readerContext.docBase + readerContext.reader().maxDoc(); @@ -182,41 +179,17 @@ public class LTRRescorer extends Rescorer { docBase = readerContext.docBase; scorer = modelWeight.modelScorer(readerContext); } - if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) { - logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery); - } - hitUpto++; + scoreSingleHit(docsToRerank, docBase, hitPosition, hit, scorer, reranked); + hitPosition++; } } - /** - * Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, - * org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])} method indicated that - * the document's feature info should be logged. - */ - protected static void logSingleHit( - IndexSearcher indexSearcher, - LTRScoringQuery.ModelWeight modelWeight, - int docid, - LTRScoringQuery scoringQuery) { - final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); - if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { - featureLogger.log( - docid, scoringQuery, (SolrIndexSearcher) indexSearcher, modelWeight.getFeaturesInfo()); - } - } - - /** - * Scores a single document and returns true if the document's feature info should be logged via - * the {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, - * LTRScoringQuery)} method. Feature info logging is only necessary for the topN documents. - */ - protected static boolean scoreSingleHit( - int topN, + /** Scores a single document. */ + protected void scoreSingleHit( + int docsToRerank, int docBase, - int hitUpto, + int hitPosition, ScoreDoc hit, - int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { @@ -228,13 +201,12 @@ public class LTRRescorer extends Rescorer { * needs to compute a potentially non-zero score from blank features. */ assert (scorer != null); - final int targetDoc = docID - docBase; + final int targetDoc = hit.doc - docBase; scorer.docID(); scorer.iterator().advance(targetDoc); - boolean logHit = false; - scorer.getDocInfo().setOriginalDocScore(hit.score); + scorer.getDocInfo().setOriginalDocId(hit.doc); hit.score = scorer.score(); if (QueryLimits.getCurrentLimits() .maybeExitWithPartialResults( @@ -243,28 +215,21 @@ public class LTRRescorer extends Rescorer { + " If partial results are tolerated the reranking got reverted and all documents preserved their original score and ranking.")) { throw new IncompleteRerankingException(); } - if (hitUpto < topN) { - reranked[hitUpto] = hit; - // if the heap is not full, maybe I want to log the features for this - // document - logHit = true; - } else if (hitUpto == topN) { + if (hitPosition < docsToRerank) { + reranked[hitPosition] = hit; + } else if (hitPosition == docsToRerank) { // collected topN document, I create the heap - heapify(reranked, topN); + heapify(reranked, docsToRerank); } - if (hitUpto >= topN) { - // once that heap is ready, if the score of this document is lower that - // the minimum - // i don't want to log the feature. Otherwise I replace it with the - // minimum and fix the - // heap. + if (hitPosition >= docsToRerank) { + // once that heap is ready, if the score of this document is greater that + // the minimum I replace it with the + // minimum and fix the heap. if (hit.score > reranked[0].score) { reranked[0] = hit; - heapAdjust(reranked, topN, 0); - logHit = true; + heapAdjust(reranked, docsToRerank, 0); } } - return logHit; } @Override @@ -289,27 +254,4 @@ public class LTRRescorer extends Rescorer { } return rankingWeight.explain(context, deBasedDoc); } - - public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( - LTRScoringQuery.ModelWeight modelWeight, - int docid, - Float originalDocScore, - List<LeafReaderContext> leafContexts) - throws IOException { - final int n = ReaderUtil.subIndex(docid, leafContexts); - final LeafReaderContext atomicContext = leafContexts.get(n); - final int deBasedDoc = docid - atomicContext.docBase; - final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext); - if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { - return new LTRScoringQuery.FeatureInfo[0]; - } else { - if (originalDocScore != null) { - // If results have not been reranked, the score passed in is the original query's - // score, which some features can use instead of recalculating it - r.getDocInfo().setOriginalDocScore(originalDocScore); - } - r.score(); - return modelWeight.getFeaturesInfo(); - } - } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 85b33fc3ebd..22a9a8e5dfb 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -30,8 +30,6 @@ import java.util.concurrent.FutureTask; import java.util.concurrent.RunnableFuture; import java.util.concurrent.Semaphore; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.DisiPriorityQueue; -import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -45,6 +43,9 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; +import org.apache.solr.ltr.scoring.MultiFeaturesScorer; +import org.apache.solr.ltr.scoring.SingleFeatureScorer; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.util.SolrDefaultScorerSupplier; import org.slf4j.Logger; @@ -70,12 +71,14 @@ public class LTRScoringQuery extends Query implements Accountable { private FeatureLogger logger; // Map of external parameters, such as query intent, that can be used by // features - private final Map<String, String[]> efi; + private Map<String, String[]> efi; // Original solr query used to fetch matching documents private Query originalQuery; // Original solr request private SolrQueryRequest request; + private Feature.FeatureWeight[] extractedFeatureWeights; + public LTRScoringQuery(LTRScoringModel ltrScoringModel) { this(ltrScoringModel, Collections.<String, String[]>emptyMap(), null); } @@ -122,6 +125,10 @@ public class LTRScoringQuery extends Query implements Accountable { return efi; } + public void setExternalFeatureInfo(Map<String, String[]> efi) { + this.efi = efi; + } + public void setRequest(SolrQueryRequest request) { this.request = request; } @@ -130,6 +137,10 @@ public class LTRScoringQuery extends Query implements Accountable { return request; } + public Feature.FeatureWeight[] getExtractedFeatureWeights() { + return extractedFeatureWeights; + } + @Override public int hashCode() { final int prime = 31; @@ -207,8 +218,7 @@ public class LTRScoringQuery extends Query implements Accountable { } else { features = modelFeatures; } - final Feature.FeatureWeight[] extractedFeatureWeights = - new Feature.FeatureWeight[features.size()]; + this.extractedFeatureWeights = new Feature.FeatureWeight[features.size()]; final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize]; List<Feature.FeatureWeight> featureWeights = new ArrayList<>(features.size()); @@ -232,7 +242,7 @@ public class LTRScoringQuery extends Query implements Accountable { modelFeaturesWeights[j++] = fw; } } - return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size()); + return new ModelWeight(modelFeaturesWeights, allFeatures.size()); } private void createWeights( @@ -367,20 +377,9 @@ public class LTRScoringQuery extends Query implements Accountable { // features used for logging. private final Feature.FeatureWeight[] modelFeatureWeights; private final float[] modelFeatureValuesNormalized; - private final Feature.FeatureWeight[] extractedFeatureWeights; - // List of all the feature names, values - used for both scoring and logging - /* - * What is the advantage of using a hashmap here instead of an array of objects? - * A set of arrays was used earlier and the elements were accessed using the featureId. - * With the updated logic to create weights selectively, - * the number of elements in the array can be fewer than the total number of features. - * When [features] are not requested, only the model features are extracted. - * In this case, the indexing by featureId, fails. For this reason, - * we need a map which holds just the features that were triggered by the documents in the result set. - * - */ - private final FeatureInfo[] featuresInfo; + // Array of all the features in the feature store of reference + private final FeatureInfo[] allFeaturesInStore; /* * @param modelFeatureWeights @@ -392,29 +391,25 @@ public class LTRScoringQuery extends Query implements Accountable { * @param allFeaturesSize * - total number of feature in the feature store used by this model */ - public ModelWeight( - Feature.FeatureWeight[] modelFeatureWeights, - Feature.FeatureWeight[] extractedFeatureWeights, - int allFeaturesSize) { + public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights, int allFeaturesSize) { super(LTRScoringQuery.this); - this.extractedFeatureWeights = extractedFeatureWeights; this.modelFeatureWeights = modelFeatureWeights; this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length]; - this.featuresInfo = new FeatureInfo[allFeaturesSize]; - setFeaturesInfo(); + this.allFeaturesInStore = new FeatureInfo[allFeaturesSize]; + setFeaturesInfo(extractedFeatureWeights); } - private void setFeaturesInfo() { + private void setFeaturesInfo(Feature.FeatureWeight[] extractedFeatureWeights) { for (int i = 0; i < extractedFeatureWeights.length; ++i) { String featName = extractedFeatureWeights[i].getName(); int featId = extractedFeatureWeights[i].getIndex(); float value = extractedFeatureWeights[i].getDefaultValue(); - featuresInfo[featId] = new FeatureInfo(featName, value, true); + allFeaturesInStore[featId] = new FeatureInfo(featName, value, true); } } - public FeatureInfo[] getFeaturesInfo() { - return featuresInfo; + public FeatureInfo[] getAllFeaturesInStore() { + return allFeaturesInStore; } // for test use @@ -423,35 +418,29 @@ public class LTRScoringQuery extends Query implements Accountable { } // for test use - float[] getModelFeatureValuesNormalized() { + public float[] getModelFeatureValuesNormalized() { return modelFeatureValuesNormalized; } - // for test use - Feature.FeatureWeight[] getExtractedFeatureWeights() { - return extractedFeatureWeights; - } - /** * Goes through all the stored feature values, and calculates the normalized values for all the * features that will be used for scoring. Then calculate and return the model's score. */ - private float makeNormalizedFeaturesAndScore() { + public void normalizeFeatures() { int pos = 0; for (final Feature.FeatureWeight feature : modelFeatureWeights) { final int featureId = feature.getIndex(); - FeatureInfo fInfo = featuresInfo[featureId]; + FeatureInfo fInfo = allFeaturesInStore[featureId]; modelFeatureValuesNormalized[pos] = fInfo.getValue(); pos++; } ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized); - return ltrScoringModel.score(modelFeatureValuesNormalized); } @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - final Explanation[] explanations = new Explanation[this.featuresInfo.length]; + final Explanation[] explanations = new Explanation[this.allFeaturesInStore.length]; for (final Feature.FeatureWeight feature : extractedFeatureWeights) { explanations[feature.getIndex()] = feature.explain(context, doc); } @@ -469,17 +458,6 @@ public class LTRScoringQuery extends Query implements Accountable { return ltrScoringModel.explain(context, doc, finalScore, featureExplanations); } - protected void reset() { - for (int i = 0; i < extractedFeatureWeights.length; ++i) { - int featId = extractedFeatureWeights[i].getIndex(); - float value = extractedFeatureWeights[i].getDefaultValue(); - // need to set default value everytime as the default value is used in 'dense' - // mode even if used=false - featuresInfo[featId].setValue(value); - featuresInfo[featId].setIsDefaultValue(true); - } - } - @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { return new SolrDefaultScorerSupplier(modelScorer(context)); @@ -500,7 +478,7 @@ public class LTRScoringQuery extends Query implements Accountable { // score on the model for every document, since 0 features matching could // return a // non 0 score for a given model. - ModelScorer mscorer = new ModelScorer(this, featureScorers); + ModelScorer mscorer = new ModelScorer(featureScorers); return mscorer; } @@ -511,22 +489,40 @@ public class LTRScoringQuery extends Query implements Accountable { public class ModelScorer extends Scorer { private final DocInfo docInfo; - private final Scorer featureTraversalScorer; + private final FeatureTraversalScorer featureTraversalScorer; public DocInfo getDocInfo() { return docInfo; } - public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) { + public ModelScorer(List<Feature.FeatureWeight.FeatureScorer> featureScorers) { docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) { subScorer.setDocInfo(docInfo); } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new DenseModelScorer(weight, featureScorers); + featureTraversalScorer = + new SingleFeatureScorer( + ModelWeight.this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + featureScorers, + docInfo); } else { - featureTraversalScorer = new SparseModelScorer(weight, featureScorers); + featureTraversalScorer = + new MultiFeaturesScorer( + ModelWeight.this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + featureScorers, + docInfo); } } @@ -555,273 +551,8 @@ public class LTRScoringQuery extends Query implements Accountable { return featureTraversalScorer.iterator(); } - private class SparseModelScorer extends Scorer { - private final DisiPriorityQueue subScorers; - private final List<DisiWrapper> wrappers; - private final ScoringQuerySparseIterator itr; - - private int targetDoc = -1; - private int activeDoc = -1; - - private SparseModelScorer( - Weight unusedWeight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) { - if (featureScorers.size() <= 1) { - throw new IllegalArgumentException("There must be at least 2 subScorers"); - } - subScorers = DisiPriorityQueue.ofMaxSize(featureScorers.size()); - wrappers = new ArrayList<>(); - for (final Scorer scorer : featureScorers) { - final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */); - subScorers.add(w); - wrappers.add(w); - } - - itr = new ScoringQuerySparseIterator(wrappers); - } - - @Override - public int docID() { - return itr.docID(); - } - - @Override - public float score() throws IOException { - final DisiWrapper topList = subScorers.topList(); - // If target doc we wanted to advance to match the actual doc - // the underlying features advanced to, perform the feature - // calculations, - // otherwise just continue with the model's scoring process with empty - // features. - reset(); - if (activeDoc == targetDoc) { - for (DisiWrapper w = topList; w != null; w = w.next) { - final Feature.FeatureWeight.FeatureScorer subScorer = - (Feature.FeatureWeight.FeatureScorer) w.scorer; - Feature.FeatureWeight scFW = subScorer.getWeight(); - final int featureId = scFW.getIndex(); - featuresInfo[featureId].setValue(subScorer.score()); - if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) { - featuresInfo[featureId].setIsDefaultValue(false); - } - } - } - return makeNormalizedFeaturesAndScore(); - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return Float.POSITIVE_INFINITY; - } - - @Override - public DocIdSetIterator iterator() { - return itr; - } - - @Override - public final Collection<ChildScorable> getChildren() { - final ArrayList<ChildScorable> children = new ArrayList<>(); - for (final DisiWrapper scorer : subScorers) { - children.add(new ChildScorable(scorer.scorer, "SHOULD")); - } - return children; - } - - private class ScoringQuerySparseIterator extends DocIdSetIterator { - - public ScoringQuerySparseIterator(Collection<DisiWrapper> wrappers) { - // Initialize all wrappers to start at -1 - for (DisiWrapper wrapper : wrappers) { - wrapper.doc = -1; - } - } - - @Override - public int docID() { - // Return the target document ID (mimicking DisjunctionDISIApproximation behavior) - return targetDoc; - } - - @Override - public final int nextDoc() throws IOException { - // Mimic DisjunctionDISIApproximation behavior - if (targetDoc == -1) { - // First call - initialize all iterators - DisiWrapper top = subScorers.top(); - if (top != null && top.doc == -1) { - // Need to advance all iterators to their first document - DisiWrapper current = subScorers.top(); - while (current != null) { - current.doc = current.iterator.nextDoc(); - current = subScorers.updateTop(); - } - top = subScorers.top(); - activeDoc = top == null ? NO_MORE_DOCS : top.doc; - } - targetDoc = activeDoc; - return targetDoc; - } - - if (activeDoc == targetDoc) { - // Advance the underlying disjunction - DisiWrapper top = subScorers.top(); - if (top == null) { - activeDoc = NO_MORE_DOCS; - } else { - // Advance the top iterator and rebalance the queue - top.doc = top.iterator.nextDoc(); - top = subScorers.updateTop(); - activeDoc = top == null ? NO_MORE_DOCS : top.doc; - } - } else if (activeDoc < targetDoc) { - // Need to catch up to targetDoc + 1 - activeDoc = advanceInternal(targetDoc + 1); - } - return ++targetDoc; - } - - @Override - public final int advance(int target) throws IOException { - // Mimic DisjunctionDISIApproximation behavior - if (activeDoc < target) { - activeDoc = advanceInternal(target); - } - targetDoc = target; - return targetDoc; - } - - private int advanceInternal(int target) throws IOException { - // Advance the underlying disjunction to the target - DisiWrapper top; - do { - top = subScorers.top(); - if (top == null) { - return NO_MORE_DOCS; - } - if (top.doc >= target) { - return top.doc; - } - top.doc = top.iterator.advance(target); - top = subScorers.updateTop(); - if (top == null) { - return NO_MORE_DOCS; - } - } while (top.doc < target); - return top.doc; - } - - @Override - public long cost() { - // Calculate cost from all wrappers - long cost = 0; - for (DisiWrapper wrapper : wrappers) { - cost += wrapper.iterator.cost(); - } - return cost; - } - } - } - - private class DenseModelScorer extends Scorer { - private int activeDoc = -1; // The doc that our scorer's are actually at - private int targetDoc = -1; // The doc we were most recently told to go to - private int freq = -1; - private final List<Feature.FeatureWeight.FeatureScorer> featureScorers; - - private DenseModelScorer( - Weight unusedWeight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) { - this.featureScorers = featureScorers; - } - - @Override - public int docID() { - return targetDoc; - } - - @Override - public float score() throws IOException { - reset(); - freq = 0; - if (targetDoc == activeDoc) { - for (final Scorer scorer : featureScorers) { - if (scorer.docID() == activeDoc) { - freq++; - Feature.FeatureWeight.FeatureScorer featureScorer = - (Feature.FeatureWeight.FeatureScorer) scorer; - Feature.FeatureWeight scFW = featureScorer.getWeight(); - final int featureId = scFW.getIndex(); - featuresInfo[featureId].setValue(scorer.score()); - if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) { - featuresInfo[featureId].setIsDefaultValue(false); - } - } - } - } - return makeNormalizedFeaturesAndScore(); - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return Float.POSITIVE_INFINITY; - } - - @Override - public final Collection<ChildScorable> getChildren() { - final ArrayList<ChildScorable> children = new ArrayList<>(); - for (final Scorer scorer : featureScorers) { - children.add(new ChildScorable(scorer, "SHOULD")); - } - return children; - } - - @Override - public DocIdSetIterator iterator() { - return new DenseIterator(); - } - - private class DenseIterator extends DocIdSetIterator { - - @Override - public int docID() { - return targetDoc; - } - - @Override - public int nextDoc() throws IOException { - if (activeDoc <= targetDoc) { - activeDoc = NO_MORE_DOCS; - for (final Scorer scorer : featureScorers) { - if (scorer.docID() != NO_MORE_DOCS) { - activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); - } - } - } - return ++targetDoc; - } - - @Override - public int advance(int target) throws IOException { - if (activeDoc < target) { - activeDoc = NO_MORE_DOCS; - for (final Scorer scorer : featureScorers) { - if (scorer.docID() != NO_MORE_DOCS) { - activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); - } - } - } - targetDoc = target; - return target; - } - - @Override - public long cost() { - long sum = 0; - for (int i = 0; i < featureScorers.size(); i++) { - sum += featureScorers.get(i).iterator().cost(); - } - return sum; - } - } + public void fillFeaturesInfo() throws IOException { + featureTraversalScorer.fillFeaturesInfo(); } } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java new file mode 100644 index 00000000000..03911b8bbc4 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature.extraction; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; +import org.apache.solr.ltr.FeatureLogger; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.search.SolrCache; + +/** The class used to extract features for LTR feature logging. */ +public abstract class FeatureExtractor { + protected final FeatureTraversalScorer traversalScorer; + SolrQueryRequest request; + Feature.FeatureWeight[] extractedFeatureWeights; + LTRScoringQuery.FeatureInfo[] allFeaturesInStore; + LTRScoringModel ltrScoringModel; + FeatureLogger logger; + Map<String, String[]> efi; + + FeatureExtractor( + FeatureTraversalScorer traversalScorer, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map<String, String[]> efi) { + this.traversalScorer = traversalScorer; + this.request = request; + this.extractedFeatureWeights = extractedFeatureWeights; + this.allFeaturesInStore = allFeaturesInStore; + this.ltrScoringModel = ltrScoringModel; + this.efi = efi; + } + + protected float[] initFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfos) { + float[] featureVector = new float[featuresInfos.length]; + for (int i = 0; i < featuresInfos.length; i++) { + if (featuresInfos[i] != null) { + featureVector[i] = featuresInfos[i].getValue(); + } + } + return featureVector; + } + + protected abstract float[] extractFeatureVector() throws IOException; + + public void fillFeaturesInfo() throws IOException { + if (traversalScorer.getActiveDoc() == traversalScorer.getTargetDoc()) { + SolrCache<Integer, float[]> featureVectorCache = null; + float[] featureVector; + + if (request != null) { + featureVectorCache = request.getSearcher().getFeatureVectorCache(); + } + if (featureVectorCache != null) { + int fvCacheKey = + computeFeatureVectorCacheKey(traversalScorer.getDocInfo().getOriginalDocId()); + featureVector = featureVectorCache.get(fvCacheKey); + if (featureVector == null) { + featureVector = extractFeatureVector(); + featureVectorCache.put(fvCacheKey, featureVector); + } + } else { + featureVector = extractFeatureVector(); + } + + for (int i = 0; i < extractedFeatureWeights.length; i++) { + int featureId = extractedFeatureWeights[i].getIndex(); + float featureValue = featureVector[featureId]; + if (!Float.isNaN(featureValue) + && featureValue != extractedFeatureWeights[i].getDefaultValue()) { + allFeaturesInStore[featureId].setValue(featureValue); + allFeaturesInStore[featureId].setIsDefaultValue(false); + } + } + } + } + + private int computeFeatureVectorCacheKey(int docId) { + int prime = 31; + int result = docId; + if (Objects.equals( + ltrScoringModel.getName(), + LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) + || (logger != null && logger.isLogFeatures() && logger.isLoggingAll())) { + result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); + } else { + result = (prime * result) + ltrScoringModel.getName().hashCode(); + } + result = (prime * result) + addEfisHash(result, prime, efi); + return result; + } + + private int addEfisHash(int result, int prime, Map<String, String[]> efi) { + if (efi != null) { + TreeMap<String, String[]> sorted = new TreeMap<>(efi); + for (final Map.Entry<String, String[]> entry : sorted.entrySet()) { + final String key = entry.getKey(); + final String[] values = entry.getValue(); + result = (prime * result) + key.hashCode(); + result = (prime * result) + Arrays.hashCode(values); + } + } + return result; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java new file mode 100644 index 00000000000..1db6d161ede --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature.extraction; + +import java.io.IOException; +import java.util.Map; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; +import org.apache.solr.request.SolrQueryRequest; + +/** The class used to extract more than one feature for LTR feature logging. */ +public class MultiFeaturesExtractor extends FeatureExtractor { + DisiPriorityQueue subScorers; + + public MultiFeaturesExtractor( + FeatureTraversalScorer multiFeaturesScorer, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map<String, String[]> efi, + DisiPriorityQueue subScorers) { + super( + multiFeaturesScorer, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi); + this.subScorers = subScorers; + } + + @Override + protected float[] extractFeatureVector() throws IOException { + final DisiWrapper topList = subScorers.topList(); + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (DisiWrapper w = topList; w != null; w = w.next) { + final Feature.FeatureWeight.FeatureScorer subScorer = + (Feature.FeatureWeight.FeatureScorer) w.scorer; + Feature.FeatureWeight feature = subScorer.getWeight(); + final int featureId = feature.getIndex(); + float featureValue = subScorer.score(); + featureVector[featureId] = featureValue; + } + return featureVector; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java new file mode 100644 index 00000000000..5d3cf648afe --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature.extraction; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; +import org.apache.solr.request.SolrQueryRequest; + +/** The class used to extract a single feature for LTR feature logging. */ +public class SingleFeatureExtractor extends FeatureExtractor { + List<Feature.FeatureWeight.FeatureScorer> featureScorers; + + public SingleFeatureExtractor( + FeatureTraversalScorer singleFeatureScorer, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map<String, String[]> efi, + List<Feature.FeatureWeight.FeatureScorer> featureScorers) { + super( + singleFeatureScorer, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi); + this.featureScorers = featureScorers; + } + + @Override + protected float[] extractFeatureVector() throws IOException { + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (final Scorer scorer : featureScorers) { + if (scorer.docID() == traversalScorer.getActiveDoc()) { + Feature.FeatureWeight.FeatureScorer featureScorer = + (Feature.FeatureWeight.FeatureScorer) scorer; + Feature.FeatureWeight scFW = featureScorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = scorer.score(); + featureVector[featureId] = featureValue; + } + } + return featureVector; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java similarity index 59% copy from solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java copy to solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java index e454d90acc2..844fc0dcfe2 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java @@ -14,28 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.solr.ltr; -import java.util.HashMap; - -public class DocInfo extends HashMap<String, Object> { - - // Name of key used to store the original score of a doc - private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE"; - - public DocInfo() { - super(); - } - - public void setOriginalDocScore(Float score) { - put(ORIGINAL_DOC_SCORE, score); - } - - public Float getOriginalDocScore() { - return (Float) get(ORIGINAL_DOC_SCORE); - } - - public boolean hasOriginalDocScore() { - return containsKey(ORIGINAL_DOC_SCORE); - } -} +/** Contains the logic to extract features. */ +package org.apache.solr.ltr.feature.extraction; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 78803afd933..8d1227056a3 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -57,12 +57,12 @@ public class LTRInterleavingRescorer extends LTRRescorer { * * @param searcher current IndexSearcher * @param firstPassTopDocs documents to rerank; - * @param topN documents to return; + * @param docsToRerank documents to return; */ @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int docsToRerank) throws IOException { - if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { + if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { return firstPassTopDocs; } @@ -72,10 +72,10 @@ public class LTRInterleavingRescorer extends LTRRescorer { System.arraycopy( firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length); } - topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value())); + docsToRerank = Math.toIntExact(Math.min(docsToRerank, firstPassTopDocs.totalHits.value())); ScoreDoc[][] reRankedPerModel = - rerank(searcher, topN, getFirstPassDocsRanked(firstPassTopDocs)); + rerank(searcher, docsToRerank, getFirstPassDocsRanked(firstPassTopDocs)); if (originalRankingIndex != null) { reRankedPerModel[originalRankingIndex] = firstPassResults; } @@ -90,9 +90,9 @@ public class LTRInterleavingRescorer extends LTRRescorer { return new TopDocs(firstPassTopDocs.totalHits, interleavedResults); } - private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) + private ScoreDoc[][] rerank(IndexSearcher searcher, int docsToRerank, ScoreDoc[] firstPassResults) throws IOException { - ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN]; + ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][docsToRerank]; final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves(); LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length]; @@ -103,7 +103,7 @@ public class LTRInterleavingRescorer extends LTRRescorer { searcher.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1); } } - scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel); + scoreFeatures(docsToRerank, modelWeights, firstPassResults, leaves, reRankedPerModel); for (int i = 0; i < rerankingQueries.length; i++) { if (originalRankingIndex == null || originalRankingIndex != i) { @@ -115,8 +115,7 @@ public class LTRInterleavingRescorer extends LTRRescorer { } public void scoreFeatures( - IndexSearcher indexSearcher, - int topN, + int docsToRerank, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves, @@ -126,14 +125,13 @@ public class LTRInterleavingRescorer extends LTRRescorer { int readerUpto = -1; int endDoc = 0; int docBase = 0; - int hitUpto = 0; + int hitPosition = 0; LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length]; - while (hitUpto < hits.length) { - final ScoreDoc hit = hits[hitUpto]; - final int docID = hit.doc; + while (hitPosition < hits.length) { + final ScoreDoc hit = hits[hitPosition]; LeafReaderContext readerContext = null; - while (docID >= endDoc) { + while (hit.doc >= endDoc) { readerUpto++; readerContext = leaves.get(readerUpto); endDoc = readerContext.docBase + readerContext.reader().maxDoc(); @@ -151,13 +149,11 @@ public class LTRInterleavingRescorer extends LTRRescorer { for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); - if (scoreSingleHit( - topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i])) { - logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]); - } + scoreSingleHit( + docsToRerank, docBase, hitPosition, hit_i, scorers[i], rerankedPerModel[i]); } } - hitUpto++; + hitPosition++; } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 21ead475609..a85597bde08 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreMode; import org.apache.solr.common.SolrDocument; @@ -30,7 +31,6 @@ import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.ltr.CSVFeatureLogger; import org.apache.solr.ltr.FeatureLogger; -import org.apache.solr.ltr.LTRRescorer; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.LTRThreadModule; import org.apache.solr.ltr.SolrQueryRequestContextUtils; @@ -75,11 +75,10 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { // used inside fl to specify to log (all|model only) features private static final String FV_LOG_ALL = "logAll"; - private static final String DEFAULT_LOGGING_MODEL_NAME = "logging-model"; + public static final String DEFAULT_LOGGING_MODEL_NAME = "logging-model"; private static final boolean DEFAULT_NO_RERANKING_LOGGING_ALL = true; - private String fvCacheName; private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME; private String defaultStore; private FeatureLogger.FeatureFormat defaultFormat = FeatureLogger.FeatureFormat.DENSE; @@ -88,10 +87,6 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { private LTRThreadModule threadManager = null; - public void setFvCacheName(String fvCacheName) { - this.fvCacheName = fvCacheName; - } - public void setLoggingModelName(String loggingModelName) { this.loggingModelName = loggingModelName; } @@ -161,11 +156,7 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { } else { format = this.defaultFormat; } - if (fvCacheName == null) { - throw new IllegalArgumentException("a fvCacheName must be configured"); - } - return new CSVFeatureLogger( - fvCacheName, format, logAll, csvKeyValueDelimiter, csvFeatureSeparator); + return new CSVFeatureLogger(format, logAll, csvKeyValueDelimiter, csvFeatureSeparator); } class FeatureTransformer extends DocTransformer { @@ -373,6 +364,12 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { : rerankingQueries[i].getExternalFeatureInfo()), threadManager); } + } else { + for (int i = 0; i < rerankingQueries.length; i++) { + if (!transformerExternalFeatureInfo.isEmpty()) { + rerankingQueries[i].setExternalFeatureInfo(transformerExternalFeatureInfo); + } + } } } } @@ -410,6 +407,32 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { implTransform(doc, docid, docInfo); } + private static LTRScoringQuery.FeatureInfo[] extractFeatures( + FeatureLogger logger, + LTRScoringQuery.ModelWeight modelWeight, + int docid, + Float originalDocScore, + List<LeafReaderContext> leafContexts) + throws IOException { + final int n = ReaderUtil.subIndex(docid, leafContexts); + final LeafReaderContext atomicContext = leafContexts.get(n); + final int deBasedDoc = docid - atomicContext.docBase; + final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext); + r.getDocInfo().setOriginalDocId(docid); + if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { + return new LTRScoringQuery.FeatureInfo[0]; + } else { + if (originalDocScore != null) { + // If results have not been reranked, the score passed in is the original query's + // score, which some features can use instead of recalculating it + r.getDocInfo().setOriginalDocScore(originalDocScore); + } + r.fillFeaturesInfo(); + logger.setLogFeatures(true); + return modelWeight.getAllFeaturesInStore(); + } + } + private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo) throws IOException { LTRScoringQuery rerankingQuery = rerankingQueries[0]; @@ -423,16 +446,14 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { } } if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { - Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher); - if (featureVector == null) { // FV for this document was not in the cache - featureVector = - featureLogger.makeFeatureVector( - LTRRescorer.extractFeaturesInfo( - rerankingModelWeight, - docid, - (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, - leafContexts)); - } + LTRScoringQuery.FeatureInfo[] featuresInfo = + extractFeatures( + featureLogger, + rerankingModelWeight, + docid, + (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, + leafContexts); + String featureVector = featureLogger.printFeatureVector(featuresInfo); doc.addField(name, featureVector); } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java new file mode 100644 index 00000000000..2c92ff9e15e --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.scoring; + +import java.io.IOException; +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.DocInfo; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.FeatureExtractor; +import org.apache.solr.ltr.model.LTRScoringModel; + +/** This class is responsible for extracting features and using them to score the document. */ +public abstract class FeatureTraversalScorer extends Scorer { + protected FeatureExtractor featureExtractor; + protected LTRScoringQuery.FeatureInfo[] allFeaturesInStore; + protected LTRScoringModel ltrScoringModel; + protected Feature.FeatureWeight[] extractedFeatureWeights; + protected LTRScoringQuery.ModelWeight modelWeight; + + public abstract int getActiveDoc(); + + public abstract int getTargetDoc(); + + public abstract DocInfo getDocInfo(); + + public void reset() { + for (int i = 0; i < extractedFeatureWeights.length; ++i) { + int featId = extractedFeatureWeights[i].getIndex(); + float value = extractedFeatureWeights[i].getDefaultValue(); + // need to set default value everytime as the default value is used in 'dense' + // mode even if used=false + allFeaturesInStore[featId].setValue(value); + allFeaturesInStore[featId].setIsDefaultValue(true); + } + } + + public void fillFeaturesInfo() throws IOException { + // Initialize features to their default values and set isDefaultValue to true. + reset(); + featureExtractor.fillFeaturesInfo(); + } + + @Override + public float score() throws IOException { + // Initialize features to their default values and set isDefaultValue to true. + reset(); + featureExtractor.fillFeaturesInfo(); + modelWeight.normalizeFeatures(); + return ltrScoringModel.score(modelWeight.getModelFeatureValuesNormalized()); + } + + @Override + public float getMaxScore(int upTo) { + return Float.POSITIVE_INFINITY; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java new file mode 100644 index 00000000000..2ac44c19be8 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.scoring; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.DocInfo; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; + +/** + * This class is responsible for extracting more than one feature and using them to score the + * document. + */ +public class MultiFeaturesScorer extends FeatureTraversalScorer { + private int targetDoc = -1; + private int activeDoc = -1; + protected DocInfo docInfo; + private final DisiPriorityQueue subScorers; + private final List<DisiWrapper> wrappers; + private final MultiFeaturesIterator multiFeaturesIteratorIterator; + + public MultiFeaturesScorer( + LTRScoringQuery.ModelWeight modelWeight, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map<String, String[]> efi, + List<Feature.FeatureWeight.FeatureScorer> featureScorers, + DocInfo docInfo) { + if (featureScorers.size() <= 1) { + throw new IllegalArgumentException("There must be at least 2 subScorers"); + } + subScorers = DisiPriorityQueue.ofMaxSize(featureScorers.size()); + wrappers = new ArrayList<>(); + for (final Scorer scorer : featureScorers) { + final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */); + subScorers.add(w); + wrappers.add(w); + } + + multiFeaturesIteratorIterator = new MultiFeaturesIterator(wrappers); + this.extractedFeatureWeights = extractedFeatureWeights; + this.allFeaturesInStore = allFeaturesInStore; + this.ltrScoringModel = ltrScoringModel; + this.featureExtractor = + new MultiFeaturesExtractor( + this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + subScorers); + this.modelWeight = modelWeight; + this.docInfo = docInfo; + } + + @Override + public int getActiveDoc() { + return activeDoc; + } + + @Override + public int getTargetDoc() { + return targetDoc; + } + + @Override + public DocInfo getDocInfo() { + return docInfo; + } + + @Override + public int docID() { + return multiFeaturesIteratorIterator.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return multiFeaturesIteratorIterator; + } + + @Override + public final Collection<ChildScorable> getChildren() { + final ArrayList<ChildScorable> children = new ArrayList<>(); + for (final DisiWrapper scorer : subScorers) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + private class MultiFeaturesIterator extends DocIdSetIterator { + + public MultiFeaturesIterator(Collection<DisiWrapper> wrappers) { + // Initialize all wrappers to start at -1 + for (DisiWrapper wrapper : wrappers) { + wrapper.doc = -1; + } + } + + @Override + public int docID() { + // Return the target document ID (mimicking DisjunctionDISIApproximation behavior) + return targetDoc; + } + + @Override + public final int nextDoc() throws IOException { + // Mimic DisjunctionDISIApproximation behavior + if (targetDoc == -1) { + // First call - initialize all iterators + DisiWrapper top = subScorers.top(); + if (top != null && top.doc == -1) { + // Need to advance all iterators to their first document + DisiWrapper current = subScorers.top(); + while (current != null) { + current.doc = current.iterator.nextDoc(); + current = subScorers.updateTop(); + } + top = subScorers.top(); + activeDoc = top == null ? NO_MORE_DOCS : top.doc; + } + targetDoc = activeDoc; + return targetDoc; + } + + if (activeDoc == targetDoc) { + // Advance the underlying disjunction + DisiWrapper top = subScorers.top(); + if (top == null) { + activeDoc = NO_MORE_DOCS; + } else { + // Advance the top iterator and rebalance the queue + top.doc = top.iterator.nextDoc(); + top = subScorers.updateTop(); + activeDoc = top == null ? NO_MORE_DOCS : top.doc; + } + } else if (activeDoc < targetDoc) { + // Need to catch up to targetDoc + 1 + activeDoc = advanceInternal(targetDoc + 1); + } + return ++targetDoc; + } + + @Override + public final int advance(int target) throws IOException { + // Mimic DisjunctionDISIApproximation behavior + if (activeDoc < target) { + activeDoc = advanceInternal(target); + } + targetDoc = target; + return targetDoc; + } + + private int advanceInternal(int target) throws IOException { + // Advance the underlying disjunction to the target + DisiWrapper top; + do { + top = subScorers.top(); + if (top == null) { + return NO_MORE_DOCS; + } + if (top.doc >= target) { + return top.doc; + } + top.doc = top.iterator.advance(target); + top = subScorers.updateTop(); + if (top == null) { + return NO_MORE_DOCS; + } + } while (top.doc < target); + return top.doc; + } + + @Override + public long cost() { + // Calculate cost from all wrappers + long cost = 0; + for (DisiWrapper wrapper : wrappers) { + cost += wrapper.iterator.cost(); + } + return cost; + } + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java new file mode 100644 index 00000000000..6619856901a --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.scoring; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.DocInfo; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; + +/** This class is responsible for extracting a single feature and using it to score the document. */ +public class SingleFeatureScorer extends FeatureTraversalScorer { + private int targetDoc = -1; + private int activeDoc = -1; + protected DocInfo docInfo; + private final List<Feature.FeatureWeight.FeatureScorer> featureScorers; + + public SingleFeatureScorer( + LTRScoringQuery.ModelWeight modelWeight, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map<String, String[]> efi, + List<Feature.FeatureWeight.FeatureScorer> featureScorers, + DocInfo docInfo) { + this.featureScorers = featureScorers; + this.extractedFeatureWeights = extractedFeatureWeights; + this.allFeaturesInStore = allFeaturesInStore; + this.ltrScoringModel = ltrScoringModel; + this.featureExtractor = + new SingleFeatureExtractor( + this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + featureScorers); + this.modelWeight = modelWeight; + this.docInfo = docInfo; + } + + @Override + public int getActiveDoc() { + return activeDoc; + } + + @Override + public int getTargetDoc() { + return targetDoc; + } + + @Override + public DocInfo getDocInfo() { + return docInfo; + } + + @Override + public int docID() { + return targetDoc; + } + + @Override + public final Collection<ChildScorable> getChildren() { + final ArrayList<ChildScorable> children = new ArrayList<>(); + for (final Scorer scorer : featureScorers) { + children.add(new ChildScorable(scorer, "SHOULD")); + } + return children; + } + + @Override + public DocIdSetIterator iterator() { + return new SingleFeatureIterator(); + } + + private class SingleFeatureIterator extends DocIdSetIterator { + + @Override + public int docID() { + return targetDoc; + } + + @Override + public int nextDoc() throws IOException { + if (activeDoc <= targetDoc) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); + } + } + } + return ++targetDoc; + } + + @Override + public int advance(int target) throws IOException { + if (activeDoc < target) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); + } + } + } + targetDoc = target; + return target; + } + + @Override + public long cost() { + long sum = 0; + for (int i = 0; i < featureScorers.size(); i++) { + sum += featureScorers.get(i).iterator().cost(); + } + return sum; + } + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java similarity index 59% copy from solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java copy to solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java index e454d90acc2..54e25ba080b 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java @@ -14,28 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.solr.ltr; -import java.util.HashMap; - -public class DocInfo extends HashMap<String, Object> { - - // Name of key used to store the original score of a doc - private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE"; - - public DocInfo() { - super(); - } - - public void setOriginalDocScore(Float score) { - put(ORIGINAL_DOC_SCORE, score); - } - - public Float getOriginalDocScore() { - return (Float) get(ORIGINAL_DOC_SCORE); - } - - public boolean hasOriginalDocScore() { - return containsKey(ORIGINAL_DOC_SCORE); - } -} +/** Contains the logic to extract features for scoring. */ +package org.apache.solr.ltr.scoring; diff --git a/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json b/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json new file mode 100644 index 00000000000..f65b0108bda --- /dev/null +++ b/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json @@ -0,0 +1,57 @@ +[ + { + "name": "value_feature_1", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 1 + } + }, + { + "name": "value_feature_3", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 3 + } + }, + { + "name" : "efi_feature", + "class":"org.apache.solr.ltr.feature.ValueFeature", + "params" : { + "value": "${efi_feature}" + } + }, + { + "name" : "match_w1_title", + "class":"org.apache.solr.ltr.feature.SolrFeature", + "params" : { + "fq": [ + "{!terms f=title}w1" + ] + } + }, + { + "name": "popularity_value", + "class": "org.apache.solr.ltr.feature.FieldValueFeature", + "params": { + "field": "popularity" + } + }, + { + "name" : "match_w1_title", + "class":"org.apache.solr.ltr.feature.SolrFeature", + "store": "store1", + "params" : { + "fq": [ + "{!terms f=title}w1" + ] + } + }, + { + "name": "value_feature_2", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "store": "store1", + "params": { + "value": 2 + } + } +] \ No newline at end of file diff --git a/solr/modules/ltr/src/test-files/modelExamples/featurevectorcache_linear_model.json b/solr/modules/ltr/src/test-files/modelExamples/featurevectorcache_linear_model.json new file mode 100644 index 00000000000..c05acac90e3 --- /dev/null +++ b/solr/modules/ltr/src/test-files/modelExamples/featurevectorcache_linear_model.json @@ -0,0 +1,26 @@ +{ + "class": "org.apache.solr.ltr.model.LinearModel", + "name": "featurevectorcache_linear_model", + "features": [ + { + "name": "value_feature_1" + }, + { + "name": "efi_feature" + }, + { + "name": "match_w1_title" + }, + { + "name": "popularity_value" + } + ], + "params": { + "weights": { + "value_feature_1": 1, + "efi_feature": 0.2, + "match_w1_title": 0.5, + "popularity_value": 0.8 + } + } +} \ No newline at end of file diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml new file mode 100644 index 00000000000..f78d5996d3d --- /dev/null +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml @@ -0,0 +1,76 @@ +<?xml version="1.0" ?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> + +<config> + <luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion> + <dataDir>${solr.data.dir:}</dataDir> + <directoryFactory name="DirectoryFactory" + class="${solr.directoryFactory:solr.MockDirectoryFactory}" /> + + <schemaFactory class="ClassicIndexSchemaFactory" /> + + <requestDispatcher> + <requestParsers /> + </requestDispatcher> + + <!-- Query parser used to rerank top docs with a provided model --> + <queryParser name="ltr" + class="org.apache.solr.ltr.search.LTRQParserPlugin" /> + + <query> + <filterCache class="solr.CaffeineCache" size="4096" initialSize="2048" autowarmCount="0" /> + <featureVectorCache class="solr.CaffeineCache" size="4096" initialSize="2048" autowarmCount="0" /> + </query> + + <!-- add a transformer that will encode the document features in the response. + For each document the transformer will add the features as an extra field + in the response. The name of the field we will be the the name of the transformer + enclosed between brackets (in this case [fv]). In order to get the feature + vector you will have to specify that you want the field (e.g., fl="*,[fv]) --> + <transformer name="fv" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory"> + <str name="defaultFormat">${solr.ltr.transformer.fv.defaultFormat:dense}</str> + </transformer> + + <!-- add a transformer that will encode the model the interleaving process chose the search result from. + For each document the transformer will add an extra field in the response with the model picked. + The name of the field will be the the name of the transformer + enclosed between brackets (in this case [interleaving]). + In order to get the model chosen for the search result + you will have to specify that you want the field (e.g., fl="*,[interleaving]) --> + <transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory"> + </transformer> + + <updateHandler class="solr.DirectUpdateHandler2"> + <autoCommit> + <maxTime>15000</maxTime> + <openSearcher>false</openSearcher> + </autoCommit> + <autoSoftCommit> + <maxTime>1000</maxTime> + </autoSoftCommit> + <updateLog> + <str name="dir">${solr.data.dir:}</str> + </updateLog> + </updateHandler> + + <requestHandler name="/update" class="solr.UpdateRequestHandler" /> + <!-- Query request handler managing models and features --> + <requestHandler name="/query" class="solr.SearchHandler"> + <lst name="defaults"> + <str name="echoParams">explicit</str> + <str name="wt">json</str> + <str name="indent">true</str> + <str name="df">id</str> + </lst> + </requestHandler> + +</config> \ No newline at end of file diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml index c20ee2026f6..496755c9765 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -29,8 +29,6 @@ <query> <filterCache class="solr.CaffeineCache" size="4096" initialSize="2048" autowarmCount="0" /> - <cache name="QUERY_DOC_FV" class="solr.search.CaffeineCache" size="4096" - initialSize="2048" autowarmCount="4096" regenerator="solr.search.NoOpRegenerator" /> </query> <!-- add a transformer that will encode the document features in the response. @@ -40,13 +38,12 @@ vector you will have to specify that you want the field (e.g., fl="*,[fv]) --> <transformer name="fv" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory"> <str name="defaultFormat">${solr.ltr.transformer.fv.defaultFormat:dense}</str> - <str name="fvCacheName">QUERY_DOC_FV</str> </transformer> <!-- add a transformer that will encode the model the interleaving process chose the search result from. - For each document the transformer will add an extra field in the response with the model picked. + For each document the transformer will add an extra field in the response with the model picked. The name of the field will be the the name of the transformer - enclosed between brackets (in this case [interleaving]). + enclosed between brackets (in this case [interleaving]). In order to get the model chosen for the search result you will have to specify that you want the field (e.g., fl="*,[interleaving]) --> <transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory"> diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml index 37ae68a2580..32f8751a9f8 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml @@ -33,8 +33,6 @@ <query> <filterCache class="solr.CaffeineCache" size="4096" initialSize="2048" autowarmCount="0" /> - <cache name="QUERY_DOC_FV" class="solr.search.CaffeineCache" size="4096" - initialSize="2048" autowarmCount="4096" regenerator="solr.search.NoOpRegenerator" /> </query> <!-- add a transformer that will encode the document features in the response. @@ -43,7 +41,6 @@ enclosed between brackets (in this case [fv]). In order to get the feature vector you will have to specify that you want the field (e.g., fl="*,[fv]) --> <transformer name="fv" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory"> - <str name="fvCacheName">QUERY_DOC_FV</str> </transformer> <updateHandler class="solr.DirectUpdateHandler2"> diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml index 911db9a9f55..04ee42b3834 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml @@ -28,8 +28,6 @@ <query> <filterCache class="solr.CaffeineCache" size="4096" initialSize="2048" autowarmCount="0" /> - <cache name="QUERY_DOC_FV" class="solr.search.CaffeineCache" size="4096" - initialSize="2048" autowarmCount="4096" regenerator="solr.search.NoOpRegenerator" /> </query> <maxBufferedDocs>1</maxBufferedDocs> @@ -43,7 +41,6 @@ enclosed between brackets (in this case [fv]). In order to get the feature vector you will have to specify that you want the field (e.g., fl="*,[fv]) --> <transformer name="features" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory"> - <str name="fvCacheName">QUERY_DOC_FV</str> </transformer> <updateHandler class="solr.DirectUpdateHandler2"> diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java new file mode 100644 index 00000000000..2674f567edd --- /dev/null +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import io.prometheus.metrics.model.snapshots.CounterSnapshot; +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.core.SolrCore; +import org.apache.solr.util.SolrMetricTestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestFeatureVectorCache extends TestRerankBase { + SolrCore core = null; + List<String> docs; + + @Before + public void before() throws Exception { + setupFeatureVectorCacheTest(false); + + this.docs = new ArrayList<>(); + docs.add(adoc("id", "1", "title", "w2", "description", "w2", "popularity", "2")); + docs.add(adoc("id", "2", "title", "w1", "description", "w1", "popularity", "0")); + for (String doc : docs) { + assertU(doc); + } + assertU(commit()); + + loadFeatures("featurevectorcache_features.json"); + loadModels("featurevectorcache_linear_model.json"); + + core = solrClientTestRule.getCoreContainer().getCore(DEFAULT_TEST_CORENAME); + } + + @After + public void after() throws Exception { + core.close(); + aftertest(); + } + + private static CounterSnapshot.CounterDataPointSnapshot getFeatureVectorCacheInserts( + SolrCore core) { + return SolrMetricTestUtils.getCacheSearcherOpsInserts(core, "featureVectorCache"); + } + + private static double getFeatureVectorCacheLookups(SolrCore core) { + return SolrMetricTestUtils.getCacheSearcherTotalLookups(core, "featureVectorCache"); + } + + private static CounterSnapshot.CounterDataPointSnapshot getFeatureVectorCacheHits(SolrCore core) { + return SolrMetricTestUtils.getCacheSearcherOpsHits(core, "featureVectorCache"); + } + + @Test + public void testFeatureVectorCache_loggingDefaultStoreNoReranking() throws Exception { + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", + "1.0", + "value_feature_3", + "3.0", + "efi_feature", + "3.0", + "match_w1_title", + "0.0", + "popularity_value", + "2.0"); + final String docs0fv_sparse_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", + "1.0", + "value_feature_3", + "3.0", + "efi_feature", + "3.0", + "popularity_value", + "2.0"); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "[fv efi.efi_feature=3]"); + + // No caching, we want to see lookups, insertions and no hits + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size(), getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); + + query.add("sort", "popularity desc"); + // Caching, we want to see hits + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0); + } + + @Test + public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exception { + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector("match_w1_title", "0.0", "value_feature_2", "2.0"); + final String docs0fv_sparse_csv = + FeatureLoggerTestUtils.toFeatureVector("value_feature_2", "2.0"); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "[fv store=store1 efi.efi_feature=3]"); + + // No caching, we want to see lookups, insertions and no hits + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size(), getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); + + query.add("sort", "popularity desc"); + // Caching, we want to see hits + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0); + } + + @Test + public void testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi() + throws Exception { + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", + "1.0", + "efi_feature", + "3.0", + "match_w1_title", + "0.0", + "popularity_value", + "2.0"); + final String docs0fv_sparse_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", "1.0", "efi_feature", "3.0", "popularity_value", "2.0"); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "id,score,fv:[fv efi.efi_feature=3]"); + query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); + + // No caching, we want to see lookups, insertions and no hits since the efis are different + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); + + query.add("sort", "popularity desc"); + // Caching, we want to see hits and same scores as before + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(), 0); + } + + @Test + public void testFeatureVectorCache_loggingModelStoreAndRerankingWithSameEfi() throws Exception { + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", + "1.0", + "efi_feature", + "4.0", + "match_w1_title", + "0.0", + "popularity_value", + "2.0"); + final String docs0fv_sparse_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", "1.0", "efi_feature", "4.0", "popularity_value", "2.0"); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "id,score,fv:[fv efi.efi_feature=4]"); + query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); + + // No caching for reranking but logging should hit since we have the same feature store and efis + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0); + + query.add("sort", "popularity desc"); + // Caching, we want to see hits and same scores + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 3, getFeatureVectorCacheHits(core).getValue(), 0); + } + + @Test + public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws Exception { + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", + "1.0", + "value_feature_3", + "3.0", + "efi_feature", + "3.0", + "match_w1_title", + "0.0", + "popularity_value", + "2.0"); + final String docs0fv_sparse_csv = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", + "1.0", + "value_feature_3", + "3.0", + "efi_feature", + "3.0", + "popularity_value", + "2.0"); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "id,score,fv:[fv logAll=true efi.efi_feature=3]"); + query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); + + // No caching, we want to see lookups, insertions and no hits since the efis are different + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); + + query.add("sort", "popularity desc"); + // Caching, we want to see hits and same scores + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(), 0); + } + + @Test + public void testFeatureVectorCache_loggingExplicitStoreAndReranking() throws Exception { + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector("match_w1_title", "0.0", "value_feature_2", "2.0"); + final String docs0fv_sparse_csv = + FeatureLoggerTestUtils.toFeatureVector("value_feature_2", "2.0"); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "id,score,fv:[fv store=store1 efi.efi_feature=3]"); + query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); + + // No caching, we want to see lookups, insertions and no hits since the efis are different + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); + + query.add("sort", "popularity desc"); + // Caching, we want to see hits and same scores + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={" + + "'id':'1'," + + "'score':3.4," + + "'fv':'" + + docs0fv_default_csv + + "'}"); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(), 0); + } +} diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java index 7ad5b34f49b..583ae22f742 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java @@ -219,7 +219,7 @@ public class TestLTRScoringQuery extends SolrTestCase { } int[] posVals = new int[] {0, 1, 2}; int pos = 0; - for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getFeaturesInfo()) { + for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getAllFeaturesInStore()) { if (fInfo == null) { continue; } diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java index e3b30b24043..2ee0c1ebdd0 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java @@ -120,6 +120,12 @@ public class TestRerankBase extends RestTestBase { if (bulkIndex) bulkIndex(); } + protected static void setupFeatureVectorCacheTest(boolean bulkIndex) throws Exception { + chooseDefaultFeatureFormat(); + setuptest("solrconfig-ltr-featurevectorcache.xml", "schema.xml"); + if (bulkIndex) bulkIndex(); + } + public static ManagedFeatureStore getManagedFeatureStore() { try (SolrCore core = solrClientTestRule.getCoreContainer().getCore(DEFAULT_TEST_CORENAME)) { return ManagedFeatureStore.getManagedFeatureStore(core); diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java index be1313b47bc..b5148f66b80 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java @@ -167,7 +167,7 @@ public class TestSelectiveWeightCreation extends TestRerankBase { searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel1)); // features not requested in response - LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo(); + LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getAllFeaturesInStore(); assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); int nonDefaultFeatures = 0; @@ -189,13 +189,12 @@ public class TestSelectiveWeightCreation extends TestRerankBase { TestLinearModel.makeFeatureWeights(features)); LTRScoringQuery ltrQuery2 = new LTRScoringQuery(ltrScoringModel2); // features requested in response - ltrQuery2.setFeatureLogger( - new CSVFeatureLogger("test", FeatureLogger.FeatureFormat.DENSE, true)); + ltrQuery2.setFeatureLogger(new CSVFeatureLogger(FeatureLogger.FeatureFormat.DENSE, true)); modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, ltrQuery2); - featuresInfo = modelWeight.getFeaturesInfo(); + featuresInfo = modelWeight.getAllFeaturesInStore(); assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); - assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length); + assertEquals(allFeatures.size(), ltrQuery2.getExtractedFeatureWeights().length); nonDefaultFeatures = 0; for (int i = 0; i < featuresInfo.length; ++i) { diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java index 3e6e986e591..f9fb8099cf3 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java @@ -236,7 +236,8 @@ public class TestFeatureExtractionFromMultipleSegments extends TestRerankBase { query.setQuery( "{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}"); // request 100 rows, if any rows are fetched from the second or subsequent segments the tests - // should succeed if LTRRescorer::extractFeaturesInfo() advances the doc iterator properly + // should succeed if LTRFeatureLoggerTransformerFactory::extractFeatures() advances the doc + // iterator properly int numRows = 100; query.add("rows", Integer.toString(numRows)); query.add("wt", "json"); diff --git a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml index 4192759f158..d6635c05e9b 100644 --- a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml +++ b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml @@ -391,7 +391,7 @@ /> --> - <!-- Feature Values Cache + <!-- Feature Vector Cache Cache used by the Learning To Rank (LTR) module. @@ -401,12 +401,11 @@ https://solr.apache.org/guide/solr/latest/query-guide/learning-to-rank.html --> - <cache enable="${solr.ltr.enabled:false}" name="QUERY_DOC_FV" + <featureVectorCache enable="${solr.ltr.enabled:false}" class="solr.CaffeineCache" size="4096" initialSize="2048" - autowarmCount="4096" - regenerator="solr.search.NoOpRegenerator" /> + autowarmCount="0" /> <!-- Custom Cache @@ -1273,6 +1272,5 @@ via parameters. The below configuration supports hl.method=original and fastVec https://solr.apache.org/guide/solr/latest/query-guide/learning-to-rank.html --> <transformer enable="${solr.ltr.enabled:false}" name="features" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory"> - <str name="fvCacheName">QUERY_DOC_FV</str> </transformer> </config> diff --git a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc index 22e3bcde95e..cf13add48d8 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc @@ -131,18 +131,18 @@ See xref:configuration-guide:solr-modules.adoc[Solr Module] for more details. <queryParser name="ltr" class="org.apache.solr.ltr.search.LTRQParserPlugin"/> ---- -* Configuration of the feature values cache. +* Configuration of the feature vector cache which is used for both reranking and feature logging. +This needs to be added in the `<query>` section as follows. + [source,xml] ---- -<cache name="QUERY_DOC_FV" - class="solr.search.CaffeineCache" - size="4096" - initialSize="2048" - autowarmCount="4096" - regenerator="solr.search.NoOpRegenerator" /> +<featureVectorCache class="solr.CaffeineCache" size="4096" initialSize="2048" autowarmCount="0"/> ---- +[NOTE] +The `featureVectorCache` key is computed using the Lucene Document ID (necessary for document-level features). +Since these IDs are transient, this cache does not support auto-warming. + * Declaration of the `[features]` transformer. + [source,xml]
