This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a1edfafcd4 fix #10741: sort keys contains partition keys for getting
window topk (#10746)
a1edfafcd4 is described below
commit a1edfafcd4025440caef8bba5a0d5a1c432c2480
Author: lgbo <[email protected]>
AuthorDate: Mon Sep 22 17:00:42 2025 +0800
fix #10741: sort keys contains partition keys for getting window topk
(#10746)
---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 11 +++++
cpp-ch/local-engine/Common/GlutenConfig.h | 4 +-
.../Parser/RelParsers/GroupLimitRelParser.cpp | 49 +++++++++++++++++++---
.../Parser/RelParsers/GroupLimitRelParser.h | 7 +++-
.../Parser/RelParsers/SortParsingUtils.cpp | 26 +++---------
.../Parser/RelParsers/SortParsingUtils.h | 4 +-
6 files changed, 70 insertions(+), 31 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 341a10cb94..c8d6da2b66 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3012,6 +3012,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite
compareResult = true,
checkWindowGroupLimit
)
+
+ compareResultsAgainstVanillaSpark(
+ """
+ |select * from(
+ |select a, b, c, row_number() over (partition by a order by b, c, a)
as r
+ |from test_win_top)
+ |where r <= 1
+ |""".stripMargin,
+ compareResult = true,
+ checkWindowGroupLimit
+ )
spark.sql("drop table if exists test_win_top")
}
diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h
b/cpp-ch/local-engine/Common/GlutenConfig.h
index 7b2a33a9b6..0f68151c03 100644
--- a/cpp-ch/local-engine/Common/GlutenConfig.h
+++ b/cpp-ch/local-engine/Common/GlutenConfig.h
@@ -164,8 +164,8 @@ struct WindowConfig
public:
inline static const String WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS =
"window.aggregate_topk_sample_rows";
inline static const String
WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD =
"window.aggregate_topk_high_cardinality_threshold";
- size_t aggregate_topk_sample_rows = 5000;
- double aggregate_topk_high_cardinality_threshold = 0.6;
+ size_t aggregate_topk_sample_rows = 50000;
+ double aggregate_topk_high_cardinality_threshold = 0.4;
static WindowConfig loadFromContext(const DB::ContextPtr & context);
};
diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
index 01c7784460..7dcdbe6430 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
@@ -16,6 +16,7 @@
*/
#include "GroupLimitRelParser.h"
+#include <algorithm>
#include <memory>
#include <unordered_set>
#include <utility>
@@ -46,12 +47,14 @@
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/wrappers.pb.h>
+#include "Common/Logger.h"
#include <Common/AggregateUtil.h>
#include <Common/ArrayJoinHelper.h>
#include <Common/GlutenConfig.h>
#include <Common/PlanUtil.h>
#include <Common/QueryContext.h>
#include <Common/logger_useful.h>
+#include "cctz/civil_time_detail.h"
namespace DB::ErrorCodes
{
@@ -226,6 +229,7 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse(
// If all partition keys are low cardinality keys, use aggregattion to get
topk of each partition
auto aggregation_plan = BranchStepHelper::createSubPlan(branch_in_header,
1);
+ collectPartitionAndSortFields();
prePrejectionForAggregateArguments(*aggregation_plan);
addGroupLmitAggregationStep(*aggregation_plan);
postProjectionForExplodingArrays(*aggregation_plan);
@@ -262,15 +266,40 @@ String
AggregateGroupLimitRelParser::getAggregateFunctionName(const String & win
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported window
function: {}", window_function_name);
}
+void AggregateGroupLimitRelParser::collectPartitionAndSortFields()
+{
+ partition_fields =
parsePartitionFields(win_rel_def->partition_expressions());
+ auto full_sort_fields = parseSortFields(win_rel_def->sorts());
+
+ std::set<size_t> partition_fields_set(partition_fields.begin(),
partition_fields.end());
+ std::set<size_t> full_sort_fields_set(full_sort_fields.begin(),
full_sort_fields.end());
+ std::set<size_t> selected_sort_fields_set;
+ // Remove partition keys from sort keys
+ std::set_difference(
+ full_sort_fields_set.begin(),
+ full_sort_fields_set.end(),
+ partition_fields_set.begin(),
+ partition_fields_set.end(),
+ std::inserter(selected_sort_fields_set,
selected_sort_fields_set.begin()));
+ if (selected_sort_fields_set.empty())
+ {
+ // FIXME: support empty sort keys.
+ sort_fields.push_back(*partition_fields_set.begin());
+ }
+ else
+ {
+ sort_fields = std::vector<size_t>(selected_sort_fields_set.begin(),
selected_sort_fields_set.end());
+ }
+}
+
// Build one tuple column as the aggregate function's arguments
void
AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryPlan
& plan)
{
auto projection_actions =
std::make_shared<DB::ActionsDAG>(input_header->getColumnsWithTypeAndName());
- auto partition_fields =
parsePartitionFields(win_rel_def->partition_expressions());
- auto sort_fields = parseSortFields(win_rel_def->sorts());
std::set<size_t> unique_partition_fields(partition_fields.begin(),
partition_fields.end());
std::set<size_t> unique_sort_fields(sort_fields.begin(),
sort_fields.end());
+
DB::NameSet required_column_names;
auto build_tuple = [&](const DB::DataTypes & data_types,
const Strings & names,
@@ -296,12 +325,13 @@ void
AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryP
for (size_t i = 0; i < input_header->columns(); ++i)
{
const auto & col = input_header->getByPosition(i);
- if (unique_partition_fields.count(i) && !unique_sort_fields.count(i))
+ if (unique_partition_fields.count(i))
{
required_column_names.insert(col.name);
aggregate_grouping_keys.push_back(col.name);
}
- else
+
+ if (!unique_partition_fields.count(i) || unique_sort_fields.count(i))
{
aggregate_data_tuple_types.push_back(col.type);
aggregate_data_tuple_names.push_back(col.name);
@@ -333,7 +363,15 @@ DB::AggregateDescription
AggregateGroupLimitRelParser::buildAggregateDescription
agg_desc.argument_names = {aggregate_tuple_column_name};
auto & parameters = agg_desc.parameters;
parameters.push_back(static_cast<UInt32>(limit));
- auto sort_directions = buildSQLLikeSortDescription(*input_header,
win_rel_def->sorts());
+ std::set<String> sort_field_names;
+ for (auto i : sort_fields)
+ sort_field_names.insert(input_header->getByPosition(i).name);
+ auto full_sort_desc = parseSortFields(*input_header, win_rel_def->sorts());
+ DB::SortDescription sort_desc;
+ for (const auto & sort_column : full_sort_desc)
+ if (sort_field_names.count(sort_column.column_name))
+ sort_desc.push_back(sort_column);
+ auto sort_directions = buildSQLLikeSortDescription(sort_desc);
parameters.push_back(sort_directions);
const auto & header = *plan.getCurrentHeader();
@@ -348,6 +386,7 @@ DB::AggregateDescription
AggregateGroupLimitRelParser::buildAggregateDescription
void AggregateGroupLimitRelParser::addGroupLmitAggregationStep(DB::QueryPlan &
plan)
{
const auto & settings = getContext()->getSettingsRef();
+
DB::AggregateDescriptions agg_descs = {buildAggregateDescription(plan)};
auto params = AggregatorParamsHelper::buildParams(
getContext(), aggregate_grouping_keys, agg_descs,
AggregatorParamsHelper::Mode::INIT_TO_COMPLETED);
diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
index 44159b0190..f0643421cf 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
+++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
@@ -66,14 +66,17 @@ private:
String aggregate_function_name;
size_t limit = 0;
DB::SharedHeader input_header;
- // DB::Block output_header;
+ // Field indexes at the input header which are used as partition keys
+ std::vector<size_t> partition_fields;
+ // Field indexes at the input header which are used as sort keys
+ std::vector<size_t> sort_fields;
DB::Names aggregate_grouping_keys;
String aggregate_tuple_column_name;
String getAggregateFunctionName(const String & window_function_name);
+ void collectPartitionAndSortFields();
void prePrejectionForAggregateArguments(DB::QueryPlan & plan);
-
void addGroupLmitAggregationStep(DB::QueryPlan & plan);
String parseSortDirections(const
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
DB::AggregateDescription buildAggregateDescription(DB::QueryPlan & plan);
diff --git a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
index c45849d972..39c722fb09 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
@@ -76,32 +76,18 @@ DB::SortDescription parseSortFields(const DB::Block &
header, const google::prot
return sort_descr;
}
-std::string
-buildSQLLikeSortDescription(const DB::Block & header, const
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
+std::string buildSQLLikeSortDescription(const DB::SortDescription &
sort_description)
{
- static const std::unordered_map<int, std::string> order_directions
- = {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls
first"}, {4, " desc nulls last"}};
- size_t n = 0;
DB::WriteBufferFromOwnString ostr;
- for (const auto & sort_field : sort_fields)
+ size_t n = 0;
+ for (const auto & sort_column : sort_description)
{
- auto it = order_directions.find(sort_field.direction());
- if (it == order_directions.end())
- throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort
direction: {}", sort_field.direction());
- auto field_index =
SubstraitParserUtils::getStructFieldIndex(sort_field.expr());
- if (!field_index)
- {
- throw DB::Exception(
- DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column
reference. but got {}", sort_field.DebugString());
- }
- const auto & col_name = header.getByPosition(*field_index).name;
if (n)
- ostr << String(",");
- // the col_name may contain '#' which can may ch fail to parse.
- ostr << "`" << col_name << "`" << it->second;
+ ostr << String(", ");
+ const auto & col_name = sort_column.column_name;
+ ostr << "`" << col_name << "` " << (sort_column.direction == 1 ? "ASC"
: "DESC") << " NULLS " << (sort_column.nulls_direction != sort_column.direction
? "FIRST" : "LAST");
n += 1;
}
- LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue:
{}", ostr.str());
return ostr.str();
}
}
diff --git a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
index c460fa758b..4f20675e5a 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
+++ b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
@@ -28,6 +28,6 @@ DB::SortDescription
parseSortFields(const DB::Block & header, const
google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
DB::SortDescription parseSortFields(const DB::Block & header, const
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
-std::string
-buildSQLLikeSortDescription(const DB::Block & header, const
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
+
+std::string buildSQLLikeSortDescription(const DB::SortDescription &
sort_description);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]