This is an automated email from the ASF dual-hosted git repository.
abenedetti pushed a commit to branch branch_10x
in repository https://gitbox.apache.org/repos/asf/solr.git
The following commit(s) were added to refs/heads/branch_10x by this push:
new 91906c2bff0 SOLR-16667: LTR Add feature vector caching for ranking
(#3433)
91906c2bff0 is described below
commit 91906c2bff0c0e7c592bdc5d784e594d2bcd7e0c
Author: Anna Ruggero <[email protected]>
AuthorDate: Tue Oct 21 10:36:57 2025 +0200
SOLR-16667: LTR Add feature vector caching for ranking (#3433)
by Anna and Alessandro
(cherry picked from commit aeb9063585869f2bce12e7a189cad9f57f1fd86a)
---
solr/CHANGES.txt | 4 +-
.../src/java/org/apache/solr/core/SolrConfig.java | 11 +-
.../org/apache/solr/search/SolrIndexSearcher.java | 11 +
solr/modules/ltr/build.gradle | 2 +
solr/modules/ltr/gradle.lockfile | 2 +-
.../java/org/apache/solr/ltr/CSVFeatureLogger.java | 11 +-
.../ltr/src/java/org/apache/solr/ltr/DocInfo.java | 10 +
.../java/org/apache/solr/ltr/FeatureLogger.java | 63 +---
.../src/java/org/apache/solr/ltr/LTRRescorer.java | 118 ++-----
.../java/org/apache/solr/ltr/LTRScoringQuery.java | 377 +++------------------
.../ltr/feature/extraction/FeatureExtractor.java | 129 +++++++
.../feature/extraction/MultiFeaturesExtractor.java | 65 ++++
.../feature/extraction/SingleFeatureExtractor.java | 66 ++++
.../extraction/package-info.java} | 26 +-
.../ltr/interleaving/LTRInterleavingRescorer.java | 36 +-
.../LTRFeatureLoggerTransformerFactory.java | 65 ++--
.../solr/ltr/scoring/FeatureTraversalScorer.java | 71 ++++
.../solr/ltr/scoring/MultiFeaturesScorer.java | 211 ++++++++++++
.../solr/ltr/scoring/SingleFeatureScorer.java | 143 ++++++++
.../{DocInfo.java => scoring/package-info.java} | 26 +-
.../featurevectorcache_features.json | 57 ++++
.../featurevectorcache_linear_model.json | 26 ++
.../conf/solrconfig-ltr-featurevectorcache.xml | 76 +++++
.../solr/collection1/conf/solrconfig-ltr.xml | 7 +-
.../collection1/conf/solrconfig-ltr_Th10_10.xml | 3 -
.../solr/collection1/conf/solrconfig-multiseg.xml | 3 -
.../apache/solr/ltr/TestFeatureVectorCache.java | 366 ++++++++++++++++++++
.../org/apache/solr/ltr/TestLTRScoringQuery.java | 2 +-
.../test/org/apache/solr/ltr/TestRerankBase.java | 6 +
.../solr/ltr/TestSelectiveWeightCreation.java | 9 +-
.../TestFeatureExtractionFromMultipleSegments.java | 3 +-
.../conf/solrconfig.xml | 8 +-
.../query-guide/pages/learning-to-rank.adoc | 14 +-
33 files changed, 1437 insertions(+), 590 deletions(-)
diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index 803018864b1..cb631b65b0e 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -58,6 +58,8 @@ New Features
* SOLR-17813: Add support for SeededKnnVectorQuery (Ilaria Petreti via
Alessandro Benedetti)
+* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero,
Alessandro Benedetti)
+
Improvements
---------------------
@@ -547,7 +549,7 @@ Bug Fixes
* SOLR-17726: MoreLikeThis to support copy-fields (Ilaria Petreti via
Alessandro Benedetti)
-* SOLR-16667: Fixed dense/sparse representation in LTR module. (Anna Ruggero,
Alessandro Benedetti)
+* SOLR-17760: Fixed dense/sparse representation in LTR module. (Anna Ruggero,
Alessandro Benedetti)
* SOLR-17800: Security Manager should handle symlink on /tmp (Kevin Risden)
diff --git a/solr/core/src/java/org/apache/solr/core/SolrConfig.java
b/solr/core/src/java/org/apache/solr/core/SolrConfig.java
index c9482838d4a..4be0350efd3 100644
--- a/solr/core/src/java/org/apache/solr/core/SolrConfig.java
+++ b/solr/core/src/java/org/apache/solr/core/SolrConfig.java
@@ -301,6 +301,9 @@ public class SolrConfig implements MapSerializable {
queryResultCacheConfig =
CacheConfig.getConfig(
this, get("query").get("queryResultCache"),
"query/queryResultCache");
+ featureVectorCacheConfig =
+ CacheConfig.getConfig(
+ this, get("query").get("featureVectorCache"),
"query/featureVectorCache");
documentCacheConfig =
CacheConfig.getConfig(this, get("query").get("documentCache"),
"query/documentCache");
CacheConfig conf =
@@ -662,6 +665,7 @@ public class SolrConfig implements MapSerializable {
public final CacheConfig queryResultCacheConfig;
public final CacheConfig documentCacheConfig;
public final CacheConfig fieldValueCacheConfig;
+ public final CacheConfig featureVectorCacheConfig;
public final Map<String, CacheConfig> userCacheConfigs;
// SolrIndexSearcher - more...
public final boolean useFilterForSortedQuery;
@@ -998,7 +1002,12 @@ public class SolrConfig implements MapSerializable {
}
addCacheConfig(
- m, filterCacheConfig, queryResultCacheConfig, documentCacheConfig,
fieldValueCacheConfig);
+ m,
+ filterCacheConfig,
+ queryResultCacheConfig,
+ documentCacheConfig,
+ fieldValueCacheConfig,
+ featureVectorCacheConfig);
m = new LinkedHashMap<>();
result.put("requestDispatcher", m);
if (httpCachingConfig != null) m.put("httpCaching", httpCachingConfig);
diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java
b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java
index 590cca8916e..15cf7ecc431 100644
--- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java
+++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java
@@ -163,6 +163,7 @@ public class SolrIndexSearcher extends IndexSearcher
implements Closeable, SolrI
private final SolrCache<Query, DocSet> filterCache;
private final SolrCache<QueryResultKey, DocList> queryResultCache;
private final SolrCache<String, UnInvertedField> fieldValueCache;
+ private final SolrCache<Integer, float[]> featureVectorCache;
private final LongAdder fullSortCount = new LongAdder();
private final LongAdder skipSortCount = new LongAdder();
private final LongAdder liveDocsNaiveCacheHitCount = new LongAdder();
@@ -448,6 +449,11 @@ public class SolrIndexSearcher extends IndexSearcher
implements Closeable, SolrI
? null
: solrConfig.queryResultCacheConfig.newInstance();
if (queryResultCache != null) clist.add(queryResultCache);
+ featureVectorCache =
+ solrConfig.featureVectorCacheConfig == null
+ ? null
+ : solrConfig.featureVectorCacheConfig.newInstance();
+ if (featureVectorCache != null) clist.add(featureVectorCache);
SolrCache<Integer, Document> documentCache =
docFetcher.getDocumentCache();
if (documentCache != null) clist.add(documentCache);
@@ -469,6 +475,7 @@ public class SolrIndexSearcher extends IndexSearcher
implements Closeable, SolrI
this.filterCache = null;
this.queryResultCache = null;
this.fieldValueCache = null;
+ this.featureVectorCache = null;
this.cacheMap = NO_GENERIC_CACHES;
this.cacheList = NO_CACHES;
}
@@ -689,6 +696,10 @@ public class SolrIndexSearcher extends IndexSearcher
implements Closeable, SolrI
return filterCache;
}
+ public SolrCache<Integer, float[]> getFeatureVectorCache() {
+ return featureVectorCache;
+ }
+
//
// Set default regenerators on filter and query caches if they don't have any
//
diff --git a/solr/modules/ltr/build.gradle b/solr/modules/ltr/build.gradle
index 61e02bb645d..20582b1d06d 100644
--- a/solr/modules/ltr/build.gradle
+++ b/solr/modules/ltr/build.gradle
@@ -55,6 +55,8 @@ dependencies {
testImplementation libs.junit.junit
testImplementation libs.hamcrest.hamcrest
+ testImplementation libs.prometheus.metrics.model
+
testImplementation libs.commonsio.commonsio
}
diff --git a/solr/modules/ltr/gradle.lockfile b/solr/modules/ltr/gradle.lockfile
index de00f483d4d..d4364c51501 100644
--- a/solr/modules/ltr/gradle.lockfile
+++ b/solr/modules/ltr/gradle.lockfile
@@ -65,7 +65,7 @@
io.opentelemetry:opentelemetry-sdk-metrics:1.53.0=jarValidation,runtimeClasspath
io.opentelemetry:opentelemetry-sdk-trace:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
io.opentelemetry:opentelemetry-sdk:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
io.prometheus:prometheus-metrics-exposition-formats:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
-io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
+io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath
io.sgr:s2-geometry-library-java:1.0.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
io.swagger.core.v3:swagger-annotations-jakarta:2.2.22=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath
jakarta.annotation:jakarta.annotation-api:2.1.1=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
index 22ddcb8724a..57a86a10e1c 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
@@ -23,21 +23,20 @@ public class CSVFeatureLogger extends FeatureLogger {
private final char keyValueSep;
private final char featureSep;
- public CSVFeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll)
{
- super(fvCacheName, f, logAll);
+ public CSVFeatureLogger(FeatureFormat f, Boolean logAll) {
+ super(f, logAll);
this.keyValueSep = DEFAULT_KEY_VALUE_SEPARATOR;
this.featureSep = DEFAULT_FEATURE_SEPARATOR;
}
- public CSVFeatureLogger(
- String fvCacheName, FeatureFormat f, Boolean logAll, char keyValueSep,
char featureSep) {
- super(fvCacheName, f, logAll);
+ public CSVFeatureLogger(FeatureFormat f, Boolean logAll, char keyValueSep,
char featureSep) {
+ super(f, logAll);
this.keyValueSep = keyValueSep;
this.featureSep = featureSep;
}
@Override
- public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
+ public String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo)
{
// Allocate the buffer to a size based on the number of features instead
of the
// default 16. You need space for the name, value, and two separators per
feature,
// but not all the features are expected to fire, so this is just a naive
estimate.
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
index e454d90acc2..ee82bb41df7 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
@@ -22,6 +22,8 @@ public class DocInfo extends HashMap<String, Object> {
// Name of key used to store the original score of a doc
private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
+ // Name of key used to store the original id of a doc
+ private static final String ORIGINAL_DOC_ID = "ORIGINAL_DOC_ID";
public DocInfo() {
super();
@@ -38,4 +40,12 @@ public class DocInfo extends HashMap<String, Object> {
public boolean hasOriginalDocScore() {
return containsKey(ORIGINAL_DOC_SCORE);
}
+
+ public void setOriginalDocId(int docId) {
+ put(ORIGINAL_DOC_ID, docId);
+ }
+
+ public int getOriginalDocId() {
+ return (int) get(ORIGINAL_DOC_ID);
+ }
}
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
index 9be531c1ef3..54d308b665e 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
@@ -16,16 +16,10 @@
*/
package org.apache.solr.ltr;
-import org.apache.solr.search.SolrIndexSearcher;
-
/**
* FeatureLogger can be registered in a model and provide a strategy for
logging the feature values.
*/
public abstract class FeatureLogger {
-
- /** the name of the cache using for storing the feature value */
- private final String fvCacheName;
-
public enum FeatureFormat {
DENSE,
SPARSE
@@ -35,54 +29,15 @@ public abstract class FeatureLogger {
protected Boolean logAll;
- protected FeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll)
{
- this.fvCacheName = fvCacheName;
+ protected boolean logFeatures;
+
+ protected FeatureLogger(FeatureFormat f, Boolean logAll) {
this.featureFormat = f;
this.logAll = logAll;
+ this.logFeatures = false;
}
- /**
- * Log will be called every time that the model generates the feature values
for a document and a
- * query.
- *
- * @param docid Solr document id whose features we are saving
- * @param featuresInfo List of all the {@link LTRScoringQuery.FeatureInfo}
objects which contain
- * name and value for all the features triggered by the result set
- * @return true if the logger successfully logged the features, false
otherwise.
- */
- public boolean log(
- int docid,
- LTRScoringQuery scoringQuery,
- SolrIndexSearcher searcher,
- LTRScoringQuery.FeatureInfo[] featuresInfo) {
- final String featureVector = makeFeatureVector(featuresInfo);
- if (featureVector == null) {
- return false;
- }
-
- if (null == searcher.cacheInsert(fvCacheName, fvCacheKey(scoringQuery,
docid), featureVector)) {
- return false;
- }
-
- return true;
- }
-
- public abstract String makeFeatureVector(LTRScoringQuery.FeatureInfo[]
featuresInfo);
-
- private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) {
- return scoringQuery.hashCode() + (31 * docid);
- }
-
- /**
- * populate the document with its feature vector
- *
- * @param docid Solr document id
- * @return String representation of the list of features calculated for docid
- */
- public String getFeatureVector(
- int docid, LTRScoringQuery scoringQuery, SolrIndexSearcher searcher) {
- return (String) searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery,
docid));
- }
+ public abstract String printFeatureVector(LTRScoringQuery.FeatureInfo[]
featuresInfo);
public Boolean isLoggingAll() {
return logAll;
@@ -91,4 +46,12 @@ public abstract class FeatureLogger {
public void setLogAll(Boolean logAll) {
this.logAll = logAll;
}
+
+ public void setLogFeatures(boolean logFeatures) {
+ this.logFeatures = logFeatures;
+ }
+
+ public boolean isLogFeatures() {
+ return logFeatures;
+ }
}
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
index a21c107438c..0cd0258eb52 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
@@ -33,7 +33,6 @@ import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.IncompleteRerankingException;
import org.apache.solr.search.QueryLimits;
-import org.apache.solr.search.SolrIndexSearcher;
/**
* Implements the rescoring logic. The top documents returned by solr with
their original scores,
@@ -114,31 +113,31 @@ public class LTRRescorer extends Rescorer {
*
* @param searcher current IndexSearcher
* @param firstPassTopDocs documents to rerank;
- * @param topN documents to return;
+ * @param docsToRerank documents to return;
*/
@Override
- public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int
topN)
+ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int
docsToRerank)
throws IOException {
- if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
+ if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
return firstPassTopDocs;
}
final ScoreDoc[] firstPassResults =
getFirstPassDocsRanked(firstPassTopDocs);
- topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value()));
+ docsToRerank = Math.toIntExact(Math.min(docsToRerank,
firstPassTopDocs.totalHits.value()));
- final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);
+ final ScoreDoc[] reranked = rerank(searcher, docsToRerank,
firstPassResults);
return new TopDocs(firstPassTopDocs.totalHits, reranked);
}
- private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[]
firstPassResults)
+ private ScoreDoc[] rerank(IndexSearcher searcher, int docsToRerank,
ScoreDoc[] firstPassResults)
throws IOException {
- final ScoreDoc[] reranked = new ScoreDoc[topN];
+ final ScoreDoc[] reranked = new ScoreDoc[docsToRerank];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
final LTRScoringQuery.ModelWeight modelWeight =
(LTRScoringQuery.ModelWeight)
searcher.createWeight(searcher.rewrite(scoringQuery),
ScoreMode.COMPLETE, 1);
- scoreFeatures(searcher, topN, modelWeight, firstPassResults, leaves,
reranked);
+ scoreFeatures(docsToRerank, modelWeight, firstPassResults, leaves,
reranked);
// Must sort all documents that we reranked, and then select the top
Arrays.sort(reranked, scoreComparator);
return reranked;
@@ -153,8 +152,7 @@ public class LTRRescorer extends Rescorer {
}
public void scoreFeatures(
- IndexSearcher indexSearcher,
- int topN,
+ int docsToRerank,
LTRScoringQuery.ModelWeight modelWeight,
ScoreDoc[] hits,
List<LeafReaderContext> leaves,
@@ -166,13 +164,12 @@ public class LTRRescorer extends Rescorer {
int docBase = 0;
LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
- int hitUpto = 0;
+ int hitPosition = 0;
- while (hitUpto < hits.length) {
- final ScoreDoc hit = hits[hitUpto];
- final int docID = hit.doc;
+ while (hitPosition < hits.length) {
+ final ScoreDoc hit = hits[hitPosition];
LeafReaderContext readerContext = null;
- while (docID >= endDoc) {
+ while (hit.doc >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
@@ -182,41 +179,17 @@ public class LTRRescorer extends Rescorer {
docBase = readerContext.docBase;
scorer = modelWeight.modelScorer(readerContext);
}
- if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer,
reranked)) {
- logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery);
- }
- hitUpto++;
+ scoreSingleHit(docsToRerank, docBase, hitPosition, hit, scorer,
reranked);
+ hitPosition++;
}
}
- /**
- * Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc,
int,
- * org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])}
method indicated that
- * the document's feature info should be logged.
- */
- protected static void logSingleHit(
- IndexSearcher indexSearcher,
- LTRScoringQuery.ModelWeight modelWeight,
- int docid,
- LTRScoringQuery scoringQuery) {
- final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
- if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
- featureLogger.log(
- docid, scoringQuery, (SolrIndexSearcher) indexSearcher,
modelWeight.getFeaturesInfo());
- }
- }
-
- /**
- * Scores a single document and returns true if the document's feature info
should be logged via
- * the {@link #logSingleHit(IndexSearcher,
org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int,
- * LTRScoringQuery)} method. Feature info logging is only necessary for the
topN documents.
- */
- protected static boolean scoreSingleHit(
- int topN,
+ /** Scores a single document. */
+ protected void scoreSingleHit(
+ int docsToRerank,
int docBase,
- int hitUpto,
+ int hitPosition,
ScoreDoc hit,
- int docID,
LTRScoringQuery.ModelWeight.ModelScorer scorer,
ScoreDoc[] reranked)
throws IOException {
@@ -228,13 +201,12 @@ public class LTRRescorer extends Rescorer {
* needs to compute a potentially non-zero score from blank features.
*/
assert (scorer != null);
- final int targetDoc = docID - docBase;
+ final int targetDoc = hit.doc - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);
- boolean logHit = false;
-
scorer.getDocInfo().setOriginalDocScore(hit.score);
+ scorer.getDocInfo().setOriginalDocId(hit.doc);
hit.score = scorer.score();
if (QueryLimits.getCurrentLimits()
.maybeExitWithPartialResults(
@@ -243,28 +215,21 @@ public class LTRRescorer extends Rescorer {
+ " If partial results are tolerated the reranking got
reverted and all documents preserved their original score and ranking.")) {
throw new IncompleteRerankingException();
}
- if (hitUpto < topN) {
- reranked[hitUpto] = hit;
- // if the heap is not full, maybe I want to log the features for this
- // document
- logHit = true;
- } else if (hitUpto == topN) {
+ if (hitPosition < docsToRerank) {
+ reranked[hitPosition] = hit;
+ } else if (hitPosition == docsToRerank) {
// collected topN document, I create the heap
- heapify(reranked, topN);
+ heapify(reranked, docsToRerank);
}
- if (hitUpto >= topN) {
- // once that heap is ready, if the score of this document is lower that
- // the minimum
- // i don't want to log the feature. Otherwise I replace it with the
- // minimum and fix the
- // heap.
+ if (hitPosition >= docsToRerank) {
+ // once that heap is ready, if the score of this document is greater that
+ // the minimum I replace it with the
+ // minimum and fix the heap.
if (hit.score > reranked[0].score) {
reranked[0] = hit;
- heapAdjust(reranked, topN, 0);
- logHit = true;
+ heapAdjust(reranked, docsToRerank, 0);
}
}
- return logHit;
}
@Override
@@ -289,27 +254,4 @@ public class LTRRescorer extends Rescorer {
}
return rankingWeight.explain(context, deBasedDoc);
}
-
- public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(
- LTRScoringQuery.ModelWeight modelWeight,
- int docid,
- Float originalDocScore,
- List<LeafReaderContext> leafContexts)
- throws IOException {
- final int n = ReaderUtil.subIndex(docid, leafContexts);
- final LeafReaderContext atomicContext = leafContexts.get(n);
- final int deBasedDoc = docid - atomicContext.docBase;
- final LTRScoringQuery.ModelWeight.ModelScorer r =
modelWeight.modelScorer(atomicContext);
- if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) {
- return new LTRScoringQuery.FeatureInfo[0];
- } else {
- if (originalDocScore != null) {
- // If results have not been reranked, the score passed in is the
original query's
- // score, which some features can use instead of recalculating it
- r.getDocInfo().setOriginalDocScore(originalDocScore);
- }
- r.score();
- return modelWeight.getFeaturesInfo();
- }
- }
}
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
index 85b33fc3ebd..22a9a8e5dfb 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
@@ -30,8 +30,6 @@ import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.Semaphore;
import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.search.DisiPriorityQueue;
-import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
@@ -45,6 +43,9 @@ import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.ltr.scoring.FeatureTraversalScorer;
+import org.apache.solr.ltr.scoring.MultiFeaturesScorer;
+import org.apache.solr.ltr.scoring.SingleFeatureScorer;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.util.SolrDefaultScorerSupplier;
import org.slf4j.Logger;
@@ -70,12 +71,14 @@ public class LTRScoringQuery extends Query implements
Accountable {
private FeatureLogger logger;
// Map of external parameters, such as query intent, that can be used by
// features
- private final Map<String, String[]> efi;
+ private Map<String, String[]> efi;
// Original solr query used to fetch matching documents
private Query originalQuery;
// Original solr request
private SolrQueryRequest request;
+ private Feature.FeatureWeight[] extractedFeatureWeights;
+
public LTRScoringQuery(LTRScoringModel ltrScoringModel) {
this(ltrScoringModel, Collections.<String, String[]>emptyMap(), null);
}
@@ -122,6 +125,10 @@ public class LTRScoringQuery extends Query implements
Accountable {
return efi;
}
+ public void setExternalFeatureInfo(Map<String, String[]> efi) {
+ this.efi = efi;
+ }
+
public void setRequest(SolrQueryRequest request) {
this.request = request;
}
@@ -130,6 +137,10 @@ public class LTRScoringQuery extends Query implements
Accountable {
return request;
}
+ public Feature.FeatureWeight[] getExtractedFeatureWeights() {
+ return extractedFeatureWeights;
+ }
+
@Override
public int hashCode() {
final int prime = 31;
@@ -207,8 +218,7 @@ public class LTRScoringQuery extends Query implements
Accountable {
} else {
features = modelFeatures;
}
- final Feature.FeatureWeight[] extractedFeatureWeights =
- new Feature.FeatureWeight[features.size()];
+ this.extractedFeatureWeights = new Feature.FeatureWeight[features.size()];
final Feature.FeatureWeight[] modelFeaturesWeights = new
Feature.FeatureWeight[modelFeatSize];
List<Feature.FeatureWeight> featureWeights = new
ArrayList<>(features.size());
@@ -232,7 +242,7 @@ public class LTRScoringQuery extends Query implements
Accountable {
modelFeaturesWeights[j++] = fw;
}
}
- return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights,
allFeatures.size());
+ return new ModelWeight(modelFeaturesWeights, allFeatures.size());
}
private void createWeights(
@@ -367,20 +377,9 @@ public class LTRScoringQuery extends Query implements
Accountable {
// features used for logging.
private final Feature.FeatureWeight[] modelFeatureWeights;
private final float[] modelFeatureValuesNormalized;
- private final Feature.FeatureWeight[] extractedFeatureWeights;
- // List of all the feature names, values - used for both scoring and
logging
- /*
- * What is the advantage of using a hashmap here instead of an array of
objects?
- * A set of arrays was used earlier and the elements were accessed
using the featureId.
- * With the updated logic to create weights selectively,
- * the number of elements in the array can be fewer than the total
number of features.
- * When [features] are not requested, only the model features are
extracted.
- * In this case, the indexing by featureId, fails. For this reason,
- * we need a map which holds just the features that were triggered by
the documents in the result set.
- *
- */
- private final FeatureInfo[] featuresInfo;
+ // Array of all the features in the feature store of reference
+ private final FeatureInfo[] allFeaturesInStore;
/*
* @param modelFeatureWeights
@@ -392,29 +391,25 @@ public class LTRScoringQuery extends Query implements
Accountable {
* @param allFeaturesSize
* - total number of feature in the feature store used by this model
*/
- public ModelWeight(
- Feature.FeatureWeight[] modelFeatureWeights,
- Feature.FeatureWeight[] extractedFeatureWeights,
- int allFeaturesSize) {
+ public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights, int
allFeaturesSize) {
super(LTRScoringQuery.this);
- this.extractedFeatureWeights = extractedFeatureWeights;
this.modelFeatureWeights = modelFeatureWeights;
this.modelFeatureValuesNormalized = new
float[modelFeatureWeights.length];
- this.featuresInfo = new FeatureInfo[allFeaturesSize];
- setFeaturesInfo();
+ this.allFeaturesInStore = new FeatureInfo[allFeaturesSize];
+ setFeaturesInfo(extractedFeatureWeights);
}
- private void setFeaturesInfo() {
+ private void setFeaturesInfo(Feature.FeatureWeight[]
extractedFeatureWeights) {
for (int i = 0; i < extractedFeatureWeights.length; ++i) {
String featName = extractedFeatureWeights[i].getName();
int featId = extractedFeatureWeights[i].getIndex();
float value = extractedFeatureWeights[i].getDefaultValue();
- featuresInfo[featId] = new FeatureInfo(featName, value, true);
+ allFeaturesInStore[featId] = new FeatureInfo(featName, value, true);
}
}
- public FeatureInfo[] getFeaturesInfo() {
- return featuresInfo;
+ public FeatureInfo[] getAllFeaturesInStore() {
+ return allFeaturesInStore;
}
// for test use
@@ -423,35 +418,29 @@ public class LTRScoringQuery extends Query implements
Accountable {
}
// for test use
- float[] getModelFeatureValuesNormalized() {
+ public float[] getModelFeatureValuesNormalized() {
return modelFeatureValuesNormalized;
}
- // for test use
- Feature.FeatureWeight[] getExtractedFeatureWeights() {
- return extractedFeatureWeights;
- }
-
/**
* Goes through all the stored feature values, and calculates the
normalized values for all the
* features that will be used for scoring. Then calculate and return the
model's score.
*/
- private float makeNormalizedFeaturesAndScore() {
+ public void normalizeFeatures() {
int pos = 0;
for (final Feature.FeatureWeight feature : modelFeatureWeights) {
final int featureId = feature.getIndex();
- FeatureInfo fInfo = featuresInfo[featureId];
+ FeatureInfo fInfo = allFeaturesInStore[featureId];
modelFeatureValuesNormalized[pos] = fInfo.getValue();
pos++;
}
ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized);
- return ltrScoringModel.score(modelFeatureValuesNormalized);
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws
IOException {
- final Explanation[] explanations = new
Explanation[this.featuresInfo.length];
+ final Explanation[] explanations = new
Explanation[this.allFeaturesInStore.length];
for (final Feature.FeatureWeight feature : extractedFeatureWeights) {
explanations[feature.getIndex()] = feature.explain(context, doc);
}
@@ -469,17 +458,6 @@ public class LTRScoringQuery extends Query implements
Accountable {
return ltrScoringModel.explain(context, doc, finalScore,
featureExplanations);
}
- protected void reset() {
- for (int i = 0; i < extractedFeatureWeights.length; ++i) {
- int featId = extractedFeatureWeights[i].getIndex();
- float value = extractedFeatureWeights[i].getDefaultValue();
- // need to set default value everytime as the default value is used in
'dense'
- // mode even if used=false
- featuresInfo[featId].setValue(value);
- featuresInfo[featId].setIsDefaultValue(true);
- }
- }
-
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws
IOException {
return new SolrDefaultScorerSupplier(modelScorer(context));
@@ -500,7 +478,7 @@ public class LTRScoringQuery extends Query implements
Accountable {
// score on the model for every document, since 0 features matching could
// return a
// non 0 score for a given model.
- ModelScorer mscorer = new ModelScorer(this, featureScorers);
+ ModelScorer mscorer = new ModelScorer(featureScorers);
return mscorer;
}
@@ -511,22 +489,40 @@ public class LTRScoringQuery extends Query implements
Accountable {
public class ModelScorer extends Scorer {
private final DocInfo docInfo;
- private final Scorer featureTraversalScorer;
+ private final FeatureTraversalScorer featureTraversalScorer;
public DocInfo getDocInfo() {
return docInfo;
}
- public ModelScorer(Weight weight,
List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
+ public ModelScorer(List<Feature.FeatureWeight.FeatureScorer>
featureScorers) {
docInfo = new DocInfo();
for (final Feature.FeatureWeight.FeatureScorer subScorer :
featureScorers) {
subScorer.setDocInfo(docInfo);
}
if (featureScorers.size() <= 1) {
// future enhancement: allow the use of dense features in other cases
- featureTraversalScorer = new DenseModelScorer(weight,
featureScorers);
+ featureTraversalScorer =
+ new SingleFeatureScorer(
+ ModelWeight.this,
+ request,
+ extractedFeatureWeights,
+ allFeaturesInStore,
+ ltrScoringModel,
+ efi,
+ featureScorers,
+ docInfo);
} else {
- featureTraversalScorer = new SparseModelScorer(weight,
featureScorers);
+ featureTraversalScorer =
+ new MultiFeaturesScorer(
+ ModelWeight.this,
+ request,
+ extractedFeatureWeights,
+ allFeaturesInStore,
+ ltrScoringModel,
+ efi,
+ featureScorers,
+ docInfo);
}
}
@@ -555,273 +551,8 @@ public class LTRScoringQuery extends Query implements
Accountable {
return featureTraversalScorer.iterator();
}
- private class SparseModelScorer extends Scorer {
- private final DisiPriorityQueue subScorers;
- private final List<DisiWrapper> wrappers;
- private final ScoringQuerySparseIterator itr;
-
- private int targetDoc = -1;
- private int activeDoc = -1;
-
- private SparseModelScorer(
- Weight unusedWeight, List<Feature.FeatureWeight.FeatureScorer>
featureScorers) {
- if (featureScorers.size() <= 1) {
- throw new IllegalArgumentException("There must be at least 2
subScorers");
- }
- subScorers = DisiPriorityQueue.ofMaxSize(featureScorers.size());
- wrappers = new ArrayList<>();
- for (final Scorer scorer : featureScorers) {
- final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */);
- subScorers.add(w);
- wrappers.add(w);
- }
-
- itr = new ScoringQuerySparseIterator(wrappers);
- }
-
- @Override
- public int docID() {
- return itr.docID();
- }
-
- @Override
- public float score() throws IOException {
- final DisiWrapper topList = subScorers.topList();
- // If target doc we wanted to advance to match the actual doc
- // the underlying features advanced to, perform the feature
- // calculations,
- // otherwise just continue with the model's scoring process with
empty
- // features.
- reset();
- if (activeDoc == targetDoc) {
- for (DisiWrapper w = topList; w != null; w = w.next) {
- final Feature.FeatureWeight.FeatureScorer subScorer =
- (Feature.FeatureWeight.FeatureScorer) w.scorer;
- Feature.FeatureWeight scFW = subScorer.getWeight();
- final int featureId = scFW.getIndex();
- featuresInfo[featureId].setValue(subScorer.score());
- if (featuresInfo[featureId].getValue() !=
scFW.getDefaultValue()) {
- featuresInfo[featureId].setIsDefaultValue(false);
- }
- }
- }
- return makeNormalizedFeaturesAndScore();
- }
-
- @Override
- public float getMaxScore(int upTo) throws IOException {
- return Float.POSITIVE_INFINITY;
- }
-
- @Override
- public DocIdSetIterator iterator() {
- return itr;
- }
-
- @Override
- public final Collection<ChildScorable> getChildren() {
- final ArrayList<ChildScorable> children = new ArrayList<>();
- for (final DisiWrapper scorer : subScorers) {
- children.add(new ChildScorable(scorer.scorer, "SHOULD"));
- }
- return children;
- }
-
- private class ScoringQuerySparseIterator extends DocIdSetIterator {
-
- public ScoringQuerySparseIterator(Collection<DisiWrapper> wrappers) {
- // Initialize all wrappers to start at -1
- for (DisiWrapper wrapper : wrappers) {
- wrapper.doc = -1;
- }
- }
-
- @Override
- public int docID() {
- // Return the target document ID (mimicking
DisjunctionDISIApproximation behavior)
- return targetDoc;
- }
-
- @Override
- public final int nextDoc() throws IOException {
- // Mimic DisjunctionDISIApproximation behavior
- if (targetDoc == -1) {
- // First call - initialize all iterators
- DisiWrapper top = subScorers.top();
- if (top != null && top.doc == -1) {
- // Need to advance all iterators to their first document
- DisiWrapper current = subScorers.top();
- while (current != null) {
- current.doc = current.iterator.nextDoc();
- current = subScorers.updateTop();
- }
- top = subScorers.top();
- activeDoc = top == null ? NO_MORE_DOCS : top.doc;
- }
- targetDoc = activeDoc;
- return targetDoc;
- }
-
- if (activeDoc == targetDoc) {
- // Advance the underlying disjunction
- DisiWrapper top = subScorers.top();
- if (top == null) {
- activeDoc = NO_MORE_DOCS;
- } else {
- // Advance the top iterator and rebalance the queue
- top.doc = top.iterator.nextDoc();
- top = subScorers.updateTop();
- activeDoc = top == null ? NO_MORE_DOCS : top.doc;
- }
- } else if (activeDoc < targetDoc) {
- // Need to catch up to targetDoc + 1
- activeDoc = advanceInternal(targetDoc + 1);
- }
- return ++targetDoc;
- }
-
- @Override
- public final int advance(int target) throws IOException {
- // Mimic DisjunctionDISIApproximation behavior
- if (activeDoc < target) {
- activeDoc = advanceInternal(target);
- }
- targetDoc = target;
- return targetDoc;
- }
-
- private int advanceInternal(int target) throws IOException {
- // Advance the underlying disjunction to the target
- DisiWrapper top;
- do {
- top = subScorers.top();
- if (top == null) {
- return NO_MORE_DOCS;
- }
- if (top.doc >= target) {
- return top.doc;
- }
- top.doc = top.iterator.advance(target);
- top = subScorers.updateTop();
- if (top == null) {
- return NO_MORE_DOCS;
- }
- } while (top.doc < target);
- return top.doc;
- }
-
- @Override
- public long cost() {
- // Calculate cost from all wrappers
- long cost = 0;
- for (DisiWrapper wrapper : wrappers) {
- cost += wrapper.iterator.cost();
- }
- return cost;
- }
- }
- }
-
- private class DenseModelScorer extends Scorer {
- private int activeDoc = -1; // The doc that our scorer's are actually
at
- private int targetDoc = -1; // The doc we were most recently told to
go to
- private int freq = -1;
- private final List<Feature.FeatureWeight.FeatureScorer> featureScorers;
-
- private DenseModelScorer(
- Weight unusedWeight, List<Feature.FeatureWeight.FeatureScorer>
featureScorers) {
- this.featureScorers = featureScorers;
- }
-
- @Override
- public int docID() {
- return targetDoc;
- }
-
- @Override
- public float score() throws IOException {
- reset();
- freq = 0;
- if (targetDoc == activeDoc) {
- for (final Scorer scorer : featureScorers) {
- if (scorer.docID() == activeDoc) {
- freq++;
- Feature.FeatureWeight.FeatureScorer featureScorer =
- (Feature.FeatureWeight.FeatureScorer) scorer;
- Feature.FeatureWeight scFW = featureScorer.getWeight();
- final int featureId = scFW.getIndex();
- featuresInfo[featureId].setValue(scorer.score());
- if (featuresInfo[featureId].getValue() !=
scFW.getDefaultValue()) {
- featuresInfo[featureId].setIsDefaultValue(false);
- }
- }
- }
- }
- return makeNormalizedFeaturesAndScore();
- }
-
- @Override
- public float getMaxScore(int upTo) throws IOException {
- return Float.POSITIVE_INFINITY;
- }
-
- @Override
- public final Collection<ChildScorable> getChildren() {
- final ArrayList<ChildScorable> children = new ArrayList<>();
- for (final Scorer scorer : featureScorers) {
- children.add(new ChildScorable(scorer, "SHOULD"));
- }
- return children;
- }
-
- @Override
- public DocIdSetIterator iterator() {
- return new DenseIterator();
- }
-
- private class DenseIterator extends DocIdSetIterator {
-
- @Override
- public int docID() {
- return targetDoc;
- }
-
- @Override
- public int nextDoc() throws IOException {
- if (activeDoc <= targetDoc) {
- activeDoc = NO_MORE_DOCS;
- for (final Scorer scorer : featureScorers) {
- if (scorer.docID() != NO_MORE_DOCS) {
- activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc());
- }
- }
- }
- return ++targetDoc;
- }
-
- @Override
- public int advance(int target) throws IOException {
- if (activeDoc < target) {
- activeDoc = NO_MORE_DOCS;
- for (final Scorer scorer : featureScorers) {
- if (scorer.docID() != NO_MORE_DOCS) {
- activeDoc = Math.min(activeDoc,
scorer.iterator().advance(target));
- }
- }
- }
- targetDoc = target;
- return target;
- }
-
- @Override
- public long cost() {
- long sum = 0;
- for (int i = 0; i < featureScorers.size(); i++) {
- sum += featureScorers.get(i).iterator().cost();
- }
- return sum;
- }
- }
+ public void fillFeaturesInfo() throws IOException {
+ featureTraversalScorer.fillFeaturesInfo();
}
}
}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java
new file mode 100644
index 00000000000..03911b8bbc4
--- /dev/null
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature.extraction;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import org.apache.solr.ltr.FeatureLogger;
+import org.apache.solr.ltr.LTRScoringQuery;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import
org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory;
+import org.apache.solr.ltr.scoring.FeatureTraversalScorer;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.search.SolrCache;
+
+/** The class used to extract features for LTR feature logging. */
+public abstract class FeatureExtractor {
+ protected final FeatureTraversalScorer traversalScorer;
+ SolrQueryRequest request;
+ Feature.FeatureWeight[] extractedFeatureWeights;
+ LTRScoringQuery.FeatureInfo[] allFeaturesInStore;
+ LTRScoringModel ltrScoringModel;
+ FeatureLogger logger;
+ Map<String, String[]> efi;
+
+ FeatureExtractor(
+ FeatureTraversalScorer traversalScorer,
+ SolrQueryRequest request,
+ Feature.FeatureWeight[] extractedFeatureWeights,
+ LTRScoringQuery.FeatureInfo[] allFeaturesInStore,
+ LTRScoringModel ltrScoringModel,
+ Map<String, String[]> efi) {
+ this.traversalScorer = traversalScorer;
+ this.request = request;
+ this.extractedFeatureWeights = extractedFeatureWeights;
+ this.allFeaturesInStore = allFeaturesInStore;
+ this.ltrScoringModel = ltrScoringModel;
+ this.efi = efi;
+ }
+
+ protected float[] initFeatureVector(LTRScoringQuery.FeatureInfo[]
featuresInfos) {
+ float[] featureVector = new float[featuresInfos.length];
+ for (int i = 0; i < featuresInfos.length; i++) {
+ if (featuresInfos[i] != null) {
+ featureVector[i] = featuresInfos[i].getValue();
+ }
+ }
+ return featureVector;
+ }
+
+ protected abstract float[] extractFeatureVector() throws IOException;
+
+ public void fillFeaturesInfo() throws IOException {
+ if (traversalScorer.getActiveDoc() == traversalScorer.getTargetDoc()) {
+ SolrCache<Integer, float[]> featureVectorCache = null;
+ float[] featureVector;
+
+ if (request != null) {
+ featureVectorCache = request.getSearcher().getFeatureVectorCache();
+ }
+ if (featureVectorCache != null) {
+ int fvCacheKey =
+
computeFeatureVectorCacheKey(traversalScorer.getDocInfo().getOriginalDocId());
+ featureVector = featureVectorCache.get(fvCacheKey);
+ if (featureVector == null) {
+ featureVector = extractFeatureVector();
+ featureVectorCache.put(fvCacheKey, featureVector);
+ }
+ } else {
+ featureVector = extractFeatureVector();
+ }
+
+ for (int i = 0; i < extractedFeatureWeights.length; i++) {
+ int featureId = extractedFeatureWeights[i].getIndex();
+ float featureValue = featureVector[featureId];
+ if (!Float.isNaN(featureValue)
+ && featureValue != extractedFeatureWeights[i].getDefaultValue()) {
+ allFeaturesInStore[featureId].setValue(featureValue);
+ allFeaturesInStore[featureId].setIsDefaultValue(false);
+ }
+ }
+ }
+ }
+
+ private int computeFeatureVectorCacheKey(int docId) {
+ int prime = 31;
+ int result = docId;
+ if (Objects.equals(
+ ltrScoringModel.getName(),
+ LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME)
+ || (logger != null && logger.isLogFeatures() &&
logger.isLoggingAll())) {
+ result = (prime * result) +
ltrScoringModel.getFeatureStoreName().hashCode();
+ } else {
+ result = (prime * result) + ltrScoringModel.getName().hashCode();
+ }
+ result = (prime * result) + addEfisHash(result, prime, efi);
+ return result;
+ }
+
+ private int addEfisHash(int result, int prime, Map<String, String[]> efi) {
+ if (efi != null) {
+ TreeMap<String, String[]> sorted = new TreeMap<>(efi);
+ for (final Map.Entry<String, String[]> entry : sorted.entrySet()) {
+ final String key = entry.getKey();
+ final String[] values = entry.getValue();
+ result = (prime * result) + key.hashCode();
+ result = (prime * result) + Arrays.hashCode(values);
+ }
+ }
+ return result;
+ }
+}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java
new file mode 100644
index 00000000000..1db6d161ede
--- /dev/null
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature.extraction;
+
+import java.io.IOException;
+import java.util.Map;
+import org.apache.lucene.search.DisiPriorityQueue;
+import org.apache.lucene.search.DisiWrapper;
+import org.apache.solr.ltr.LTRScoringQuery;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.ltr.scoring.FeatureTraversalScorer;
+import org.apache.solr.request.SolrQueryRequest;
+
+/** The class used to extract more than one feature for LTR feature logging. */
+public class MultiFeaturesExtractor extends FeatureExtractor {
+ DisiPriorityQueue subScorers;
+
+ public MultiFeaturesExtractor(
+ FeatureTraversalScorer multiFeaturesScorer,
+ SolrQueryRequest request,
+ Feature.FeatureWeight[] extractedFeatureWeights,
+ LTRScoringQuery.FeatureInfo[] allFeaturesInStore,
+ LTRScoringModel ltrScoringModel,
+ Map<String, String[]> efi,
+ DisiPriorityQueue subScorers) {
+ super(
+ multiFeaturesScorer,
+ request,
+ extractedFeatureWeights,
+ allFeaturesInStore,
+ ltrScoringModel,
+ efi);
+ this.subScorers = subScorers;
+ }
+
+ @Override
+ protected float[] extractFeatureVector() throws IOException {
+ final DisiWrapper topList = subScorers.topList();
+ float[] featureVector = initFeatureVector(allFeaturesInStore);
+ for (DisiWrapper w = topList; w != null; w = w.next) {
+ final Feature.FeatureWeight.FeatureScorer subScorer =
+ (Feature.FeatureWeight.FeatureScorer) w.scorer;
+ Feature.FeatureWeight feature = subScorer.getWeight();
+ final int featureId = feature.getIndex();
+ float featureValue = subScorer.score();
+ featureVector[featureId] = featureValue;
+ }
+ return featureVector;
+ }
+}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java
new file mode 100644
index 00000000000..5d3cf648afe
--- /dev/null
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature.extraction;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import org.apache.lucene.search.Scorer;
+import org.apache.solr.ltr.LTRScoringQuery;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.ltr.scoring.FeatureTraversalScorer;
+import org.apache.solr.request.SolrQueryRequest;
+
+/** The class used to extract a single feature for LTR feature logging. */
+public class SingleFeatureExtractor extends FeatureExtractor {
+ List<Feature.FeatureWeight.FeatureScorer> featureScorers;
+
+ public SingleFeatureExtractor(
+ FeatureTraversalScorer singleFeatureScorer,
+ SolrQueryRequest request,
+ Feature.FeatureWeight[] extractedFeatureWeights,
+ LTRScoringQuery.FeatureInfo[] allFeaturesInStore,
+ LTRScoringModel ltrScoringModel,
+ Map<String, String[]> efi,
+ List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
+ super(
+ singleFeatureScorer,
+ request,
+ extractedFeatureWeights,
+ allFeaturesInStore,
+ ltrScoringModel,
+ efi);
+ this.featureScorers = featureScorers;
+ }
+
+ @Override
+ protected float[] extractFeatureVector() throws IOException {
+ float[] featureVector = initFeatureVector(allFeaturesInStore);
+ for (final Scorer scorer : featureScorers) {
+ if (scorer.docID() == traversalScorer.getActiveDoc()) {
+ Feature.FeatureWeight.FeatureScorer featureScorer =
+ (Feature.FeatureWeight.FeatureScorer) scorer;
+ Feature.FeatureWeight scFW = featureScorer.getWeight();
+ final int featureId = scFW.getIndex();
+ float featureValue = scorer.score();
+ featureVector[featureId] = featureValue;
+ }
+ }
+ return featureVector;
+ }
+}
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java
similarity index 59%
copy from solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
copy to
solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java
index e454d90acc2..844fc0dcfe2 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java
@@ -14,28 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.solr.ltr;
-import java.util.HashMap;
-
-public class DocInfo extends HashMap<String, Object> {
-
- // Name of key used to store the original score of a doc
- private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
-
- public DocInfo() {
- super();
- }
-
- public void setOriginalDocScore(Float score) {
- put(ORIGINAL_DOC_SCORE, score);
- }
-
- public Float getOriginalDocScore() {
- return (Float) get(ORIGINAL_DOC_SCORE);
- }
-
- public boolean hasOriginalDocScore() {
- return containsKey(ORIGINAL_DOC_SCORE);
- }
-}
+/** Contains the logic to extract features. */
+package org.apache.solr.ltr.feature.extraction;
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java
index 78803afd933..8d1227056a3 100644
---
a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java
@@ -57,12 +57,12 @@ public class LTRInterleavingRescorer extends LTRRescorer {
*
* @param searcher current IndexSearcher
* @param firstPassTopDocs documents to rerank;
- * @param topN documents to return;
+ * @param docsToRerank documents to return;
*/
@Override
- public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int
topN)
+ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int
docsToRerank)
throws IOException {
- if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
+ if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
return firstPassTopDocs;
}
@@ -72,10 +72,10 @@ public class LTRInterleavingRescorer extends LTRRescorer {
System.arraycopy(
firstPassTopDocs.scoreDocs, 0, firstPassResults, 0,
firstPassTopDocs.scoreDocs.length);
}
- topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value()));
+ docsToRerank = Math.toIntExact(Math.min(docsToRerank,
firstPassTopDocs.totalHits.value()));
ScoreDoc[][] reRankedPerModel =
- rerank(searcher, topN, getFirstPassDocsRanked(firstPassTopDocs));
+ rerank(searcher, docsToRerank,
getFirstPassDocsRanked(firstPassTopDocs));
if (originalRankingIndex != null) {
reRankedPerModel[originalRankingIndex] = firstPassResults;
}
@@ -90,9 +90,9 @@ public class LTRInterleavingRescorer extends LTRRescorer {
return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
}
- private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[]
firstPassResults)
+ private ScoreDoc[][] rerank(IndexSearcher searcher, int docsToRerank,
ScoreDoc[] firstPassResults)
throws IOException {
- ScoreDoc[][] reRankedPerModel = new
ScoreDoc[rerankingQueries.length][topN];
+ ScoreDoc[][] reRankedPerModel = new
ScoreDoc[rerankingQueries.length][docsToRerank];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
LTRScoringQuery.ModelWeight[] modelWeights =
new LTRScoringQuery.ModelWeight[rerankingQueries.length];
@@ -103,7 +103,7 @@ public class LTRInterleavingRescorer extends LTRRescorer {
searcher.createWeight(searcher.rewrite(rerankingQueries[i]),
ScoreMode.COMPLETE, 1);
}
}
- scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves,
reRankedPerModel);
+ scoreFeatures(docsToRerank, modelWeights, firstPassResults, leaves,
reRankedPerModel);
for (int i = 0; i < rerankingQueries.length; i++) {
if (originalRankingIndex == null || originalRankingIndex != i) {
@@ -115,8 +115,7 @@ public class LTRInterleavingRescorer extends LTRRescorer {
}
public void scoreFeatures(
- IndexSearcher indexSearcher,
- int topN,
+ int docsToRerank,
LTRScoringQuery.ModelWeight[] modelWeights,
ScoreDoc[] hits,
List<LeafReaderContext> leaves,
@@ -126,14 +125,13 @@ public class LTRInterleavingRescorer extends LTRRescorer {
int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
- int hitUpto = 0;
+ int hitPosition = 0;
LTRScoringQuery.ModelWeight.ModelScorer[] scorers =
new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length];
- while (hitUpto < hits.length) {
- final ScoreDoc hit = hits[hitUpto];
- final int docID = hit.doc;
+ while (hitPosition < hits.length) {
+ final ScoreDoc hit = hits[hitPosition];
LeafReaderContext readerContext = null;
- while (docID >= endDoc) {
+ while (hit.doc >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
@@ -151,13 +149,11 @@ public class LTRInterleavingRescorer extends LTRRescorer {
for (int i = 0; i < rerankingQueries.length; i++) {
if (modelWeights[i] != null) {
final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score,
hit.shardIndex);
- if (scoreSingleHit(
- topN, docBase, hitUpto, hit_i, docID, scorers[i],
rerankedPerModel[i])) {
- logSingleHit(indexSearcher, modelWeights[i], hit_i.doc,
rerankingQueries[i]);
- }
+ scoreSingleHit(
+ docsToRerank, docBase, hitPosition, hit_i, scorers[i],
rerankedPerModel[i]);
}
}
- hitUpto++;
+ hitPosition++;
}
}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java
index 21ead475609..a85597bde08 100644
---
a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java
@@ -22,6 +22,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreMode;
import org.apache.solr.common.SolrDocument;
@@ -30,7 +31,6 @@ import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.ltr.CSVFeatureLogger;
import org.apache.solr.ltr.FeatureLogger;
-import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
@@ -75,11 +75,10 @@ public class LTRFeatureLoggerTransformerFactory extends
TransformerFactory {
// used inside fl to specify to log (all|model only) features
private static final String FV_LOG_ALL = "logAll";
- private static final String DEFAULT_LOGGING_MODEL_NAME = "logging-model";
+ public static final String DEFAULT_LOGGING_MODEL_NAME = "logging-model";
private static final boolean DEFAULT_NO_RERANKING_LOGGING_ALL = true;
- private String fvCacheName;
private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME;
private String defaultStore;
private FeatureLogger.FeatureFormat defaultFormat =
FeatureLogger.FeatureFormat.DENSE;
@@ -88,10 +87,6 @@ public class LTRFeatureLoggerTransformerFactory extends
TransformerFactory {
private LTRThreadModule threadManager = null;
- public void setFvCacheName(String fvCacheName) {
- this.fvCacheName = fvCacheName;
- }
-
public void setLoggingModelName(String loggingModelName) {
this.loggingModelName = loggingModelName;
}
@@ -161,11 +156,7 @@ public class LTRFeatureLoggerTransformerFactory extends
TransformerFactory {
} else {
format = this.defaultFormat;
}
- if (fvCacheName == null) {
- throw new IllegalArgumentException("a fvCacheName must be configured");
- }
- return new CSVFeatureLogger(
- fvCacheName, format, logAll, csvKeyValueDelimiter,
csvFeatureSeparator);
+ return new CSVFeatureLogger(format, logAll, csvKeyValueDelimiter,
csvFeatureSeparator);
}
class FeatureTransformer extends DocTransformer {
@@ -373,6 +364,12 @@ public class LTRFeatureLoggerTransformerFactory extends
TransformerFactory {
: rerankingQueries[i].getExternalFeatureInfo()),
threadManager);
}
+ } else {
+ for (int i = 0; i < rerankingQueries.length; i++) {
+ if (!transformerExternalFeatureInfo.isEmpty()) {
+
rerankingQueries[i].setExternalFeatureInfo(transformerExternalFeatureInfo);
+ }
+ }
}
}
}
@@ -410,6 +407,32 @@ public class LTRFeatureLoggerTransformerFactory extends
TransformerFactory {
implTransform(doc, docid, docInfo);
}
+ private static LTRScoringQuery.FeatureInfo[] extractFeatures(
+ FeatureLogger logger,
+ LTRScoringQuery.ModelWeight modelWeight,
+ int docid,
+ Float originalDocScore,
+ List<LeafReaderContext> leafContexts)
+ throws IOException {
+ final int n = ReaderUtil.subIndex(docid, leafContexts);
+ final LeafReaderContext atomicContext = leafContexts.get(n);
+ final int deBasedDoc = docid - atomicContext.docBase;
+ final LTRScoringQuery.ModelWeight.ModelScorer r =
modelWeight.modelScorer(atomicContext);
+ r.getDocInfo().setOriginalDocId(docid);
+ if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) {
+ return new LTRScoringQuery.FeatureInfo[0];
+ } else {
+ if (originalDocScore != null) {
+ // If results have not been reranked, the score passed in is the
original query's
+ // score, which some features can use instead of recalculating it
+ r.getDocInfo().setOriginalDocScore(originalDocScore);
+ }
+ r.fillFeaturesInfo();
+ logger.setLogFeatures(true);
+ return modelWeight.getAllFeaturesInStore();
+ }
+ }
+
private void implTransform(SolrDocument doc, int docid, DocIterationInfo
docInfo)
throws IOException {
LTRScoringQuery rerankingQuery = rerankingQueries[0];
@@ -423,16 +446,14 @@ public class LTRFeatureLoggerTransformerFactory extends
TransformerFactory {
}
}
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) ||
hasExplicitFeatureStore) {
- Object featureVector = featureLogger.getFeatureVector(docid,
rerankingQuery, searcher);
- if (featureVector == null) { // FV for this document was not in the
cache
- featureVector =
- featureLogger.makeFeatureVector(
- LTRRescorer.extractFeaturesInfo(
- rerankingModelWeight,
- docid,
- (!docsWereReranked && docsHaveScores) ? docInfo.score()
: null,
- leafContexts));
- }
+ LTRScoringQuery.FeatureInfo[] featuresInfo =
+ extractFeatures(
+ featureLogger,
+ rerankingModelWeight,
+ docid,
+ (!docsWereReranked && docsHaveScores) ? docInfo.score() : null,
+ leafContexts);
+ String featureVector = featureLogger.printFeatureVector(featuresInfo);
doc.addField(name, featureVector);
}
}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java
new file mode 100644
index 00000000000..2c92ff9e15e
--- /dev/null
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.scoring;
+
+import java.io.IOException;
+import org.apache.lucene.search.Scorer;
+import org.apache.solr.ltr.DocInfo;
+import org.apache.solr.ltr.LTRScoringQuery;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.extraction.FeatureExtractor;
+import org.apache.solr.ltr.model.LTRScoringModel;
+
+/** This class is responsible for extracting features and using them to score
the document. */
+public abstract class FeatureTraversalScorer extends Scorer {
+ protected FeatureExtractor featureExtractor;
+ protected LTRScoringQuery.FeatureInfo[] allFeaturesInStore;
+ protected LTRScoringModel ltrScoringModel;
+ protected Feature.FeatureWeight[] extractedFeatureWeights;
+ protected LTRScoringQuery.ModelWeight modelWeight;
+
+ public abstract int getActiveDoc();
+
+ public abstract int getTargetDoc();
+
+ public abstract DocInfo getDocInfo();
+
+ public void reset() {
+ for (int i = 0; i < extractedFeatureWeights.length; ++i) {
+ int featId = extractedFeatureWeights[i].getIndex();
+ float value = extractedFeatureWeights[i].getDefaultValue();
+ // need to set default value everytime as the default value is used in
'dense'
+ // mode even if used=false
+ allFeaturesInStore[featId].setValue(value);
+ allFeaturesInStore[featId].setIsDefaultValue(true);
+ }
+ }
+
+ public void fillFeaturesInfo() throws IOException {
+ // Initialize features to their default values and set isDefaultValue to
true.
+ reset();
+ featureExtractor.fillFeaturesInfo();
+ }
+
+ @Override
+ public float score() throws IOException {
+ // Initialize features to their default values and set isDefaultValue to
true.
+ reset();
+ featureExtractor.fillFeaturesInfo();
+ modelWeight.normalizeFeatures();
+ return
ltrScoringModel.score(modelWeight.getModelFeatureValuesNormalized());
+ }
+
+ @Override
+ public float getMaxScore(int upTo) {
+ return Float.POSITIVE_INFINITY;
+ }
+}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java
new file mode 100644
index 00000000000..2ac44c19be8
--- /dev/null
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.scoring;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import org.apache.lucene.search.DisiPriorityQueue;
+import org.apache.lucene.search.DisiWrapper;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Scorer;
+import org.apache.solr.ltr.DocInfo;
+import org.apache.solr.ltr.LTRScoringQuery;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.request.SolrQueryRequest;
+
+/**
+ * This class is responsible for extracting more than one feature and using
them to score the
+ * document.
+ */
+public class MultiFeaturesScorer extends FeatureTraversalScorer {
+ private int targetDoc = -1;
+ private int activeDoc = -1;
+ protected DocInfo docInfo;
+ private final DisiPriorityQueue subScorers;
+ private final List<DisiWrapper> wrappers;
+ private final MultiFeaturesIterator multiFeaturesIteratorIterator;
+
+ public MultiFeaturesScorer(
+ LTRScoringQuery.ModelWeight modelWeight,
+ SolrQueryRequest request,
+ Feature.FeatureWeight[] extractedFeatureWeights,
+ LTRScoringQuery.FeatureInfo[] allFeaturesInStore,
+ LTRScoringModel ltrScoringModel,
+ Map<String, String[]> efi,
+ List<Feature.FeatureWeight.FeatureScorer> featureScorers,
+ DocInfo docInfo) {
+ if (featureScorers.size() <= 1) {
+ throw new IllegalArgumentException("There must be at least 2
subScorers");
+ }
+ subScorers = DisiPriorityQueue.ofMaxSize(featureScorers.size());
+ wrappers = new ArrayList<>();
+ for (final Scorer scorer : featureScorers) {
+ final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */);
+ subScorers.add(w);
+ wrappers.add(w);
+ }
+
+ multiFeaturesIteratorIterator = new MultiFeaturesIterator(wrappers);
+ this.extractedFeatureWeights = extractedFeatureWeights;
+ this.allFeaturesInStore = allFeaturesInStore;
+ this.ltrScoringModel = ltrScoringModel;
+ this.featureExtractor =
+ new MultiFeaturesExtractor(
+ this,
+ request,
+ extractedFeatureWeights,
+ allFeaturesInStore,
+ ltrScoringModel,
+ efi,
+ subScorers);
+ this.modelWeight = modelWeight;
+ this.docInfo = docInfo;
+ }
+
+ @Override
+ public int getActiveDoc() {
+ return activeDoc;
+ }
+
+ @Override
+ public int getTargetDoc() {
+ return targetDoc;
+ }
+
+ @Override
+ public DocInfo getDocInfo() {
+ return docInfo;
+ }
+
+ @Override
+ public int docID() {
+ return multiFeaturesIteratorIterator.docID();
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return multiFeaturesIteratorIterator;
+ }
+
+ @Override
+ public final Collection<ChildScorable> getChildren() {
+ final ArrayList<ChildScorable> children = new ArrayList<>();
+ for (final DisiWrapper scorer : subScorers) {
+ children.add(new ChildScorable(scorer.scorer, "SHOULD"));
+ }
+ return children;
+ }
+
+ private class MultiFeaturesIterator extends DocIdSetIterator {
+
+ public MultiFeaturesIterator(Collection<DisiWrapper> wrappers) {
+ // Initialize all wrappers to start at -1
+ for (DisiWrapper wrapper : wrappers) {
+ wrapper.doc = -1;
+ }
+ }
+
+ @Override
+ public int docID() {
+ // Return the target document ID (mimicking DisjunctionDISIApproximation
behavior)
+ return targetDoc;
+ }
+
+ @Override
+ public final int nextDoc() throws IOException {
+ // Mimic DisjunctionDISIApproximation behavior
+ if (targetDoc == -1) {
+ // First call - initialize all iterators
+ DisiWrapper top = subScorers.top();
+ if (top != null && top.doc == -1) {
+ // Need to advance all iterators to their first document
+ DisiWrapper current = subScorers.top();
+ while (current != null) {
+ current.doc = current.iterator.nextDoc();
+ current = subScorers.updateTop();
+ }
+ top = subScorers.top();
+ activeDoc = top == null ? NO_MORE_DOCS : top.doc;
+ }
+ targetDoc = activeDoc;
+ return targetDoc;
+ }
+
+ if (activeDoc == targetDoc) {
+ // Advance the underlying disjunction
+ DisiWrapper top = subScorers.top();
+ if (top == null) {
+ activeDoc = NO_MORE_DOCS;
+ } else {
+ // Advance the top iterator and rebalance the queue
+ top.doc = top.iterator.nextDoc();
+ top = subScorers.updateTop();
+ activeDoc = top == null ? NO_MORE_DOCS : top.doc;
+ }
+ } else if (activeDoc < targetDoc) {
+ // Need to catch up to targetDoc + 1
+ activeDoc = advanceInternal(targetDoc + 1);
+ }
+ return ++targetDoc;
+ }
+
+ @Override
+ public final int advance(int target) throws IOException {
+ // Mimic DisjunctionDISIApproximation behavior
+ if (activeDoc < target) {
+ activeDoc = advanceInternal(target);
+ }
+ targetDoc = target;
+ return targetDoc;
+ }
+
+ private int advanceInternal(int target) throws IOException {
+ // Advance the underlying disjunction to the target
+ DisiWrapper top;
+ do {
+ top = subScorers.top();
+ if (top == null) {
+ return NO_MORE_DOCS;
+ }
+ if (top.doc >= target) {
+ return top.doc;
+ }
+ top.doc = top.iterator.advance(target);
+ top = subScorers.updateTop();
+ if (top == null) {
+ return NO_MORE_DOCS;
+ }
+ } while (top.doc < target);
+ return top.doc;
+ }
+
+ @Override
+ public long cost() {
+ // Calculate cost from all wrappers
+ long cost = 0;
+ for (DisiWrapper wrapper : wrappers) {
+ cost += wrapper.iterator.cost();
+ }
+ return cost;
+ }
+ }
+}
diff --git
a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java
new file mode 100644
index 00000000000..6619856901a
--- /dev/null
+++
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.scoring;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Scorer;
+import org.apache.solr.ltr.DocInfo;
+import org.apache.solr.ltr.LTRScoringQuery;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.request.SolrQueryRequest;
+
+/** This class is responsible for extracting a single feature and using it to
score the document. */
+public class SingleFeatureScorer extends FeatureTraversalScorer {
+ private int targetDoc = -1;
+ private int activeDoc = -1;
+ protected DocInfo docInfo;
+ private final List<Feature.FeatureWeight.FeatureScorer> featureScorers;
+
+ public SingleFeatureScorer(
+ LTRScoringQuery.ModelWeight modelWeight,
+ SolrQueryRequest request,
+ Feature.FeatureWeight[] extractedFeatureWeights,
+ LTRScoringQuery.FeatureInfo[] allFeaturesInStore,
+ LTRScoringModel ltrScoringModel,
+ Map<String, String[]> efi,
+ List<Feature.FeatureWeight.FeatureScorer> featureScorers,
+ DocInfo docInfo) {
+ this.featureScorers = featureScorers;
+ this.extractedFeatureWeights = extractedFeatureWeights;
+ this.allFeaturesInStore = allFeaturesInStore;
+ this.ltrScoringModel = ltrScoringModel;
+ this.featureExtractor =
+ new SingleFeatureExtractor(
+ this,
+ request,
+ extractedFeatureWeights,
+ allFeaturesInStore,
+ ltrScoringModel,
+ efi,
+ featureScorers);
+ this.modelWeight = modelWeight;
+ this.docInfo = docInfo;
+ }
+
+ @Override
+ public int getActiveDoc() {
+ return activeDoc;
+ }
+
+ @Override
+ public int getTargetDoc() {
+ return targetDoc;
+ }
+
+ @Override
+ public DocInfo getDocInfo() {
+ return docInfo;
+ }
+
+ @Override
+ public int docID() {
+ return targetDoc;
+ }
+
+ @Override
+ public final Collection<ChildScorable> getChildren() {
+ final ArrayList<ChildScorable> children = new ArrayList<>();
+ for (final Scorer scorer : featureScorers) {
+ children.add(new ChildScorable(scorer, "SHOULD"));
+ }
+ return children;
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return new SingleFeatureIterator();
+ }
+
+ private class SingleFeatureIterator extends DocIdSetIterator {
+
+ @Override
+ public int docID() {
+ return targetDoc;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ if (activeDoc <= targetDoc) {
+ activeDoc = NO_MORE_DOCS;
+ for (final Scorer scorer : featureScorers) {
+ if (scorer.docID() != NO_MORE_DOCS) {
+ activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc());
+ }
+ }
+ }
+ return ++targetDoc;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ if (activeDoc < target) {
+ activeDoc = NO_MORE_DOCS;
+ for (final Scorer scorer : featureScorers) {
+ if (scorer.docID() != NO_MORE_DOCS) {
+ activeDoc = Math.min(activeDoc, scorer.iterator().advance(target));
+ }
+ }
+ }
+ targetDoc = target;
+ return target;
+ }
+
+ @Override
+ public long cost() {
+ long sum = 0;
+ for (int i = 0; i < featureScorers.size(); i++) {
+ sum += featureScorers.get(i).iterator().cost();
+ }
+ return sum;
+ }
+ }
+}
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java
similarity index 59%
copy from solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
copy to solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java
index e454d90acc2..54e25ba080b 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java
@@ -14,28 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.solr.ltr;
-import java.util.HashMap;
-
-public class DocInfo extends HashMap<String, Object> {
-
- // Name of key used to store the original score of a doc
- private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
-
- public DocInfo() {
- super();
- }
-
- public void setOriginalDocScore(Float score) {
- put(ORIGINAL_DOC_SCORE, score);
- }
-
- public Float getOriginalDocScore() {
- return (Float) get(ORIGINAL_DOC_SCORE);
- }
-
- public boolean hasOriginalDocScore() {
- return containsKey(ORIGINAL_DOC_SCORE);
- }
-}
+/** Contains the logic to extract features for scoring. */
+package org.apache.solr.ltr.scoring;
diff --git
a/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json
b/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json
new file mode 100644
index 00000000000..f65b0108bda
--- /dev/null
+++
b/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json
@@ -0,0 +1,57 @@
+[
+ {
+ "name": "value_feature_1",
+ "class": "org.apache.solr.ltr.feature.ValueFeature",
+ "params": {
+ "value": 1
+ }
+ },
+ {
+ "name": "value_feature_3",
+ "class": "org.apache.solr.ltr.feature.ValueFeature",
+ "params": {
+ "value": 3
+ }
+ },
+ {
+ "name" : "efi_feature",
+ "class":"org.apache.solr.ltr.feature.ValueFeature",
+ "params" : {
+ "value": "${efi_feature}"
+ }
+ },
+ {
+ "name" : "match_w1_title",
+ "class":"org.apache.solr.ltr.feature.SolrFeature",
+ "params" : {
+ "fq": [
+ "{!terms f=title}w1"
+ ]
+ }
+ },
+ {
+ "name": "popularity_value",
+ "class": "org.apache.solr.ltr.feature.FieldValueFeature",
+ "params": {
+ "field": "popularity"
+ }
+ },
+ {
+ "name" : "match_w1_title",
+ "class":"org.apache.solr.ltr.feature.SolrFeature",
+ "store": "store1",
+ "params" : {
+ "fq": [
+ "{!terms f=title}w1"
+ ]
+ }
+ },
+ {
+ "name": "value_feature_2",
+ "class": "org.apache.solr.ltr.feature.ValueFeature",
+ "store": "store1",
+ "params": {
+ "value": 2
+ }
+ }
+]
\ No newline at end of file
diff --git
a/solr/modules/ltr/src/test-files/modelExamples/featurevectorcache_linear_model.json
b/solr/modules/ltr/src/test-files/modelExamples/featurevectorcache_linear_model.json
new file mode 100644
index 00000000000..c05acac90e3
--- /dev/null
+++
b/solr/modules/ltr/src/test-files/modelExamples/featurevectorcache_linear_model.json
@@ -0,0 +1,26 @@
+{
+ "class": "org.apache.solr.ltr.model.LinearModel",
+ "name": "featurevectorcache_linear_model",
+ "features": [
+ {
+ "name": "value_feature_1"
+ },
+ {
+ "name": "efi_feature"
+ },
+ {
+ "name": "match_w1_title"
+ },
+ {
+ "name": "popularity_value"
+ }
+ ],
+ "params": {
+ "weights": {
+ "value_feature_1": 1,
+ "efi_feature": 0.2,
+ "match_w1_title": 0.5,
+ "popularity_value": 0.8
+ }
+ }
+}
\ No newline at end of file
diff --git
a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml
new file mode 100644
index 00000000000..f78d5996d3d
--- /dev/null
+++
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml
@@ -0,0 +1,76 @@
+<?xml version="1.0" ?>
+<!-- Licensed to the Apache Software Foundation (ASF) under one or more
contributor
+ license agreements. See the NOTICE file distributed with this work for
additional
+ information regarding copyright ownership. The ASF licenses this file to
+ You under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of
+ the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
+ by applicable law or agreed to in writing, software distributed under the
+ License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ OF ANY KIND, either express or implied. See the License for the specific
+ language governing permissions and limitations under the License. -->
+
+<config>
+ <luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
+ <dataDir>${solr.data.dir:}</dataDir>
+ <directoryFactory name="DirectoryFactory"
+
class="${solr.directoryFactory:solr.MockDirectoryFactory}" />
+
+ <schemaFactory class="ClassicIndexSchemaFactory" />
+
+ <requestDispatcher>
+ <requestParsers />
+ </requestDispatcher>
+
+ <!-- Query parser used to rerank top docs with a provided model -->
+ <queryParser name="ltr"
+ class="org.apache.solr.ltr.search.LTRQParserPlugin" />
+
+ <query>
+ <filterCache class="solr.CaffeineCache" size="4096" initialSize="2048"
autowarmCount="0" />
+ <featureVectorCache class="solr.CaffeineCache" size="4096"
initialSize="2048" autowarmCount="0" />
+ </query>
+
+ <!-- add a transformer that will encode the document features in the
response.
+ For each document the transformer will add the features as an extra field
+ in the response. The name of the field we will be the the name of the
transformer
+ enclosed between brackets (in this case [fv]). In order to get the feature
+ vector you will have to specify that you want the field (e.g.,
fl="*,[fv]) -->
+ <transformer name="fv"
class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
+ <str
name="defaultFormat">${solr.ltr.transformer.fv.defaultFormat:dense}</str>
+ </transformer>
+
+ <!-- add a transformer that will encode the model the interleaving process
chose the search result from.
+ For each document the transformer will add an extra field in the response
with the model picked.
+ The name of the field will be the the name of the transformer
+ enclosed between brackets (in this case [interleaving]).
+ In order to get the model chosen for the search result
+ you will have to specify that you want the field (e.g.,
fl="*,[interleaving]) -->
+ <transformer name="interleaving"
class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory">
+ </transformer>
+
+ <updateHandler class="solr.DirectUpdateHandler2">
+ <autoCommit>
+ <maxTime>15000</maxTime>
+ <openSearcher>false</openSearcher>
+ </autoCommit>
+ <autoSoftCommit>
+ <maxTime>1000</maxTime>
+ </autoSoftCommit>
+ <updateLog>
+ <str name="dir">${solr.data.dir:}</str>
+ </updateLog>
+ </updateHandler>
+
+ <requestHandler name="/update" class="solr.UpdateRequestHandler" />
+ <!-- Query request handler managing models and features -->
+ <requestHandler name="/query" class="solr.SearchHandler">
+ <lst name="defaults">
+ <str name="echoParams">explicit</str>
+ <str name="wt">json</str>
+ <str name="indent">true</str>
+ <str name="df">id</str>
+ </lst>
+ </requestHandler>
+
+</config>
\ No newline at end of file
diff --git
a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml
index c20ee2026f6..496755c9765 100644
--- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml
+++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml
@@ -29,8 +29,6 @@
<query>
<filterCache class="solr.CaffeineCache" size="4096"
initialSize="2048" autowarmCount="0" />
- <cache name="QUERY_DOC_FV" class="solr.search.CaffeineCache" size="4096"
- initialSize="2048" autowarmCount="4096"
regenerator="solr.search.NoOpRegenerator" />
</query>
<!-- add a transformer that will encode the document features in the response.
@@ -40,13 +38,12 @@
vector you will have to specify that you want the field (e.g., fl="*,[fv])
-->
<transformer name="fv"
class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
<str
name="defaultFormat">${solr.ltr.transformer.fv.defaultFormat:dense}</str>
- <str name="fvCacheName">QUERY_DOC_FV</str>
</transformer>
<!-- add a transformer that will encode the model the interleaving process
chose the search result from.
- For each document the transformer will add an extra field in the response
with the model picked.
+ For each document the transformer will add an extra field in the response
with the model picked.
The name of the field will be the the name of the transformer
- enclosed between brackets (in this case [interleaving]).
+ enclosed between brackets (in this case [interleaving]).
In order to get the model chosen for the search result
you will have to specify that you want the field (e.g.,
fl="*,[interleaving]) -->
<transformer name="interleaving"
class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory">
diff --git
a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml
index 37ae68a2580..32f8751a9f8 100644
---
a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml
+++
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml
@@ -33,8 +33,6 @@
<query>
<filterCache class="solr.CaffeineCache" size="4096"
initialSize="2048" autowarmCount="0" />
- <cache name="QUERY_DOC_FV" class="solr.search.CaffeineCache" size="4096"
- initialSize="2048" autowarmCount="4096"
regenerator="solr.search.NoOpRegenerator" />
</query>
<!-- add a transformer that will encode the document features in the response.
@@ -43,7 +41,6 @@
enclosed between brackets (in this case [fv]). In order to get the feature
vector you will have to specify that you want the field (e.g., fl="*,[fv])
-->
<transformer name="fv"
class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
- <str name="fvCacheName">QUERY_DOC_FV</str>
</transformer>
<updateHandler class="solr.DirectUpdateHandler2">
diff --git
a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml
index 911db9a9f55..04ee42b3834 100644
---
a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml
+++
b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml
@@ -28,8 +28,6 @@
<query>
<filterCache class="solr.CaffeineCache" size="4096"
initialSize="2048" autowarmCount="0" />
- <cache name="QUERY_DOC_FV" class="solr.search.CaffeineCache" size="4096"
- initialSize="2048" autowarmCount="4096"
regenerator="solr.search.NoOpRegenerator" />
</query>
<maxBufferedDocs>1</maxBufferedDocs>
@@ -43,7 +41,6 @@
enclosed between brackets (in this case [fv]). In order to get the feature
vector you will have to specify that you want the field (e.g., fl="*,[fv])
-->
<transformer name="features"
class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
- <str name="fvCacheName">QUERY_DOC_FV</str>
</transformer>
<updateHandler class="solr.DirectUpdateHandler2">
diff --git
a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java
b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java
new file mode 100644
index 00000000000..2674f567edd
--- /dev/null
+++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import io.prometheus.metrics.model.snapshots.CounterSnapshot;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.solr.client.solrj.SolrQuery;
+import org.apache.solr.core.SolrCore;
+import org.apache.solr.util.SolrMetricTestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestFeatureVectorCache extends TestRerankBase {
+ SolrCore core = null;
+ List<String> docs;
+
+ @Before
+ public void before() throws Exception {
+ setupFeatureVectorCacheTest(false);
+
+ this.docs = new ArrayList<>();
+ docs.add(adoc("id", "1", "title", "w2", "description", "w2", "popularity",
"2"));
+ docs.add(adoc("id", "2", "title", "w1", "description", "w1", "popularity",
"0"));
+ for (String doc : docs) {
+ assertU(doc);
+ }
+ assertU(commit());
+
+ loadFeatures("featurevectorcache_features.json");
+ loadModels("featurevectorcache_linear_model.json");
+
+ core =
solrClientTestRule.getCoreContainer().getCore(DEFAULT_TEST_CORENAME);
+ }
+
+ @After
+ public void after() throws Exception {
+ core.close();
+ aftertest();
+ }
+
+ private static CounterSnapshot.CounterDataPointSnapshot
getFeatureVectorCacheInserts(
+ SolrCore core) {
+ return SolrMetricTestUtils.getCacheSearcherOpsInserts(core,
"featureVectorCache");
+ }
+
+ private static double getFeatureVectorCacheLookups(SolrCore core) {
+ return SolrMetricTestUtils.getCacheSearcherTotalLookups(core,
"featureVectorCache");
+ }
+
+ private static CounterSnapshot.CounterDataPointSnapshot
getFeatureVectorCacheHits(SolrCore core) {
+ return SolrMetricTestUtils.getCacheSearcherOpsHits(core,
"featureVectorCache");
+ }
+
+ @Test
+ public void testFeatureVectorCache_loggingDefaultStoreNoReranking() throws
Exception {
+ final String docs0fv_dense_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1",
+ "1.0",
+ "value_feature_3",
+ "3.0",
+ "efi_feature",
+ "3.0",
+ "match_w1_title",
+ "0.0",
+ "popularity_value",
+ "2.0");
+ final String docs0fv_sparse_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1",
+ "1.0",
+ "value_feature_3",
+ "3.0",
+ "efi_feature",
+ "3.0",
+ "popularity_value",
+ "2.0");
+
+ final String docs0fv_default_csv =
+ chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv);
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "[fv efi.efi_feature=3]");
+
+ // No caching, we want to see lookups, insertions and no hits
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}");
+ assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(),
0);
+ assertEquals(docs.size(), getFeatureVectorCacheLookups(core), 0);
+ assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0);
+
+ query.add("sort", "popularity desc");
+ // Caching, we want to see hits
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}");
+ assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(),
0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0);
+ }
+
+ @Test
+ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws
Exception {
+ final String docs0fv_dense_csv =
+ FeatureLoggerTestUtils.toFeatureVector("match_w1_title", "0.0",
"value_feature_2", "2.0");
+ final String docs0fv_sparse_csv =
+ FeatureLoggerTestUtils.toFeatureVector("value_feature_2", "2.0");
+
+ final String docs0fv_default_csv =
+ chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv);
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "[fv store=store1 efi.efi_feature=3]");
+
+ // No caching, we want to see lookups, insertions and no hits
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}");
+ assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(),
0);
+ assertEquals(docs.size(), getFeatureVectorCacheLookups(core), 0);
+ assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0);
+
+ query.add("sort", "popularity desc");
+ // Caching, we want to see hits
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}");
+ assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(),
0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0);
+ }
+
+ @Test
+ public void
testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi()
+ throws Exception {
+ final String docs0fv_dense_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1",
+ "1.0",
+ "efi_feature",
+ "3.0",
+ "match_w1_title",
+ "0.0",
+ "popularity_value",
+ "2.0");
+ final String docs0fv_sparse_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1", "1.0", "efi_feature", "3.0",
"popularity_value", "2.0");
+
+ final String docs0fv_default_csv =
+ chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv);
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "id,score,fv:[fv efi.efi_feature=3]");
+ query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model
efi.efi_feature=4}");
+
+ // No caching, we want to see lookups, insertions and no hits since the
efis are different
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size() * 2,
getFeatureVectorCacheInserts(core).getValue(), 0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0);
+
+ query.add("sort", "popularity desc");
+ // Caching, we want to see hits and same scores as before
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size() * 2,
getFeatureVectorCacheInserts(core).getValue(), 0);
+ assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(),
0);
+ }
+
+ @Test
+ public void
testFeatureVectorCache_loggingModelStoreAndRerankingWithSameEfi() throws
Exception {
+ final String docs0fv_dense_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1",
+ "1.0",
+ "efi_feature",
+ "4.0",
+ "match_w1_title",
+ "0.0",
+ "popularity_value",
+ "2.0");
+ final String docs0fv_sparse_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1", "1.0", "efi_feature", "4.0",
"popularity_value", "2.0");
+
+ final String docs0fv_default_csv =
+ chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv);
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "id,score,fv:[fv efi.efi_feature=4]");
+ query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model
efi.efi_feature=4}");
+
+ // No caching for reranking but logging should hit since we have the same
feature store and efis
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(),
0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0);
+
+ query.add("sort", "popularity desc");
+ // Caching, we want to see hits and same scores
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(),
0);
+ assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size() * 3, getFeatureVectorCacheHits(core).getValue(),
0);
+ }
+
+ @Test
+ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking()
throws Exception {
+ final String docs0fv_dense_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1",
+ "1.0",
+ "value_feature_3",
+ "3.0",
+ "efi_feature",
+ "3.0",
+ "match_w1_title",
+ "0.0",
+ "popularity_value",
+ "2.0");
+ final String docs0fv_sparse_csv =
+ FeatureLoggerTestUtils.toFeatureVector(
+ "value_feature_1",
+ "1.0",
+ "value_feature_3",
+ "3.0",
+ "efi_feature",
+ "3.0",
+ "popularity_value",
+ "2.0");
+
+ final String docs0fv_default_csv =
+ chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv);
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "id,score,fv:[fv logAll=true efi.efi_feature=3]");
+ query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model
efi.efi_feature=4}");
+
+ // No caching, we want to see lookups, insertions and no hits since the
efis are different
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size() * 2,
getFeatureVectorCacheInserts(core).getValue(), 0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0);
+
+ query.add("sort", "popularity desc");
+ // Caching, we want to see hits and same scores
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size() * 2,
getFeatureVectorCacheInserts(core).getValue(), 0);
+ assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(),
0);
+ }
+
+ @Test
+ public void testFeatureVectorCache_loggingExplicitStoreAndReranking() throws
Exception {
+ final String docs0fv_dense_csv =
+ FeatureLoggerTestUtils.toFeatureVector("match_w1_title", "0.0",
"value_feature_2", "2.0");
+ final String docs0fv_sparse_csv =
+ FeatureLoggerTestUtils.toFeatureVector("value_feature_2", "2.0");
+
+ final String docs0fv_default_csv =
+ chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv);
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "id,score,fv:[fv store=store1 efi.efi_feature=3]");
+ query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model
efi.efi_feature=4}");
+
+ // No caching, we want to see lookups, insertions and no hits since the
efis are different
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size() * 2,
getFeatureVectorCacheInserts(core).getValue(), 0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0);
+
+ query.add("sort", "popularity desc");
+ // Caching, we want to see hits and same scores
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/response/docs/[0]/=={"
+ + "'id':'1',"
+ + "'score':3.4,"
+ + "'fv':'"
+ + docs0fv_default_csv
+ + "'}");
+ assertEquals(docs.size() * 2,
getFeatureVectorCacheInserts(core).getValue(), 0);
+ assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0);
+ assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(),
0);
+ }
+}
diff --git
a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
index 7ad5b34f49b..583ae22f742 100644
--- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
+++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
@@ -219,7 +219,7 @@ public class TestLTRScoringQuery extends SolrTestCase {
}
int[] posVals = new int[] {0, 1, 2};
int pos = 0;
- for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getFeaturesInfo()) {
+ for (LTRScoringQuery.FeatureInfo fInfo :
modelWeight.getAllFeaturesInStore()) {
if (fInfo == null) {
continue;
}
diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
index e3b30b24043..2ee0c1ebdd0 100644
--- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
+++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
@@ -120,6 +120,12 @@ public class TestRerankBase extends RestTestBase {
if (bulkIndex) bulkIndex();
}
+ protected static void setupFeatureVectorCacheTest(boolean bulkIndex) throws
Exception {
+ chooseDefaultFeatureFormat();
+ setuptest("solrconfig-ltr-featurevectorcache.xml", "schema.xml");
+ if (bulkIndex) bulkIndex();
+ }
+
public static ManagedFeatureStore getManagedFeatureStore() {
try (SolrCore core =
solrClientTestRule.getCoreContainer().getCore(DEFAULT_TEST_CORENAME)) {
return ManagedFeatureStore.getManagedFeatureStore(core);
diff --git
a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java
b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java
index be1313b47bc..b5148f66b80 100644
---
a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java
+++
b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java
@@ -167,7 +167,7 @@ public class TestSelectiveWeightCreation extends
TestRerankBase {
searcher,
hits.scoreDocs[0].doc,
new LTRScoringQuery(ltrScoringModel1)); // features not requested
in response
- LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo();
+ LTRScoringQuery.FeatureInfo[] featuresInfo =
modelWeight.getAllFeaturesInStore();
assertEquals(features.size(),
modelWeight.getModelFeatureValuesNormalized().length);
int nonDefaultFeatures = 0;
@@ -189,13 +189,12 @@ public class TestSelectiveWeightCreation extends
TestRerankBase {
TestLinearModel.makeFeatureWeights(features));
LTRScoringQuery ltrQuery2 = new LTRScoringQuery(ltrScoringModel2);
// features requested in response
- ltrQuery2.setFeatureLogger(
- new CSVFeatureLogger("test", FeatureLogger.FeatureFormat.DENSE, true));
+ ltrQuery2.setFeatureLogger(new
CSVFeatureLogger(FeatureLogger.FeatureFormat.DENSE, true));
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
ltrQuery2);
- featuresInfo = modelWeight.getFeaturesInfo();
+ featuresInfo = modelWeight.getAllFeaturesInStore();
assertEquals(features.size(),
modelWeight.getModelFeatureValuesNormalized().length);
- assertEquals(allFeatures.size(),
modelWeight.getExtractedFeatureWeights().length);
+ assertEquals(allFeatures.size(),
ltrQuery2.getExtractedFeatureWeights().length);
nonDefaultFeatures = 0;
for (int i = 0; i < featuresInfo.length; ++i) {
diff --git
a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java
b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java
index 3e6e986e591..f9fb8099cf3 100644
---
a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java
+++
b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java
@@ -236,7 +236,8 @@ public class TestFeatureExtractionFromMultipleSegments
extends TestRerankBase {
query.setQuery(
"{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7),
1600), .1)' v='apple'}");
// request 100 rows, if any rows are fetched from the second or subsequent
segments the tests
- // should succeed if LTRRescorer::extractFeaturesInfo() advances the doc
iterator properly
+ // should succeed if LTRFeatureLoggerTransformerFactory::extractFeatures()
advances the doc
+ // iterator properly
int numRows = 100;
query.add("rows", Integer.toString(numRows));
query.add("wt", "json");
diff --git
a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml
b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml
index 4192759f158..d6635c05e9b 100644
---
a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml
+++
b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml
@@ -391,7 +391,7 @@
/>
-->
- <!-- Feature Values Cache
+ <!-- Feature Vector Cache
Cache used by the Learning To Rank (LTR) module.
@@ -401,12 +401,11 @@
https://solr.apache.org/guide/solr/latest/query-guide/learning-to-rank.html
-->
- <cache enable="${solr.ltr.enabled:false}" name="QUERY_DOC_FV"
+ <featureVectorCache enable="${solr.ltr.enabled:false}"
class="solr.CaffeineCache"
size="4096"
initialSize="2048"
- autowarmCount="4096"
- regenerator="solr.search.NoOpRegenerator" />
+ autowarmCount="0" />
<!-- Custom Cache
@@ -1273,6 +1272,5 @@ via parameters. The below configuration supports
hl.method=original and fastVec
https://solr.apache.org/guide/solr/latest/query-guide/learning-to-rank.html
-->
<transformer enable="${solr.ltr.enabled:false}" name="features"
class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
- <str name="fvCacheName">QUERY_DOC_FV</str>
</transformer>
</config>
diff --git
a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
index 22e3bcde95e..cf13add48d8 100644
--- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
+++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
@@ -131,18 +131,18 @@ See xref:configuration-guide:solr-modules.adoc[Solr
Module] for more details.
<queryParser name="ltr" class="org.apache.solr.ltr.search.LTRQParserPlugin"/>
----
-* Configuration of the feature values cache.
+* Configuration of the feature vector cache which is used for both reranking
and feature logging.
+This needs to be added in the `<query>` section as follows.
+
[source,xml]
----
-<cache name="QUERY_DOC_FV"
- class="solr.search.CaffeineCache"
- size="4096"
- initialSize="2048"
- autowarmCount="4096"
- regenerator="solr.search.NoOpRegenerator" />
+<featureVectorCache class="solr.CaffeineCache" size="4096" initialSize="2048"
autowarmCount="0"/>
----
+[NOTE]
+The `featureVectorCache` key is computed using the Lucene Document ID
(necessary for document-level features).
+Since these IDs are transient, this cache does not support auto-warming.
+
* Declaration of the `[features]` transformer.
+
[source,xml]