This is an automated email from the ASF dual-hosted git repository.
gianm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new 06ef24c1e09 feat: ingest support for numeric typed
ExpressionLambdaAggregatorFactory (#19508)
06ef24c1e09 is described below
commit 06ef24c1e09a8000c91dd0767e09278b34066252
Author: Clint Wylie <[email protected]>
AuthorDate: Sat May 23 14:40:06 2026 -0700
feat: ingest support for numeric typed ExpressionLambdaAggregatorFactory
(#19508)
---
docs/querying/aggregations.md | 6 +-
.../embedded/compact/CompactionTaskTest.java | 90 ++++++
.../ExpressionLambdaAggregatorFactory.java | 82 ++++++
.../ExpressionLambdaAggregationTest.java | 215 +++++++++++++++
.../ExpressionLambdaAggregatorFactoryTest.java | 304 +++++++++++++++++++++
5 files changed, 696 insertions(+), 1 deletion(-)
diff --git a/docs/querying/aggregations.md b/docs/querying/aggregations.md
index c7b7d4e4efc..3add7863c46 100644
--- a/docs/querying/aggregations.md
+++ b/docs/querying/aggregations.md
@@ -471,7 +471,11 @@ For these reasons, we have deprecated this aggregator and
recommend using the Da
### Expression aggregator
-Aggregator applicable only at query time. Aggregates results using [Druid
expressions](./math-expr.md) functions to facilitate building custom functions.
+Aggregates results using [Druid expressions](./math-expr.md) functions to
facilitate building custom functions.
+
+The expression aggregator can be used at query time with any intermediate
type. It can also be used at ingest time, but
+only when the type of `initialValue` is a primitive numeric type (`LONG` or
`DOUBLE`) and matches the type of
+`initialCombineValue`. Other intermediate types, such as strings, arrays, and
complex types, are query-time only.
| Property | Description | Required |
| --- | --- | --- |
diff --git
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
index 84ee947c846..4692ec0715f 100644
---
a/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
+++
b/embedded-tests/src/test/java/org/apache/druid/testing/embedded/compact/CompactionTaskTest.java
@@ -33,9 +33,12 @@ import
org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.query.Druids;
+import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule;
import
org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchModule;
import org.apache.druid.query.aggregation.datasketches.theta.SketchModule;
+import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.metadata.metadata.SegmentMetadataQuery;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.testing.embedded.EmbeddedClusterApis;
@@ -55,6 +58,7 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
@@ -107,6 +111,65 @@ public class CompactionTaskTest extends CompactionTestBase
"namespace", "continent", "country", "region", "city", "timestamp"
);
+ /**
+ * Index task identical in shape to {@link
MoreResources.Task#INDEX_TASK_WITH_AGGREGATORS} but with a pair of
+ * {@link ExpressionLambdaAggregatorFactory} metrics over the {@code added}
long field. Used by
+ * {@link #testCompactionWithExpressionLambdaAggregator} to verify that an
expression aggregator works correctly.
+ */
+ private static final Supplier<TaskBuilder.Index> INDEX_TASK_WITH_EXPR_AGG =
() ->
+ TaskBuilder
+ .ofTypeIndex()
+ .jsonInputFormat()
+ .localInputSourceWithFiles(
+ Resources.DataFile.tinyWiki1Json(),
+ Resources.DataFile.tinyWiki2Json(),
+ Resources.DataFile.tinyWiki3Json()
+ )
+ .timestampColumn("timestamp")
+ .dimensions(
+ "page",
+ "language", "tags", "user", "unpatrolled", "newPage", "robot",
+ "anonymous", "namespace", "continent", "country", "region",
"city"
+ )
+ .metricAggregates(
+ new CountAggregatorFactory("ingested_events"),
+ new ExpressionLambdaAggregatorFactory(
+ "added_sum_expr",
+ Set.of("added"),
+ null,
+ "0",
+ null,
+ null,
+ false,
+ false,
+ "__acc + added",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "added_or_expr",
+ Set.of("added"),
+ null,
+ "0",
+ null,
+ null,
+ false,
+ false,
+ "bitwiseOr(\"__acc\", \"added\")",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .dynamicPartitionWithMaxRows(3)
+ .granularitySpec("DAY", "SECOND", true)
+ .appendToExisting(false);
+
private String fullDatasourceName;
@BeforeEach
@@ -259,6 +322,33 @@ public class CompactionTaskTest extends CompactionTestBase
loadDataAndCompact(INDEX_TASK_WITH_TIMESTAMP.get(), COMPACTION_TASK.get(),
null);
}
+ @Test
+ public void testCompactionWithExpressionLambdaAggregator() throws Exception
+ {
+ try (final Closeable ignored = unloader(fullDatasourceName)) {
+ runTask(INDEX_TASK_WITH_EXPR_AGG.get());
+ verifySegmentsCount(4);
+
+ // Snapshot metric values prior to compaction.
+ final String preCompact = cluster.runSql(
+ "SELECT SUM(added_sum_expr), SUM(added_or_expr) FROM %s",
+ fullDatasourceName
+ );
+
+ // Compact 4 segments -> 2; this performs cross-segment rollup which
drives RowCombiningTimeAndDimsIterator
+ // into ExpressionLambdaAggregatorFactory.makeAggregateCombiner().
+ compactData(COMPACTION_TASK.get(), null, null);
+ verifySegmentsCount(2);
+
+ // Metric values must round-trip through compaction unchanged.
+ final String postCompact = cluster.runSql(
+ "SELECT SUM(added_sum_expr), SUM(added_or_expr) FROM %s",
+ fullDatasourceName
+ );
+ Assertions.assertEquals(preCompact, postCompact);
+ }
+ }
+
private void loadDataAndCompact(
TaskBuilder.Index indexTask,
TaskBuilder.Compact compactionResource,
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
index 3235d709eee..c901b52962f 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
@@ -40,9 +40,11 @@ import org.apache.druid.math.expr.SettableObjectBinding;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.virtual.ExpressionPlan;
import org.apache.druid.segment.virtual.ExpressionPlanner;
import org.apache.druid.segment.virtual.ExpressionSelectors;
@@ -347,6 +349,86 @@ public class ExpressionLambdaAggregatorFactory extends
AggregatorFactory
).value();
}
+ @Override
+ public AggregateCombiner makeAggregateCombiner()
+ {
+ final ColumnType intermediateType = getIntermediateType();
+ // The combiner delegates to combine(), which feeds inputs into
combineExpression typed against initialCombineValue.
+ // If the fold-side intermediate type (what's stored in the segment
column) differs from the combine-side type,
+ // the primitive selector would silently feed wrong-typed values into the
expression. Fall through to UOE.
+ if
(!intermediateType.equals(ExpressionType.toColumnType(initialCombineValue.get().type())))
{
+ return super.makeAggregateCombiner();
+ }
+ if (intermediateType.is(ValueType.LONG)) {
+ return new LongAggregateCombiner()
+ {
+ private long state;
+ private boolean isNull;
+
+ @Override
+ public void reset(ColumnValueSelector selector)
+ {
+ state = selector.getLong();
+ isNull = selector.isNull();
+ }
+
+ @Override
+ public void fold(ColumnValueSelector selector)
+ {
+ final Object combined = combine(isNull ? null : state,
selector.getObject());
+ isNull = combined == null;
+ state = combined == null ? 0L : ((Number) combined).longValue();
+ }
+
+ @Override
+ public long getLong()
+ {
+ return state;
+ }
+
+ @Override
+ public boolean isNull()
+ {
+ return isNull;
+ }
+ };
+ } else if (intermediateType.is(ValueType.DOUBLE)) {
+ return new DoubleAggregateCombiner()
+ {
+ private double state;
+ private boolean isNull;
+
+ @Override
+ public void reset(ColumnValueSelector selector)
+ {
+ state = selector.getDouble();
+ isNull = selector.isNull();
+ }
+
+ @Override
+ public void fold(ColumnValueSelector selector)
+ {
+ final Object combined = combine(isNull ? null : state,
selector.getObject());
+ isNull = combined == null;
+ state = combined == null ? 0.0 : ((Number) combined).doubleValue();
+ }
+
+ @Override
+ public double getDouble()
+ {
+ return state;
+ }
+
+ @Override
+ public boolean isNull()
+ {
+ return isNull;
+ }
+ };
+ }
+ return super.makeAggregateCombiner();
+ }
+
@Override
public Object deserialize(Object object)
{
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregationTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregationTest.java
new file mode 100644
index 00000000000..baef2de3a0a
--- /dev/null
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregationTest.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import org.apache.druid.data.input.InputRow;
+import org.apache.druid.data.input.MapBasedInputRow;
+import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.data.input.impl.StringDimensionSchema;
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.query.Druids;
+import org.apache.druid.query.Result;
+import org.apache.druid.query.expression.TestExprMacroTable;
+import org.apache.druid.query.timeseries.TimeseriesQuery;
+import org.apache.druid.query.timeseries.TimeseriesResultValue;
+import org.apache.druid.segment.IndexBuilder;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.QueryableIndexSegment;
+import org.apache.druid.segment.Segment;
+import org.apache.druid.segment.incremental.IncrementalIndexSchema;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.apache.druid.timeline.SegmentId;
+import org.apache.druid.utils.CloseableUtils;
+import org.joda.time.DateTime;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Verifies that {@link ExpressionLambdaAggregatorFactory} can be used as an
ingest-time metric for primitive numeric
+ * types.
+ */
+public class ExpressionLambdaAggregationTest extends
InitializedNullHandlingTest
+{
+ private static final String DIM = "groupKey";
+ private static final String LONG_FIELD = "longField";
+ private static final String DOUBLE_FIELD = "doubleField";
+ private static final DateTime TIMESTAMP = DateTimes.of("2020-01-01");
+
+ @Rule
+ public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private QueryableIndex mergedIndex;
+ private Segment segment;
+
+ @After
+ public void tearDown()
+ {
+ if (segment != null) {
+ CloseableUtils.closeAndWrapExceptions(segment);
+ }
+ if (mergedIndex != null) {
+ CloseableUtils.closeAndWrapExceptions(mergedIndex);
+ }
+ }
+
+ @Test
+ public void testNumericExpressionLambdaIngestRollupViaMerge() throws
Exception
+ {
+ // Three rows sharing the same (timestamp, dim) so they roll up into a
single output row during merge.
+ // longField values: 1 (0b001), 2 (0b010), 4 (0b100) -> sum=7, bitwiseOr=7
+ // doubleField values: 1.5, 2.0, 0.25 -> sum=3.75
+ final List<InputRow> rows = List.of(
+ row(1L, 1.5),
+ row(2L, 2.0),
+ row(4L, 0.25)
+ );
+
+ final ExpressionLambdaAggregatorFactory longSum = new
ExpressionLambdaAggregatorFactory(
+ "long_sum",
+ Set.of(LONG_FIELD),
+ null,
+ "0",
+ null,
+ null,
+ false,
+ false,
+ "__acc + " + LONG_FIELD,
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ // BitwiseSqlAggregator-style: same single-field, op("__acc", field) fold
+ final ExpressionLambdaAggregatorFactory bitwiseOr = new
ExpressionLambdaAggregatorFactory(
+ "bitwise_or",
+ ImmutableSet.of(LONG_FIELD),
+ null,
+ "0",
+ null,
+ null,
+ false,
+ false,
+ "bitwiseOr(\"__acc\", \"" + LONG_FIELD + "\")",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ final ExpressionLambdaAggregatorFactory doubleSum = new
ExpressionLambdaAggregatorFactory(
+ "double_sum",
+ ImmutableSet.of(DOUBLE_FIELD),
+ null,
+ "0.0",
+ null,
+ null,
+ false,
+ false,
+ "__acc + " + DOUBLE_FIELD,
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ final IncrementalIndexSchema schema = IncrementalIndexSchema.builder()
+ .withQueryGranularity(Granularities.NONE)
+ .withRollup(true)
+ .withDimensionsSpec(
+ DimensionsSpec.builder()
+ .setDimensions(ImmutableList.of(new
StringDimensionSchema(DIM)))
+ .build()
+ )
+ .withMetrics(
+ new CountAggregatorFactory("count"),
+ longSum,
+ bitwiseOr,
+ doubleSum
+ )
+ .build();
+
+ mergedIndex = IndexBuilder.create()
+ .tmpDir(tempFolder.newFolder())
+ .schema(schema)
+ .intermediaryPersistSize(1)
+ .rows(rows)
+ .buildMMappedMergedIndex();
+
+ segment = new QueryableIndexSegment(mergedIndex, SegmentId.dummy("test"));
+
+ final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
+ .dataSource("test")
+ .granularity(Granularities.ALL)
+ .intervals("1970/2050")
+ .aggregators(
+ new
LongSumAggregatorFactory("count", "count"),
+ longSum.getCombiningFactory(),
+ bitwiseOr.getCombiningFactory(),
+ doubleSum.getCombiningFactory()
+ )
+ .build();
+
+ try (final AggregationTestHelper helper =
+
AggregationTestHelper.createTimeseriesQueryAggregationTestHelper(Collections.emptyList(),
tempFolder)) {
+
+ final Sequence<Result<TimeseriesResultValue>> seq =
helper.runQueryOnSegmentsObjs(
+ ImmutableList.of(segment),
+ query
+ );
+ final TimeseriesResultValue result =
Iterables.getOnlyElement(seq.toList()).getValue();
+
+ // Three input rows rolled up into one, count reflects rollup happened
+ Assert.assertEquals(3L, result.getLongMetric("count").longValue());
+ Assert.assertEquals(7L, result.getLongMetric("long_sum").longValue());
+ Assert.assertEquals(7L, result.getLongMetric("bitwise_or").longValue());
+ Assert.assertEquals(3.75,
result.getDoubleMetric("double_sum").doubleValue(), 0.0);
+ }
+ }
+
+ private static InputRow row(long longVal, double doubleVal)
+ {
+ return new MapBasedInputRow(
+ TIMESTAMP,
+ ImmutableList.of(DIM),
+ ImmutableMap.of(
+ DIM, "a",
+ LONG_FIELD, longVal,
+ DOUBLE_FIELD, doubleVal
+ )
+ );
+ }
+}
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
index 499bcef08fe..29bf850d3d4 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
@@ -24,24 +24,31 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.HumanReadableBytes;
+import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import
org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import
org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
+import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.selector.TestColumnValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
+import javax.annotation.Nullable;
import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
public class ExpressionLambdaAggregatorFactoryTest extends
InitializedNullHandlingTest
{
@@ -545,6 +552,303 @@ public class ExpressionLambdaAggregatorFactoryTest
extends InitializedNullHandli
Assert.assertEquals(ColumnType.DOUBLE, agg.getResultType());
}
+ @Test
+ public void testLongAggregateCombiner()
+ {
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ true,
+ false,
+ false,
+ "__acc + x",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ AggregateCombiner combiner = agg.makeAggregateCombiner();
+ TestColumnValueSelector<Long> selector = TestColumnValueSelector.of(
+ Long.class,
+ Arrays.asList(1L, 2L, 3L)
+ );
+ selector.advance();
+ combiner.reset(selector);
+ Assert.assertEquals(1L, combiner.getLong());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(3L, combiner.getLong());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(6L, combiner.getLong());
+ }
+
+ @Test
+ public void testDoubleAggregateCombiner()
+ {
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0.0",
+ null,
+ true,
+ false,
+ false,
+ "__acc + x",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ AggregateCombiner combiner = agg.makeAggregateCombiner();
+ TestColumnValueSelector<Double> selector = TestColumnValueSelector.of(
+ Double.class,
+ Arrays.asList(1.5, 2.25, 0.25)
+ );
+ selector.advance();
+ combiner.reset(selector);
+ Assert.assertEquals(1.5, combiner.getDouble(), 0.0);
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(3.75, combiner.getDouble(), 0.0);
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(4.0, combiner.getDouble(), 0.0);
+ }
+
+ @Test
+ public void testNullableAggregateCombinerSkipsNulls()
+ {
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ true,
+ false,
+ false,
+ "__acc + x",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ AggregateCombiner combiner = agg.makeNullableAggregateCombiner();
+ NullableLongSelector selector = new
NullableLongSelector(Arrays.asList(null, 5L, null, 7L));
+ selector.advance();
+ combiner.reset(selector);
+ Assert.assertTrue(combiner.isNull());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertFalse(combiner.isNull());
+ Assert.assertEquals(5L, combiner.getLong());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(5L, combiner.getLong());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(12L, combiner.getLong());
+ }
+
+ @Test
+ public void
testNullableAggregateCombinerWhenCombineAggregatesNullsExpressionSeesNulls()
+ {
+ // shouldCombineAggregateNullInputs=true means the combine expression sees
null inputs directly. The expression
+ // itself is responsible for handling them; here `nvl` coalesces nulls to
0 so the accumulator keeps advancing.
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ true,
+ true,
+ true,
+ "nvl(__acc, 0) + nvl(x, 0)",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ AggregateCombiner combiner = agg.makeNullableAggregateCombiner();
+ NullableLongSelector selector = new NullableLongSelector(Arrays.asList(1L,
null, 3L));
+ selector.advance();
+ combiner.reset(selector);
+ Assert.assertEquals(1L, combiner.getLong());
+
+ // null is passed through to the expression, which coalesces to 0
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(1L, combiner.getLong());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertEquals(4L, combiner.getLong());
+ }
+
+ @Test
+ public void testNullableAggregateCombinerNullExpressionResultPropagates()
+ {
+ // shouldCombineAggregateNullInputs=true with an expression that doesn't
handle nulls: `__acc + null` evaluates
+ // to null in Druid expression semantics, and the combiner reports isNull
accordingly.
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ true,
+ true,
+ true,
+ "__acc + x",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ AggregateCombiner combiner = agg.makeNullableAggregateCombiner();
+ NullableLongSelector selector = new NullableLongSelector(Arrays.asList(1L,
null));
+ selector.advance();
+ combiner.reset(selector);
+ Assert.assertFalse(combiner.isNull());
+ Assert.assertEquals(1L, combiner.getLong());
+
+ selector.advance();
+ combiner.fold(selector);
+ Assert.assertTrue(combiner.isNull());
+ }
+
+
+ private static final class NullableLongSelector implements
ColumnValueSelector<Long>
+ {
+ private final List<Long> values;
+ private int index = -1;
+
+ NullableLongSelector(List<Long> values)
+ {
+ this.values = values;
+ }
+
+ void advance()
+ {
+ index++;
+ }
+
+ @Override
+ public long getLong()
+ {
+ Long v = values.get(index);
+ return v == null ? 0L : v;
+ }
+
+ @Override
+ public double getDouble()
+ {
+ return getLong();
+ }
+
+ @Override
+ public float getFloat()
+ {
+ return getLong();
+ }
+
+ @Override
+ public boolean isNull()
+ {
+ return values.get(index) == null;
+ }
+
+ @Nullable
+ @Override
+ public Long getObject()
+ {
+ return values.get(index);
+ }
+
+ @Override
+ public Class<Long> classOfObject()
+ {
+ return Long.class;
+ }
+
+ @Override
+ public void inspectRuntimeShape(RuntimeShapeInspector inspector)
+ {
+ }
+ }
+
+ @Test(expected = UOE.class)
+ public void testAggregateCombinerNotSupportedForNonNumericTypes()
+ {
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "''",
+ "''",
+ true,
+ true,
+ true,
+ "concat(__acc, x)",
+ "concat(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ agg.makeAggregateCombiner();
+ }
+
+ @Test(expected = UOE.class)
+ public void testAggregateCombinerNotSupportedWhenFoldAndCombineTypesDiffer()
+ {
+ // fold seed is LONG (intermediate column type), but combine seed is
LONG_ARRAY — combining a long segment column
+ // with an expression that expects arrays would silently produce wrong
values, so the combiner refuses to handle it.
+ ExpressionLambdaAggregatorFactory agg = new
ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ "ARRAY<LONG>[]",
+ null,
+ false,
+ false,
+ "__acc + x",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ColumnType.LONG, agg.getIntermediateType());
+ agg.makeAggregateCombiner();
+ }
+
@Test
public void testResultArraySignature()
{
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]