Repository: calcite Updated Branches: refs/heads/master 03bb2cea5 -> 898c2d66a
[CALCITE-1853] Push Count distinct into Druid when approximate results are acceptable (Zain Humayun) Close apache/calcite#478 Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/898c2d66 Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/898c2d66 Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/898c2d66 Branch: refs/heads/master Commit: 898c2d66a18af9f97f8061803a19f211d3d6ed85 Parents: 03bb2ce Author: Zain Humayun <[email protected]> Authored: Tue Jun 20 13:10:01 2017 -0700 Committer: Jesus Camacho Rodriguez <[email protected]> Committed: Thu Jun 22 17:09:23 2017 +0100 ---------------------------------------------------------------------- .../calcite/adapter/druid/DruidQuery.java | 18 ++- .../calcite/adapter/druid/DruidRules.java | 38 ++++--- .../calcite/adapter/druid/DruidTable.java | 4 + .../org/apache/calcite/test/DruidAdapterIT.java | 112 +++++++++++++++++++ 4 files changed, 153 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/898c2d66/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java ---------------------------------------------------------------------- diff --git a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java index 39ce4a8..1c39fc6 100644 --- a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java +++ b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java @@ -480,13 +480,15 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { return getQuerySpec().queryString; } + protected CalciteConnectionConfig getConnectionConfig() { + return getCluster().getPlanner().getContext().unwrap(CalciteConnectionConfig.class); + } + protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode> projects, ImmutableBitSet groupSet, List<AggregateCall> aggCalls, List<String> aggNames, List<Integer> collationIndexes, List<Direction> collationDirections, ImmutableBitSet numericCollationIndexes, Integer fetch) { - final CalciteConnectionConfig config = - getCluster().getPlanner().getContext() - .unwrap(CalciteConnectionConfig.class); + final CalciteConnectionConfig config = getConnectionConfig(); QueryType queryType = QueryType.SELECT; final Translator translator = new Translator(druidTable, rowType); List<String> fieldNames = rowType.getFieldNames(); @@ -805,10 +807,18 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { // Cannot handle this aggregate function type throw new AssertionError("unknown aggregate type " + type); } + CalciteConnectionConfig config = getConnectionConfig(); switch (aggCall.getAggregation().getKind()) { case COUNT: if (aggCall.isDistinct()) { - return new JsonCardinalityAggregation("cardinality", name, list); + if (config.approximateDistinctCount()) { + return new JsonCardinalityAggregation("cardinality", name, list); + } else { + // Gets thrown if one of the rules allows a count(distinct ...) through + // when approximate results were not told be acceptable. + throw new UnsupportedOperationException("Cannot push " + aggCall + + " because an approximate count distinct is not acceptable."); + } } return new JsonAggregation("count", name, only); case SUM: http://git-wip-us.apache.org/repos/asf/calcite/blob/898c2d66/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java ---------------------------------------------------------------------- diff --git a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java index ef5d6f2..de65a3a 100644 --- a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java +++ b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java @@ -113,20 +113,30 @@ public class DruidRules { SORT_PROJECT_TRANSPOSE); /** Predicate that returns whether Druid can not handle an aggregate. */ - private static final Predicate<Aggregate> BAD_AGG = - new PredicateImpl<Aggregate>() { - public boolean test(Aggregate aggregate) { - final CalciteConnectionConfig config = - aggregate.getCluster().getPlanner().getContext() - .unwrap(CalciteConnectionConfig.class); + private static final Predicate<Triple<Aggregate, RelNode, DruidQuery>> BAD_AGG = + new PredicateImpl<Triple<Aggregate, RelNode, DruidQuery>>() { + public boolean test(Triple<Aggregate, RelNode, DruidQuery> triple) { + final Aggregate aggregate = triple.getLeft(); + final RelNode node = triple.getMiddle(); + final DruidQuery query = triple.getRight(); + + final CalciteConnectionConfig config = query.getConnectionConfig(); for (AggregateCall aggregateCall : aggregate.getAggCallList()) { switch (aggregateCall.getAggregation().getKind()) { case COUNT: - if (!aggregateCall.getArgList().isEmpty()) { - // Cannot handle this aggregate function + // Druid can handle 2 scenarios: + // 1. count(distinct col) when approximate results + // are acceptable and col is not a metric + // 2. count(*) + if (checkAggregateOnMetric(ImmutableBitSet.of(aggregateCall.getArgList()), + node, query)) { return true; } - break; + if ((config.approximateDistinctCount() && aggregateCall.isDistinct()) + || aggregateCall.getArgList().isEmpty()) { + continue; + } + return true; case SUM: case SUM0: case MIN: @@ -264,8 +274,7 @@ public class DruidRules { } else { boolean filterOnMetrics = false; for (Integer i : visitor.inputPosReferenced) { - if (input.druidTable.metricFieldNames.contains( - input.getRowType().getFieldList().get(i).getName())) { + if (input.druidTable.isMetric(input.getRowType().getFieldList().get(i).getName())) { // Filter on metrics, not supported in Druid filterOnMetrics = true; break; @@ -398,7 +407,7 @@ public class DruidRules { } if (aggregate.indicator || aggregate.getGroupSets().size() != 1 - || BAD_AGG.apply(aggregate) + || BAD_AGG.apply(ImmutableTriple.of(aggregate, (RelNode) aggregate, query)) || !validAggregate(aggregate, query)) { return; } @@ -445,7 +454,7 @@ public class DruidRules { } if (aggregate.indicator || aggregate.getGroupSets().size() != 1 - || BAD_AGG.apply(aggregate) + || BAD_AGG.apply(ImmutableTriple.of(aggregate, (RelNode) project, query)) || !validAggregate(aggregate, timestampIdx)) { return; } @@ -698,8 +707,7 @@ public class DruidRules { set = newSet.build(); } for (int index : set) { - if (query.druidTable.metricFieldNames - .contains(query.getTopNode().getRowType().getFieldNames().get(index))) { + if (query.druidTable.isMetric(query.getTopNode().getRowType().getFieldNames().get(index))) { return true; } } http://git-wip-us.apache.org/repos/asf/calcite/blob/898c2d66/druid/src/main/java/org/apache/calcite/adapter/druid/DruidTable.java ---------------------------------------------------------------------- diff --git a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidTable.java b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidTable.java index e92fd4b..656f20f 100644 --- a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidTable.java +++ b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidTable.java @@ -122,6 +122,10 @@ public class DruidTable extends AbstractTable implements TranslatableTable { ImmutableList.<RelNode>of(scan)); } + public boolean isMetric(String name) { + return metricFieldNames.contains(name); + } + /** Creates a {@link RelDataType} from a map of * field names and types. */ private static class MapRelProtoDataType implements RelProtoDataType { http://git-wip-us.apache.org/repos/asf/calcite/blob/898c2d66/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java ---------------------------------------------------------------------- diff --git a/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java b/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java index cb2024d..e04aba1 100644 --- a/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java +++ b/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java @@ -2197,6 +2197,118 @@ public class DruidAdapterIT { .queryContains(druidChecker(druidFilter)) .returnsOrdered("EXPR$0=11"); } + + /** + * Test to ensure that count(distinct ...) gets pushed to Druid when approximate results are + * acceptable + * */ + @Test public void testDistinctCountWhenApproxResultsAccepted() { + String sql = "select count(distinct \"customer_id\") from \"foodmart\""; + String expectedSubExplain = "DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00" + + ":00:00.000/2992-01-10T00:00:00.000]], groups=[{}], aggs=[[COUNT(DISTINCT $20)]])"; + String expectedAggregate = "{'type':'cardinality','name':" + + "'EXPR$0','fieldNames':['customer_id']}"; + + testCountWithApproxDistinct(true, sql, expectedSubExplain, expectedAggregate); + } + + /** + * Test to ensure that count(distinct ...) doesn't get pushed to Druid when approximate results + * are not acceptable + */ + @Test public void testDistinctCountWhenApproxResultsNotAccepted() { + String sql = "select count(distinct \"customer_id\") from \"foodmart\""; + String expectedSubExplain = " BindableAggregate(group=[{}], EXPR$0=[COUNT($0)])\n" + + " DruidQuery(table=[[foodmart, foodmart]], " + + "intervals=[[1900-01-09T00:00:00.000/2992-01-10T00:00:00.000]], " + + "groups=[{20}], aggs=[[]])"; + + testCountWithApproxDistinct(false, sql, expectedSubExplain); + } + + /** + * Test to ensure that a count distinct on metric does not get pushed into Druid + */ + @Test public void testDistinctCountOnMetric() { + String sql = "select count(distinct \"store_sales\") from \"foodmart\" " + + "where \"store_state\" = 'WA'"; + String expectedSubExplain = " BindableAggregate(group=[{}], EXPR$0=[COUNT($0)])\n" + + " BindableAggregate(group=[{1}])"; + + testCountWithApproxDistinct(true, sql, expectedSubExplain); + testCountWithApproxDistinct(false, sql, expectedSubExplain); + } + + /** + * Test to ensure that a count on a metric does not get pushed into Druid + */ + @Test public void testCountOnMetric() { + String sql = "select \"brand_name\", count(\"store_sales\") from \"foodmart\" " + + "group by \"brand_name\""; + String expectedSubExplain = " BindableAggregate(group=[{0}], EXPR$1=[COUNT($1)])\n" + + " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000/" + + "2992-01-10T00:00:00.000]], projects=[[$2, $90]])"; + + testCountWithApproxDistinct(true, sql, expectedSubExplain); + testCountWithApproxDistinct(false, sql, expectedSubExplain); + } + + /** + * Test to ensure that count(*) is pushed into Druid + */ + @Test public void testCountStar() { + String sql = "select count(*) from \"foodmart\""; + String expectedSubExplain = " DruidQuery(table=[[foodmart, foodmart]], " + + "intervals=[[1900-01-09T00:00:00.000/2992-01-10T00:00:00.000]], " + + "projects=[[]], groups=[{}], aggs=[[COUNT()]])"; + + sql(sql).explainContains(expectedSubExplain); + } + + /** + * Test to ensure that count() aggregates with metric columns are not pushed into Druid + * even when the metric column has been renamed + */ + @Test public void testCountOnMetricRenamed() { + String sql = "select \"B\", count(\"A\") from " + + "(select \"unit_sales\" as \"A\", \"customer_id\" as \"B\" from \"foodmart\") " + + "group by \"B\""; + String expectedSubExplain = " BindableAggregate(group=[{0}], EXPR$1=[COUNT($1)])\n" + + " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000" + + "/2992-01-10T00:00:00.000]], projects=[[$20, $89]])\n"; + + testCountWithApproxDistinct(true, sql, expectedSubExplain); + testCountWithApproxDistinct(false, sql, expectedSubExplain); + } + + @Test public void testDistinctCountOnMetricRenamed() { + String sql = "select \"B\", count(distinct \"A\") from " + + "(select \"unit_sales\" as \"A\", \"customer_id\" as \"B\" from \"foodmart\") " + + "group by \"B\""; + String expectedSubExplain = " BindableAggregate(group=[{0}], EXPR$1=[COUNT($1)])\n" + + " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:" + + "00.000/2992-01-10T00:00:00.000]], projects=[[$20, $89]], groups=[{0, 1}], " + + "aggs=[[]])"; + + testCountWithApproxDistinct(true, sql, expectedSubExplain); + testCountWithApproxDistinct(false, sql, expectedSubExplain); + } + + private void testCountWithApproxDistinct(boolean approx, String sql, String expectedExplain) { + testCountWithApproxDistinct(approx, sql, expectedExplain, ""); + } + + private void testCountWithApproxDistinct(boolean approx, String sql, String expectedExplain, + String expectedDruidQuery) { + CalciteAssert.that() + .enable(enabled()) + .with(ImmutableMap.of("model", FOODMART.getPath())) + .with(CalciteConnectionProperty.APPROXIMATE_DISTINCT_COUNT.camelName(), approx) + .query(sql) + .runs() + .explainContains(expectedExplain) + .queryContains(druidChecker(expectedDruidQuery)); + } } // End DruidAdapterIT.java
