cpoerschke commented on code in PR #3316:
URL: https://github.com/apache/solr/pull/3316#discussion_r2039495987
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnFloatVectorQuery knnChildrenQuery =
+ (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ float[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query childrenFilter =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenFloatKnnVectorQuery(
+ vectorField, queryVector, childrenFilter, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else {
+ return new AllParentsAware(
+ childrenQuery,
+ getBitSetProducer(allParents),
+ ScoreModeParser.parse(scoreMode),
+ allParents);
+ }
+ } catch (IOException e) {
+ throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+ }
+ }
+
+ private void setAppropriateChildrenListingTransformer(SolrQueryRequest
request, Query knnOnVectorField) throws IOException {
+ QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+ ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+ DocTransformer originalTransformer = returnFields.getTransformer();
+
+ if (originalTransformer instanceof DocTransformers) {
+ DocTransformers transformers = (DocTransformers) originalTransformer;
+ boolean noChildTransformer = true;
+ for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+ DocTransformer t = transformers.getTransformer(i);
+ if (t instanceof ChildDocTransformer) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ noChildTransformer = false;
+ }
+ }
+ } else {
+ if ((originalTransformer instanceof ChildDocTransformer)) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer)
originalTransformer;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ }
+ }
+ }
+
+ private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+ return childrenClauses.size() == 1
+ &&
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+ }
+
+ private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+ return childrenClauses.size() == 1
+ &&
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+ }
Review Comment:
wondering if `childrenClauses.get(0).getQuery() instances
KnnFloatVectorQuery` and `childrenClauses.get(0).getQuery() instances
KnnByteVectorQuery` would work too?
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
Review Comment:
1/3 maybe move outside the if and use in all three codepaths
```suggestion
BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
if (isByteKnnQuery(childrenClauses)) {
```
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnFloatVectorQuery knnChildrenQuery =
+ (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ float[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query childrenFilter =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenFloatKnnVectorQuery(
+ vectorField, queryVector, childrenFilter, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else {
+ return new AllParentsAware(
+ childrenQuery,
+ getBitSetProducer(allParents),
+ ScoreModeParser.parse(scoreMode),
+ allParents);
+ }
+ } catch (IOException e) {
+ throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+ }
Review Comment:
curious about the addition of a try/catch block here, perhaps we could have
a comment w.r.t. what it might catch?
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
Review Comment:
2/3
```suggestion
```
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnFloatVectorQuery knnChildrenQuery =
+ (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ float[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query childrenFilter =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenFloatKnnVectorQuery(
+ vectorField, queryVector, childrenFilter, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else {
+ return new AllParentsAware(
+ childrenQuery,
+ getBitSetProducer(allParents),
Review Comment:
3/3
```suggestion
allParentsBitSet,
```
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnFloatVectorQuery knnChildrenQuery =
+ (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ float[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query childrenFilter =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenFloatKnnVectorQuery(
+ vectorField, queryVector, childrenFilter, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else {
+ return new AllParentsAware(
+ childrenQuery,
+ getBitSetProducer(allParents),
+ ScoreModeParser.parse(scoreMode),
+ allParents);
+ }
+ } catch (IOException e) {
+ throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+ }
+ }
+
+ private void setAppropriateChildrenListingTransformer(SolrQueryRequest
request, Query knnOnVectorField) throws IOException {
+ QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+ ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+ DocTransformer originalTransformer = returnFields.getTransformer();
+
+ if (originalTransformer instanceof DocTransformers) {
+ DocTransformers transformers = (DocTransformers) originalTransformer;
+ boolean noChildTransformer = true;
+ for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+ DocTransformer t = transformers.getTransformer(i);
+ if (t instanceof ChildDocTransformer) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ noChildTransformer = false;
+ }
+ }
+ } else {
+ if ((originalTransformer instanceof ChildDocTransformer)) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer)
originalTransformer;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ }
+ }
+ }
+
+ private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+ return childrenClauses.size() == 1
+ &&
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+ }
+
+ private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+ return childrenClauses.size() == 1
+ &&
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+ }
+
+ private Query getChildrenFilter(
+ Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer
allParentsBitSet) {
+ Query childrenFilter = childrenKnnPreFilter;
+
+ if (parentsFilter.clauses().size() > 0) {
Review Comment:
```suggestion
if (!parentsFilter.clauses().isEmpty()) {
```
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnFloatVectorQuery knnChildrenQuery =
+ (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ float[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query childrenFilter =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenFloatKnnVectorQuery(
+ vectorField, queryVector, childrenFilter, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else {
+ return new AllParentsAware(
+ childrenQuery,
+ getBitSetProducer(allParents),
+ ScoreModeParser.parse(scoreMode),
+ allParents);
+ }
+ } catch (IOException e) {
+ throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+ }
+ }
+
+ private void setAppropriateChildrenListingTransformer(SolrQueryRequest
request, Query knnOnVectorField) throws IOException {
+ QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+ ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+ DocTransformer originalTransformer = returnFields.getTransformer();
+
+ if (originalTransformer instanceof DocTransformers) {
+ DocTransformers transformers = (DocTransformers) originalTransformer;
+ boolean noChildTransformer = true;
+ for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+ DocTransformer t = transformers.getTransformer(i);
+ if (t instanceof ChildDocTransformer) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ noChildTransformer = false;
+ }
+ }
+ } else {
+ if ((originalTransformer instanceof ChildDocTransformer)) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer)
originalTransformer;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ }
+ }
+ }
+
+ private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+ return childrenClauses.size() == 1
+ &&
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+ }
+
+ private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+ return childrenClauses.size() == 1
+ &&
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+ }
Review Comment:
if `true` is returned then later the caller does (repeat)
`childrenClauses.get(0).getQuery()` and type cast instead of check -- perhaps
this could be encapsulated
```
private KnnFloatVectorQuery getFloatKnnQuery(List<BooleanClause>
childrenClauses) {
if (childrenClauses.size() == 1)
{
Query query = childrenClauses.get(0).getQuery();
if (query instanceof KnnFloatVectorQuery)
{
return (KnnFloatVectorQuery) query;
}
}
return null;
}
```
##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
}
- protected Query createQuery(final Query parentList, Query query, String
scoreMode)
+ protected Query createQuery(final Query allParents, BooleanQuery
childrenQuery, String scoreMode)
throws SyntaxError {
- return new AllParentsAware(
- query, getBitSetProducer(parentList),
ScoreModeParser.parse(scoreMode), parentList);
+ try {
+ List<BooleanClause> childrenClauses = childrenQuery.clauses();
+ if (isByteKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery)
childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ byte[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query acceptedChildren =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenByteKnnVectorQuery(
+ vectorField, queryVector, acceptedChildren, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else if (isFloatKnnQuery(childrenClauses)) {
+ BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+ BooleanQuery parentsFilter = getParentsFilter();
+
+ KnnFloatVectorQuery knnChildrenQuery =
+ (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+ String vectorField = knnChildrenQuery.getField();
+ float[] queryVector = knnChildrenQuery.getTargetCopy();
+ int topK = knnChildrenQuery.getK();
+
+ Query childrenFilter =
+ getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter,
allParentsBitSet);
+
+ Query knnChildren =
+ new DiversifyingChildrenFloatKnnVectorQuery(
+ vectorField, queryVector, childrenFilter, topK,
allParentsBitSet);
+ knnChildren = knnChildren.rewrite(req.getSearcher());
+ this.setAppropriateChildrenListingTransformer(req,knnChildren);
+
+ return new ToParentBlockJoinQuery(
+ knnChildren, allParentsBitSet,
ScoreModeParser.parse(scoreMode));
+ } else {
+ return new AllParentsAware(
+ childrenQuery,
+ getBitSetProducer(allParents),
+ ScoreModeParser.parse(scoreMode),
+ allParents);
+ }
+ } catch (IOException e) {
+ throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+ }
+ }
+
+ private void setAppropriateChildrenListingTransformer(SolrQueryRequest
request, Query knnOnVectorField) throws IOException {
+ QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+ ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+ DocTransformer originalTransformer = returnFields.getTransformer();
+
+ if (originalTransformer instanceof DocTransformers) {
+ DocTransformers transformers = (DocTransformers) originalTransformer;
+ boolean noChildTransformer = true;
+ for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+ DocTransformer t = transformers.getTransformer(i);
+ if (t instanceof ChildDocTransformer) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ noChildTransformer = false;
+ }
+ }
+ } else {
+ if ((originalTransformer instanceof ChildDocTransformer)) {
+ ChildDocTransformer childTransformer = (ChildDocTransformer)
originalTransformer;
+ if (childTransformer.getChildDocSet() == null) {
+
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+ }
+ }
+ }
Review Comment:
minor
```suggestion
} else if ((originalTransformer instanceof ChildDocTransformer)) {
ChildDocTransformer childTransformer = (ChildDocTransformer)
originalTransformer;
if (childTransformer.getChildDocSet() == null) {
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]