This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 99c5a805d08 [Feature](topn) support general expr with topn filter and
some refactor (#35405)
99c5a805d08 is described below
commit 99c5a805d081304bcd3666d5b8bb46a77a45318b
Author: Pxl <[email protected]>
AuthorDate: Mon May 27 11:07:46 2024 +0800
[Feature](topn) support general expr with topn filter and some refactor
(#35405)
support general expr with topn filter and some refactor
Co-authored-by: minghong <[email protected]>
---
be/src/olap/iterators.h | 3 +-
be/src/olap/rowset/beta_rowset_reader.cpp | 2 +-
be/src/olap/rowset/rowset_reader_context.h | 3 +-
be/src/olap/rowset/segment_v2/segment.cpp | 9 +-
be/src/olap/rowset/segment_v2/segment_iterator.cpp | 25 +--
be/src/olap/tablet_reader.cpp | 15 +-
be/src/olap/tablet_reader.h | 3 +-
be/src/pipeline/exec/file_scan_operator.cpp | 2 +-
be/src/pipeline/exec/olap_scan_operator.h | 4 +-
be/src/pipeline/exec/scan_operator.cpp | 39 ++--
be/src/pipeline/exec/scan_operator.h | 14 +-
be/src/pipeline/exec/sort_sink_operator.cpp | 33 +--
be/src/pipeline/exec/sort_sink_operator.h | 1 -
be/src/pipeline/pipeline_fragment_context.cpp | 6 +-
be/src/runtime/query_context.h | 17 +-
be/src/runtime/runtime_predicate.cpp | 125 ++++++++---
be/src/runtime/runtime_predicate.h | 137 ++++++------
be/src/runtime/runtime_state.cpp | 8 -
be/src/vec/exec/format/orc/vorc_reader.cpp | 1 -
.../exec/format/parquet/vparquet_group_reader.cpp | 1 -
be/src/vec/exec/scan/new_olap_scanner.cpp | 6 +-
be/src/vec/exec/vsort_node.cpp | 40 +---
be/src/vec/exec/vsort_node.h | 1 -
be/src/vec/exprs/vtopn_pred.h | 20 +-
be/src/vec/olap/vcollect_iterator.cpp | 2 +-
.../org/apache/doris/nereids/NereidsPlanner.java | 6 +
.../glue/translator/PhysicalPlanTranslator.java | 52 +----
.../doris/nereids/processor/post/TopNScanOpt.java | 110 ++++------
.../nereids/processor/post/TopnFilterContext.java | 84 ++++----
.../processor/post/TopnFilterPushDownVisitor.java | 233 +++++++++++++++++++++
.../nereids/trees/plans/physical/TopnFilter.java | 78 +++++++
.../java/org/apache/doris/planner/Planner.java | 5 +
.../main/java/org/apache/doris/qe/Coordinator.java | 22 +-
gensrc/thrift/Exprs.thrift | 2 +-
gensrc/thrift/PlanNodes.thrift | 8 +-
.../suites/nereids_tpch_p0/tpch/topn-filter.groovy | 8 +-
36 files changed, 695 insertions(+), 430 deletions(-)
diff --git a/be/src/olap/iterators.h b/be/src/olap/iterators.h
index deb14ff554f..330aa9e3475 100644
--- a/be/src/olap/iterators.h
+++ b/be/src/olap/iterators.h
@@ -99,9 +99,8 @@ public:
TabletSchemaSPtr tablet_schema = nullptr;
bool enable_unique_key_merge_on_write = false;
bool record_rowids = false;
- // flag for enable topn opt
- bool use_topn_opt = false;
std::vector<int> topn_filter_source_node_ids;
+ int topn_filter_target_node_id = -1;
// used for special optimization for query : ORDER BY key DESC LIMIT n
bool read_orderby_key_reverse = false;
// columns for orderby keys
diff --git a/be/src/olap/rowset/beta_rowset_reader.cpp
b/be/src/olap/rowset/beta_rowset_reader.cpp
index d8cced3b00a..e113d20ed88 100644
--- a/be/src/olap/rowset/beta_rowset_reader.cpp
+++ b/be/src/olap/rowset/beta_rowset_reader.cpp
@@ -220,8 +220,8 @@ Status
BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context
_read_options.enable_unique_key_merge_on_write =
_read_context->enable_unique_key_merge_on_write;
_read_options.record_rowids = _read_context->record_rowids;
- _read_options.use_topn_opt = _read_context->use_topn_opt;
_read_options.topn_filter_source_node_ids =
_read_context->topn_filter_source_node_ids;
+ _read_options.topn_filter_target_node_id =
_read_context->topn_filter_target_node_id;
_read_options.read_orderby_key_reverse =
_read_context->read_orderby_key_reverse;
_read_options.read_orderby_key_columns =
_read_context->read_orderby_key_columns;
_read_options.io_ctx.reader_type = _read_context->reader_type;
diff --git a/be/src/olap/rowset/rowset_reader_context.h
b/be/src/olap/rowset/rowset_reader_context.h
index 44cf8556412..59abf85fb72 100644
--- a/be/src/olap/rowset/rowset_reader_context.h
+++ b/be/src/olap/rowset/rowset_reader_context.h
@@ -36,9 +36,8 @@ struct RowsetReaderContext {
ReaderType reader_type = ReaderType::READER_QUERY;
Version version {-1, -1};
TabletSchemaSPtr tablet_schema = nullptr;
- // flag for enable topn opt
- bool use_topn_opt = false;
std::vector<int> topn_filter_source_node_ids;
+ int topn_filter_target_node_id = -1;
// whether rowset should return ordered rows.
bool need_ordered_result = true;
// used for special optimization for query : ORDER BY key DESC LIMIT n
diff --git a/be/src/olap/rowset/segment_v2/segment.cpp
b/be/src/olap/rowset/segment_v2/segment.cpp
index 8a898a5db4d..555c676ef31 100644
--- a/be/src/olap/rowset/segment_v2/segment.cpp
+++ b/be/src/olap/rowset/segment_v2/segment.cpp
@@ -164,13 +164,12 @@ Status Segment::new_iterator(SchemaSPtr schema, const
StorageReadOptions& read_o
return Status::OK();
}
}
- if (read_options.use_topn_opt) {
+
+ if (!read_options.topn_filter_source_node_ids.empty()) {
auto* query_ctx = read_options.runtime_state->get_query_ctx();
for (int id : read_options.topn_filter_source_node_ids) {
- if (!query_ctx->get_runtime_predicate(id).need_update()) {
- continue;
- }
- auto runtime_predicate =
query_ctx->get_runtime_predicate(id).get_predicate();
+ auto runtime_predicate =
query_ctx->get_runtime_predicate(id).get_predicate(
+ read_options.topn_filter_target_node_id);
int32_t uid =
read_options.tablet_schema->column(runtime_predicate->column_id()).unique_id();
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
index 7c6853e9fb3..06f71c3cc11 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
@@ -523,7 +523,7 @@ Status
SegmentIterator::_get_row_ranges_by_column_conditions() {
RETURN_IF_ERROR(_apply_inverted_index());
if (!_row_bitmap.isEmpty() &&
- (_opts.use_topn_opt || !_opts.col_id_to_predicates.empty() ||
+ (!_opts.topn_filter_source_node_ids.empty() ||
!_opts.col_id_to_predicates.empty() ||
_opts.delete_condition_predicates->num_of_column_predicate() > 0)) {
RowRanges condition_row_ranges =
RowRanges::create_single(_segment->num_rows());
RETURN_IF_ERROR(_get_row_ranges_from_conditions(&condition_row_ranges));
@@ -572,7 +572,7 @@ Status
SegmentIterator::_get_row_ranges_from_conditions(RowRanges* condition_row
SCOPED_RAW_TIMER(&_opts.stats->block_conditions_filtered_zonemap_ns);
RowRanges zone_map_row_ranges = RowRanges::create_single(num_rows());
// second filter data by zone map
- for (auto& cid : cids) {
+ for (const auto& cid : cids) {
DCHECK(_opts.col_id_to_predicates.count(cid) > 0);
if (!_segment->can_apply_predicate_safely(cid,
_opts.col_id_to_predicates.at(cid).get(),
*_schema,
_opts.io_ctx.reader_type)) {
@@ -600,16 +600,12 @@ Status
SegmentIterator::_get_row_ranges_from_conditions(RowRanges* condition_row
RowRanges::ranges_intersection(*condition_row_ranges,
zone_map_row_ranges,
condition_row_ranges);
- if (_opts.use_topn_opt) {
-
SCOPED_RAW_TIMER(&_opts.stats->block_conditions_filtered_zonemap_ns);
+ if (!_opts.topn_filter_source_node_ids.empty()) {
auto* query_ctx = _opts.runtime_state->get_query_ctx();
for (int id : _opts.topn_filter_source_node_ids) {
- if (!query_ctx->get_runtime_predicate(id).need_update()) {
- continue;
- }
-
std::shared_ptr<doris::ColumnPredicate> runtime_predicate =
- query_ctx->get_runtime_predicate(id).get_predicate();
+ query_ctx->get_runtime_predicate(id).get_predicate(
+ _opts.topn_filter_target_node_id);
if
(_segment->can_apply_predicate_safely(runtime_predicate->column_id(),
runtime_predicate.get(), *_schema,
_opts.io_ctx.reader_type)) {
@@ -1579,7 +1575,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() {
std::set<const ColumnPredicate*> delete_predicate_set {};
_opts.delete_condition_predicates->get_all_column_predicate(delete_predicate_set);
- for (const auto predicate : delete_predicate_set) {
+ for (const auto* const predicate : delete_predicate_set) {
if (PredicateTypeTraits::is_range(predicate->type())) {
_delete_range_column_ids.push_back(predicate->column_id());
} else if (PredicateTypeTraits::is_bloom_filter(predicate->type())) {
@@ -1593,16 +1589,13 @@ Status
SegmentIterator::_vec_init_lazy_materialization() {
// but runtime predicate will filter some rows and read more than N rows.
// should add add for order by none-key column, since none-key column is
not sorted and
// all rows should be read, so runtime predicate will reduce rows for
topn node
- if (_opts.use_topn_opt &&
+ if (!_opts.topn_filter_source_node_ids.empty() &&
(_opts.read_orderby_key_columns == nullptr ||
_opts.read_orderby_key_columns->empty())) {
for (int id : _opts.topn_filter_source_node_ids) {
- if
(!_opts.runtime_state->get_query_ctx()->get_runtime_predicate(id).need_update())
{
- continue;
- }
-
auto& runtime_predicate =
_opts.runtime_state->get_query_ctx()->get_runtime_predicate(id);
- _col_predicates.push_back(runtime_predicate.get_predicate().get());
+ _col_predicates.push_back(
+
runtime_predicate.get_predicate(_opts.topn_filter_target_node_id).get());
}
}
diff --git a/be/src/olap/tablet_reader.cpp b/be/src/olap/tablet_reader.cpp
index ae304f6f3e9..728e6e3d6d7 100644
--- a/be/src/olap/tablet_reader.cpp
+++ b/be/src/olap/tablet_reader.cpp
@@ -233,14 +233,14 @@ Status TabletReader::_capture_rs_readers(const
ReaderParams& read_params) {
_reader_context.version = read_params.version;
_reader_context.tablet_schema = _tablet_schema;
_reader_context.need_ordered_result = need_ordered_result;
- _reader_context.use_topn_opt = read_params.use_topn_opt;
_reader_context.topn_filter_source_node_ids =
read_params.topn_filter_source_node_ids;
+ _reader_context.topn_filter_target_node_id =
read_params.topn_filter_target_node_id;
_reader_context.read_orderby_key_reverse =
read_params.read_orderby_key_reverse;
_reader_context.read_orderby_key_limit =
read_params.read_orderby_key_limit;
_reader_context.filter_block_conjuncts =
read_params.filter_block_conjuncts;
_reader_context.return_columns = &_return_columns;
_reader_context.read_orderby_key_columns =
- _orderby_key_columns.size() > 0 ? &_orderby_key_columns : nullptr;
+ !_orderby_key_columns.empty() ? &_orderby_key_columns : nullptr;
_reader_context.predicates = &_col_predicates;
_reader_context.predicates_except_leafnode_of_andnode =
&_col_preds_except_leafnode_of_andnode;
_reader_context.value_predicates = &_value_col_predicates;
@@ -575,12 +575,11 @@ Status
TabletReader::_init_conditions_param_except_leafnode_of_andnode(
}
}
- if (read_params.use_topn_opt) {
- for (int id : read_params.topn_filter_source_node_ids) {
- auto& runtime_predicate =
-
read_params.runtime_state->get_query_ctx()->get_runtime_predicate(id);
-
RETURN_IF_ERROR(runtime_predicate.set_tablet_schema(_tablet_schema));
- }
+ for (int id : read_params.topn_filter_source_node_ids) {
+ auto& runtime_predicate =
+
read_params.runtime_state->get_query_ctx()->get_runtime_predicate(id);
+
RETURN_IF_ERROR(runtime_predicate.set_tablet_schema(read_params.topn_filter_target_node_id,
+ _tablet_schema));
}
return Status::OK();
}
diff --git a/be/src/olap/tablet_reader.h b/be/src/olap/tablet_reader.h
index 3bf83ec296c..a3cd3bd4a49 100644
--- a/be/src/olap/tablet_reader.h
+++ b/be/src/olap/tablet_reader.h
@@ -159,9 +159,8 @@ public:
// used for compaction to record row ids
bool record_rowids = false;
- // flag for enable topn opt
- bool use_topn_opt = false;
std::vector<int> topn_filter_source_node_ids;
+ int topn_filter_target_node_id = -1;
// used for special optimization for query : ORDER BY key LIMIT n
bool read_orderby_key = false;
// used for special optimization for query : ORDER BY key DESC LIMIT n
diff --git a/be/src/pipeline/exec/file_scan_operator.cpp
b/be/src/pipeline/exec/file_scan_operator.cpp
index 2484737af89..ae7a4f06d53 100644
--- a/be/src/pipeline/exec/file_scan_operator.cpp
+++ b/be/src/pipeline/exec/file_scan_operator.cpp
@@ -47,7 +47,7 @@ Status
FileScanLocalState::_init_scanners(std::list<vectorized::VScannerSPtr>* s
state(), this, p._limit_per_scanner, _split_source,
_scanner_profile.get(),
_kv_cache.get());
RETURN_IF_ERROR(
- scanner->prepare(_conjuncts, &_colname_to_value_range,
&_colname_to_slot_id));
+ scanner->prepare(_conjuncts, &_colname_to_value_range,
&p._colname_to_slot_id));
scanners->push_back(std::move(scanner));
}
return Status::OK();
diff --git a/be/src/pipeline/exec/olap_scan_operator.h
b/be/src/pipeline/exec/olap_scan_operator.h
index daff2167f7f..061e514b040 100644
--- a/be/src/pipeline/exec/olap_scan_operator.h
+++ b/be/src/pipeline/exec/olap_scan_operator.h
@@ -81,10 +81,10 @@ private:
bool _storage_no_merge() override;
bool _push_down_topn(const vectorized::RuntimePredicate& predicate)
override {
- if (!predicate.target_is_slot()) {
+ if (!predicate.target_is_slot(_parent->node_id())) {
return false;
}
- return _is_key_column(predicate.get_col_name()) || _storage_no_merge();
+ return _is_key_column(predicate.get_col_name(_parent->node_id())) ||
_storage_no_merge();
}
Status _init_scanners(std::list<vectorized::VScannerSPtr>* scanners)
override;
diff --git a/be/src/pipeline/exec/scan_operator.cpp
b/be/src/pipeline/exec/scan_operator.cpp
index d0e83937fd9..be7af24b684 100644
--- a/be/src/pipeline/exec/scan_operator.cpp
+++ b/be/src/pipeline/exec/scan_operator.cpp
@@ -171,23 +171,20 @@ Status
ScanLocalState<Derived>::_normalize_conjuncts(RuntimeState* state) {
}
};
- for (int slot_idx = 0; slot_idx < slots.size(); ++slot_idx) {
- _colname_to_slot_id[slots[slot_idx]->col_name()] =
slots[slot_idx]->id();
- _slot_id_to_slot_desc[slots[slot_idx]->id()] = slots[slot_idx];
-
- auto type = slots[slot_idx]->type().type;
- if (slots[slot_idx]->type().type == TYPE_ARRAY) {
- type = slots[slot_idx]->type().children[0].type;
+ for (auto& slot : slots) {
+ auto type = slot->type().type;
+ if (slot->type().type == TYPE_ARRAY) {
+ type = slot->type().children[0].type;
if (type == TYPE_ARRAY) {
continue;
}
}
- init_value_range(slots[slot_idx], slots[slot_idx]->type().type);
+ init_value_range(slot, slot->type().type);
}
get_cast_types_for_variants();
for (const auto& [colname, type] : _cast_types_for_variants) {
- init_value_range(_slot_id_to_slot_desc[_colname_to_slot_id[colname]],
type);
+
init_value_range(p._slot_id_to_slot_desc[p._colname_to_slot_id[colname]], type);
}
RETURN_IF_ERROR(_get_topn_filters(state));
@@ -1268,12 +1265,12 @@ Status ScanLocalState<Derived>::_init_profile() {
template <typename Derived>
Status ScanLocalState<Derived>::_get_topn_filters(RuntimeState* state) {
+ auto& p = _parent->cast<typename Derived::Parent>();
for (auto id : get_topn_filter_source_node_ids(state, false)) {
const auto& pred = state->get_query_ctx()->get_runtime_predicate(id);
- SlotDescriptor* slot_desc =
_slot_id_to_slot_desc[_colname_to_slot_id[pred.get_col_name()]];
-
vectorized::VExprSPtr topn_pred;
- RETURN_IF_ERROR(vectorized::VTopNPred::create_vtopn_pred(slot_desc,
id, topn_pred));
+
RETURN_IF_ERROR(vectorized::VTopNPred::create_vtopn_pred(pred.get_texpr(p.node_id()),
id,
+ topn_pred));
vectorized::VExprContextSPtr conjunct =
vectorized::VExprContext::create_shared(topn_pred);
RETURN_IF_ERROR(conjunct->prepare(
@@ -1288,6 +1285,7 @@ template <typename Derived>
void ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant(
const vectorized::VExpr* expr,
phmap::flat_hash_map<std::string, std::vector<PrimitiveType>>&
colname_to_cast_types) {
+ auto& p = _parent->cast<typename Derived::Parent>();
const auto* cast_expr = dynamic_cast<const vectorized::VCastExpr*>(expr);
if (cast_expr != nullptr) {
const auto* src_slot =
@@ -1298,7 +1296,7 @@ void
ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant(
return;
}
std::vector<SlotDescriptor*> slots = output_tuple_desc()->slots();
- SlotDescriptor* src_slot_desc =
_slot_id_to_slot_desc[src_slot->slot_id()];
+ SlotDescriptor* src_slot_desc =
p._slot_id_to_slot_desc[src_slot->slot_id()];
PrimitiveType cast_dst_type =
cast_expr->get_target_type()->get_type_as_type_descriptor().type;
if (src_slot_desc->type().is_variant_type()) {
@@ -1388,6 +1386,21 @@ Status ScanOperatorX<LocalStateType>::open(RuntimeState*
state) {
_output_tuple_desc =
state->desc_tbl().get_tuple_descriptor(_output_tuple_id);
RETURN_IF_ERROR(OperatorX<LocalStateType>::open(state));
+ const auto slots = _output_tuple_desc->slots();
+ for (auto* slot : slots) {
+ _colname_to_slot_id[slot->col_name()] = slot->id();
+ _slot_id_to_slot_desc[slot->id()] = slot;
+ }
+ for (auto id : topn_filter_source_node_ids) {
+ if (!state->get_query_ctx()->has_runtime_predicate(id)) {
+ // compatible with older versions fe
+ continue;
+ }
+
+
state->get_query_ctx()->get_runtime_predicate(id).init_target(node_id(),
+
_slot_id_to_slot_desc);
+ }
+
RETURN_IF_CANCELLED(state);
return Status::OK();
}
diff --git a/be/src/pipeline/exec/scan_operator.h
b/be/src/pipeline/exec/scan_operator.h
index 25a33de3f66..a1b85aac30a 100644
--- a/be/src/pipeline/exec/scan_operator.h
+++ b/be/src/pipeline/exec/scan_operator.h
@@ -161,8 +161,13 @@ class ScanLocalState : public ScanLocalStateBase {
std::vector<int> get_topn_filter_source_node_ids(RuntimeState* state, bool
push_down) {
std::vector<int> result;
for (int id : _parent->cast<typename
Derived::Parent>().topn_filter_source_node_ids) {
+ if (!state->get_query_ctx()->has_runtime_predicate(id)) {
+ // compatible with older versions fe
+ continue;
+ }
+
const auto& pred =
state->get_query_ctx()->get_runtime_predicate(id);
- if (!pred.inited()) {
+ if (!pred.enable()) {
continue;
}
if (_push_down_topn(pred) == push_down) {
@@ -334,9 +339,6 @@ protected:
// colname -> cast dst type
std::map<std::string, PrimitiveType> _cast_types_for_variants;
- // slot id -> SlotDescriptor
- phmap::flat_hash_map<int, SlotDescriptor*> _slot_id_to_slot_desc;
-
// slot id -> ColumnValueRange
// Parsed from conjuncts
phmap::flat_hash_map<int, std::pair<SlotDescriptor*, ColumnValueRangeType>>
@@ -345,7 +347,6 @@ protected:
// We use _colname_to_value_range to store a column and its conresponding
value ranges.
std::unordered_map<std::string, ColumnValueRangeType>
_colname_to_value_range;
- std::unordered_map<std::string, int> _colname_to_slot_id;
/**
* _colname_to_value_range only store the leaf of and in the conjunct expr
tree,
* we use _compound_value_ranges to store conresponding value ranges
@@ -426,6 +427,9 @@ protected:
const TupleDescriptor* _input_tuple_desc = nullptr;
const TupleDescriptor* _output_tuple_desc = nullptr;
+ phmap::flat_hash_map<int, SlotDescriptor*> _slot_id_to_slot_desc;
+ std::unordered_map<std::string, int> _colname_to_slot_id;
+
// These two values are from query_options
int _max_scan_key_num;
int _max_pushdown_conditions_per_column;
diff --git a/be/src/pipeline/exec/sort_sink_operator.cpp
b/be/src/pipeline/exec/sort_sink_operator.cpp
index 7c9f40d1b59..201c31d353b 100644
--- a/be/src/pipeline/exec/sort_sink_operator.cpp
+++ b/be/src/pipeline/exec/sort_sink_operator.cpp
@@ -79,7 +79,6 @@ SortSinkOperatorX::SortSinkOperatorX(ObjectPool* pool, int
operator_id, const TP
_pool(pool),
_reuse_mem(true),
_limit(tnode.limit),
- _use_topn_opt(tnode.sort_node.use_topn_opt),
_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples),
_use_two_phase_read(tnode.sort_node.sort_info.use_two_phase_read),
_merge_by_exchange(tnode.sort_node.merge_by_exchange),
@@ -96,29 +95,10 @@ Status SortSinkOperatorX::init(const TPlanNode& tnode,
RuntimeState* state) {
_is_asc_order = tnode.sort_node.sort_info.is_asc_order;
_nulls_first = tnode.sort_node.sort_info.nulls_first;
+ auto* query_ctx = state->get_query_ctx();
// init runtime predicate
- if (_use_topn_opt) {
- auto* query_ctx = state->get_query_ctx();
- auto first_sort_expr_node =
tnode.sort_node.sort_info.ordering_exprs[0].nodes[0];
- if (first_sort_expr_node.node_type == TExprNodeType::SLOT_REF) {
- auto first_sort_slot = first_sort_expr_node.slot_ref;
- for (auto* tuple_desc : _row_descriptor.tuple_descriptors()) {
- if (tuple_desc->id() != first_sort_slot.tuple_id) {
- continue;
- }
- for (auto* slot : tuple_desc->slots()) {
- if (slot->id() == first_sort_slot.slot_id) {
-
RETURN_IF_ERROR(query_ctx->get_runtime_predicate(_node_id).init(
- slot->type().type, _nulls_first[0],
_is_asc_order[0],
- slot->col_name()));
- break;
- }
- }
- }
- }
- if (!query_ctx->get_runtime_predicate(_node_id).inited()) {
- return Status::InternalError("runtime predicate is not properly
initialized");
- }
+ if (query_ctx->has_runtime_predicate(_node_id)) {
+ query_ctx->get_runtime_predicate(_node_id).set_detected_source();
}
return Status::OK();
}
@@ -132,7 +112,8 @@ Status SortSinkOperatorX::prepare(RuntimeState* state) {
// exclude cases which incoming blocks has string column which is
sensitive to operations like
// `filter` and `memcpy`
if (_limit > 0 && _limit + _offset <
vectorized::HeapSorter::HEAP_SORT_THRESHOLD &&
- (_use_two_phase_read || _use_topn_opt ||
!row_desc.has_varlen_slots())) {
+ (_use_two_phase_read ||
state->get_query_ctx()->has_runtime_predicate(_node_id) ||
+ !row_desc.has_varlen_slots())) {
_algorithm = SortAlgorithm::HEAP_SORT;
_reuse_mem = false;
} else if (_limit > 0 && row_desc.has_varlen_slots() &&
@@ -158,9 +139,9 @@ Status SortSinkOperatorX::sink(doris::RuntimeState* state,
vectorized::Block* in
local_state._mem_tracker->set_consumption(local_state._shared_state->sorter->data_size());
RETURN_IF_CANCELLED(state);
- if (_use_topn_opt) {
+ if (state->get_query_ctx()->has_runtime_predicate(_node_id)) {
auto& predicate =
state->get_query_ctx()->get_runtime_predicate(_node_id);
- if (predicate.inited()) {
+ if (predicate.enable()) {
vectorized::Field new_top =
local_state._shared_state->sorter->get_top_value();
if (!new_top.is_null() && new_top != local_state.old_top) {
auto* query_ctx = state->get_query_ctx();
diff --git a/be/src/pipeline/exec/sort_sink_operator.h
b/be/src/pipeline/exec/sort_sink_operator.h
index 6ade3ada8bd..ba279a4aac4 100644
--- a/be/src/pipeline/exec/sort_sink_operator.h
+++ b/be/src/pipeline/exec/sort_sink_operator.h
@@ -100,7 +100,6 @@ private:
bool _reuse_mem;
const int64_t _limit;
- const bool _use_topn_opt;
SortAlgorithm _algorithm;
const RowDescriptor _row_descriptor;
diff --git a/be/src/pipeline/pipeline_fragment_context.cpp
b/be/src/pipeline/pipeline_fragment_context.cpp
index 8db4b1be587..fd4c903d5aa 100644
--- a/be/src/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_fragment_context.cpp
@@ -291,10 +291,8 @@ Status PipelineFragmentContext::prepare(const
doris::TPipelineFragmentParams& re
_query_ctx->runtime_filter_mgr()->set_runtime_filter_params(
local_params.runtime_filter_params);
}
- if (local_params.__isset.topn_filter_source_node_ids) {
-
_query_ctx->init_runtime_predicates(local_params.topn_filter_source_node_ids);
- } else {
- _query_ctx->init_runtime_predicates({0});
+ if (local_params.__isset.topn_filter_descs) {
+
_query_ctx->init_runtime_predicates(local_params.topn_filter_descs);
}
_need_local_merge = request.__isset.parallel_instances;
diff --git a/be/src/runtime/query_context.h b/be/src/runtime/query_context.h
index 6fd35e00fd2..ff5d1a549ea 100644
--- a/be/src/runtime/query_context.h
+++ b/be/src/runtime/query_context.h
@@ -135,17 +135,18 @@ public:
return _shared_scanner_controller;
}
+ bool has_runtime_predicate(int source_node_id) {
+ return _runtime_predicates.contains(source_node_id);
+ }
+
vectorized::RuntimePredicate& get_runtime_predicate(int source_node_id) {
- DCHECK(_runtime_predicates.contains(source_node_id) ||
_runtime_predicates.contains(0));
- if (_runtime_predicates.contains(source_node_id)) {
- return _runtime_predicates[source_node_id];
- }
- return _runtime_predicates[0];
+ DCHECK(has_runtime_predicate(source_node_id));
+ return _runtime_predicates.find(source_node_id)->second;
}
- void init_runtime_predicates(std::vector<int> source_node_ids) {
- for (int id : source_node_ids) {
- _runtime_predicates.try_emplace(id);
+ void init_runtime_predicates(const std::vector<TTopnFilterDesc>&
topn_filter_descs) {
+ for (auto desc : topn_filter_descs) {
+ _runtime_predicates.try_emplace(desc.source_node_id, desc);
}
}
diff --git a/be/src/runtime/runtime_predicate.cpp
b/be/src/runtime/runtime_predicate.cpp
index 2655ff86680..f90a5743fdd 100644
--- a/be/src/runtime/runtime_predicate.cpp
+++ b/be/src/runtime/runtime_predicate.cpp
@@ -17,35 +17,99 @@
#include "runtime/runtime_predicate.h"
-#include <stdint.h>
-
#include <memory>
#include "common/compiler_util.h" // IWYU pragma: keep
+#include "common/exception.h"
+#include "common/status.h"
#include "olap/accept_null_predicate.h"
#include "olap/column_predicate.h"
#include "olap/predicate_creator.h"
namespace doris::vectorized {
-Status RuntimePredicate::init(PrimitiveType type, bool nulls_first, bool
is_asc,
- const std::string& col_name) {
- std::unique_lock<std::shared_mutex> wlock(_rwlock);
+RuntimePredicate::RuntimePredicate(const TTopnFilterDesc& desc)
+ : _nulls_first(desc.null_first), _is_asc(desc.is_asc) {
+ DCHECK(!desc.target_node_id_to_target_expr.empty());
+ for (auto p : desc.target_node_id_to_target_expr) {
+ _contexts[p.first].expr = p.second;
+ }
- if (_inited) {
- return Status::OK();
+ PrimitiveType type =
thrift_to_type(desc.target_node_id_to_target_expr.begin()
+ ->second.nodes[0]
+ .type.types[0]
+ .scalar_type.type);
+ if (!_init(type)) {
+ std::stringstream ss;
+
desc.target_node_id_to_target_expr.begin()->second.nodes[0].printTo(ss);
+ throw Exception(ErrorCode::INTERNAL_ERROR, "meet invalid type,
type={}, expr={}", int(type),
+ ss.str());
}
- _nulls_first = nulls_first;
- _is_asc = is_asc;
// For ASC sort, create runtime predicate col_name <= max_top_value
// since values that > min_top_value are large than any value in current
topn values
// For DESC sort, create runtime predicate col_name >= min_top_value
// since values that < min_top_value are less than any value in current
topn values
- _pred_constructor = is_asc ? create_comparison_predicate<PredicateType::LE>
- :
create_comparison_predicate<PredicateType::GE>;
- _col_name = col_name;
+ _pred_constructor = _is_asc ?
create_comparison_predicate<PredicateType::LE>
+ :
create_comparison_predicate<PredicateType::GE>;
+}
+
+void RuntimePredicate::init_target(
+ int32_t target_node_id, phmap::flat_hash_map<int, SlotDescriptor*>
slot_id_to_slot_desc) {
+ std::unique_lock<std::shared_mutex> wlock(_rwlock);
+ check_target_node_id(target_node_id);
+ if (target_is_slot(target_node_id)) {
+ _contexts[target_node_id].col_name =
+
slot_id_to_slot_desc[get_texpr(target_node_id).nodes[0].slot_ref.slot_id]
+ ->col_name();
+ }
+ _detected_target = true;
+}
+
+template <PrimitiveType type>
+std::string get_normal_value(const Field& field) {
+ using ValueType = typename PrimitiveTypeTraits<type>::CppType;
+ return cast_to_string<type, ValueType>(field.get<ValueType>(), 0);
+}
+
+std::string get_date_value(const Field& field) {
+ using ValueType = typename PrimitiveTypeTraits<TYPE_DATE>::CppType;
+ ValueType value;
+ Int64 v = field.get<Int64>();
+ auto* p = (VecDateTimeValue*)&v;
+ value.from_olap_date(p->to_olap_date());
+ value.cast_to_date();
+ return cast_to_string<TYPE_DATE, ValueType>(value, 0);
+}
+std::string get_datetime_value(const Field& field) {
+ using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIME>::CppType;
+ ValueType value;
+ Int64 v = field.get<Int64>();
+ auto* p = (VecDateTimeValue*)&v;
+ value.from_olap_datetime(p->to_olap_datetime());
+ value.to_datetime();
+ return cast_to_string<TYPE_DATETIME, ValueType>(value, 0);
+}
+
+std::string get_decimalv2_value(const Field& field) {
+ // can NOT use PrimitiveTypeTraits<TYPE_DECIMALV2>::CppType since
+ // it is DecimalV2Value and Decimal128V2 can not convert to it implicitly
+ using ValueType = Decimal128V2::NativeType;
+ auto v = field.get<DecimalField<Decimal128V2>>();
+ // use TYPE_DECIMAL128I instead of TYPE_DECIMALV2 since v.get_scale()
+ // is always 9 for DECIMALV2
+ return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(),
v.get_scale());
+}
+
+template <PrimitiveType type>
+std::string get_decimal_value(const Field& field) {
+ using ValueType = typename PrimitiveTypeTraits<type>::CppType;
+ auto v = field.get<DecimalField<ValueType>>();
+ return cast_to_string<type, ValueType>(v.get_value(), v.get_scale());
+}
+
+bool RuntimePredicate::_init(PrimitiveType type) {
// set get value function
switch (type) {
case PrimitiveType::TYPE_BOOLEAN: {
@@ -123,17 +187,16 @@ Status RuntimePredicate::init(PrimitiveType type, bool
nulls_first, bool is_asc,
break;
}
default:
- return Status::InvalidArgument("unsupported runtime predicate type
{}", type);
+ return false;
}
- _inited = true;
- return Status::OK();
+ return true;
}
Status RuntimePredicate::update(const Field& value) {
std::unique_lock<std::shared_mutex> wlock(_rwlock);
// skip null value
- if (value.is_null() || !_inited) {
+ if (value.is_null()) {
return Status::OK();
}
@@ -151,22 +214,28 @@ Status RuntimePredicate::update(const Field& value) {
_has_value = true;
- if (!updated || !_tablet_schema) {
+ if (!updated) {
return Status::OK();
}
- std::unique_ptr<ColumnPredicate> pred {
- _pred_constructor(_tablet_schema->column(_col_name),
_predicate->column_id(),
- _get_value_fn(_orderby_extrem), false,
&_predicate_arena)};
- // For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL
- // since ORDER BY ASC/DESC should get NULL first but pred returns NULL
- // and NULL in where predicate will be treated as FALSE
- if (_nulls_first) {
- pred = AcceptNullPredicate::create_unique(pred.release());
- }
-
- ((SharedPredicate*)_predicate.get())->set_nested(pred.release());
+ for (auto p : _contexts) {
+ auto ctx = p.second;
+ if (!ctx.tablet_schema) {
+ continue;
+ }
+ std::unique_ptr<ColumnPredicate> pred {_pred_constructor(
+ ctx.tablet_schema->column(ctx.col_name),
ctx.predicate->column_id(),
+ _get_value_fn(_orderby_extrem), false, &_predicate_arena)};
+
+ // For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL
+ // since ORDER BY ASC/DESC should get NULL first but pred returns NULL
+ // and NULL in where predicate will be treated as FALSE
+ if (_nulls_first) {
+ pred = AcceptNullPredicate::create_unique(pred.release());
+ }
+ ((SharedPredicate*)ctx.predicate.get())->set_nested(pred.release());
+ }
return Status::OK();
}
diff --git a/be/src/runtime/runtime_predicate.h
b/be/src/runtime/runtime_predicate.h
index 0305994e0fc..73ed657c0bb 100644
--- a/be/src/runtime/runtime_predicate.h
+++ b/be/src/runtime/runtime_predicate.h
@@ -42,36 +42,39 @@ namespace vectorized {
class RuntimePredicate {
public:
- RuntimePredicate() = default;
+ RuntimePredicate(const TTopnFilterDesc& desc);
- Status init(PrimitiveType type, bool nulls_first, bool is_asc, const
std::string& col_name);
+ void init_target(int32_t target_node_id,
+ phmap::flat_hash_map<int, SlotDescriptor*>
slot_id_to_slot_desc);
- bool inited() const {
- // when sort node and scan node are not in the same fragment,
predicate will not be initialized
+ bool enable() const {
+ // when sort node and scan node are not in the same fragment,
predicate will be disabled
std::shared_lock<std::shared_mutex> rlock(_rwlock);
- return _inited;
+ return _detected_source && _detected_target;
}
- bool need_update() const {
- std::shared_lock<std::shared_mutex> rlock(_rwlock);
- return _inited && _tablet_schema;
+ void set_detected_source() {
+ std::unique_lock<std::shared_mutex> wlock(_rwlock);
+ _detected_source = true;
}
- Status set_tablet_schema(TabletSchemaSPtr tablet_schema) {
+ Status set_tablet_schema(int32_t target_node_id, TabletSchemaSPtr
tablet_schema) {
std::unique_lock<std::shared_mutex> wlock(_rwlock);
- if (_tablet_schema || !_inited) {
+ check_target_node_id(target_node_id);
+ if (_contexts[target_node_id].tablet_schema) {
return Status::OK();
}
- RETURN_IF_ERROR(tablet_schema->have_column(_col_name));
- _tablet_schema = tablet_schema;
- _predicate = SharedPredicate::create_shared(
-
_tablet_schema->field_index(_tablet_schema->column(_col_name).unique_id()));
+
RETURN_IF_ERROR(tablet_schema->have_column(_contexts[target_node_id].col_name));
+ _contexts[target_node_id].tablet_schema = tablet_schema;
+ _contexts[target_node_id].predicate =
+
SharedPredicate::create_shared(_contexts[target_node_id].get_field_index());
return Status::OK();
}
- std::shared_ptr<ColumnPredicate> get_predicate() {
+ std::shared_ptr<ColumnPredicate> get_predicate(int32_t target_node_id) {
std::shared_lock<std::shared_mutex> rlock(_rwlock);
- return _predicate;
+ check_target_node_id(target_node_id);
+ return _contexts.find(target_node_id)->second.predicate;
}
Status update(const Field& value);
@@ -86,72 +89,72 @@ public:
return _orderby_extrem;
}
- std::string get_col_name() const { return _col_name; }
+ std::string get_col_name(int32_t target_node_id) const {
+ check_target_node_id(target_node_id);
+ return _contexts.find(target_node_id)->second.col_name;
+ }
bool is_asc() const { return _is_asc; }
bool nulls_first() const { return _nulls_first; }
- bool target_is_slot() const { return true; }
+ bool target_is_slot(int32_t target_node_id) const {
+ check_target_node_id(target_node_id);
+ return _contexts.find(target_node_id)->second.target_is_slot();
+ }
+
+ const TExpr& get_texpr(int32_t target_node_id) const {
+ check_target_node_id(target_node_id);
+ return _contexts.find(target_node_id)->second.expr;
+ }
private:
+ void check_target_node_id(int32_t target_node_id) const {
+ if (!_contexts.contains(target_node_id)) {
+ std::string msg = "context target node ids: [";
+ bool first = true;
+ for (auto p : _contexts) {
+ if (first) {
+ first = false;
+ } else {
+ msg += ',';
+ }
+ msg += std::to_string(p.first);
+ }
+ msg += "], input target node is: " +
std::to_string(target_node_id);
+ DCHECK(false) << msg;
+ }
+ }
+ struct TargetContext {
+ TExpr expr;
+ std::string col_name;
+ TabletSchemaSPtr tablet_schema;
+ std::shared_ptr<ColumnPredicate> predicate;
+
+ int32_t get_field_index() {
+ return
tablet_schema->field_index(tablet_schema->column(col_name).unique_id());
+ }
+
+ bool target_is_slot() const { return expr.nodes[0].node_type ==
TExprNodeType::SLOT_REF; }
+ };
+
+ bool _init(PrimitiveType type);
+
mutable std::shared_mutex _rwlock;
+
+ bool _nulls_first;
+ bool _is_asc;
+ std::map<int32_t, TargetContext> _contexts;
+
Field _orderby_extrem {Field::Types::Null};
- std::shared_ptr<ColumnPredicate> _predicate;
- TabletSchemaSPtr _tablet_schema = nullptr;
Arena _predicate_arena;
std::function<std::string(const Field&)> _get_value_fn;
- bool _nulls_first = true;
- bool _is_asc;
std::function<ColumnPredicate*(const TabletColumn&, int, const
std::string&, bool,
vectorized::Arena*)>
_pred_constructor;
- bool _inited = false;
- std::string _col_name;
+ bool _detected_source = false;
+ bool _detected_target = false;
bool _has_value = false;
-
- template <PrimitiveType type>
- static std::string get_normal_value(const Field& field) {
- using ValueType = typename PrimitiveTypeTraits<type>::CppType;
- return cast_to_string<type, ValueType>(field.get<ValueType>(), 0);
- }
-
- static std::string get_date_value(const Field& field) {
- using ValueType = typename PrimitiveTypeTraits<TYPE_DATE>::CppType;
- ValueType value;
- Int64 v = field.get<Int64>();
- auto* p = (VecDateTimeValue*)&v;
- value.from_olap_date(p->to_olap_date());
- value.cast_to_date();
- return cast_to_string<TYPE_DATE, ValueType>(value, 0);
- }
-
- static std::string get_datetime_value(const Field& field) {
- using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIME>::CppType;
- ValueType value;
- Int64 v = field.get<Int64>();
- auto* p = (VecDateTimeValue*)&v;
- value.from_olap_datetime(p->to_olap_datetime());
- value.to_datetime();
- return cast_to_string<TYPE_DATETIME, ValueType>(value, 0);
- }
-
- static std::string get_decimalv2_value(const Field& field) {
- // can NOT use PrimitiveTypeTraits<TYPE_DECIMALV2>::CppType since
- // it is DecimalV2Value and Decimal128V2 can not convert to it
implicitly
- using ValueType = Decimal128V2::NativeType;
- auto v = field.get<DecimalField<Decimal128V2>>();
- // use TYPE_DECIMAL128I instead of TYPE_DECIMALV2 since v.get_scale()
- // is always 9 for DECIMALV2
- return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(),
v.get_scale());
- }
-
- template <PrimitiveType type>
- static std::string get_decimal_value(const Field& field) {
- using ValueType = typename PrimitiveTypeTraits<type>::CppType;
- auto v = field.get<DecimalField<ValueType>>();
- return cast_to_string<type, ValueType>(v.get_value(), v.get_scale());
- }
};
} // namespace vectorized
diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp
index ac560c2c7e1..504f6deabf3 100644
--- a/be/src/runtime/runtime_state.cpp
+++ b/be/src/runtime/runtime_state.cpp
@@ -93,14 +93,6 @@ RuntimeState::RuntimeState(const TPlanFragmentExecParams&
fragment_exec_params,
_query_ctx->runtime_filter_mgr()->set_runtime_filter_params(
fragment_exec_params.runtime_filter_params);
}
-
- if (_query_ctx) {
- if (fragment_exec_params.__isset.topn_filter_source_node_ids) {
-
_query_ctx->init_runtime_predicates(fragment_exec_params.topn_filter_source_node_ids);
- } else {
- _query_ctx->init_runtime_predicates({0});
- }
- }
}
RuntimeState::RuntimeState(const TUniqueId& instance_id, const TUniqueId&
query_id,
diff --git a/be/src/vec/exec/format/orc/vorc_reader.cpp
b/be/src/vec/exec/format/orc/vorc_reader.cpp
index f6a410cd81f..5c6a015d05c 100644
--- a/be/src/vec/exec/format/orc/vorc_reader.cpp
+++ b/be/src/vec/exec/format/orc/vorc_reader.cpp
@@ -2186,7 +2186,6 @@ Status
OrcReader::_rewrite_dict_conjuncts(std::vector<int32_t>& dict_codes, int
texpr_node.__set_node_type(TExprNodeType::BINARY_PRED);
texpr_node.__set_opcode(TExprOpcode::EQ);
texpr_node.__set_fn(fn);
- texpr_node.__set_child_type(TPrimitiveType::INT);
texpr_node.__set_num_children(2);
texpr_node.__set_is_nullable(is_nullable);
root = VectorizedFnCall::create_shared(texpr_node);
diff --git a/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp
b/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp
index 807f016cb43..426810ccbfc 100644
--- a/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp
+++ b/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp
@@ -911,7 +911,6 @@ Status
RowGroupReader::_rewrite_dict_conjuncts(std::vector<int32_t>& dict_codes,
texpr_node.__set_node_type(TExprNodeType::BINARY_PRED);
texpr_node.__set_opcode(TExprOpcode::EQ);
texpr_node.__set_fn(fn);
- texpr_node.__set_child_type(TPrimitiveType::INT);
texpr_node.__set_num_children(2);
texpr_node.__set_is_nullable(is_nullable);
root = VectorizedFnCall::create_shared(texpr_node);
diff --git a/be/src/vec/exec/scan/new_olap_scanner.cpp
b/be/src/vec/exec/scan/new_olap_scanner.cpp
index 70fafade3b7..4507afc2cb2 100644
--- a/be/src/vec/exec/scan/new_olap_scanner.cpp
+++ b/be/src/vec/exec/scan/new_olap_scanner.cpp
@@ -406,8 +406,10 @@ Status NewOlapScanner::_init_tablet_reader_params(
_tablet_reader_params.topn_filter_source_node_ids =
((pipeline::OlapScanLocalState*)_local_state)
->get_topn_filter_source_node_ids(_state, true);
- _tablet_reader_params.use_topn_opt =
- !_tablet_reader_params.topn_filter_source_node_ids.empty();
+ if (!_tablet_reader_params.topn_filter_source_node_ids.empty()) {
+ _tablet_reader_params.topn_filter_target_node_id =
+
((pipeline::OlapScanLocalState*)_local_state)->parent()->node_id();
+ }
}
}
diff --git a/be/src/vec/exec/vsort_node.cpp b/be/src/vec/exec/vsort_node.cpp
index 160690f7737..bd86776d50e 100644
--- a/be/src/vec/exec/vsort_node.cpp
+++ b/be/src/vec/exec/vsort_node.cpp
@@ -64,8 +64,7 @@ Status VSortNode::init(const TPlanNode& tnode, RuntimeState*
state) {
// exclude cases which incoming blocks has string column which is
sensitive to operations like
// `filter` and `memcpy`
if (_limit > 0 && _limit + _offset < HeapSorter::HEAP_SORT_THRESHOLD &&
- (tnode.sort_node.sort_info.use_two_phase_read ||
tnode.sort_node.use_topn_opt ||
- !row_desc.has_varlen_slots())) {
+ (tnode.sort_node.sort_info.use_two_phase_read ||
!row_desc.has_varlen_slots())) {
_sorter = HeapSorter::create_unique(_vsort_exec_exprs, _limit,
_offset, _pool,
_is_asc_order, _nulls_first,
row_desc);
_reuse_mem = false;
@@ -79,31 +78,6 @@ Status VSortNode::init(const TPlanNode& tnode, RuntimeState*
state) {
FullSorter::create_unique(_vsort_exec_exprs, _limit, _offset,
_pool, _is_asc_order,
_nulls_first, row_desc, state,
_runtime_profile.get());
}
- // init runtime predicate
- _use_topn_opt = tnode.sort_node.use_topn_opt;
- if (_use_topn_opt) {
- auto* query_ctx = state->get_query_ctx();
- auto first_sort_expr_node =
tnode.sort_node.sort_info.ordering_exprs[0].nodes[0];
- if (first_sort_expr_node.node_type == TExprNodeType::SLOT_REF) {
- auto first_sort_slot = first_sort_expr_node.slot_ref;
- for (auto* tuple_desc :
this->intermediate_row_desc().tuple_descriptors()) {
- if (tuple_desc->id() != first_sort_slot.tuple_id) {
- continue;
- }
- for (auto* slot : tuple_desc->slots()) {
- if (slot->id() == first_sort_slot.slot_id) {
-
RETURN_IF_ERROR(query_ctx->get_runtime_predicate(_id).init(
- slot->type().type, _nulls_first[0],
_is_asc_order[0],
- slot->col_name()));
- break;
- }
- }
- }
- }
- if (!query_ctx->get_runtime_predicate(_id).inited()) {
- return Status::InternalError("runtime predicate is not properly
initialized");
- }
- }
_sorter->init_profile(_runtime_profile.get());
return Status::OK();
@@ -140,18 +114,6 @@ Status VSortNode::sink(RuntimeState* state,
vectorized::Block* input_block, bool
if (input_block->rows() > 0) {
RETURN_IF_ERROR(_sorter->append_block(input_block));
RETURN_IF_CANCELLED(state);
-
- if (_use_topn_opt) {
- auto& predicate =
state->get_query_ctx()->get_runtime_predicate(_id);
- if (predicate.need_update()) {
- vectorized::Field new_top = _sorter->get_top_value();
- if (!new_top.is_null() && new_top != old_top) {
- auto* query_ctx = state->get_query_ctx();
-
RETURN_IF_ERROR(query_ctx->get_runtime_predicate(_id).update(new_top));
- old_top = std::move(new_top);
- }
- }
- }
if (!_reuse_mem) {
input_block->clear();
}
diff --git a/be/src/vec/exec/vsort_node.h b/be/src/vec/exec/vsort_node.h
index 5b13eb8fa93..4f009fc5c0d 100644
--- a/be/src/vec/exec/vsort_node.h
+++ b/be/src/vec/exec/vsort_node.h
@@ -89,7 +89,6 @@ private:
RuntimeProfile::Counter* _memory_usage_counter = nullptr;
RuntimeProfile::Counter* _sort_blocks_memory_usage = nullptr;
- bool _use_topn_opt = false;
// topn top value
Field old_top {Field::Types::Null};
diff --git a/be/src/vec/exprs/vtopn_pred.h b/be/src/vec/exprs/vtopn_pred.h
index 326ceaf0f2c..675c8fb293c 100644
--- a/be/src/vec/exprs/vtopn_pred.h
+++ b/be/src/vec/exprs/vtopn_pred.h
@@ -19,6 +19,8 @@
#include <gen_cpp/types.pb.h>
+#include <utility>
+
#include "runtime/query_context.h"
#include "runtime/runtime_predicate.h"
#include "runtime/runtime_state.h"
@@ -36,21 +38,24 @@ class VTopNPred : public VExpr {
ENABLE_FACTORY_CREATOR(VTopNPred);
public:
- VTopNPred(const TExprNode& node, int source_node_id)
+ VTopNPred(const TExprNode& node, int source_node_id, VExprContextSPtr
target_ctx)
: VExpr(node),
_source_node_id(source_node_id),
- _expr_name(fmt::format("VTopNPred(source_node_id={})",
_source_node_id)) {}
+ _expr_name(fmt::format("VTopNPred(source_node_id={})",
_source_node_id)),
+ _target_ctx(std::move(target_ctx)) {}
- // TODO: support general expr
- static Status create_vtopn_pred(SlotDescriptor* slot_desc, int
source_node_id,
+ static Status create_vtopn_pred(const TExpr& target_expr, int
source_node_id,
vectorized::VExprSPtr& expr) {
+ vectorized::VExprContextSPtr target_ctx;
+ RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(target_expr,
target_ctx));
+
TExprNode node;
node.__set_node_type(TExprNodeType::FUNCTION_CALL);
node.__set_type(create_type_desc(PrimitiveType::TYPE_BOOLEAN));
- node.__set_is_nullable(slot_desc->is_nullable());
- expr = vectorized::VTopNPred::create_shared(node, source_node_id);
+ node.__set_is_nullable(target_ctx->root()->is_nullable());
+ expr = vectorized::VTopNPred::create_shared(node, source_node_id,
target_ctx);
- expr->add_child(VSlotRef::create_shared(slot_desc));
+ expr->add_child(target_ctx->root());
return Status::OK();
}
@@ -112,5 +117,6 @@ private:
std::string _expr_name;
RuntimePredicate* _predicate = nullptr;
FunctionBasePtr _function;
+ VExprContextSPtr _target_ctx;
};
} // namespace doris::vectorized
diff --git a/be/src/vec/olap/vcollect_iterator.cpp
b/be/src/vec/olap/vcollect_iterator.cpp
index 3ce1869546c..61050979b84 100644
--- a/be/src/vec/olap/vcollect_iterator.cpp
+++ b/be/src/vec/olap/vcollect_iterator.cpp
@@ -414,7 +414,7 @@ Status VCollectIterator::_topn_next(Block* block) {
}
// update runtime_predicate
- if (_reader->_reader_context.use_topn_opt && changed &&
+ if (!_reader->_reader_context.topn_filter_source_node_ids.empty()
&& changed &&
sorted_row_pos.size() >= _topn_limit) {
// get field value from column
size_t last_sorted_row = *sorted_row_pos.rbegin();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
index ec54cc4cc4c..546e73e990d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
@@ -60,6 +60,7 @@ import
org.apache.doris.nereids.trees.plans.physical.PhysicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalResultSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSqlCache;
+import org.apache.doris.nereids.trees.plans.physical.TopnFilter;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.Planner;
import org.apache.doris.planner.RuntimeFilter;
@@ -716,4 +717,9 @@ public class NereidsPlanner extends Planner {
task.run();
}
}
+
+ @Override
+ public List<TopnFilter> getTopnFilters() {
+ return cascadesContext.getTopnFilterContext().getTopnFilters();
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index e40eead5340..ca1d2a53328 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -215,7 +215,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
@@ -620,10 +619,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
expr ->
runtimeFilterGenerator.translateRuntimeFilterTarget(expr, finalScanNode,
context)
)
);
- if (context.getTopnFilterContext().isTopnFilterTarget(fileScan)) {
- context.getTopnFilterContext().addLegacyTarget(fileScan, scanNode);
- }
-
+ context.getTopnFilterContext().translateTarget(fileScan, scanNode,
context);
Utils.execWithUncheckedException(scanNode::finalizeForNereids);
// Create PlanFragment
DataPartition dataPartition = DataPartition.RANDOM;
@@ -668,9 +664,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
expr ->
runtimeFilterGenerator.translateRuntimeFilterTarget(expr, esScanNode, context)
)
);
- if (context.getTopnFilterContext().isTopnFilterTarget(esScan)) {
- context.getTopnFilterContext().addLegacyTarget(esScan, esScanNode);
- }
+ context.getTopnFilterContext().translateTarget(esScan, esScanNode,
context);
Utils.execWithUncheckedException(esScanNode::finalizeForNereids);
DataPartition dataPartition = DataPartition.RANDOM;
PlanFragment planFragment = new PlanFragment(context.nextFragmentId(),
esScanNode, dataPartition);
@@ -695,9 +689,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
expr ->
runtimeFilterGenerator.translateRuntimeFilterTarget(expr, jdbcScanNode, context)
)
);
- if (context.getTopnFilterContext().isTopnFilterTarget(jdbcScan)) {
- context.getTopnFilterContext().addLegacyTarget(jdbcScan,
jdbcScanNode);
- }
+ context.getTopnFilterContext().translateTarget(jdbcScan, jdbcScanNode,
context);
Utils.execWithUncheckedException(jdbcScanNode::finalizeForNereids);
DataPartition dataPartition = DataPartition.RANDOM;
PlanFragment planFragment = new PlanFragment(context.nextFragmentId(),
jdbcScanNode, dataPartition);
@@ -722,10 +714,9 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
expr ->
runtimeFilterGenerator.translateRuntimeFilterTarget(expr, odbcScanNode, context)
)
);
- if (context.getTopnFilterContext().isTopnFilterTarget(odbcScan)) {
- context.getTopnFilterContext().addLegacyTarget(odbcScan,
odbcScanNode);
- }
+ context.getTopnFilterContext().translateTarget(odbcScan, odbcScanNode,
context);
Utils.execWithUncheckedException(odbcScanNode::finalizeForNereids);
+ context.getTopnFilterContext().translateTarget(odbcScan, odbcScanNode,
context);
DataPartition dataPartition = DataPartition.RANDOM;
PlanFragment planFragment = new PlanFragment(context.nextFragmentId(),
odbcScanNode, dataPartition);
context.addPlanFragment(planFragment);
@@ -801,10 +792,8 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
expr, olapScanNode, context)
)
);
+ context.getTopnFilterContext().translateTarget(olapScan, olapScanNode,
context);
olapScanNode.setPushDownAggNoGrouping(context.getRelationPushAggOp(olapScan.getRelationId()));
- if (context.getTopnFilterContext().isTopnFilterTarget(olapScan)) {
- context.getTopnFilterContext().addLegacyTarget(olapScan,
olapScanNode);
- }
// TODO: we need to remove all finalizeForNereids
olapScanNode.finalizeForNereids();
// Create PlanFragment
@@ -828,9 +817,6 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
PhysicalDeferMaterializeOlapScan deferMaterializeOlapScan,
PlanTranslatorContext context) {
PlanFragment planFragment =
visitPhysicalOlapScan(deferMaterializeOlapScan.getPhysicalOlapScan(), context);
OlapScanNode olapScanNode = (OlapScanNode) planFragment.getPlanRoot();
- if
(context.getTopnFilterContext().isTopnFilterTarget(deferMaterializeOlapScan)) {
-
context.getTopnFilterContext().addLegacyTarget(deferMaterializeOlapScan,
olapScanNode);
- }
TupleDescriptor tupleDescriptor =
context.getTupleDesc(olapScanNode.getTupleId());
for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
if (deferMaterializeOlapScan.getDeferMaterializeSlotIds()
@@ -839,6 +825,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
}
}
context.createSlotDesc(tupleDescriptor,
deferMaterializeOlapScan.getColumnIdSlot());
+
context.getTopnFilterContext().translateTarget(deferMaterializeOlapScan,
olapScanNode, context);
return planFragment;
}
@@ -2149,19 +2136,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
sortNode.setOffset(topN.getOffset());
sortNode.setLimit(topN.getLimit());
if (context.getTopnFilterContext().isTopnFilterSource(topN)) {
- sortNode.setUseTopnOpt(true);
- context.getTopnFilterContext().getTargets(topN).forEach(
- relation -> {
- Optional<ScanNode> legacyScan =
-
context.getTopnFilterContext().getLegacyScanNode(relation);
- Preconditions.checkState(legacyScan.isPresent(),
- "cannot find ScanNode for topn filter:\n"
- + "relation: %s \n%s",
- relation.toString(),
- context.getTopnFilterContext().toString());
- legacyScan.get().addTopnFilterSortNode(sortNode);
- }
- );
+ context.getTopnFilterContext().translateSource(topN, sortNode);
}
// push sort to scan opt
if (sortNode.getChild(0) instanceof OlapScanNode) {
@@ -2212,16 +2187,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
sortNode.setUseTwoPhaseReadOpt(true);
sortNode.getSortInfo().setUseTwoPhaseRead();
if (context.getTopnFilterContext().isTopnFilterSource(topN)) {
- sortNode.setUseTopnOpt(true);
- context.getTopnFilterContext().getTargets(topN).forEach(
- relation -> {
- Optional<ScanNode> legacyScan =
-
context.getTopnFilterContext().getLegacyScanNode(relation);
- Preconditions.checkState(legacyScan.isPresent(),
- "cannot find ScanNode for topn filter");
- legacyScan.get().addTopnFilterSortNode(sortNode);
- }
- );
+ context.getTopnFilterContext().translateSource(topN, sortNode);
}
TupleDescriptor tupleDescriptor =
sortNode.getSortInfo().getSortTupleDescriptor();
for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
index ec1e52d6426..fd3c794317d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
@@ -18,11 +18,11 @@
package org.apache.doris.nereids.processor.post;
import org.apache.doris.nereids.CascadesContext;
+import
org.apache.doris.nereids.processor.post.TopnFilterPushDownVisitor.PushDownContext;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.SortPhase;
-import org.apache.doris.nereids.trees.plans.algebra.Join;
+import org.apache.doris.nereids.trees.plans.algebra.TopN;
import
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
import
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
@@ -32,10 +32,8 @@ import
org.apache.doris.nereids.trees.plans.physical.PhysicalOdbcScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.qe.ConnectContext;
-import java.util.Optional;
/**
* topN opt
* refer to:
@@ -48,33 +46,39 @@ import java.util.Optional;
public class TopNScanOpt extends PlanPostProcessor {
@Override
public PhysicalTopN<? extends Plan> visitPhysicalTopN(PhysicalTopN<?
extends Plan> topN, CascadesContext ctx) {
- Optional<PhysicalRelation> scanOpt = findScanForTopnFilter(topN);
- scanOpt.ifPresent(scan ->
ctx.getTopnFilterContext().addTopnFilter(topN, scan));
topN.child().accept(this, ctx);
+ if (checkTopN(topN)) {
+ TopnFilterPushDownVisitor pusher = new
TopnFilterPushDownVisitor(ctx.getTopnFilterContext());
+ TopnFilterPushDownVisitor.PushDownContext pushdownContext = new
PushDownContext(topN,
+ topN.getOrderKeys().get(0).getExpr(),
+ topN.getOrderKeys().get(0).isNullFirst());
+ topN.accept(pusher, pushdownContext);
+ }
return topN;
}
- @Override
- public Plan
visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan>
topN,
- CascadesContext context) {
- Optional<PhysicalRelation> scanOpt =
findScanForTopnFilter(topN.getPhysicalTopN());
- scanOpt.ifPresent(scan ->
context.getTopnFilterContext().addTopnFilter(topN, scan));
- topN.child().accept(this, context);
- return topN;
- }
-
- private Optional<PhysicalRelation> findScanForTopnFilter(PhysicalTopN<?
extends Plan> topN) {
- if (topN.getSortPhase() != SortPhase.LOCAL_SORT) {
- return Optional.empty();
+ boolean checkTopN(TopN topN) {
+ if (!(topN instanceof PhysicalTopN) && !(topN instanceof
PhysicalDeferMaterializeTopN)) {
+ return false;
+ }
+ if (topN instanceof PhysicalTopN
+ && ((PhysicalTopN) topN).getSortPhase() !=
SortPhase.LOCAL_SORT) {
+ return false;
+ } else {
+ if (topN instanceof PhysicalDeferMaterializeTopN
+ && ((PhysicalDeferMaterializeTopN) topN).getSortPhase() !=
SortPhase.LOCAL_SORT) {
+ return false;
+ }
}
+
if (topN.getOrderKeys().isEmpty()) {
- return Optional.empty();
+ return false;
}
// topn opt
long topNOptLimitThreshold = getTopNOptLimitThreshold();
if (topNOptLimitThreshold == -1 || topN.getLimit() >
topNOptLimitThreshold) {
- return Optional.empty();
+ return false;
}
// if firstKey's column is not present, it means the firstKey is not
an original column from scan node
// for example: "select cast(k1 as INT) as id from tbl1 order by id
limit 2;" the firstKey "id" is
@@ -84,64 +88,28 @@ public class TopNScanOpt extends PlanPostProcessor {
// see Alias::toSlot() method to get how column info is passed around
by alias of slotReference
Expression firstKey = topN.getOrderKeys().get(0).getExpr();
if (!firstKey.isColumnFromTable()) {
- return Optional.empty();
+ return false;
}
+
if (firstKey.getDataType().isFloatType()
|| firstKey.getDataType().isDoubleType()) {
- return Optional.empty();
+ return false;
}
-
- if (! (firstKey instanceof SlotReference)) {
- return Optional.empty();
- }
-
- boolean nullsFirst = topN.getOrderKeys().get(0).isNullFirst();
- return findScanNodeBySlotReference(topN, (SlotReference) firstKey,
nullsFirst);
+ return true;
}
- private Optional<PhysicalRelation> findScanNodeBySlotReference(Plan root,
SlotReference slot, boolean nullsFirst) {
- if (root instanceof PhysicalWindow) {
- return Optional.empty();
- }
-
- if (root instanceof PhysicalRelation) {
- if (root.getOutputSet().contains(slot) &&
supportPhysicalRelations((PhysicalRelation) root)) {
- return Optional.of((PhysicalRelation) root);
- } else {
- return Optional.empty();
- }
- }
-
- Optional<PhysicalRelation> target;
- if (root instanceof Join) {
- Join join = (Join) root;
- if (nullsFirst && join.getJoinType().isOuterJoin()) {
- // in fact, topn-filter can be pushed down to the left child
of leftOuterJoin
- // and to the right child of rightOuterJoin.
- // but we have rule to push topn down to the left/right side.
and topn-filter
- // will be generated according to the inferred topn node.
- return Optional.empty();
- }
- // try to push to both left and right child
- if (root.child(0).getOutputSet().contains(slot)) {
- target = findScanNodeBySlotReference(root.child(0), slot,
nullsFirst);
- } else {
- target = findScanNodeBySlotReference(root.child(1), slot,
nullsFirst);
- }
- return target;
- }
-
- if (!root.children().isEmpty()) {
- // TODO for set operator, topn-filter can be pushed down to all of
its children.
- Plan child = root.child(0);
- if (child.getOutputSet().contains(slot)) {
- target = findScanNodeBySlotReference(child, slot, nullsFirst);
- if (target.isPresent()) {
- return target;
- }
- }
+ @Override
+ public Plan
visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan>
topN,
+ CascadesContext ctx) {
+ topN.child().accept(this, ctx);
+ if (checkTopN(topN)) {
+ TopnFilterPushDownVisitor pusher = new
TopnFilterPushDownVisitor(ctx.getTopnFilterContext());
+ TopnFilterPushDownVisitor.PushDownContext pushdownContext = new
PushDownContext(topN,
+ topN.getOrderKeys().get(0).getExpr(),
+ topN.getOrderKeys().get(0).isNullFirst());
+ topN.accept(pusher, pushdownContext);
}
- return Optional.empty();
+ return topN;
}
private long getTopNOptLimitThreshold() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
index 6a4fe3123df..fceec21ee7e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
@@ -17,78 +17,76 @@
package org.apache.doris.nereids.processor.post;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
+import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
+import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.algebra.TopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
+import org.apache.doris.nereids.trees.plans.physical.TopnFilter;
import org.apache.doris.planner.ScanNode;
import org.apache.doris.planner.SortNode;
+import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
/**
* topN runtime filter context
*/
public class TopnFilterContext {
- private final Map<TopN, List<PhysicalRelation>> filters =
Maps.newHashMap();
- private final Set<TopN> sources = Sets.newHashSet();
- private final Set<PhysicalRelation> targets = Sets.newHashSet();
- private final Map<PhysicalRelation, ScanNode> legacyTargetsMap =
Maps.newHashMap();
- private final Map<TopN, SortNode> legacySourceMap = Maps.newHashMap();
+ private final Map<TopN, TopnFilter> filters = Maps.newHashMap();
/**
* add topN filter
*/
- public void addTopnFilter(TopN topn, PhysicalRelation scan) {
- targets.add(scan);
- sources.add(topn);
-
- List<PhysicalRelation> targets = filters.get(topn);
- if (targets == null) {
- filters.put(topn, Lists.newArrayList(scan));
+ public void addTopnFilter(TopN topn, PhysicalRelation scan, Expression
expr) {
+ TopnFilter filter = filters.get(topn);
+ if (filter == null) {
+ filters.put(topn, new TopnFilter(topn, scan, expr));
} else {
- targets.add(scan);
+ filter.addTarget(scan, expr);
}
}
- /**
- * find the corresponding sortNode for topn filter
- */
- public Optional<ScanNode> getLegacyScanNode(PhysicalRelation scan) {
- return legacyTargetsMap.containsKey(scan)
- ? Optional.of(legacyTargetsMap.get(scan))
- : Optional.empty();
- }
-
- public Optional<SortNode> getLegacySortNode(TopN topn) {
- return legacyTargetsMap.containsKey(topn)
- ? Optional.of(legacySourceMap.get(topn))
- : Optional.empty();
- }
-
public boolean isTopnFilterSource(TopN topn) {
- return sources.contains(topn);
- }
-
- public boolean isTopnFilterTarget(PhysicalRelation relation) {
- return targets.contains(relation);
+ return filters.containsKey(topn);
}
- public void addLegacySource(TopN topn, SortNode sort) {
- legacySourceMap.put(topn, sort);
+ public List<TopnFilter> getTopnFilters() {
+ return Lists.newArrayList(filters.values());
}
- public void addLegacyTarget(PhysicalRelation relation, ScanNode legacy) {
- legacyTargetsMap.put(relation, legacy);
+ /**
+ * translate topn-filter
+ */
+ public void translateTarget(PhysicalRelation relation, ScanNode legacyScan,
+ PlanTranslatorContext translatorContext) {
+ for (TopnFilter filter : filters.values()) {
+ if (filter.hasTargetRelation(relation)) {
+ Expr expr =
ExpressionTranslator.translate(filter.targets.get(relation), translatorContext);
+ filter.legacyTargets.put(legacyScan, expr);
+ }
+ }
}
- public List<PhysicalRelation> getTargets(TopN topn) {
- return filters.get(topn);
+ /**
+ * translate topn-filter
+ */
+ public void translateSource(TopN topn, SortNode sortNode) {
+ TopnFilter filter = filters.get(topn);
+ if (filter == null) {
+ return;
+ }
+ filter.legacySortNode = sortNode;
+ sortNode.setUseTopnOpt(true);
+ Preconditions.checkArgument(!filter.legacyTargets.isEmpty(), "missing
targets: " + filter);
+ for (ScanNode scan : filter.legacyTargets.keySet()) {
+ scan.addTopnFilterSortNode(sortNode);
+ }
}
/**
@@ -100,10 +98,8 @@ public class TopnFilterContext {
String arrow = " -> ";
builder.append("filters:\n");
for (TopN topn : filters.keySet()) {
- builder.append(indent).append(topn.toString()).append("\n");
builder.append(indent).append(arrow).append(filters.get(topn)).append("\n");
}
return builder.toString();
-
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterPushDownVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterPushDownVisitor.java
new file mode 100644
index 00000000000..6962f3fddf2
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterPushDownVisitor.java
@@ -0,0 +1,233 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import
org.apache.doris.nereids.processor.post.TopnFilterPushDownVisitor.PushDownContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.TopN;
+import org.apache.doris.nereids.trees.plans.algebra.Union;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
+import
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalOdbcScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
+import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+
+import com.google.common.collect.Maps;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * push down topn filter
+ */
+public class TopnFilterPushDownVisitor extends PlanVisitor<Boolean,
PushDownContext> {
+ private TopnFilterContext topnFilterContext;
+
+ public TopnFilterPushDownVisitor(TopnFilterContext topnFilterContext) {
+ this.topnFilterContext = topnFilterContext;
+ }
+
+ /**
+ * topn filter push-down context
+ */
+ public static class PushDownContext {
+ final Expression probeExpr;
+ final TopN topn;
+ final boolean nullsFirst;
+
+ public PushDownContext(TopN topn, Expression probeExpr, boolean
nullsFirst) {
+ this.topn = topn;
+ this.probeExpr = probeExpr;
+ this.nullsFirst = nullsFirst;
+ }
+
+ public PushDownContext withNewProbeExpression(Expression newProbe) {
+ return new PushDownContext(topn, newProbe, nullsFirst);
+ }
+
+ }
+
+ @Override
+ public Boolean visit(Plan plan, PushDownContext ctx) {
+ boolean pushed = false;
+ for (Plan child : plan.children()) {
+ pushed |= child.accept(this, ctx);
+ }
+ return pushed;
+ }
+
+ @Override
+ public Boolean visitPhysicalProject(
+ PhysicalProject<? extends Plan> project, PushDownContext ctx) {
+ // project ( A+1 as x)
+ // probeExpr: abs(x) => abs(A+1)
+ PushDownContext ctxProjectProbeExpr = ctx;
+ Map<Expression, Expression> replaceMap = Maps.newHashMap();
+ for (NamedExpression ne : project.getProjects()) {
+ if (ne instanceof Alias &&
ctx.probeExpr.getInputSlots().contains(ne.toSlot())) {
+
replaceMap.put(ctx.probeExpr.getInputSlots().iterator().next(), ((Alias)
ne).child());
+ }
+ }
+ if (! replaceMap.isEmpty()) {
+ Expression newProbeExpr =
ctx.probeExpr.accept(ExpressionVisitors.EXPRESSION_MAP_REPLACER, replaceMap);
+ ctxProjectProbeExpr = ctx.withNewProbeExpression(newProbeExpr);
+ }
+ return project.child().accept(this, ctxProjectProbeExpr);
+ }
+
+ @Override
+ public Boolean visitPhysicalSetOperation(
+ PhysicalSetOperation setOperation, PushDownContext ctx) {
+ boolean pushedDown = pushDownFilterToSetOperatorChild(setOperation,
ctx, 0);
+
+ if (setOperation instanceof Union) {
+ for (int i = 1; i < setOperation.children().size(); i++) {
+ // push down to the other children
+ pushedDown |= pushDownFilterToSetOperatorChild(setOperation,
ctx, i);
+ }
+ }
+ return pushedDown;
+ }
+
+ private Boolean pushDownFilterToSetOperatorChild(PhysicalSetOperation
setOperation,
+ PushDownContext ctx, int childIdx) {
+ Plan child = setOperation.child(childIdx);
+ Map<Expression, Expression> replaceMap = Maps.newHashMap();
+
+ List<NamedExpression> setOutputs = setOperation.getOutputs();
+ for (int i = 0; i < setOutputs.size(); i++) {
+ replaceMap.put(setOutputs.get(i).toSlot(),
+
setOperation.getRegularChildrenOutputs().get(childIdx).get(i));
+ }
+ if (!replaceMap.isEmpty()) {
+ Expression newProbeExpr =
ctx.probeExpr.accept(ExpressionVisitors.EXPRESSION_MAP_REPLACER,
+ replaceMap);
+ PushDownContext childPushDownContext =
ctx.withNewProbeExpression(newProbeExpr);
+ return child.accept(this, childPushDownContext);
+ }
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalCTEAnchor(PhysicalCTEAnchor<? extends Plan, ?
extends Plan> anchor,
+ PushDownContext ctx) {
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalCTEProducer(PhysicalCTEProducer<? extends
Plan> anchor,
+ PushDownContext ctx) {
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalTopN(PhysicalTopN<? extends Plan> topn,
PushDownContext ctx) {
+ if (topn.equals(ctx.topn)) {
+ return topn.child().accept(this, ctx);
+ }
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalWindow(PhysicalWindow<? extends Plan> window,
PushDownContext ctx) {
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ?
extends Plan> join,
+ PushDownContext ctx) {
+ if (ctx.nullsFirst && join.getJoinType().isOuterJoin()) {
+ // topn-filter can be pushed down to the left child of
leftOuterJoin
+ // and to the right child of rightOuterJoin,
+ // but PushDownTopNThroughJoin rule already pushes topn to the
left and right side.
+ // the topn-filter will be generated according to the inferred
topn node.
+ return false;
+ }
+ if
(join.left().getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) {
+ return join.left().accept(this, ctx);
+ }
+ if
(join.right().getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) {
+ // expand expr to the other side of hash join condition:
+ // T1 join T2 on T1.x = T2.y order by T2.y limit 10
+ // we rewrite probeExpr from T2.y to T1.x and try to push T1.x to
left side
+ for (Expression conj : join.getHashJoinConjuncts()) {
+ if (ctx.probeExpr.equals(conj.child(1))) {
+ // push to left child. right child is blocking operator,
do not need topn-filter
+ PushDownContext newCtx =
ctx.withNewProbeExpression(conj.child(0));
+ return join.left().accept(this, newCtx);
+ }
+ }
+ }
+ // topn key is combination of left and right
+ // select * from (select T1.A+T2.B as x from T1 join T2) T3 order by x
limit 1;
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalNestedLoopJoin(PhysicalNestedLoopJoin<?
extends Plan, ? extends Plan> join,
+ PushDownContext ctx) {
+ if (ctx.nullsFirst && join.getJoinType().isOuterJoin()) {
+ // topn-filter can be pushed down to the left child of
leftOuterJoin
+ // and to the right child of rightOuterJoin,
+ // but PushDownTopNThroughJoin rule already pushes topn to the
left and right side.
+ // the topn-filter will be generated according to the inferred
topn node.
+ return false;
+ }
+ // push to left child. right child is blocking operator, do not need
topn-filter
+ if
(join.left().getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) {
+ return join.left().accept(this, ctx);
+ }
+ return false;
+ }
+
+ @Override
+ public Boolean visitPhysicalRelation(PhysicalRelation relation,
PushDownContext ctx) {
+ if (supportPhysicalRelations(relation)
+ &&
relation.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) {
+ if
(relation.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) {
+ topnFilterContext.addTopnFilter(ctx.topn, relation,
ctx.probeExpr);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private boolean supportPhysicalRelations(PhysicalRelation relation) {
+ return relation instanceof PhysicalOlapScan
+ || relation instanceof PhysicalOdbcScan
+ || relation instanceof PhysicalEsScan
+ || relation instanceof PhysicalFileScan
+ || relation instanceof PhysicalJdbcScan
+ || relation instanceof PhysicalDeferMaterializeOlapScan;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/TopnFilter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/TopnFilter.java
new file mode 100644
index 00000000000..6c8e2187b48
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/TopnFilter.java
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans.physical;
+
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.plans.algebra.TopN;
+import org.apache.doris.planner.ScanNode;
+import org.apache.doris.planner.SortNode;
+import org.apache.doris.thrift.TTopnFilterDesc;
+
+import com.google.common.collect.Maps;
+
+import java.util.Map;
+
+/**
+ * topn filter
+ */
+public class TopnFilter {
+ public TopN topn;
+ public SortNode legacySortNode;
+ public Map<PhysicalRelation, Expression> targets = Maps.newHashMap();
+ public Map<ScanNode, Expr> legacyTargets = Maps.newHashMap();
+
+ public TopnFilter(TopN topn, PhysicalRelation rel, Expression expr) {
+ this.topn = topn;
+ targets.put(rel, expr);
+ }
+
+ public void addTarget(PhysicalRelation rel, Expression expr) {
+ targets.put(rel, expr);
+ }
+
+ public boolean hasTargetRelation(PhysicalRelation rel) {
+ return targets.containsKey(rel);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder();
+ builder.append(topn).append("->[ ");
+ for (PhysicalRelation rel : targets.keySet()) {
+
builder.append("(").append(rel).append(":").append(targets.get(rel)).append(")
");
+ }
+ builder.append("]");
+ return builder.toString();
+ }
+
+ /**
+ * to thrift
+ */
+ public TTopnFilterDesc toThrift() {
+ TTopnFilterDesc tFilter = new TTopnFilterDesc();
+ tFilter.setSourceNodeId(legacySortNode.getId().asInt());
+ tFilter.setIsAsc(topn.getOrderKeys().get(0).isAsc());
+ tFilter.setNullFirst(topn.getOrderKeys().get(0).isNullFirst());
+ for (ScanNode scan : legacyTargets.keySet()) {
+ tFilter.putToTargetNodeIdToTargetExpr(scan.getId().asInt(),
+ legacyTargets.get(scan).treeToThrift());
+ }
+ return tFilter;
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java
b/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java
index e64a0c85818..ce47ab5ac67 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java
@@ -24,6 +24,7 @@ import org.apache.doris.common.UserException;
import org.apache.doris.common.profile.PlanTreeBuilder;
import org.apache.doris.common.profile.PlanTreePrinter;
import org.apache.doris.nereids.PlannerHook;
+import org.apache.doris.nereids.trees.plans.physical.TopnFilter;
import org.apache.doris.qe.ResultSet;
import org.apache.doris.thrift.TQueryOptions;
@@ -131,4 +132,8 @@ public abstract class Planner {
public abstract Optional<ResultSet> handleQueryInFe(StatementBase
parsedStmt);
public abstract void addHook(PlannerHook hook);
+
+ public List<TopnFilter> getTopnFilters() {
+ return Lists.newArrayList();
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
index bab1c1decbb..ef071e78d39 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
@@ -42,6 +42,7 @@ import org.apache.doris.metric.MetricRepo;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.stats.StatsErrorEstimator;
+import org.apache.doris.nereids.trees.plans.physical.TopnFilter;
import org.apache.doris.planner.DataPartition;
import org.apache.doris.planner.DataSink;
import org.apache.doris.planner.DataStreamSink;
@@ -63,7 +64,6 @@ import org.apache.doris.planner.RuntimeFilter;
import org.apache.doris.planner.RuntimeFilterId;
import org.apache.doris.planner.ScanNode;
import org.apache.doris.planner.SetOperationNode;
-import org.apache.doris.planner.SortNode;
import org.apache.doris.planner.UnionNode;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.PExecPlanFragmentResult;
@@ -117,6 +117,7 @@ import org.apache.doris.thrift.TScanRangeLocations;
import org.apache.doris.thrift.TScanRangeParams;
import org.apache.doris.thrift.TStatusCode;
import org.apache.doris.thrift.TTabletCommitInfo;
+import org.apache.doris.thrift.TTopnFilterDesc;
import org.apache.doris.thrift.TUniqueId;
import com.google.common.base.Preconditions;
@@ -276,6 +277,7 @@ public class Coordinator implements CoordInterface {
public Map<RuntimeFilterId, List<FRuntimeFilterTargetParam>>
ridToTargetParam = Maps.newHashMap();
// The runtime filter that expects the instance to be used
public List<RuntimeFilter> assignedRuntimeFilters = new ArrayList<>();
+ public List<TopnFilter> topnFilters = new ArrayList<>();
// Runtime filter ID to the builder instance number
public Map<RuntimeFilterId, Integer> ridToBuilderNum = Maps.newHashMap();
private ConnectContext context;
@@ -356,6 +358,7 @@ public class Coordinator implements CoordInterface {
nextInstanceId.setHi(queryId.hi);
nextInstanceId.setLo(queryId.lo + 1);
this.assignedRuntimeFilters = planner.getRuntimeFilters();
+ this.topnFilters = planner.getTopnFilters();
this.executionProfile = new ExecutionProfile(queryId, fragments);
}
@@ -3783,12 +3786,6 @@ public class Coordinator implements CoordInterface {
int rate =
Math.min(Config.query_colocate_join_memory_limit_penalty_factor,
instanceExecParams.size());
memLimit = queryOptions.getMemLimit() / rate;
}
- Set<Integer> topnFilterSources = Sets.newLinkedHashSet();
- for (ScanNode scanNode : scanNodes) {
- for (SortNode sortNode : scanNode.getTopnFilterSortNodes()) {
- topnFilterSources.add(sortNode.getId().asInt());
- }
- }
Map<TNetworkAddress, TPipelineFragmentParams> res = new HashMap();
Map<TNetworkAddress, Integer> instanceIdx = new HashMap();
@@ -3859,11 +3856,12 @@ public class Coordinator implements CoordInterface {
localParams.setBackendNum(backendNum++);
localParams.setRuntimeFilterParams(new TRuntimeFilterParams());
localParams.runtime_filter_params.setRuntimeFilterMergeAddr(runtimeFilterMergeAddr);
- if (!topnFilterSources.isEmpty()) {
- // topn_filter_source_node_ids is used by nereids not by
legacy planner.
- // if there is no topnFilterSources, do not set it.
- // topn_filter_source_node_ids=null means legacy planner
- localParams.topn_filter_source_node_ids =
Lists.newArrayList(topnFilterSources);
+ if (!topnFilters.isEmpty()) {
+ List<TTopnFilterDesc> filterDescs = new ArrayList<>();
+ for (TopnFilter filter : topnFilters) {
+ filterDescs.add(filter.toThrift());
+ }
+ localParams.setTopnFilterDescs(filterDescs);
}
if
(instanceExecParam.instanceId.equals(runtimeFilterMergeInstanceId)) {
Set<Integer> broadCastRf =
assignedRuntimeFilters.stream().filter(RuntimeFilter::isBroadcast)
diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift
index 5d46e4b38db..3ef7c7acb07 100644
--- a/gensrc/thrift/Exprs.thrift
+++ b/gensrc/thrift/Exprs.thrift
@@ -253,7 +253,7 @@ struct TExprNode {
26: optional Types.TFunction fn
// If set, child[vararg_start_idx] is the first vararg child.
27: optional i32 vararg_start_idx
- 28: optional Types.TPrimitiveType child_type
+ 28: optional Types.TPrimitiveType child_type // Deprecated
// For vectorized engine
29: optional bool is_nullable
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index d9a24a2b8cc..43fcdf321a1 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -732,12 +732,12 @@ struct TOlapScanNode {
10: optional i64 sort_limit
11: optional bool enable_unique_key_merge_on_write
12: optional TPushAggOp push_down_agg_type_opt //Deprecated
- 13: optional bool use_topn_opt
+ 13: optional bool use_topn_opt // Deprecated
14: optional list<Descriptors.TOlapTableIndex> indexes_desc
15: optional set<i32> output_column_unique_ids
16: optional list<i32> distribute_column_ids
17: optional i32 schema_version
- 18: optional list<i32> topn_filter_source_node_ids
+ 18: optional list<i32> topn_filter_source_node_ids //deprecated, move to
TPlanNode.106
}
struct TEqJoinCondition {
@@ -935,7 +935,7 @@ struct TSortNode {
// Indicates whether the imposed limit comes DEFAULT_ORDER_BY_LIMIT.
6: optional bool is_default_limit
- 7: optional bool use_topn_opt
+ 7: optional bool use_topn_opt // Deprecated
8: optional bool merge_by_exchange
9: optional bool is_analytic_sort
10: optional bool is_colocate
@@ -1186,7 +1186,7 @@ struct TTopnFilterDesc {
2: required bool is_asc
3: required bool null_first
// scan node id -> expr on scan node
- 4: required map<Types.TPlanNodeId, Exprs.TExpr> targetNodeId_to_target_expr
+ 4: required map<Types.TPlanNodeId, Exprs.TExpr> target_node_id_to_target_expr
}
// Specification of a runtime filter.
diff --git a/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy
b/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy
index ab77d51b6e3..617794e1988 100644
--- a/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy
+++ b/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy
@@ -79,7 +79,7 @@ suite("topn-filter") {
qt_check_result2 "${multi_topn_desc}"
- // push down topn-filter to both join children
+ // push down topn-filter to join left child
explain {
sql """
select o_orderkey, c_custkey
@@ -87,7 +87,7 @@ suite("topn-filter") {
join customer on o_custkey = c_custkey
order by c_custkey limit 2;
"""
- contains "TOPN OPT:"
+ contains "TOPN OPT:4"
}
// push topn filter down through AGG
@@ -107,7 +107,7 @@ suite("topn-filter") {
join nation on s_nationkey = n_nationkey
order by s_nationkey limit 1;
"""
- contains "TOPN OPT:"
+ contains "TOPN OPT:7"
}
explain {
@@ -141,7 +141,7 @@ suite("topn-filter") {
// this topn-filter is generated by topn2, not topn1
explain {
sql "select * from nation left outer join region on r_regionkey =
n_regionkey order by n_regionkey nulls first limit 1; "
- contains "TOPN OPT:"
+ multiContains ("TOPN OPT:", 1)
}
// TODO: support latter, push topn to right outer join
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]