This is an automated email from the ASF dual-hosted git repository.

amashenkov pushed a commit to branch ignite-21580
in repository https://gitbox.apache.org/repos/asf/ignite-3.git

commit f2df075a0e12b54d1a7064d150a489979f8502db
Author: amashenkov <andrey.mashen...@gmail.com>
AuthorDate: Wed Mar 27 21:09:35 2024 +0300

    Fix Collation pass-through for sorted map aggregate
---
 .../engine/rel/agg/IgniteSortAggregateBase.java    | 49 +++++++++----
 .../internal/sql/engine/planner/PlannerTest.java   | 82 ++++++++++++++++++++++
 2 files changed, 118 insertions(+), 13 deletions(-)

diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/IgniteSortAggregateBase.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/IgniteSortAggregateBase.java
index d763c49b95..e5e88d9e18 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/IgniteSortAggregateBase.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/IgniteSortAggregateBase.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.internal.sql.engine.rel.agg;
 
+import static java.util.function.Predicate.not;
 import static org.apache.ignite.internal.sql.engine.util.Commons.maxPrefix;
 
 import it.unimi.dsi.fastutil.ints.IntList;
@@ -29,9 +30,11 @@ import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.mapping.Mapping;
 import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
 import org.apache.ignite.internal.sql.engine.trait.TraitUtils;
 import org.apache.ignite.internal.sql.engine.trait.TraitsAwareIgniteRel;
+import org.apache.ignite.internal.sql.engine.util.Commons;
 
 /**
  * Defines common methods for {@link IgniteRel relational nodes} that 
implement sort-base aggregates.
@@ -77,24 +80,35 @@ interface IgniteSortAggregateBase extends 
TraitsAwareIgniteRel {
             RelTraitSet nodeTraits, List<RelTraitSet> inputTraits
     ) {
         RelCollation required = TraitUtils.collation(nodeTraits);
-        ImmutableBitSet requiredKeys = ImmutableBitSet.of(required.getKeys());
-        RelCollation collation;
 
-        if (getGroupSet().contains(requiredKeys)) {
-            List<RelFieldCollation> newCollationFields = new 
ArrayList<>(getGroupSet().cardinality());
-            newCollationFields.addAll(required.getFieldCollations());
-
-            ImmutableBitSet keysLeft = getGroupSet().except(requiredKeys);
+        if (getGroupSet().isEmpty()) {
+            return passThroughCollation(nodeTraits, inputTraits, 
RelCollations.EMPTY); // Erase collation for a single group.
+        } else if (required.getFieldCollations().isEmpty()) {
+            return passThroughCollation(nodeTraits, inputTraits, 
TraitUtils.createCollation(getGroupSet().asList())); // No match.
+        }
 
-            keysLeft.forEach(fieldIdx -> 
newCollationFields.add(TraitUtils.createFieldCollation(fieldIdx)));
+        ImmutableBitSet groupingColumns = 
ImmutableBitSet.range(getGroupSet().cardinality());
+        IntList prefix = maxPrefix(required.getKeys(), 
groupingColumns.asSet());
 
-            collation = RelCollations.of(newCollationFields);
-        } else {
-            collation = TraitUtils.createCollation(getGroupSet().toList());
+        if (prefix.isEmpty()) {
+            return passThroughCollation(nodeTraits, inputTraits, 
TraitUtils.createCollation(getGroupSet().asList())); // No match.
         }
 
-        return Pair.of(nodeTraits.replace(collation),
-                List.of(inputTraits.get(0).replace(collation)));
+        // Rearrange grouping columns to satisfy required collation (as much 
as possible).
+        List<RelFieldCollation> newCollationColumns = new 
ArrayList<>(getGroupSet().cardinality());
+        Mapping mapping = 
Commons.trimmingMapping(groupingColumns.cardinality(), groupingColumns);
+
+        // Add required columns first.
+        prefix.intStream().map(mapping::getTarget)
+                .forEach(fieldIdx -> 
newCollationColumns.add(TraitUtils.createFieldCollation(fieldIdx)));
+
+        // Then add missed grouping columns.
+        groupingColumns.asList().stream()
+                .filter(not(prefix::contains))
+                .map(mapping::getTarget)
+                .forEach(fieldIdx -> 
newCollationColumns.add(TraitUtils.createFieldCollation(fieldIdx)));
+
+        return passThroughCollation(nodeTraits, inputTraits, 
RelCollations.of(newCollationColumns));
     }
 
     /** {@inheritDoc} */
@@ -118,4 +132,13 @@ interface IgniteSortAggregateBase extends 
TraitsAwareIgniteRel {
                 inputTraits
         ));
     }
+
+    private static Pair<RelTraitSet, List<RelTraitSet>> passThroughCollation(
+            RelTraitSet nodeTraits,
+            List<RelTraitSet> inputTraits,
+            RelCollation collation
+    ) {
+        return Pair.of(nodeTraits.replace(collation),
+                List.of(inputTraits.get(0).replace(collation)));
+    }
 }
diff --git 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/PlannerTest.java
 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/PlannerTest.java
index a43123a395..322b99fc41 100644
--- 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/PlannerTest.java
+++ 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/PlannerTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.internal.sql.engine.planner;
 
+import static java.util.function.Predicate.not;
 import static org.apache.calcite.tools.Frameworks.createRootSchema;
 import static org.apache.calcite.tools.Frameworks.newConfigBuilder;
 import static 
org.apache.ignite.internal.sql.engine.planner.CorrelatedSubqueryPlannerTest.createTestTable;
@@ -38,6 +39,7 @@ import org.apache.calcite.plan.RelOptUtil;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelRoot;
+import org.apache.calcite.rel.core.Sort;
 import org.apache.calcite.rel.core.TableScan;
 import org.apache.calcite.rel.hint.HintStrategyTable;
 import org.apache.calcite.rex.RexNode;
@@ -52,9 +54,15 @@ import 
org.apache.ignite.internal.sql.engine.prepare.IgnitePlanner;
 import org.apache.ignite.internal.sql.engine.prepare.PlannerPhase;
 import org.apache.ignite.internal.sql.engine.prepare.PlanningContext;
 import org.apache.ignite.internal.sql.engine.rel.IgniteConvention;
+import org.apache.ignite.internal.sql.engine.rel.IgniteExchange;
 import org.apache.ignite.internal.sql.engine.rel.IgniteFilter;
+import org.apache.ignite.internal.sql.engine.rel.IgniteIndexScan;
 import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
+import org.apache.ignite.internal.sql.engine.rel.IgniteSort;
 import org.apache.ignite.internal.sql.engine.rel.IgniteTableScan;
+import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapHashAggregate;
+import org.apache.ignite.internal.sql.engine.rel.agg.IgniteReduceHashAggregate;
+import org.apache.ignite.internal.sql.engine.schema.IgniteIndex.Collation;
 import org.apache.ignite.internal.sql.engine.schema.IgniteSchema;
 import org.apache.ignite.internal.sql.engine.trait.IgniteDistribution;
 import org.apache.ignite.internal.sql.engine.trait.IgniteDistributions;
@@ -115,6 +123,80 @@ public class PlannerTest extends AbstractPlannerTest {
                 )));
     }
 
+
+    @Test
+    public void tpchTest_q1() throws Exception {
+        IgniteSchema publicSchema = createSchema(TestBuilders.table()
+                .name("LINEITEM")
+                .addKeyColumn("L_ORDERKEY", NativeTypes.INT32)
+                .addColumn("L_PARTKEY", NativeTypes.INT32)
+                .addColumn("L_SUPPKEY", NativeTypes.INT32)
+                .addKeyColumn("L_LINENUMBER", NativeTypes.INT32)
+                .addColumn("L_QUANTITY", NativeTypes.decimalOf(15, 2))
+                .addColumn("L_EXTENDEDPRICE", NativeTypes.decimalOf(15, 2))
+                .addColumn("L_DISCOUNT", NativeTypes.decimalOf(15, 2))
+                .addColumn("L_TAX", NativeTypes.decimalOf(15, 2))
+                .addColumn("L_RETURNFLAG", NativeTypes.stringOf(1))
+                .addColumn("L_LINESTATUS", NativeTypes.stringOf(1))
+                .addColumn("L_SHIPDATE", NativeTypes.DATE)
+                .addColumn("L_COMMITDATE", NativeTypes.DATE)
+                .addColumn("L_RECEIPTDATE", NativeTypes.DATE)
+                .addColumn("L_SHIPINSTRUCT", NativeTypes.stringOf(25))
+                .addColumn("L_SHIPMODE", NativeTypes.stringOf(10))
+                .addColumn("L_COMMENT", NativeTypes.STRING)
+                .distribution(IgniteDistributions.hash(List.of(0, 3)))
+                .hashIndex()
+                .name("LINEITEM_PK")
+                .addColumn("L_ORDERKEY")
+                .addColumn("L_LINENUMBER")
+                .end()
+                .sortedIndex()
+                .name("IDX_SHIPDATE")
+                .addColumn("L_SHIPDATE", Collation.ASC_NULLS_LAST)
+                .end()
+                .build());
+
+        String sql = ""
+                + "SELECT\n"
+//                  + "/*+ DISABLE_RULE("
+//                + "'ColocatedHashAggregateConverterRule',"
+//                + "'ColocatedSortAggregateConverterRule',"
+//                + "'MapReduceSortAggregateConverterRule'"
+//                + ") */"
+                + "    l_returnflag,\n"
+                + "    l_linestatus,\n"
+                + "    sum(l_quantity)                                       
AS sum_qty,\n"
+                + "    sum(l_extendedprice)                                  
AS sum_base_price,\n"
+                + "    sum(l_extendedprice * (1 - l_discount))               
AS sum_disc_price,\n"
+                + "    sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) 
AS sum_charge,\n"
+                + "    avg(l_quantity)                                       
AS avg_qty,\n"
+                + "    avg(l_extendedprice)                                  
AS avg_price,\n"
+                + "    avg(l_discount)                                       
AS avg_disc,\n"
+                + "    count(*)                                              
AS count_order\n"
+                + "FROM\n"
+                + "    lineitem \n"
+                + "WHERE\n"
+                + "        l_shipdate <= DATE '1998-12-01' - INTERVAL '90' 
DAY\n"
+                + "GROUP BY\n"
+                + "    l_returnflag,\n"
+                + "    l_linestatus\n"
+                + "ORDER BY\n"
+                + "    l_returnflag,\n"
+                + "    l_linestatus\n";
+
+        assertPlan(sql, publicSchema,
+                nodeOrAnyChild(isInstanceOf(IgniteSort.class)
+                        
.and(nodeOrAnyChild(isInstanceOf(IgniteReduceHashAggregate.class)
+                                
.and(nodeOrAnyChild(isInstanceOf(IgniteExchange.class)
+                                        
.and(nodeOrAnyChild(isInstanceOf(IgniteMapHashAggregate.class)
+                                                
.and(nodeOrAnyChild(isInstanceOf(IgniteIndexScan.class)))
+                                                
.and(not(nodeOrAnyChild(isInstanceOf(Sort.class))))
+                                        ))
+                                ))
+                        ))
+                ));
+    }
+
     @Test
     @Disabled("https://issues.apache.org/jira/browse/IGNITE-21286";)
     public void testJoinPushExpressionRule() throws Exception {

Reply via email to