This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a1edfafcd4 fix #10741: sort keys contains partition keys for getting 
window topk (#10746)
a1edfafcd4 is described below

commit a1edfafcd4025440caef8bba5a0d5a1c432c2480
Author: lgbo <[email protected]>
AuthorDate: Mon Sep 22 17:00:42 2025 +0800

    fix #10741: sort keys contains partition keys for getting window topk 
(#10746)
---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 11 +++++
 cpp-ch/local-engine/Common/GlutenConfig.h          |  4 +-
 .../Parser/RelParsers/GroupLimitRelParser.cpp      | 49 +++++++++++++++++++---
 .../Parser/RelParsers/GroupLimitRelParser.h        |  7 +++-
 .../Parser/RelParsers/SortParsingUtils.cpp         | 26 +++---------
 .../Parser/RelParsers/SortParsingUtils.h           |  4 +-
 6 files changed, 70 insertions(+), 31 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 341a10cb94..c8d6da2b66 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3012,6 +3012,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite
         compareResult = true,
         checkWindowGroupLimit
       )
+
+      compareResultsAgainstVanillaSpark(
+        """
+          |select * from(
+          |select a, b, c, row_number() over (partition by a order by b, c, a) 
as r
+          |from test_win_top)
+          |where r <= 1
+          |""".stripMargin,
+        compareResult = true,
+        checkWindowGroupLimit
+      )
       spark.sql("drop table if exists test_win_top")
     }
 
diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h 
b/cpp-ch/local-engine/Common/GlutenConfig.h
index 7b2a33a9b6..0f68151c03 100644
--- a/cpp-ch/local-engine/Common/GlutenConfig.h
+++ b/cpp-ch/local-engine/Common/GlutenConfig.h
@@ -164,8 +164,8 @@ struct WindowConfig
 public:
     inline static const String WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS = 
"window.aggregate_topk_sample_rows";
     inline static const String 
WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD = 
"window.aggregate_topk_high_cardinality_threshold";
-    size_t aggregate_topk_sample_rows = 5000;
-    double aggregate_topk_high_cardinality_threshold = 0.6;
+    size_t aggregate_topk_sample_rows = 50000;
+    double aggregate_topk_high_cardinality_threshold = 0.4;
     static WindowConfig loadFromContext(const DB::ContextPtr & context);
 };
 
diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp 
b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
index 01c7784460..7dcdbe6430 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
@@ -16,6 +16,7 @@
  */
 
 #include "GroupLimitRelParser.h"
+#include <algorithm>
 #include <memory>
 #include <unordered_set>
 #include <utility>
@@ -46,12 +47,14 @@
 #include <QueryPipeline/QueryPipelineBuilder.h>
 #include <google/protobuf/repeated_field.h>
 #include <google/protobuf/wrappers.pb.h>
+#include "Common/Logger.h"
 #include <Common/AggregateUtil.h>
 #include <Common/ArrayJoinHelper.h>
 #include <Common/GlutenConfig.h>
 #include <Common/PlanUtil.h>
 #include <Common/QueryContext.h>
 #include <Common/logger_useful.h>
+#include "cctz/civil_time_detail.h"
 
 namespace DB::ErrorCodes
 {
@@ -226,6 +229,7 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse(
 
     // If all partition keys are low cardinality keys, use aggregattion to get 
topk of each partition
     auto aggregation_plan = BranchStepHelper::createSubPlan(branch_in_header, 
1);
+    collectPartitionAndSortFields();
     prePrejectionForAggregateArguments(*aggregation_plan);
     addGroupLmitAggregationStep(*aggregation_plan);
     postProjectionForExplodingArrays(*aggregation_plan);
@@ -262,15 +266,40 @@ String 
AggregateGroupLimitRelParser::getAggregateFunctionName(const String & win
         throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported window 
function: {}", window_function_name);
 }
 
+void AggregateGroupLimitRelParser::collectPartitionAndSortFields()
+{
+    partition_fields = 
parsePartitionFields(win_rel_def->partition_expressions());
+    auto full_sort_fields = parseSortFields(win_rel_def->sorts());
+
+    std::set<size_t> partition_fields_set(partition_fields.begin(), 
partition_fields.end());
+    std::set<size_t> full_sort_fields_set(full_sort_fields.begin(), 
full_sort_fields.end());
+    std::set<size_t> selected_sort_fields_set;
+    // Remove partition keys from sort keys
+    std::set_difference(
+        full_sort_fields_set.begin(),
+        full_sort_fields_set.end(),
+        partition_fields_set.begin(),
+        partition_fields_set.end(),
+        std::inserter(selected_sort_fields_set, 
selected_sort_fields_set.begin()));
+    if (selected_sort_fields_set.empty())
+    {
+        // FIXME: support empty sort keys.
+        sort_fields.push_back(*partition_fields_set.begin());
+    }
+    else
+    {
+        sort_fields = std::vector<size_t>(selected_sort_fields_set.begin(), 
selected_sort_fields_set.end());
+    }
+}
+
 // Build one tuple column as the aggregate function's arguments
 void 
AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryPlan 
& plan)
 {
     auto projection_actions = 
std::make_shared<DB::ActionsDAG>(input_header->getColumnsWithTypeAndName());
 
-    auto partition_fields = 
parsePartitionFields(win_rel_def->partition_expressions());
-    auto sort_fields = parseSortFields(win_rel_def->sorts());
     std::set<size_t> unique_partition_fields(partition_fields.begin(), 
partition_fields.end());
     std::set<size_t> unique_sort_fields(sort_fields.begin(), 
sort_fields.end());
+
     DB::NameSet required_column_names;
     auto build_tuple = [&](const DB::DataTypes & data_types,
                            const Strings & names,
@@ -296,12 +325,13 @@ void 
AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryP
     for (size_t i = 0; i < input_header->columns(); ++i)
     {
         const auto & col = input_header->getByPosition(i);
-        if (unique_partition_fields.count(i) && !unique_sort_fields.count(i))
+        if (unique_partition_fields.count(i))
         {
             required_column_names.insert(col.name);
             aggregate_grouping_keys.push_back(col.name);
         }
-        else
+        
+        if (!unique_partition_fields.count(i) || unique_sort_fields.count(i))
         {
             aggregate_data_tuple_types.push_back(col.type);
             aggregate_data_tuple_names.push_back(col.name);
@@ -333,7 +363,15 @@ DB::AggregateDescription 
AggregateGroupLimitRelParser::buildAggregateDescription
     agg_desc.argument_names = {aggregate_tuple_column_name};
     auto & parameters = agg_desc.parameters;
     parameters.push_back(static_cast<UInt32>(limit));
-    auto sort_directions = buildSQLLikeSortDescription(*input_header, 
win_rel_def->sorts());
+    std::set<String> sort_field_names;
+    for (auto i : sort_fields)
+        sort_field_names.insert(input_header->getByPosition(i).name);
+    auto full_sort_desc = parseSortFields(*input_header, win_rel_def->sorts());
+    DB::SortDescription sort_desc;
+    for (const auto & sort_column : full_sort_desc)
+        if (sort_field_names.count(sort_column.column_name))
+            sort_desc.push_back(sort_column);
+    auto sort_directions = buildSQLLikeSortDescription(sort_desc);
     parameters.push_back(sort_directions);
 
     const auto & header = *plan.getCurrentHeader();
@@ -348,6 +386,7 @@ DB::AggregateDescription 
AggregateGroupLimitRelParser::buildAggregateDescription
 void AggregateGroupLimitRelParser::addGroupLmitAggregationStep(DB::QueryPlan & 
plan)
 {
     const auto & settings = getContext()->getSettingsRef();
+
     DB::AggregateDescriptions agg_descs = {buildAggregateDescription(plan)};
     auto params = AggregatorParamsHelper::buildParams(
         getContext(), aggregate_grouping_keys, agg_descs, 
AggregatorParamsHelper::Mode::INIT_TO_COMPLETED);
diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h 
b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
index 44159b0190..f0643421cf 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
+++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
@@ -66,14 +66,17 @@ private:
     String aggregate_function_name;
     size_t limit = 0;
     DB::SharedHeader input_header;
-    // DB::Block output_header;
+    // Field indexes at the input header which are used as partition keys
+    std::vector<size_t> partition_fields;
+    // Field indexes at the input header which are used as sort keys
+    std::vector<size_t> sort_fields;
     DB::Names aggregate_grouping_keys;
     String aggregate_tuple_column_name;
 
     String getAggregateFunctionName(const String & window_function_name);
 
+    void collectPartitionAndSortFields();
     void prePrejectionForAggregateArguments(DB::QueryPlan & plan);
-
     void addGroupLmitAggregationStep(DB::QueryPlan & plan);
     String parseSortDirections(const 
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
     DB::AggregateDescription buildAggregateDescription(DB::QueryPlan & plan);
diff --git a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp 
b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
index c45849d972..39c722fb09 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
@@ -76,32 +76,18 @@ DB::SortDescription parseSortFields(const DB::Block & 
header, const google::prot
     return sort_descr;
 }
 
-std::string
-buildSQLLikeSortDescription(const DB::Block & header, const 
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
+std::string buildSQLLikeSortDescription(const DB::SortDescription & 
sort_description)
 {
-    static const std::unordered_map<int, std::string> order_directions
-        = {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls 
first"}, {4, " desc nulls last"}};
-    size_t n = 0;
     DB::WriteBufferFromOwnString ostr;
-    for (const auto & sort_field : sort_fields)
+    size_t n = 0;
+    for (const auto & sort_column : sort_description)
     {
-        auto it = order_directions.find(sort_field.direction());
-        if (it == order_directions.end())
-            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort 
direction: {}", sort_field.direction());
-        auto field_index = 
SubstraitParserUtils::getStructFieldIndex(sort_field.expr());
-        if (!field_index)
-        {
-            throw DB::Exception(
-                DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column 
reference. but got {}", sort_field.DebugString());
-        }
-        const auto & col_name = header.getByPosition(*field_index).name;
         if (n)
-            ostr << String(",");
-        // the col_name may contain '#' which can may ch fail to parse.
-        ostr << "`" << col_name << "`" << it->second;
+            ostr << String(", ");
+        const auto & col_name = sort_column.column_name;
+        ostr << "`" << col_name << "` " << (sort_column.direction == 1 ? "ASC" 
: "DESC") << " NULLS " << (sort_column.nulls_direction != sort_column.direction 
? "FIRST" : "LAST");
         n += 1;
     }
-    LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: 
{}", ostr.str());
     return ostr.str();
 }
 }
diff --git a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h 
b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
index c460fa758b..4f20675e5a 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
+++ b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
@@ -28,6 +28,6 @@ DB::SortDescription
 parseSortFields(const DB::Block & header, const 
google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
 DB::SortDescription parseSortFields(const DB::Block & header, const 
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
 
-std::string
-buildSQLLikeSortDescription(const DB::Block & header, const 
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
+
+std::string buildSQLLikeSortDescription(const DB::SortDescription & 
sort_description);
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to