This is an automated email from the ASF dual-hosted git repository.

FrankChen021 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 454c1c6d191 fix: Optimize nullable numeric aggregation for Pooled TopN 
(#19390)
454c1c6d191 is described below

commit 454c1c6d191273514c46d652d7b26f2ce95b5afc
Author: Frank Chen <[email protected]>
AuthorDate: Fri May 8 09:56:39 2026 +0800

    fix: Optimize nullable numeric aggregation for Pooled TopN (#19390)
    
    * Optimize nullable numeric aggregators for no-null columns
    
    * fix: Preserve nullable aggregate semantics outside TopN
    
    * fix: Apply nullable numeric TopN sizing consistently
    
    * Revert stylistic changes
    
    * Add topN Engine Test to test topN selection
    
    * Fix forbidden apis
    
    ---------
    
    Co-authored-by: GWphua <[email protected]>
---
 .../NullableNumericAggregatorFactory.java          |  49 +++++
 .../aggregation/SimpleDoubleAggregatorFactory.java |   7 +
 .../aggregation/SimpleFloatAggregatorFactory.java  |   7 +
 .../aggregation/SimpleLongAggregatorFactory.java   |   7 +
 .../druid/query/topn/PooledTopNAlgorithm.java      |  42 +++-
 .../apache/druid/query/topn/TopNQueryEngine.java   |   9 +-
 .../query/aggregation/DoubleSumAggregatorTest.java | 126 +++++++++++
 .../druid/query/topn/TopNQueryEngineTest.java      | 234 +++++++++++++++++++++
 8 files changed, 477 insertions(+), 4 deletions(-)

diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/NullableNumericAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/NullableNumericAggregatorFactory.java
index 49dadfd3fea..29b6be9065f 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/NullableNumericAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/NullableNumericAggregatorFactory.java
@@ -24,12 +24,17 @@ import com.google.common.base.Preconditions;
 import org.apache.druid.guice.annotations.ExtensionPoint;
 import org.apache.druid.segment.BaseNullableColumnValueSelector;
 import org.apache.druid.segment.BaseObjectColumnValueSelector;
+import org.apache.druid.segment.ColumnInspector;
 import org.apache.druid.segment.ColumnSelectorFactory;
 import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.Types;
 import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
 import org.apache.druid.segment.vector.VectorValueSelector;
 
+import javax.annotation.Nullable;
+
 /**
  * Abstract superclass for null-aware numeric aggregators.
  *
@@ -92,6 +97,29 @@ public abstract class NullableNumericAggregatorFactory<T 
extends BaseNullableCol
     return new NullableNumericVectorAggregator(aggregator, selector);
   }
 
+  /**
+   * Factorizes a buffer aggregator for Pooled TopN. Unlike general 
aggregation engines, Pooled TopN creates aggregate
+   * state only after a dimension value is seen. If the input column is known 
to be numeric and non-null, the nullable
+   * wrapper cannot affect the result, so TopN may skip it to keep hot-loop 
specialization effective.
+   */
+  public final BufferAggregator 
factorizeBufferedForPooledTopN(ColumnSelectorFactory columnSelectorFactory)
+  {
+    T selector = selector(columnSelectorFactory);
+    BufferAggregator aggregator = factorizeBuffered(columnSelectorFactory, 
selector);
+    if (!useNullableNumericAggregatorsForPooledTopN(columnSelectorFactory)) {
+      return aggregator;
+    }
+    return new NullableNumericBufferAggregator(aggregator, 
makeNullSelector(selector, columnSelectorFactory));
+  }
+
+  public final int 
getMaxIntermediateSizeWithNullsForPooledTopN(ColumnInspector columnInspector)
+  {
+    if (!useNullableNumericAggregatorsForPooledTopN(columnInspector)) {
+      return getMaxIntermediateSize();
+    }
+    return getMaxIntermediateSizeWithNulls();
+  }
+
   @Override
   public final AggregateCombiner makeNullableAggregateCombiner()
   {
@@ -111,6 +139,27 @@ public abstract class NullableNumericAggregatorFactory<T 
extends BaseNullableCol
     return getMaxIntermediateSize() + Byte.BYTES;
   }
 
+  private boolean useNullableNumericAggregatorsForPooledTopN(ColumnInspector 
columnInspector)
+  {
+    if (this.forceNotNullable()) {
+      return false;
+    }
+
+    final String inputColumn = getInputColumn();
+    if (inputColumn == null) {
+      return true;
+    }
+
+    final ColumnCapabilities capabilities = 
columnInspector.getColumnCapabilities(inputColumn);
+    return !(Types.isNumeric(capabilities) && 
capabilities.hasNulls().isFalse());
+  }
+
+  @Nullable
+  protected String getInputColumn()
+  {
+    return null;
+  }
+
   /**
    * Returns the selector that should be used by {@link 
NullableNumericAggregator} and
    * {@link NullableNumericBufferAggregator} to determine if the current value 
is null.
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
index 0fa96e226ea..b8e5f824903 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
@@ -134,6 +134,13 @@ public abstract class SimpleDoubleAggregatorFactory 
extends NullableNumericAggre
     return AggregatorUtil.shouldUseObjectColumnAggregatorWrapper(fieldName, 
columnSelectorFactory);
   }
 
+  @Override
+  @Nullable
+  protected String getInputColumn()
+  {
+    return expression == null ? fieldName : null;
+  }
+
   @Override
   public Object deserialize(Object object)
   {
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
index 5268c454ce1..cb2aab35e8c 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
@@ -124,6 +124,13 @@ public abstract class SimpleFloatAggregatorFactory extends 
NullableNumericAggreg
     return AggregatorUtil.shouldUseObjectColumnAggregatorWrapper(fieldName, 
columnSelectorFactory);
   }
 
+  @Override
+  @Nullable
+  protected String getInputColumn()
+  {
+    return expression == null ? fieldName : null;
+  }
+
   @Override
   public Object deserialize(Object object)
   {
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
index c4bc5307ed4..7fed8a5da2d 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
@@ -130,6 +130,13 @@ public abstract class SimpleLongAggregatorFactory extends 
NullableNumericAggrega
     return AggregatorUtil.shouldUseObjectColumnAggregatorWrapper(fieldName, 
columnSelectorFactory);
   }
 
+  @Override
+  @Nullable
+  protected String getInputColumn()
+  {
+    return expression == null ? fieldName : null;
+  }
+
   @Override
   public Object deserialize(Object object)
   {
diff --git 
a/processing/src/main/java/org/apache/druid/query/topn/PooledTopNAlgorithm.java 
b/processing/src/main/java/org/apache/druid/query/topn/PooledTopNAlgorithm.java
index e79ad59b0fc..ca6db31add4 100644
--- 
a/processing/src/main/java/org/apache/druid/query/topn/PooledTopNAlgorithm.java
+++ 
b/processing/src/main/java/org/apache/druid/query/topn/PooledTopNAlgorithm.java
@@ -27,11 +27,14 @@ import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.query.BaseQuery;
 import org.apache.druid.query.ColumnSelectorPlus;
 import org.apache.druid.query.CursorGranularizer;
+import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.BufferAggregator;
+import org.apache.druid.query.aggregation.NullableNumericAggregatorFactory;
 import org.apache.druid.query.aggregation.SimpleDoubleBufferAggregator;
 import org.apache.druid.query.monomorphicprocessing.SpecializationService;
 import org.apache.druid.query.monomorphicprocessing.SpecializationState;
 import org.apache.druid.query.monomorphicprocessing.StringRuntimeShape;
+import org.apache.druid.segment.ColumnInspector;
 import org.apache.druid.segment.Cursor;
 import org.apache.druid.segment.DimensionSelector;
 import org.apache.druid.segment.FilteredOffset;
@@ -259,7 +262,10 @@ public class PooledTopNAlgorithm
       int numBytesPerRecord = 0;
 
       for (int i = 0; i < query.getAggregatorSpecs().size(); ++i) {
-        aggregatorSizes[i] = 
query.getAggregatorSpecs().get(i).getMaxIntermediateSizeWithNulls();
+        aggregatorSizes[i] = getMaxIntermediateSizeWithNullsForPooledTopN(
+            cursor.getColumnSelectorFactory(),
+            query.getAggregatorSpecs().get(i)
+        );
         numBytesPerRecord += aggregatorSizes[i];
       }
 
@@ -329,7 +335,39 @@ public class PooledTopNAlgorithm
   @Override
   protected BufferAggregator[] makeDimValAggregateStore(PooledTopNParams 
params)
   {
-    return makeBufferAggregators(params.getCursor(), 
query.getAggregatorSpecs());
+    return makeBufferAggregatorsForPooledTopN(params.getCursor(), 
query.getAggregatorSpecs());
+  }
+
+  static int getMaxIntermediateSizeWithNullsForPooledTopN(
+      ColumnInspector columnInspector,
+      AggregatorFactory aggregatorFactory
+  )
+  {
+    if (aggregatorFactory instanceof NullableNumericAggregatorFactory) {
+      return ((NullableNumericAggregatorFactory<?>) 
aggregatorFactory).getMaxIntermediateSizeWithNullsForPooledTopN(
+          columnInspector
+      );
+    }
+    return aggregatorFactory.getMaxIntermediateSizeWithNulls();
+  }
+
+  private static BufferAggregator[] makeBufferAggregatorsForPooledTopN(
+      Cursor cursor,
+      List<AggregatorFactory> aggregatorSpecs
+  )
+  {
+    BufferAggregator[] aggregators = new 
BufferAggregator[aggregatorSpecs.size()];
+    int aggregatorIndex = 0;
+    for (AggregatorFactory spec : aggregatorSpecs) {
+      if (spec instanceof NullableNumericAggregatorFactory) {
+        aggregators[aggregatorIndex] =
+            ((NullableNumericAggregatorFactory<?>) 
spec).factorizeBufferedForPooledTopN(cursor.getColumnSelectorFactory());
+      } else {
+        aggregators[aggregatorIndex] = 
spec.factorizeBuffered(cursor.getColumnSelectorFactory());
+      }
+      ++aggregatorIndex;
+    }
+    return aggregators;
   }
 
   @Override
diff --git 
a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java 
b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java
index f8c7762c004..4a269469c17 100644
--- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java
+++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java
@@ -19,6 +19,7 @@
 
 package org.apache.druid.query.topn;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicates;
 import org.apache.druid.collections.NonBlockingPool;
@@ -157,7 +158,8 @@ public class TopNQueryEngine
   /**
    * Choose the best {@link TopNAlgorithm} for the given query.
    */
-  private TopNMapFn getMapFn(
+  @VisibleForTesting
+  TopNMapFn getMapFn(
       final TopNQuery query,
       final TopNCursorInspector cursorInspector,
       final @Nullable TopNQueryMetrics queryMetrics
@@ -172,7 +174,10 @@ public class TopNQueryEngine
 
     int numBytesPerRecord = 0;
     for (AggregatorFactory aggregatorFactory : query.getAggregatorSpecs()) {
-      numBytesPerRecord += aggregatorFactory.getMaxIntermediateSizeWithNulls();
+      numBytesPerRecord += 
PooledTopNAlgorithm.getMaxIntermediateSizeWithNullsForPooledTopN(
+          cursorInspector.getColumnInspector(),
+          aggregatorFactory
+      );
     }
 
     final TopNAlgorithmSelector selector = new TopNAlgorithmSelector(
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/DoubleSumAggregatorTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/DoubleSumAggregatorTest.java
index c29d6fceb96..b9276e71f61 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/DoubleSumAggregatorTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/DoubleSumAggregatorTest.java
@@ -19,9 +19,14 @@
 
 package org.apache.druid.query.aggregation;
 
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
+import org.apache.druid.segment.column.ColumnType;
+import org.easymock.EasyMock;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.nio.ByteBuffer;
 import java.util.Comparator;
 
 /**
@@ -73,4 +78,125 @@ public class DoubleSumAggregatorTest
     Assertions.assertEquals(0, comp.compare(agg.get(), agg.get()));
     Assertions.assertEquals(1, comp.compare(agg.get(), first));
   }
+
+  @Test
+  public void testUsesNullableBufferAggregatorWhenInputHasNoNulls()
+  {
+    ColumnSelectorFactory selectorFactory = 
EasyMock.createMock(ColumnSelectorFactory.class);
+    TestDoubleColumnSelectorImpl selector = new 
TestDoubleColumnSelectorImpl(new double[]{1.0d});
+    ColumnCapabilitiesImpl capabilities =
+        
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE).setHasNulls(false);
+    
EasyMock.expect(selectorFactory.makeColumnValueSelector("metric")).andReturn(selector);
+    
EasyMock.expect(selectorFactory.getColumnCapabilities("metric")).andReturn(capabilities).times(2);
+    EasyMock.replay(selectorFactory);
+
+    BufferAggregator aggregator = new DoubleSumAggregatorFactory("sum", 
"metric").factorizeBuffered(selectorFactory);
+
+    Assertions.assertTrue(aggregator instanceof 
NullableNumericBufferAggregator);
+    ByteBuffer buffer = ByteBuffer.allocate(Double.BYTES + Byte.BYTES);
+    aggregator.init(buffer, 0);
+    Assertions.assertNull(aggregator.get(buffer, 0));
+    aggregator.aggregate(buffer, 0);
+    Assertions.assertEquals(1.0d, aggregator.getDouble(buffer, 0), 0.0d);
+    EasyMock.verify(selectorFactory);
+  }
+
+  @Test
+  public void testUsesNullableAggregatorWhenInputHasNoNulls()
+  {
+    ColumnSelectorFactory selectorFactory = 
EasyMock.createMock(ColumnSelectorFactory.class);
+    TestDoubleColumnSelectorImpl selector = new 
TestDoubleColumnSelectorImpl(new double[]{1.0d});
+    ColumnCapabilitiesImpl capabilities =
+        
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE).setHasNulls(false);
+    
EasyMock.expect(selectorFactory.makeColumnValueSelector("metric")).andReturn(selector);
+    
EasyMock.expect(selectorFactory.getColumnCapabilities("metric")).andReturn(capabilities).times(2);
+    EasyMock.replay(selectorFactory);
+
+    Aggregator aggregator = new DoubleSumAggregatorFactory("sum", 
"metric").factorize(selectorFactory);
+
+    Assertions.assertTrue(aggregator instanceof NullableNumericAggregator);
+    Assertions.assertNull(aggregator.get());
+    aggregator.aggregate();
+    Assertions.assertEquals(1.0d, aggregator.getDouble(), 0.0d);
+    EasyMock.verify(selectorFactory);
+  }
+
+  @Test
+  public void testPooledTopNSkipsNullableBufferAggregatorWhenInputHasNoNulls()
+  {
+    ColumnSelectorFactory selectorFactory = 
EasyMock.createMock(ColumnSelectorFactory.class);
+    TestDoubleColumnSelectorImpl selector = new 
TestDoubleColumnSelectorImpl(new double[]{1.0d});
+    ColumnCapabilitiesImpl capabilities =
+        
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE).setHasNulls(false);
+    DoubleSumAggregatorFactory factory = new DoubleSumAggregatorFactory("sum", 
"metric");
+    
EasyMock.expect(selectorFactory.makeColumnValueSelector("metric")).andReturn(selector);
+    
EasyMock.expect(selectorFactory.getColumnCapabilities("metric")).andReturn(capabilities).times(3);
+    EasyMock.replay(selectorFactory);
+
+    BufferAggregator aggregator = 
factory.factorizeBufferedForPooledTopN(selectorFactory);
+
+    Assertions.assertTrue(aggregator instanceof DoubleSumBufferAggregator);
+    Assertions.assertEquals(Double.BYTES, 
factory.getMaxIntermediateSizeWithNullsForPooledTopN(selectorFactory));
+    ByteBuffer buffer = ByteBuffer.allocate(Double.BYTES);
+    aggregator.init(buffer, 0);
+    aggregator.aggregate(buffer, 0);
+    Assertions.assertEquals(1.0d, aggregator.getDouble(buffer, 0), 0.0d);
+    EasyMock.verify(selectorFactory);
+  }
+
+  @Test
+  public void testUsesNullableBufferAggregatorWhenInputNullsAreUnknown()
+  {
+    ColumnSelectorFactory selectorFactory = 
EasyMock.createMock(ColumnSelectorFactory.class);
+    TestDoubleColumnSelectorImpl selector = new 
TestDoubleColumnSelectorImpl(new double[]{1.0d});
+    ColumnCapabilitiesImpl capabilities =
+        
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE);
+    
EasyMock.expect(selectorFactory.makeColumnValueSelector("metric")).andReturn(selector);
+    
EasyMock.expect(selectorFactory.getColumnCapabilities("metric")).andReturn(capabilities).times(2);
+    EasyMock.replay(selectorFactory);
+
+    BufferAggregator aggregator = new DoubleSumAggregatorFactory("sum", 
"metric").factorizeBuffered(selectorFactory);
+
+    Assertions.assertTrue(aggregator instanceof 
NullableNumericBufferAggregator);
+    EasyMock.verify(selectorFactory);
+  }
+
+  @Test
+  public void 
testPooledTopNUsesNullableBufferAggregatorWhenInputNullsAreUnknown()
+  {
+    ColumnSelectorFactory selectorFactory = 
EasyMock.createMock(ColumnSelectorFactory.class);
+    TestDoubleColumnSelectorImpl selector = new 
TestDoubleColumnSelectorImpl(new double[]{1.0d});
+    ColumnCapabilitiesImpl capabilities =
+        
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE);
+    DoubleSumAggregatorFactory factory = new DoubleSumAggregatorFactory("sum", 
"metric");
+    
EasyMock.expect(selectorFactory.makeColumnValueSelector("metric")).andReturn(selector);
+    
EasyMock.expect(selectorFactory.getColumnCapabilities("metric")).andReturn(capabilities).times(4);
+    EasyMock.replay(selectorFactory);
+
+    BufferAggregator aggregator = 
factory.factorizeBufferedForPooledTopN(selectorFactory);
+
+    Assertions.assertTrue(aggregator instanceof 
NullableNumericBufferAggregator);
+    Assertions.assertEquals(
+        Double.BYTES + Byte.BYTES,
+        factory.getMaxIntermediateSizeWithNullsForPooledTopN(selectorFactory)
+    );
+    EasyMock.verify(selectorFactory);
+  }
+
+  @Test
+  public void testUsesNullableAggregatorWhenInputNullsAreUnknown()
+  {
+    ColumnSelectorFactory selectorFactory = 
EasyMock.createMock(ColumnSelectorFactory.class);
+    TestDoubleColumnSelectorImpl selector = new 
TestDoubleColumnSelectorImpl(new double[]{1.0d});
+    ColumnCapabilitiesImpl capabilities =
+        
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE);
+    
EasyMock.expect(selectorFactory.makeColumnValueSelector("metric")).andReturn(selector);
+    
EasyMock.expect(selectorFactory.getColumnCapabilities("metric")).andReturn(capabilities).times(2);
+    EasyMock.replay(selectorFactory);
+
+    Aggregator aggregator = new DoubleSumAggregatorFactory("sum", 
"metric").factorize(selectorFactory);
+
+    Assertions.assertTrue(aggregator instanceof NullableNumericAggregator);
+    EasyMock.verify(selectorFactory);
+  }
 }
diff --git 
a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryEngineTest.java 
b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryEngineTest.java
new file mode 100644
index 00000000000..61edd6287bf
--- /dev/null
+++ 
b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryEngineTest.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.topn;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.druid.collections.NonBlockingPool;
+import org.apache.druid.collections.ResourceHolder;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.granularity.Granularity;
+import org.apache.druid.query.aggregation.AggregatorFactory;
+import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.dimension.DefaultDimensionSpec;
+import org.apache.druid.query.spec.LegacySegmentSpec;
+import org.apache.druid.segment.ColumnInspector;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
+import org.apache.druid.segment.column.ColumnType;
+import org.joda.time.Interval;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+
+public class TopNQueryEngineTest
+{
+  private static final String DIMENSION = "dim";
+  private static final String METRIC = "metric";
+  private static final Interval INTERVAL = Intervals.of("2000/2001");
+
+  @Test
+  public void testUsesReducedPooledTopNSizeForPooledVsHeapSelection()
+  {
+    final TopNAlgorithm algorithm = selectAlgorithm(
+        10,
+        80,
+        metricCapabilities(false),
+        ImmutableList.of(new DoubleSumAggregatorFactory("sum", METRIC)),
+        new NumericTopNMetricSpec("sum"),
+        Granularities.DAY
+    );
+
+    Assertions.assertInstanceOf(PooledTopNAlgorithm.class, algorithm);
+  }
+
+  @Test
+  public void 
testUsesNullableSizeForUnknownNullabilityInPooledVsHeapSelection()
+  {
+    final TopNAlgorithm algorithm = selectAlgorithm(
+        10,
+        80,
+        metricCapabilitiesUnknownNullability(),
+        ImmutableList.of(new DoubleSumAggregatorFactory("sum", METRIC)),
+        new NumericTopNMetricSpec("sum"),
+        Granularities.DAY
+    );
+
+    Assertions.assertInstanceOf(HeapBasedTopNAlgorithm.class, algorithm);
+  }
+
+  @Test
+  public void testReducedPooledTopNSizeAvoidsAggregateMetricFirstThreshold()
+  {
+    final TopNAlgorithm algorithm = selectAlgorithm(
+        400001,
+        96,
+        metricCapabilities(false),
+        doubleSumAggregators(12),
+        new NumericTopNMetricSpec("sum0"),
+        Granularities.ALL
+    );
+
+    Assertions.assertEquals(PooledTopNAlgorithm.class, algorithm.getClass());
+  }
+
+  @Test
+  public void testUnknownNullabilityStillUsesAggregateMetricFirstThreshold()
+  {
+    final TopNAlgorithm algorithm = selectAlgorithm(
+        400001,
+        108,
+        metricCapabilitiesUnknownNullability(),
+        doubleSumAggregators(12),
+        new NumericTopNMetricSpec("sum0"),
+        Granularities.ALL
+    );
+
+    Assertions.assertInstanceOf(AggregateTopNMetricFirstAlgorithm.class, 
algorithm);
+  }
+
+  private static TopNAlgorithm selectAlgorithm(
+      final int dimensionCardinality,
+      final int bufferSize,
+      final ColumnCapabilities metricCapabilities,
+      final List<AggregatorFactory> aggregatorFactories,
+      final TopNMetricSpec metricSpec,
+      final Granularity granularity
+  )
+  {
+    final TopNQuery query = new TopNQueryBuilder()
+        .dataSource("test")
+        .dimension(new DefaultDimensionSpec(DIMENSION, DIMENSION))
+        .metric(metricSpec)
+        .threshold(10)
+        .intervals(new LegacySegmentSpec(INTERVAL))
+        .granularity(granularity)
+        .aggregators(aggregatorFactories)
+        .build();
+    final CapturingTopNQueryMetrics queryMetrics = new 
CapturingTopNQueryMetrics();
+    final TopNQueryEngine queryEngine = new TopNQueryEngine(new 
TestBufferPool(bufferSize));
+
+    queryEngine.getMapFn(
+        query,
+        new TopNCursorInspector(
+            new TestColumnInspector(metricCapabilities),
+            null,
+            INTERVAL,
+            dimensionCardinality
+        ),
+        queryMetrics
+    );
+
+    return queryMetrics.algorithm;
+  }
+
+  private static List<AggregatorFactory> doubleSumAggregators(final int count)
+  {
+    final ImmutableList.Builder<AggregatorFactory> aggregators = 
ImmutableList.builder();
+    for (int i = 0; i < count; i++) {
+      aggregators.add(new DoubleSumAggregatorFactory("sum" + i, METRIC));
+    }
+    return aggregators.build();
+  }
+
+  private static ColumnCapabilities metricCapabilities(final boolean hasNulls)
+  {
+    return 
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE).setHasNulls(hasNulls);
+  }
+
+  private static ColumnCapabilities metricCapabilitiesUnknownNullability()
+  {
+    return 
ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE);
+  }
+
+  private static ColumnCapabilities dimensionCapabilities()
+  {
+    return new ColumnCapabilitiesImpl()
+        .setType(ColumnType.STRING)
+        .setHasMultipleValues(false)
+        .setHasBitmapIndexes(false)
+        .setDictionaryEncoded(true)
+        .setDictionaryValuesUnique(true);
+  }
+
+  private static class TestColumnInspector implements ColumnInspector
+  {
+    private final ColumnCapabilities metricCapabilities;
+
+    private TestColumnInspector(final ColumnCapabilities metricCapabilities)
+    {
+      this.metricCapabilities = metricCapabilities;
+    }
+
+    @Override
+    public ColumnCapabilities getColumnCapabilities(final String column)
+    {
+      if (DIMENSION.equals(column)) {
+        return dimensionCapabilities();
+      }
+      if (METRIC.equals(column)) {
+        return metricCapabilities;
+      }
+      return null;
+    }
+  }
+
+  private static class TestBufferPool implements NonBlockingPool<ByteBuffer>
+  {
+    private final int bufferSize;
+
+    private TestBufferPool(final int bufferSize)
+    {
+      this.bufferSize = bufferSize;
+    }
+
+    @Override
+    public ResourceHolder<ByteBuffer> take()
+    {
+      return new ResourceHolder<>()
+      {
+        @Override
+        public ByteBuffer get()
+        {
+          return ByteBuffer.allocate(bufferSize);
+        }
+
+        @Override
+        public void close()
+        {
+          // Nothing to release.
+        }
+      };
+    }
+  }
+
+  private static class CapturingTopNQueryMetrics extends 
DefaultTopNQueryMetrics
+  {
+    private TopNAlgorithm algorithm;
+
+    @Override
+    public void algorithm(final TopNAlgorithm algorithm)
+    {
+      this.algorithm = algorithm;
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to