This is an automated email from the ASF dual-hosted git repository.
cgivre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git
The following commit(s) were added to refs/heads/master by this push:
new e9afb95f16 DRILL-8381: Add support for filtered aggregate calls (#2734)
e9afb95f16 is described below
commit e9afb95f161f1fec527a04458adf78add91567bc
Author: Volodymyr Vysotskyi <[email protected]>
AuthorDate: Mon Jan 9 16:52:35 2023 +0200
DRILL-8381: Add support for filtered aggregate calls (#2734)
---
.../apache/drill/exec/planner/physical/AggPrelBase.java | 11 ++++++++++-
.../drill/exec/fn/impl/TestAggregateFunctions.java | 17 +++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
index 00ab5394fd..f9a7d0e099 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
@@ -17,6 +17,8 @@
*/
package org.apache.drill.exec.planner.physical;
+import org.apache.drill.common.expression.IfExpression;
+import org.apache.drill.common.expression.NullExpression;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.util.BitSets;
@@ -199,7 +201,14 @@ public abstract class AggPrelBase extends
DrillAggregateRelBase implements Prel
protected LogicalExpression toDrill(AggregateCall call, List<String> fn) {
List<LogicalExpression> args = Lists.newArrayList();
for (Integer i : call.getArgList()) {
- args.add(FieldReference.getWithQuotedRef(fn.get(i)));
+ LogicalExpression expr = FieldReference.getWithQuotedRef(fn.get(i));
+ if (call.hasFilter()) {
+ expr = IfExpression.newBuilder()
+ .setIfCondition(new
IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)),
expr))
+ .setElse(NullExpression.INSTANCE)
+ .build();
+ }
+ args.add(expr);
}
if (SqlKind.COUNT.name().equals(call.getAggregation().getName()) &&
args.isEmpty()) {
diff --git
a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
index 11be3bc3c6..97f3c254b2 100644
---
a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
+++
b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
@@ -1254,4 +1254,21 @@ public class TestAggregateFunctions extends ClusterTest {
.baselineValues("USA", 2.0816659994661326, 8L)
.go();
}
+
+ @Test
+ public void testAggregateWithFilterCall() throws Exception {
+ testBuilder()
+ .sqlQuery(
+ "SELECT count(n_name) FILTER(WHERE n_regionkey = 1) AS
nations_count_in_1_region," +
+ "count(n_name) FILTER(WHERE n_regionkey = 2) AS
nations_count_in_2_region," +
+ "count(n_name) FILTER(WHERE n_regionkey = 3) AS
nations_count_in_3_region," +
+ "count(n_name) FILTER(WHERE n_regionkey = 4) AS
nations_count_in_4_region," +
+ "count(n_name) FILTER(WHERE n_regionkey = 0) AS
nations_count_in_0_region\n" +
+ "FROM cp.`tpch/nation.parquet`")
+ .unOrdered()
+ .baselineColumns("nations_count_in_1_region",
"nations_count_in_2_region",
+ "nations_count_in_3_region", "nations_count_in_4_region",
"nations_count_in_0_region")
+ .baselineValues(5L, 5L, 5L, 5L, 5L)
+ .go();
+ }
}