This is an automated email from the ASF dual-hosted git repository.
lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 255d80cf981 [Feature](exec) Support group by limit opt in BE code
(#29641)
255d80cf981 is described below
commit 255d80cf981a573fe3c6753c5f2da089f6ba2479
Author: HappenLee <[email protected]>
AuthorDate: Mon Jun 3 20:19:37 2024 +0800
[Feature](exec) Support group by limit opt in BE code (#29641)
## Proposed changes
Do group by limit, do topn in opt in BE
---
be/src/pipeline/dependency.cpp | 76 ++++++
be/src/pipeline/dependency.h | 68 +++++
be/src/pipeline/exec/aggregation_sink_operator.cpp | 288 ++++++++++++++++-----
be/src/pipeline/exec/aggregation_sink_operator.h | 11 +-
.../pipeline/exec/aggregation_source_operator.cpp | 12 +-
be/src/pipeline/exec/aggregation_source_operator.h | 1 +
be/src/pipeline/exec/operator.cpp | 3 +-
be/src/vec/columns/column_nullable.cpp | 1 +
be/src/vec/columns/column_string.cpp | 10 +-
be/src/vec/columns/column_string.h | 1 +
be/src/vec/columns/column_vector.cpp | 12 +-
be/src/vec/core/block.cpp | 13 +
be/src/vec/core/block.h | 3 +-
be/src/vec/exec/vaggregation_node.cpp | 243 ++++++++++++++++-
be/src/vec/exec/vaggregation_node.h | 205 +++++++++------
15 files changed, 773 insertions(+), 174 deletions(-)
diff --git a/be/src/pipeline/dependency.cpp b/be/src/pipeline/dependency.cpp
index 8cf025274af..e7159b2df35 100644
--- a/be/src/pipeline/dependency.cpp
+++ b/be/src/pipeline/dependency.cpp
@@ -196,6 +196,82 @@ LocalExchangeSharedState::LocalExchangeSharedState(int
num_instances) {
mem_trackers.resize(num_instances, nullptr);
}
+vectorized::MutableColumns AggSharedState::_get_keys_hash_table() {
+ return std::visit(
+ vectorized::Overload {
+ [&](std::monostate& arg) {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ return vectorized::MutableColumns();
+ },
+ [&](auto&& agg_method) -> vectorized::MutableColumns {
+ vectorized::MutableColumns key_columns;
+ for (int i = 0; i < probe_expr_ctxs.size(); ++i) {
+ key_columns.emplace_back(
+
probe_expr_ctxs[i]->root()->data_type()->create_column());
+ }
+ auto& data = *agg_method.hash_table;
+ bool has_null_key = data.has_null_key_data();
+ const auto size = data.size() - has_null_key;
+ using KeyType =
std::decay_t<decltype(agg_method.iterator->get_first())>;
+ std::vector<KeyType> keys(size);
+
+ size_t num_rows = 0;
+ auto iter = aggregate_data_container->begin();
+ {
+ while (iter != aggregate_data_container->end()) {
+ keys[num_rows] = iter.get_key<KeyType>();
+ ++iter;
+ ++num_rows;
+ }
+ }
+ agg_method.insert_keys_into_columns(keys, key_columns,
num_rows);
+ if (has_null_key) {
+ key_columns[0]->insert_data(nullptr, 0);
+ }
+ return key_columns;
+ }},
+ agg_data->method_variant);
+}
+
+void AggSharedState::build_limit_heap(size_t hash_table_size) {
+ limit_columns = _get_keys_hash_table();
+ for (size_t i = 0; i < hash_table_size; ++i) {
+ limit_heap.emplace(i, limit_columns, order_directions,
null_directions);
+ }
+ while (hash_table_size > limit) {
+ limit_heap.pop();
+ hash_table_size--;
+ }
+ limit_columns_min = limit_heap.top()._row_id;
+}
+
+bool AggSharedState::do_limit_filter(vectorized::Block* block, size_t
num_rows) {
+ if (num_rows) {
+ cmp_res.resize(num_rows);
+ need_computes.resize(num_rows);
+ memset(need_computes.data(), 0, need_computes.size());
+ memset(cmp_res.data(), 0, cmp_res.size());
+
+ const auto key_size = null_directions.size();
+ for (int i = 0; i < key_size; i++) {
+ block->get_by_position(i).column->compare_internal(
+ limit_columns_min, *limit_columns[i], null_directions[i],
order_directions[i],
+ cmp_res, need_computes.data());
+ }
+
+ auto set_computes_arr = [](auto* __restrict res, auto* __restrict
computes, int rows) {
+ for (int i = 0; i < rows; ++i) {
+ computes[i] = computes[i] == res[i];
+ }
+ };
+ set_computes_arr(cmp_res.data(), need_computes.data(), num_rows);
+
+ return std::find(need_computes.begin(), need_computes.end(), 0) !=
need_computes.end();
+ }
+
+ return false;
+}
+
Status AggSharedState::reset_hash_table() {
return std::visit(
vectorized::Overload {
diff --git a/be/src/pipeline/dependency.h b/be/src/pipeline/dependency.h
index d7084f85d5d..e32f5a1c0d6 100644
--- a/be/src/pipeline/dependency.h
+++ b/be/src/pipeline/dependency.h
@@ -311,6 +311,9 @@ public:
Status reset_hash_table();
+ bool do_limit_filter(vectorized::Block* block, size_t num_rows);
+ void build_limit_heap(size_t hash_table_size);
+
// We should call this function only at 1st phase.
// 1st phase: is_merge=true, only have one SlotRef.
// 2nd phase: is_merge=false, maybe have multiple exprs.
@@ -346,8 +349,73 @@ public:
MemoryRecord mem_usage_record;
bool agg_data_created_without_key = false;
bool enable_spill = false;
+ bool reach_limit = false;
+
+ int64_t limit = -1;
+ bool do_sort_limit = false;
+ vectorized::MutableColumns limit_columns;
+ int limit_columns_min = -1;
+ vectorized::PaddedPODArray<uint8_t> need_computes;
+ std::vector<uint8_t> cmp_res;
+ std::vector<int> order_directions;
+ std::vector<int> null_directions;
+
+ struct HeapLimitCursor {
+ HeapLimitCursor(int row_id, vectorized::MutableColumns& limit_columns,
+ std::vector<int>& order_directions, std::vector<int>&
null_directions)
+ : _row_id(row_id),
+ _limit_columns(limit_columns),
+ _order_directions(order_directions),
+ _null_directions(null_directions) {}
+
+ HeapLimitCursor(const HeapLimitCursor& other) noexcept
+ : _row_id(other._row_id),
+ _limit_columns(other._limit_columns),
+ _order_directions(other._order_directions),
+ _null_directions(other._null_directions) {}
+
+ HeapLimitCursor(HeapLimitCursor&& other) noexcept
+ : _row_id(other._row_id),
+ _limit_columns(other._limit_columns),
+ _order_directions(other._order_directions),
+ _null_directions(other._null_directions) {}
+
+ HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept {
+ _row_id = other._row_id;
+ return *this;
+ }
+
+ HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept {
+ _row_id = other._row_id;
+ return *this;
+ }
+
+ bool operator<(const HeapLimitCursor& rhs) const {
+ for (int i = 0; i < _limit_columns.size(); ++i) {
+ const auto& _limit_column = _limit_columns[i];
+ auto res = _limit_column->compare_at(_row_id, rhs._row_id,
*_limit_column,
+ _null_directions[i]) *
+ _order_directions[i];
+ if (res < 0) {
+ return true;
+ } else if (res > 0) {
+ return false;
+ }
+ }
+ return false;
+ }
+
+ int _row_id;
+ vectorized::MutableColumns& _limit_columns;
+ std::vector<int>& _order_directions;
+ std::vector<int>& _null_directions;
+ };
+
+ std::priority_queue<HeapLimitCursor> limit_heap;
private:
+ vectorized::MutableColumns _get_keys_hash_table();
+
void _close_with_serialized_key() {
std::visit(vectorized::Overload {[&](std::monostate& arg) -> void {
// Do nothing
diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp
b/be/src/pipeline/exec/aggregation_sink_operator.cpp
index 79f5b5af083..a3ac73a5d85 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.cpp
+++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp
@@ -69,6 +69,7 @@ Status AggSinkLocalState::init(RuntimeState* state,
LocalSinkStateInfo& info) {
_serialize_data_timer = ADD_TIMER(Base::profile(), "SerializeDataTime");
_deserialize_data_timer = ADD_TIMER(Base::profile(),
"DeserializeAndMergeTime");
_hash_table_compute_timer = ADD_TIMER(Base::profile(),
"HashTableComputeTime");
+ _hash_table_limit_compute_timer = ADD_TIMER(Base::profile(),
"DoLimitComputeTime");
_hash_table_emplace_timer = ADD_TIMER(Base::profile(),
"HashTableEmplaceTime");
_hash_table_input_counter = ADD_COUNTER(Base::profile(),
"HashTableInputCount", TUnit::UNIT);
_max_row_size_counter = ADD_COUNTER(Base::profile(), "MaxRowSizeInBytes",
TUnit::UNIT);
@@ -86,6 +87,11 @@ Status AggSinkLocalState::open(RuntimeState* state) {
Base::_shared_state->offsets_of_aggregate_states =
p._offsets_of_aggregate_states;
Base::_shared_state->make_nullable_keys = p._make_nullable_keys;
Base::_shared_state->probe_expr_ctxs.resize(p._probe_expr_ctxs.size());
+
+ Base::_shared_state->limit = p._limit;
+ Base::_shared_state->do_sort_limit = p._do_sort_limit;
+ Base::_shared_state->null_directions = p._null_directions;
+ Base::_shared_state->order_directions = p._order_directions;
for (size_t i = 0; i < Base::_shared_state->probe_expr_ctxs.size(); i++) {
RETURN_IF_ERROR(
p._probe_expr_ctxs[i]->clone(state,
Base::_shared_state->probe_expr_ctxs[i]));
@@ -132,7 +138,6 @@ Status AggSinkLocalState::open(RuntimeState* state) {
_should_limit_output = p._limit != -1 && // has limit
(!p._have_conjuncts) && // no having conjunct
- p._needs_finalize && // agg's finalize step
!Base::_shared_state->enable_spill;
}
for (auto& evaluator : p._aggregate_evaluators) {
@@ -183,7 +188,7 @@ Status
AggSinkLocalState::_execute_without_key(vectorized::Block* block) {
}
Status AggSinkLocalState::_merge_with_serialized_key(vectorized::Block* block)
{
- if (_reach_limit) {
+ if (_shared_state->reach_limit) {
return _merge_with_serialized_key_helper<true, false>(block);
} else {
return _merge_with_serialized_key_helper<false, false>(block);
@@ -260,12 +265,14 @@ Status
AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b
size_t key_size = Base::_shared_state->probe_expr_ctxs.size();
vectorized::ColumnRawPtrs key_columns(key_size);
+ std::vector<int> key_locs(key_size);
for (size_t i = 0; i < key_size; ++i) {
if constexpr (for_spill) {
key_columns[i] = block->get_by_position(i).column.get();
+ key_locs[i] = i;
} else {
- int result_column_id = -1;
+ int& result_column_id = key_locs[i];
RETURN_IF_ERROR(
Base::_shared_state->probe_expr_ctxs[i]->execute(block,
&result_column_id));
block->replace_by_position_if_const(result_column_id);
@@ -278,7 +285,7 @@ Status
AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b
_places.resize(rows);
}
- if constexpr (limit) {
+ if (limit && !_shared_state->do_sort_limit) {
_find_in_hash_table(_places.data(), key_columns, rows);
for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size();
++i) {
@@ -318,52 +325,66 @@ Status
AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b
}
}
} else {
- _emplace_into_hash_table(_places.data(), key_columns, rows);
+ bool need_do_agg = true;
+ if (limit) {
+ need_do_agg = _emplace_into_hash_table_limit(_places.data(),
block, key_locs,
+ key_columns, rows);
+ } else {
+ _emplace_into_hash_table(_places.data(), key_columns, rows);
+ }
- for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size();
++i) {
- if (Base::_shared_state->aggregate_evaluators[i]->is_merge() ||
for_spill) {
- int col_id = 0;
- if constexpr (for_spill) {
- col_id = Base::_shared_state->probe_expr_ctxs.size() + i;
+ if (need_do_agg) {
+ for (int i = 0; i <
Base::_shared_state->aggregate_evaluators.size(); ++i) {
+ if (Base::_shared_state->aggregate_evaluators[i]->is_merge()
|| for_spill) {
+ int col_id = 0;
+ if constexpr (for_spill) {
+ col_id = Base::_shared_state->probe_expr_ctxs.size() +
i;
+ } else {
+ col_id = AggSharedState::get_slot_column_id(
+ Base::_shared_state->aggregate_evaluators[i]);
+ }
+ auto column = block->get_by_position(col_id).column;
+ if (column->is_nullable()) {
+ column = ((vectorized::ColumnNullable*)column.get())
+ ->get_nested_column_ptr();
+ }
+
+ size_t buffer_size =
Base::_shared_state->aggregate_evaluators[i]
+ ->function()
+ ->size_of_data() *
+ rows;
+ if (_deserialize_buffer.size() < buffer_size) {
+ _deserialize_buffer.resize(buffer_size);
+ }
+
+ {
+ SCOPED_TIMER(_deserialize_data_timer);
+ Base::_shared_state->aggregate_evaluators[i]
+ ->function()
+ ->deserialize_and_merge_vec(
+ _places.data(),
+ Base::_parent->template
cast<AggSinkOperatorX>()
+
._offsets_of_aggregate_states[i],
+ _deserialize_buffer.data(),
column.get(), _agg_arena_pool,
+ rows);
+ }
} else {
- col_id = AggSharedState::get_slot_column_id(
- Base::_shared_state->aggregate_evaluators[i]);
- }
- auto column = block->get_by_position(col_id).column;
- if (column->is_nullable()) {
- column =
((vectorized::ColumnNullable*)column.get())->get_nested_column_ptr();
- }
-
- size_t buffer_size =
-
Base::_shared_state->aggregate_evaluators[i]->function()->size_of_data() *
- rows;
- if (_deserialize_buffer.size() < buffer_size) {
- _deserialize_buffer.resize(buffer_size);
- }
-
- {
- SCOPED_TIMER(_deserialize_data_timer);
- Base::_shared_state->aggregate_evaluators[i]
- ->function()
- ->deserialize_and_merge_vec(
- _places.data(),
- Base::_parent->template
cast<AggSinkOperatorX>()
- ._offsets_of_aggregate_states[i],
- _deserialize_buffer.data(), column.get(),
_agg_arena_pool,
- rows);
+
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
+ block,
+ Base::_parent->template cast<AggSinkOperatorX>()
+ ._offsets_of_aggregate_states[i],
+ _places.data(), _agg_arena_pool));
}
- } else {
-
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
- block,
- Base::_parent->template cast<AggSinkOperatorX>()
- ._offsets_of_aggregate_states[i],
- _places.data(), _agg_arena_pool));
}
}
- if (_should_limit_output) {
- _reach_limit = _get_hash_table_size() >=
- Base::_parent->template
cast<AggSinkOperatorX>()._limit;
+ if (!limit && _should_limit_output) {
+ const size_t hash_table_size = _get_hash_table_size();
+ _shared_state->reach_limit =
+ hash_table_size >= Base::_parent->template
cast<AggSinkOperatorX>()._limit;
+ if (_shared_state->do_sort_limit && _shared_state->reach_limit) {
+ _shared_state->build_limit_heap(hash_table_size);
+ }
}
}
@@ -410,7 +431,7 @@ void AggSinkLocalState::_update_memusage_without_key() {
}
Status AggSinkLocalState::_execute_with_serialized_key(vectorized::Block*
block) {
- if (_reach_limit) {
+ if (_shared_state->reach_limit) {
return _execute_with_serialized_key_helper<true>(block);
} else {
return _execute_with_serialized_key_helper<false>(block);
@@ -424,10 +445,11 @@ Status
AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block*
size_t key_size = Base::_shared_state->probe_expr_ctxs.size();
vectorized::ColumnRawPtrs key_columns(key_size);
+ std::vector<int> key_locs(key_size);
{
SCOPED_TIMER(_expr_timer);
for (size_t i = 0; i < key_size; ++i) {
- int result_column_id = -1;
+ int& result_column_id = key_locs[i];
RETURN_IF_ERROR(
Base::_shared_state->probe_expr_ctxs[i]->execute(block,
&result_column_id));
block->get_by_position(result_column_id).column =
@@ -442,7 +464,7 @@ Status
AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block*
_places.resize(rows);
}
- if constexpr (limit) {
+ if (limit && !_shared_state->do_sort_limit) {
_find_in_hash_table(_places.data(), key_columns, rows);
for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size();
++i) {
@@ -454,27 +476,48 @@ Status
AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block*
_places.data(), _agg_arena_pool));
}
} else {
- _emplace_into_hash_table(_places.data(), key_columns, rows);
-
- for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size();
++i) {
-
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
- block,
- Base::_parent->template cast<AggSinkOperatorX>()
- ._offsets_of_aggregate_states[i],
- _places.data(), _agg_arena_pool));
- }
+ auto do_aggregate_evaluators = [&] {
+ for (int i = 0; i <
Base::_shared_state->aggregate_evaluators.size(); ++i) {
+
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
+ block,
+ Base::_parent->template cast<AggSinkOperatorX>()
+ ._offsets_of_aggregate_states[i],
+ _places.data(), _agg_arena_pool));
+ }
+ return Status::OK();
+ };
- if (_should_limit_output) {
- _reach_limit = _get_hash_table_size() >=
- Base::_parent->template
cast<AggSinkOperatorX>()._limit;
- if (_reach_limit &&
- Base::_parent->template
cast<AggSinkOperatorX>()._can_short_circuit) {
- Base::_dependency->set_ready_to_read();
- return Status::Error<ErrorCode::END_OF_FILE>("");
+ if constexpr (limit) {
+ if (_emplace_into_hash_table_limit(_places.data(), block,
key_locs, key_columns,
+ rows)) {
+ RETURN_IF_ERROR(do_aggregate_evaluators());
+ }
+ } else {
+ _emplace_into_hash_table(_places.data(), key_columns, rows);
+ RETURN_IF_ERROR(do_aggregate_evaluators());
+
+ if (_should_limit_output && !Base::_shared_state->enable_spill) {
+ const size_t hash_table_size = _get_hash_table_size();
+ if (Base::_parent->template
cast<AggSinkOperatorX>()._can_short_circuit) {
+ _shared_state->reach_limit =
+ hash_table_size >=
+ Base::_parent->template
cast<AggSinkOperatorX>()._limit;
+ if (_shared_state->reach_limit) {
+ Base::_dependency->set_ready_to_read();
+ return Status::Error<ErrorCode::END_OF_FILE>("");
+ }
+ } else {
+ _shared_state->reach_limit =
+ hash_table_size >= _shared_state->do_sort_limit
+ ? Base::_parent->template
cast<AggSinkOperatorX>()._limit * 5
+ : Base::_parent->template
cast<AggSinkOperatorX>()._limit;
+ if (_shared_state->reach_limit &&
_shared_state->do_sort_limit) {
+ _shared_state->build_limit_heap(hash_table_size);
+ }
+ }
}
}
}
-
return Status::OK();
}
@@ -535,6 +578,108 @@ void
AggSinkLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* p
_agg_data->method_variant);
}
+bool
AggSinkLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr*
places,
+ vectorized::Block*
block,
+ const std::vector<int>&
key_locs,
+
vectorized::ColumnRawPtrs& key_columns,
+ size_t num_rows) {
+ return std::visit(
+ vectorized::Overload {
+ [&](std::monostate& arg) {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ return true;
+ },
+ [&](auto&& agg_method) -> bool {
+ SCOPED_TIMER(_hash_table_compute_timer);
+ using HashMethodType =
std::decay_t<decltype(agg_method)>;
+ using AggState = typename HashMethodType::State;
+
+ bool need_filter = false;
+ {
+ SCOPED_TIMER(_hash_table_limit_compute_timer);
+ need_filter =
_shared_state->do_limit_filter(block, num_rows);
+ }
+
+ auto& need_computes = _shared_state->need_computes;
+ if (auto need_agg =
+ std::find(need_computes.begin(),
need_computes.end(), 1);
+ need_agg != need_computes.end()) {
+ if (need_filter) {
+
vectorized::Block::filter_block_internal(block, need_computes);
+ for (int i = 0; i < key_locs.size(); ++i) {
+ key_columns[i] =
+
block->get_by_position(key_locs[i]).column.get();
+ }
+ num_rows = block->rows();
+ }
+
+ AggState state(key_columns);
+ agg_method.init_serialized_keys(key_columns,
num_rows);
+ size_t i = 0;
+
+ auto refresh_top_limit = [&, this]() {
+ _shared_state->limit_heap.pop();
+ for (int j = 0; j < key_columns.size(); ++j) {
+
_shared_state->limit_columns[j]->insert_from(*key_columns[j],
+
i);
+ }
+ _shared_state->limit_heap.emplace(
+
_shared_state->limit_columns[0]->size() - 1,
+ _shared_state->limit_columns,
+ _shared_state->order_directions,
+ _shared_state->null_directions);
+ _shared_state->limit_columns_min =
+
_shared_state->limit_heap.top()._row_id;
+ };
+
+ auto creator = [this, refresh_top_limit](const
auto& ctor, auto& key,
+ auto&
origin) {
+ try {
+
HashMethodType::try_presis_key_and_origin(key, origin,
+
*_agg_arena_pool);
+ auto mapped =
+
_shared_state->aggregate_data_container->append_data(
+ origin);
+ auto st = _create_agg_status(mapped);
+ if (!st) {
+ throw Exception(st.code(),
st.to_string());
+ }
+ ctor(key, mapped);
+ refresh_top_limit();
+ } catch (...) {
+ // Exception-safety - if it can not
allocate memory or create status,
+ // the destructors will not be called.
+ ctor(key, nullptr);
+ throw;
+ }
+ };
+
+ auto creator_for_null_key = [this,
refresh_top_limit](auto& mapped) {
+ mapped = _agg_arena_pool->aligned_alloc(
+ Base::_parent->template
cast<AggSinkOperatorX>()
+
._total_size_of_aggregate_states,
+ Base::_parent->template
cast<AggSinkOperatorX>()
+ ._align_aggregate_states);
+ auto st = _create_agg_status(mapped);
+ if (!st) {
+ throw Exception(st.code(), st.to_string());
+ }
+ refresh_top_limit();
+ };
+
+ SCOPED_TIMER(_hash_table_emplace_timer);
+ for (i = 0; i < num_rows; ++i) {
+ places[i] = agg_method.lazy_emplace(state, i,
creator,
+
creator_for_null_key);
+ }
+ COUNTER_UPDATE(_hash_table_input_counter,
num_rows);
+ return true;
+ }
+ return false;
+ }},
+ _agg_data->method_variant);
+}
+
void AggSinkLocalState::_find_in_hash_table(vectorized::AggregateDataPtr*
places,
vectorized::ColumnRawPtrs&
key_columns,
size_t num_rows) {
@@ -616,6 +761,21 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode,
RuntimeState* state) {
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
[](const auto& e) { return
e.nodes[0].agg_expr.is_merge_agg; });
+ if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
+ _do_sort_limit = true;
+ const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
+ DCHECK_EQ(agg_sort_info.nulls_first.size(),
agg_sort_info.is_asc_order.size());
+
+ const int order_by_key_size = agg_sort_info.is_asc_order.size();
+ _order_directions.resize(order_by_key_size);
+ _null_directions.resize(order_by_key_size);
+ for (int i = 0; i < order_by_key_size; ++i) {
+ _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
+ _null_directions[i] =
+ agg_sort_info.nulls_first[i] ? -_order_directions[i] :
_order_directions[i];
+ }
+ }
+
return Status::OK();
}
diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h
b/be/src/pipeline/exec/aggregation_sink_operator.h
index d48debc2d83..39fee1707e4 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.h
+++ b/be/src/pipeline/exec/aggregation_sink_operator.h
@@ -85,6 +85,9 @@ protected:
vectorized::ColumnRawPtrs& key_columns, size_t
num_rows);
void _emplace_into_hash_table(vectorized::AggregateDataPtr* places,
vectorized::ColumnRawPtrs& key_columns,
size_t num_rows);
+ bool _emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places,
+ vectorized::Block* block, const
std::vector<int>& key_locs,
+ vectorized::ColumnRawPtrs&
key_columns, size_t num_rows);
size_t _get_hash_table_size() const;
template <bool limit, bool for_spill = false>
@@ -96,6 +99,7 @@ protected:
RuntimeProfile::Counter* _hash_table_compute_timer = nullptr;
RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr;
+ RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr;
RuntimeProfile::Counter* _hash_table_input_counter = nullptr;
RuntimeProfile::Counter* _build_timer = nullptr;
RuntimeProfile::Counter* _expr_timer = nullptr;
@@ -109,7 +113,6 @@ protected:
RuntimeProfile::HighWaterMarkCounter* _serialize_key_arena_memory_usage =
nullptr;
bool _should_limit_output = false;
- bool _reach_limit = false;
vectorized::PODArray<vectorized::AggregateDataPtr> _places;
std::vector<char> _deserialize_buffer;
@@ -191,8 +194,12 @@ protected:
ObjectPool* _pool = nullptr;
std::vector<size_t> _make_nullable_keys;
int64_t _limit; // -1: no limit
- bool _have_conjuncts;
+ // do sort limit and directions
+ bool _do_sort_limit = false;
+ std::vector<int> _order_directions;
+ std::vector<int> _null_directions;
+ bool _have_conjuncts;
const std::vector<TExpr> _partition_exprs;
const bool _is_colocate;
diff --git a/be/src/pipeline/exec/aggregation_source_operator.cpp
b/be/src/pipeline/exec/aggregation_source_operator.cpp
index b94f076bdbf..cca9fefbdb2 100644
--- a/be/src/pipeline/exec/aggregation_source_operator.cpp
+++ b/be/src/pipeline/exec/aggregation_source_operator.cpp
@@ -22,7 +22,6 @@
#include "common/exception.h"
#include "pipeline/exec/operator.h"
-#include "vec//utils/util.hpp"
namespace doris::pipeline {
@@ -444,10 +443,19 @@ Status AggSourceOperatorX::get_block(RuntimeState* state,
vectorized::Block* blo
local_state.make_nullable_output_key(block);
// dispose the having clause, should not be execute in prestreaming agg
RETURN_IF_ERROR(vectorized::VExprContext::filter_block(_conjuncts, block,
block->columns()));
- local_state.reached_limit(block, eos);
+ local_state.do_agg_limit(block, eos);
return Status::OK();
}
+void AggLocalState::do_agg_limit(vectorized::Block* block, bool* eos) {
+ if (_shared_state->reach_limit) {
+ if (_shared_state->do_sort_limit &&
_shared_state->do_limit_filter(block, block->rows())) {
+ vectorized::Block::filter_block_internal(block,
_shared_state->need_computes);
+ }
+ reached_limit(block, eos);
+ }
+}
+
void AggLocalState::make_nullable_output_key(vectorized::Block* block) {
if (block->rows() != 0) {
for (auto cid : _shared_state->make_nullable_keys) {
diff --git a/be/src/pipeline/exec/aggregation_source_operator.h
b/be/src/pipeline/exec/aggregation_source_operator.h
index c4ea6c6ccde..a3824a381eb 100644
--- a/be/src/pipeline/exec/aggregation_source_operator.h
+++ b/be/src/pipeline/exec/aggregation_source_operator.h
@@ -41,6 +41,7 @@ public:
void make_nullable_output_key(vectorized::Block* block);
template <bool limit>
Status merge_with_serialized_key_helper(vectorized::Block* block);
+ void do_agg_limit(vectorized::Block* block, bool* eos);
protected:
friend class AggSourceOperatorX;
diff --git a/be/src/pipeline/exec/operator.cpp
b/be/src/pipeline/exec/operator.cpp
index 938eb22f253..455f11fa9f1 100644
--- a/be/src/pipeline/exec/operator.cpp
+++ b/be/src/pipeline/exec/operator.cpp
@@ -389,8 +389,7 @@ std::shared_ptr<BasicSharedState>
DataSinkOperatorX<LocalStateType>::create_shar
LOG(FATAL) << "should not reach here!";
return nullptr;
} else {
- std::shared_ptr<BasicSharedState> ss = nullptr;
- ss = LocalStateType::SharedStateType::create_shared();
+ auto ss = LocalStateType::SharedStateType::create_shared();
ss->id = operator_id();
for (auto& dest : dests_id()) {
ss->related_op_ids.insert(dest);
diff --git a/be/src/vec/columns/column_nullable.cpp
b/be/src/vec/columns/column_nullable.cpp
index 6efa690d7db..c516b96b72f 100644
--- a/be/src/vec/columns/column_nullable.cpp
+++ b/be/src/vec/columns/column_nullable.cpp
@@ -422,6 +422,7 @@ int ColumnNullable::compare_at(size_t n, size_t m, const
IColumn& rhs_,
return get_nested_column().compare_at(n, m,
nullable_rhs.get_nested_column(),
null_direction_hint);
}
+
void ColumnNullable::compare_internal(size_t rhs_row_id, const IColumn& rhs,
int nan_direction_hint,
int direction, std::vector<uint8>&
cmp_res,
uint8* __restrict filter) const {
diff --git a/be/src/vec/columns/column_string.cpp
b/be/src/vec/columns/column_string.cpp
index 446fd283b1c..919854a42d9 100644
--- a/be/src/vec/columns/column_string.cpp
+++ b/be/src/vec/columns/column_string.cpp
@@ -544,7 +544,7 @@ template <typename T>
void ColumnStr<T>::compare_internal(size_t rhs_row_id, const IColumn& rhs, int
nan_direction_hint,
int direction, std::vector<uint8>& cmp_res,
uint8* __restrict filter) const {
- auto sz = this->size();
+ auto sz = offsets.size();
DCHECK(cmp_res.size() == sz);
const auto& cmp_base = assert_cast<const
ColumnStr<T>&>(rhs).get_data_at(rhs_row_id);
size_t begin = simd::find_zero(cmp_res, 0);
@@ -554,12 +554,8 @@ void ColumnStr<T>::compare_internal(size_t rhs_row_id,
const IColumn& rhs, int n
auto value_a = get_data_at(row_id);
int res = memcmp_small_allow_overflow15(value_a.data,
value_a.size, cmp_base.data,
cmp_base.size);
- if (res * direction < 0) {
- filter[row_id] = 1;
- cmp_res[row_id] = 1;
- } else if (res * direction > 0) {
- cmp_res[row_id] = 1;
- }
+ cmp_res[row_id] = res != 0;
+ filter[row_id] = res * direction < 0;
}
begin = simd::find_zero(cmp_res, end + 1);
}
diff --git a/be/src/vec/columns/column_string.h
b/be/src/vec/columns/column_string.h
index d0994607a46..22dcd612d3a 100644
--- a/be/src/vec/columns/column_string.h
+++ b/be/src/vec/columns/column_string.h
@@ -549,6 +549,7 @@ public:
void compare_internal(size_t rhs_row_id, const IColumn& rhs, int
nan_direction_hint,
int direction, std::vector<uint8>& cmp_res,
uint8* __restrict filter) const override;
+
MutableColumnPtr get_shinked_column() const {
auto shrinked_column = ColumnStr<T>::create();
for (int i = 0; i < size(); i++) {
diff --git a/be/src/vec/columns/column_vector.cpp
b/be/src/vec/columns/column_vector.cpp
index 60a75420405..14d52045943 100644
--- a/be/src/vec/columns/column_vector.cpp
+++ b/be/src/vec/columns/column_vector.cpp
@@ -141,21 +141,17 @@ void ColumnVector<T>::compare_internal(size_t rhs_row_id,
const IColumn& rhs,
int nan_direction_hint, int direction,
std::vector<uint8>& cmp_res,
uint8* __restrict filter) const {
- auto sz = this->size();
+ const auto sz = data.size();
DCHECK(cmp_res.size() == sz);
const auto& cmp_base = assert_cast<const
ColumnVector<T>&>(rhs).get_data()[rhs_row_id];
size_t begin = simd::find_zero(cmp_res, 0);
while (begin < sz) {
size_t end = simd::find_one(cmp_res, begin + 1);
for (size_t row_id = begin; row_id < end; row_id++) {
- auto value_a = get_data()[row_id];
+ auto value_a = data[row_id];
int res = value_a > cmp_base ? 1 : (value_a < cmp_base ? -1 : 0);
- if (res * direction < 0) {
- filter[row_id] = 1;
- cmp_res[row_id] = 1;
- } else if (res * direction > 0) {
- cmp_res[row_id] = 1;
- }
+ cmp_res[row_id] = (res != 0);
+ filter[row_id] = (res * direction < 0);
}
begin = simd::find_zero(cmp_res, end + 1);
}
diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp
index 7595ffb6620..95af060dfc7 100644
--- a/be/src/vec/core/block.cpp
+++ b/be/src/vec/core/block.cpp
@@ -797,6 +797,19 @@ void Block::filter_block_internal(Block* block, const
IColumn::Filter& filter,
filter_block_internal(block, columns_to_filter, filter);
}
+void Block::filter_block_internal(Block* block, const IColumn::Filter& filter)
{
+ const size_t count =
+ filter.size() - simd::count_zero_num((int8_t*)filter.data(),
filter.size());
+ for (int i = 0; i < block->columns(); ++i) {
+ auto& column = block->get_by_position(i).column;
+ if (column->is_exclusive()) {
+ column->assume_mutable()->filter(filter);
+ } else {
+ column = column->filter(filter, count);
+ }
+ }
+}
+
Block Block::copy_block(const std::vector<int>& column_offset) const {
ColumnsWithTypeAndName columns_with_type_and_name;
for (auto offset : column_offset) {
diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h
index 593d37f7ff2..3611252ea59 100644
--- a/be/src/vec/core/block.h
+++ b/be/src/vec/core/block.h
@@ -281,10 +281,11 @@ public:
// need exception safety
static void filter_block_internal(Block* block, const
std::vector<uint32_t>& columns_to_filter,
const IColumn::Filter& filter);
-
// need exception safety
static void filter_block_internal(Block* block, const IColumn::Filter&
filter,
uint32_t column_to_keep);
+ // need exception safety
+ static void filter_block_internal(Block* block, const IColumn::Filter&
filter);
static Status filter_block(Block* block, const std::vector<uint32_t>&
columns_to_filter,
int filter_column_id, int column_to_keep);
diff --git a/be/src/vec/exec/vaggregation_node.cpp
b/be/src/vec/exec/vaggregation_node.cpp
index 1845382a2b4..f009802d5dd 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -158,6 +158,21 @@ Status AggregationNode::init(const TPlanNode& tnode,
RuntimeState* state) {
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
[](const auto& e) { return
e.nodes[0].agg_expr.is_merge_agg; });
+
+ if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
+ _do_sort_limit = true;
+ const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
+ DCHECK_EQ(agg_sort_info.nulls_first.size(),
agg_sort_info.is_asc_order.size());
+
+ const int order_by_key_size = agg_sort_info.is_asc_order.size();
+ _order_directions.resize(order_by_key_size);
+ _null_directions.resize(order_by_key_size);
+ for (int i = 0; i < order_by_key_size; ++i) {
+ _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
+ _null_directions[i] =
+ agg_sort_info.nulls_first[i] ? -_order_directions[i] :
_order_directions[i];
+ }
+ }
return Status::OK();
}
@@ -183,6 +198,7 @@ Status AggregationNode::prepare_profile(RuntimeState*
state) {
_deserialize_data_timer = ADD_TIMER(runtime_profile(),
"DeserializeAndMergeTime");
_hash_table_compute_timer = ADD_TIMER(runtime_profile(),
"HashTableComputeTime");
_hash_table_emplace_timer = ADD_TIMER(runtime_profile(),
"HashTableEmplaceTime");
+ _hash_table_limit_compute_timer = ADD_TIMER(runtime_profile(),
"DoLimitComputeTime");
_hash_table_iterate_timer = ADD_TIMER(runtime_profile(),
"HashTableIterateTime");
_insert_keys_to_column_timer = ADD_TIMER(runtime_profile(),
"InsertKeysToColumnTime");
_streaming_agg_timer = ADD_TIMER(runtime_profile(), "StreamingAggTime");
@@ -315,9 +331,8 @@ Status AggregationNode::prepare_profile(RuntimeState*
state) {
std::bind<void>(&AggregationNode::_update_memusage_with_serialized_key, this);
_executor.close =
std::bind<void>(&AggregationNode::_close_with_serialized_key, this);
- _should_limit_output = _limit != -1 && // has limit
- _conjuncts.empty() && // no having conjunct
- _needs_finalize; // agg's finalize step
+ _should_limit_output = _limit != -1 && // has limit
+ _conjuncts.empty();
}
fmt::memory_buffer msg;
@@ -436,8 +451,12 @@ Status AggregationNode::pull(doris::RuntimeState* state,
vectorized::Block* bloc
_make_nullable_output_key(block);
// dispose the having clause, should not be execute in prestreaming agg
RETURN_IF_ERROR(VExprContext::filter_block(_conjuncts, block,
block->columns()));
- reached_limit(block, eos);
-
+ if (_reach_limit) {
+ if (_do_sort_limit && _do_limit_filter(block,
_order_directions.size(), block->rows())) {
+ Block::filter_block_internal(block, _need_computes);
+ }
+ reached_limit(block, eos);
+ }
return Status::OK();
}
@@ -775,6 +794,158 @@ size_t AggregationNode::_get_hash_table_size() {
_agg_data->method_variant);
}
+template <bool limit>
+Status AggregationNode::_execute_with_serialized_key_helper(Block* block) {
+ DCHECK(!_probe_expr_ctxs.empty());
+
+ size_t key_size = _probe_expr_ctxs.size();
+ ColumnRawPtrs key_columns(key_size);
+ std::vector<int> key_locs(key_size);
+ {
+ SCOPED_TIMER(_expr_timer);
+ for (size_t i = 0; i < key_size; ++i) {
+ auto& result_column_id = key_locs[i];
+ RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block,
&result_column_id));
+ block->get_by_position(result_column_id).column =
+ block->get_by_position(result_column_id)
+ .column->convert_to_full_column_if_const();
+ key_columns[i] =
block->get_by_position(result_column_id).column.get();
+ }
+ }
+
+ int rows = block->rows();
+ if (_places.size() < rows) {
+ _places.resize(rows);
+ }
+
+ if constexpr (limit) {
+ if (_emplace_into_hash_table_limit(_places.data(), block, key_locs,
key_columns, rows)) {
+ for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+ RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
+ block, _offsets_of_aggregate_states[i], _places.data(),
+ _agg_arena_pool.get()));
+ }
+ }
+ } else {
+ _emplace_into_hash_table(_places.data(), key_columns, rows);
+
+ for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+ RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
+ block, _offsets_of_aggregate_states[i], _places.data(),
_agg_arena_pool.get()));
+ }
+
+ if (_should_limit_output && !_reach_limit) {
+ auto size = _get_hash_table_size();
+ _reach_limit = size >= _limit * 5;
+ if (_reach_limit) {
+ _build_limit_heap(size);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+void AggregationNode::_build_limit_heap(size_t hash_table_size) {
+ _limit_columns = _get_keys_hash_table();
+ for (size_t i = 0; i < hash_table_size; ++i) {
+ _limit_heap.emplace(i, _limit_columns, _order_directions,
_null_directions);
+ }
+ while (hash_table_size > _limit) {
+ _limit_heap.pop();
+ hash_table_size--;
+ }
+ _limit_columns_min = _limit_heap.top()._row_id;
+}
+
+bool AggregationNode::_emplace_into_hash_table_limit(AggregateDataPtr* places,
Block* block,
+ const std::vector<int>&
key_locs,
+ ColumnRawPtrs&
key_columns, size_t num_rows) {
+ return std::visit(
+ Overload {[&](std::monostate& arg) {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ return true;
+ },
+ [&](auto&& agg_method) -> bool {
+ SCOPED_TIMER(_hash_table_compute_timer);
+ using HashMethodType =
std::decay_t<decltype(agg_method)>;
+ using AggState = typename HashMethodType::State;
+
+ bool need_filter = false;
+ {
+ SCOPED_TIMER(_hash_table_limit_compute_timer);
+ need_filter = _do_limit_filter(block,
key_columns.size(), num_rows);
+ }
+
+ if (auto need_agg =
+ std::find(_need_computes.begin(),
_need_computes.end(), 1);
+ need_agg != _need_computes.end()) {
+ if (need_filter) {
+ Block::filter_block_internal(block,
_need_computes);
+ for (int i = 0; i < key_locs.size(); ++i) {
+ key_columns[i] =
+
block->get_by_position(key_locs[i]).column.get();
+ }
+ num_rows = block->rows();
+ }
+
+ AggState state(key_columns);
+ agg_method.init_serialized_keys(key_columns,
num_rows);
+ size_t i = 0;
+
+ auto refresh_top_limit = [&, this] {
+ _limit_heap.pop();
+ for (int j = 0; j < key_columns.size(); ++j)
{
+
_limit_columns[j]->insert_from(*key_columns[j], i);
+ }
+
_limit_heap.emplace(_limit_columns[0]->size() - 1, _limit_columns,
+ _order_directions,
_null_directions);
+ _limit_columns_min =
_limit_heap.top()._row_id;
+ };
+
+ auto creator = [this, refresh_top_limit](const
auto& ctor, auto& key,
+ auto&
origin) {
+ try {
+
HashMethodType::try_presis_key_and_origin(key, origin,
+
*_agg_arena_pool);
+ auto mapped =
_aggregate_data_container->append_data(origin);
+ auto st = _create_agg_status(mapped);
+ if (!st) {
+ throw Exception(st.code(),
st.to_string());
+ }
+ ctor(key, mapped);
+ refresh_top_limit();
+ } catch (...) {
+ // Exception-safety - if it can not
allocate memory or create status,
+ // the destructors will not be called.
+ ctor(key, nullptr);
+ throw;
+ }
+ };
+
+ auto creator_for_null_key = [this,
refresh_top_limit](auto& mapped) {
+ mapped = _agg_arena_pool->aligned_alloc(
+ _total_size_of_aggregate_states,
_align_aggregate_states);
+ auto st = _create_agg_status(mapped);
+ if (!st) {
+ throw Exception(st.code(),
st.to_string());
+ }
+ refresh_top_limit();
+ };
+
+ SCOPED_TIMER(_hash_table_emplace_timer);
+ for (i = 0; i < num_rows; ++i) {
+ places[i] = agg_method.lazy_emplace(state,
i, creator,
+
creator_for_null_key);
+ }
+ COUNTER_UPDATE(_hash_table_input_counter,
num_rows);
+ return true;
+ }
+ return false;
+ }},
+ _agg_data->method_variant);
+}
+
void AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places,
ColumnRawPtrs& key_columns,
const size_t num_rows) {
std::visit(Overload {[&](std::monostate& arg) {
@@ -785,6 +956,7 @@ void
AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, ColumnR
SCOPED_TIMER(_hash_table_compute_timer);
using HashMethodType =
std::decay_t<decltype(agg_method)>;
using AggState = typename HashMethodType::State;
+
AggState state(key_columns);
agg_method.init_serialized_keys(key_columns,
num_rows);
@@ -850,6 +1022,42 @@ void
AggregationNode::_find_in_hash_table(AggregateDataPtr* places, ColumnRawPtr
_agg_data->method_variant);
}
+MutableColumns AggregationNode::_get_keys_hash_table() {
+ return std::visit(
+ Overload {[&](std::monostate& arg) {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ return MutableColumns();
+ },
+ [&](auto&& agg_method) -> MutableColumns {
+ MutableColumns key_columns;
+ for (int i = 0; i < _probe_expr_ctxs.size(); ++i) {
+ key_columns.emplace_back(
+
_probe_expr_ctxs[i]->root()->data_type()->create_column());
+ }
+ auto& data = *agg_method.hash_table;
+ bool has_null_key = data.has_null_key_data();
+ const auto size = data.size() - has_null_key;
+ using KeyType =
std::decay_t<decltype(agg_method.iterator->get_first())>;
+ std::vector<KeyType> keys(size);
+
+ size_t num_rows = 0;
+ auto iter = _aggregate_data_container->begin();
+ {
+ while (iter != _aggregate_data_container->end())
{
+ keys[num_rows] = iter.get_key<KeyType>();
+ ++iter;
+ ++num_rows;
+ }
+ }
+ agg_method.insert_keys_into_columns(keys,
key_columns, num_rows);
+ if (has_null_key) {
+ key_columns[0]->insert_data(nullptr, 0);
+ }
+ return key_columns;
+ }},
+ _agg_data->method_variant);
+}
+
Status AggregationNode::_pre_agg_with_serialized_key(doris::vectorized::Block*
in_block,
doris::vectorized::Block*
out_block) {
DCHECK(!_probe_expr_ctxs.empty());
@@ -1455,14 +1663,37 @@ Status
AggregationNode::_serialize_with_serialized_key_result_non_spill(RuntimeS
_probe_expr_ctxs[i]->root()->data_type(),
_probe_expr_ctxs[i]->root()->expr_name());
}
+
for (int i = 0; i < agg_size; ++i) {
columns_with_schema.emplace_back(std::move(value_columns[i]),
value_data_types[i], "");
+ *block = Block(columns_with_schema);
}
- *block = Block(columns_with_schema);
}
+
return Status::OK();
}
+bool AggregationNode::_do_limit_filter(Block* block, int key_size, size_t
num_rows) {
+ if (num_rows) {
+ _cmp_res.resize(num_rows);
+ _need_computes.resize(num_rows);
+ memset(_need_computes.data(), 0, _need_computes.size());
+ memset(_cmp_res.data(), 0, _cmp_res.size());
+
+ for (int i = 0; i < key_size; i++) {
+ block->get_by_position(i).column->compare_internal(
+ _limit_columns_min, *_limit_columns[i],
_null_directions[i],
+ _order_directions[i], _cmp_res, _need_computes.data());
+ }
+
+ for (int i = 0; i < num_rows; ++i) {
+ _need_computes[i] = _need_computes[i] == _cmp_res[i];
+ }
+ return std::find(_need_computes.begin(), _need_computes.end(), 0) !=
_need_computes.end();
+ }
+ return false;
+}
+
Status AggregationNode::_merge_with_serialized_key(Block* block) {
if (_reach_limit) {
return _merge_with_serialized_key_helper<true, false>(block);
diff --git a/be/src/vec/exec/vaggregation_node.h
b/be/src/vec/exec/vaggregation_node.h
index 70222f9ccf7..cd7aedebead 100644
--- a/be/src/vec/exec/vaggregation_node.h
+++ b/be/src/vec/exec/vaggregation_node.h
@@ -438,6 +438,7 @@ protected:
// nullable diff. so we need make nullable of it.
std::vector<size_t> _make_nullable_keys;
RuntimeProfile::Counter* _hash_table_compute_timer = nullptr;
+ RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr;
RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr;
RuntimeProfile::Counter* _hash_table_input_counter = nullptr;
RuntimeProfile::Counter* _expr_timer = nullptr;
@@ -491,8 +492,70 @@ private:
RuntimeProfile::HighWaterMarkCounter* _serialize_key_arena_memory_usage =
nullptr;
bool _should_expand_hash_table = true;
+
bool _should_limit_output = false;
bool _reach_limit = false;
+ bool _do_sort_limit = false;
+ MutableColumns _limit_columns;
+ int _limit_columns_min = -1;
+ PaddedPODArray<uint8_t> _need_computes;
+ std::vector<uint8_t> _cmp_res;
+ std::vector<int> _order_directions;
+ std::vector<int> _null_directions;
+
+ struct HeapLimitCursor {
+ HeapLimitCursor(int row_id, MutableColumns& limit_columns,
+ std::vector<int>& order_directions, std::vector<int>&
null_directions)
+ : _row_id(row_id),
+ _limit_columns(limit_columns),
+ _order_directions(order_directions),
+ _null_directions(null_directions) {}
+
+ HeapLimitCursor(const HeapLimitCursor& other) noexcept
+ : _row_id(other._row_id),
+ _limit_columns(other._limit_columns),
+ _order_directions(other._order_directions),
+ _null_directions(other._null_directions) {}
+
+ HeapLimitCursor(HeapLimitCursor&& other) noexcept
+ : _row_id(other._row_id),
+ _limit_columns(other._limit_columns),
+ _order_directions(other._order_directions),
+ _null_directions(other._null_directions) {}
+
+ HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept {
+ _row_id = other._row_id;
+ return *this;
+ }
+
+ HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept {
+ _row_id = other._row_id;
+ return *this;
+ }
+
+ bool operator<(const HeapLimitCursor& rhs) const {
+ for (int i = 0; i < _limit_columns.size(); ++i) {
+ const auto& _limit_column = _limit_columns[i];
+ auto res = _limit_column->compare_at(_row_id, rhs._row_id,
*_limit_column,
+ _null_directions[i]) *
+ _order_directions[i];
+ if (res < 0) {
+ return true;
+ } else if (res > 0) {
+ return false;
+ }
+ }
+ return false;
+ }
+
+ int _row_id;
+ MutableColumns& _limit_columns;
+ std::vector<int>& _order_directions;
+ std::vector<int>& _null_directions;
+ };
+
+ std::priority_queue<HeapLimitCursor> _limit_heap;
+
bool _agg_data_created_without_key = false;
PODArray<AggregateDataPtr> _places;
@@ -535,52 +598,7 @@ private:
Status _init_hash_method(const VExprContextSPtrs& probe_exprs);
template <bool limit>
- Status _execute_with_serialized_key_helper(Block* block) {
- DCHECK(!_probe_expr_ctxs.empty());
-
- size_t key_size = _probe_expr_ctxs.size();
- ColumnRawPtrs key_columns(key_size);
- {
- SCOPED_TIMER(_expr_timer);
- for (size_t i = 0; i < key_size; ++i) {
- int result_column_id = -1;
- RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block,
&result_column_id));
- block->get_by_position(result_column_id).column =
- block->get_by_position(result_column_id)
- .column->convert_to_full_column_if_const();
- key_columns[i] =
block->get_by_position(result_column_id).column.get();
- }
- }
-
- int rows = block->rows();
- if (_places.size() < rows) {
- _places.resize(rows);
- }
-
- if constexpr (limit) {
- _find_in_hash_table(_places.data(), key_columns, rows);
-
- for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add_selected(
- block, _offsets_of_aggregate_states[i], _places.data(),
- _agg_arena_pool.get()));
- }
- } else {
- _emplace_into_hash_table(_places.data(), key_columns, rows);
-
- for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
- RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
- block, _offsets_of_aggregate_states[i], _places.data(),
- _agg_arena_pool.get()));
- }
-
- if (_should_limit_output) {
- _reach_limit = _get_hash_table_size() >= _limit;
- }
- }
-
- return Status::OK();
- }
+ Status _execute_with_serialized_key_helper(Block* block);
// We should call this function only at 1st phase.
// 1st phase: is_merge=true, only have one SlotRef.
@@ -599,15 +617,16 @@ private:
size_t key_size = _probe_expr_ctxs.size();
ColumnRawPtrs key_columns(key_size);
+ std::vector<int> key_locs(key_size);
for (size_t i = 0; i < key_size; ++i) {
if constexpr (for_spill) {
key_columns[i] = block->get_by_position(i).column.get();
+ key_locs[i] = i;
} else {
- int result_column_id = -1;
- RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block,
&result_column_id));
- block->replace_by_position_if_const(result_column_id);
- key_columns[i] =
block->get_by_position(result_column_id).column.get();
+ RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block,
&key_locs[i]));
+ block->replace_by_position_if_const(key_locs[i]);
+ key_columns[i] =
block->get_by_position(key_locs[i]).column.get();
}
}
@@ -616,7 +635,7 @@ private:
_places.resize(rows);
}
- if constexpr (limit) {
+ if (limit && !_do_sort_limit) {
_find_in_hash_table(_places.data(), key_columns, rows);
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
@@ -647,43 +666,55 @@ private:
}
}
} else {
- _emplace_into_hash_table(_places.data(), key_columns, rows);
+ bool need_do_agg = true;
+ if (limit) {
+ need_do_agg = _emplace_into_hash_table_limit(_places.data(),
block, key_locs,
+ key_columns,
rows);
+ } else {
+ _emplace_into_hash_table(_places.data(), key_columns, rows);
+ }
- for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
- if (_aggregate_evaluators[i]->is_merge() || for_spill) {
- int col_id;
- if constexpr (for_spill) {
- col_id = _probe_expr_ctxs.size() + i;
+ if (need_do_agg) {
+ for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+ if (_aggregate_evaluators[i]->is_merge() || for_spill) {
+ int col_id;
+ if constexpr (for_spill) {
+ col_id = _probe_expr_ctxs.size() + i;
+ } else {
+ col_id =
_get_slot_column_id(_aggregate_evaluators[i]);
+ }
+ auto column = block->get_by_position(col_id).column;
+ if (column->is_nullable()) {
+ column =
((ColumnNullable*)column.get())->get_nested_column_ptr();
+ }
+
+ size_t buffer_size =
+
_aggregate_evaluators[i]->function()->size_of_data() * rows;
+ if (_deserialize_buffer.size() < buffer_size) {
+ _deserialize_buffer.resize(buffer_size);
+ }
+
+ {
+ SCOPED_TIMER(_deserialize_data_timer);
+
_aggregate_evaluators[i]->function()->deserialize_and_merge_vec(
+ _places.data(),
_offsets_of_aggregate_states[i],
+ _deserialize_buffer.data(), column.get(),
_agg_arena_pool.get(),
+ rows);
+ }
} else {
- col_id = _get_slot_column_id(_aggregate_evaluators[i]);
- }
- auto column = block->get_by_position(col_id).column;
- if (column->is_nullable()) {
- column =
((ColumnNullable*)column.get())->get_nested_column_ptr();
- }
-
- size_t buffer_size =
-
_aggregate_evaluators[i]->function()->size_of_data() * rows;
- if (_deserialize_buffer.size() < buffer_size) {
- _deserialize_buffer.resize(buffer_size);
- }
-
- {
- SCOPED_TIMER(_deserialize_data_timer);
-
_aggregate_evaluators[i]->function()->deserialize_and_merge_vec(
- _places.data(),
_offsets_of_aggregate_states[i],
- _deserialize_buffer.data(), column.get(),
_agg_arena_pool.get(),
- rows);
+
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
+ block, _offsets_of_aggregate_states[i],
_places.data(),
+ _agg_arena_pool.get()));
}
- } else {
-
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
- block, _offsets_of_aggregate_states[i],
_places.data(),
- _agg_arena_pool.get()));
}
}
- if (_should_limit_output) {
- _reach_limit = _get_hash_table_size() >= _limit;
+ if (!limit && _should_limit_output) {
+ const size_t hash_table_size = _get_hash_table_size();
+ _reach_limit = hash_table_size >= _limit;
+ if (_do_sort_limit && _reach_limit) {
+ _build_limit_heap(hash_table_size);
+ }
}
}
@@ -693,6 +724,10 @@ private:
void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs&
key_columns,
const size_t num_rows);
+ bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,
+ const std::vector<int>& key_locs,
+ ColumnRawPtrs& key_columns, size_t
num_rows);
+
size_t _memory_usage() const;
Status _reset_hash_table();
@@ -736,6 +771,12 @@ private:
};
MemoryRecord _mem_usage_record;
+
+ MutableColumns _get_keys_hash_table();
+
+ bool _do_limit_filter(Block* block, int key_size, size_t num_rows);
+
+ void _build_limit_heap(size_t hash_table_size);
};
} // namespace vectorized
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]