cpoerschke commented on code in PR #3316:
URL: https://github.com/apache/solr/pull/3316#discussion_r2039495987


##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),
+                ScoreModeParser.parse(scoreMode),
+                allParents);
+          }
+      } catch (IOException e) {
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+      }
+  }
+
+  private void setAppropriateChildrenListingTransformer(SolrQueryRequest 
request, Query knnOnVectorField) throws IOException {
+    QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+    ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+    DocTransformer originalTransformer = returnFields.getTransformer();
+
+    if (originalTransformer instanceof DocTransformers) {
+      DocTransformers transformers = (DocTransformers) originalTransformer;
+      boolean noChildTransformer = true;
+      for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+        DocTransformer t = transformers.getTransformer(i);
+        if (t instanceof ChildDocTransformer) {
+          ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+          if (childTransformer.getChildDocSet() == null) {
+            
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+          }
+          noChildTransformer = false;
+        }
+      }
+    } else {
+      if ((originalTransformer instanceof ChildDocTransformer)) {
+        ChildDocTransformer childTransformer = (ChildDocTransformer) 
originalTransformer;
+        if (childTransformer.getChildDocSet() == null) {
+          
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+        }
+      }
+    }
+  }
+
+  private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+  }
+
+  private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+  }

Review Comment:
   wondering if `childrenClauses.get(0).getQuery() instances 
KnnFloatVectorQuery` and `childrenClauses.get(0).getQuery() instances 
KnnByteVectorQuery` would work too?



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);

Review Comment:
   1/3 maybe move outside the if and use in all three codepaths
   ```suggestion
             BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
             if (isByteKnnQuery(childrenClauses)) {
   ```



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),
+                ScoreModeParser.parse(scoreMode),
+                allParents);
+          }
+      } catch (IOException e) {
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+      }

Review Comment:
   curious about the addition of a try/catch block here, perhaps we could have 
a comment w.r.t. what it might catch?



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);

Review Comment:
   2/3
   ```suggestion
   ```



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),

Review Comment:
   3/3
   ```suggestion
                   allParentsBitSet,
   ```



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),
+                ScoreModeParser.parse(scoreMode),
+                allParents);
+          }
+      } catch (IOException e) {
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+      }
+  }
+
+  private void setAppropriateChildrenListingTransformer(SolrQueryRequest 
request, Query knnOnVectorField) throws IOException {
+    QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+    ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+    DocTransformer originalTransformer = returnFields.getTransformer();
+
+    if (originalTransformer instanceof DocTransformers) {
+      DocTransformers transformers = (DocTransformers) originalTransformer;
+      boolean noChildTransformer = true;
+      for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+        DocTransformer t = transformers.getTransformer(i);
+        if (t instanceof ChildDocTransformer) {
+          ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+          if (childTransformer.getChildDocSet() == null) {
+            
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+          }
+          noChildTransformer = false;
+        }
+      }
+    } else {
+      if ((originalTransformer instanceof ChildDocTransformer)) {
+        ChildDocTransformer childTransformer = (ChildDocTransformer) 
originalTransformer;
+        if (childTransformer.getChildDocSet() == null) {
+          
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+        }
+      }
+    }
+  }
+
+  private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+  }
+
+  private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+  }
+
+  private Query getChildrenFilter(
+      Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer 
allParentsBitSet) {
+    Query childrenFilter = childrenKnnPreFilter;
+
+    if (parentsFilter.clauses().size() > 0) {

Review Comment:
   ```suggestion
       if (!parentsFilter.clauses().isEmpty()) {
   ```



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),
+                ScoreModeParser.parse(scoreMode),
+                allParents);
+          }
+      } catch (IOException e) {
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+      }
+  }
+
+  private void setAppropriateChildrenListingTransformer(SolrQueryRequest 
request, Query knnOnVectorField) throws IOException {
+    QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+    ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+    DocTransformer originalTransformer = returnFields.getTransformer();
+
+    if (originalTransformer instanceof DocTransformers) {
+      DocTransformers transformers = (DocTransformers) originalTransformer;
+      boolean noChildTransformer = true;
+      for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+        DocTransformer t = transformers.getTransformer(i);
+        if (t instanceof ChildDocTransformer) {
+          ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+          if (childTransformer.getChildDocSet() == null) {
+            
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+          }
+          noChildTransformer = false;
+        }
+      }
+    } else {
+      if ((originalTransformer instanceof ChildDocTransformer)) {
+        ChildDocTransformer childTransformer = (ChildDocTransformer) 
originalTransformer;
+        if (childTransformer.getChildDocSet() == null) {
+          
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+        }
+      }
+    }
+  }
+
+  private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+  }
+
+  private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+  }

Review Comment:
   if `true` is returned then later the caller does (repeat) 
`childrenClauses.get(0).getQuery()` and type cast instead of check -- perhaps 
this could be encapsulated
   
   ```
   private KnnFloatVectorQuery getFloatKnnQuery(List<BooleanClause> 
childrenClauses) {
       if (childrenClauses.size() == 1)
       {
           Query query = childrenClauses.get(0).getQuery();
           if (query instanceof KnnFloatVectorQuery)
           {
               return (KnnFloatVectorQuery) query;
           }
       }
       return null;
   }
   ```



##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),
+                ScoreModeParser.parse(scoreMode),
+                allParents);
+          }
+      } catch (IOException e) {
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+      }
+  }
+
+  private void setAppropriateChildrenListingTransformer(SolrQueryRequest 
request, Query knnOnVectorField) throws IOException {
+    QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+    ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+    DocTransformer originalTransformer = returnFields.getTransformer();
+
+    if (originalTransformer instanceof DocTransformers) {
+      DocTransformers transformers = (DocTransformers) originalTransformer;
+      boolean noChildTransformer = true;
+      for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+        DocTransformer t = transformers.getTransformer(i);
+        if (t instanceof ChildDocTransformer) {
+          ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+          if (childTransformer.getChildDocSet() == null) {
+            
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+          }
+          noChildTransformer = false;
+        }
+      }
+    } else {
+      if ((originalTransformer instanceof ChildDocTransformer)) {
+        ChildDocTransformer childTransformer = (ChildDocTransformer) 
originalTransformer;
+        if (childTransformer.getChildDocSet() == null) {
+          
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+        }
+      }
+    }

Review Comment:
   minor
   ```suggestion
       } else if ((originalTransformer instanceof ChildDocTransformer)) {
           ChildDocTransformer childTransformer = (ChildDocTransformer) 
originalTransformer;
           if (childTransformer.getChildDocSet() == null) {
             
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
           }
       }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to