This is an automated email from the ASF dual-hosted git repository. dataroaring pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 632ac70bb8190bee98926d9d953f56419d5b39d4 Author: Pxl <[email protected]> AuthorDate: Mon May 27 11:07:46 2024 +0800 [Feature](topn) support general expr with topn filter and some refactor (#35405) support general expr with topn filter and some refactor Co-authored-by: minghong <[email protected]> --- be/src/olap/iterators.h | 3 +- be/src/olap/rowset/beta_rowset_reader.cpp | 2 +- be/src/olap/rowset/rowset_reader_context.h | 3 +- be/src/olap/rowset/segment_v2/segment.cpp | 9 +- be/src/olap/rowset/segment_v2/segment_iterator.cpp | 25 +-- be/src/olap/tablet_reader.cpp | 15 +- be/src/olap/tablet_reader.h | 3 +- be/src/pipeline/exec/file_scan_operator.cpp | 2 +- be/src/pipeline/exec/olap_scan_operator.h | 4 +- be/src/pipeline/exec/scan_operator.cpp | 39 ++-- be/src/pipeline/exec/scan_operator.h | 14 +- be/src/pipeline/exec/sort_sink_operator.cpp | 33 +-- be/src/pipeline/exec/sort_sink_operator.h | 1 - be/src/pipeline/pipeline_fragment_context.cpp | 6 +- be/src/runtime/query_context.h | 17 +- be/src/runtime/runtime_predicate.cpp | 125 ++++++++--- be/src/runtime/runtime_predicate.h | 137 ++++++------ be/src/runtime/runtime_state.cpp | 8 - be/src/vec/exec/format/orc/vorc_reader.cpp | 1 - .../exec/format/parquet/vparquet_group_reader.cpp | 1 - be/src/vec/exec/scan/new_olap_scanner.cpp | 6 +- be/src/vec/exec/vsort_node.cpp | 40 +--- be/src/vec/exec/vsort_node.h | 1 - be/src/vec/exprs/vtopn_pred.h | 20 +- be/src/vec/olap/vcollect_iterator.cpp | 2 +- .../org/apache/doris/nereids/NereidsPlanner.java | 6 + .../glue/translator/PhysicalPlanTranslator.java | 52 +---- .../doris/nereids/processor/post/TopNScanOpt.java | 110 ++++------ .../nereids/processor/post/TopnFilterContext.java | 84 ++++---- .../processor/post/TopnFilterPushDownVisitor.java | 233 +++++++++++++++++++++ .../nereids/trees/plans/physical/TopnFilter.java | 78 +++++++ .../java/org/apache/doris/planner/Planner.java | 5 + .../main/java/org/apache/doris/qe/Coordinator.java | 22 +- gensrc/thrift/Exprs.thrift | 2 +- gensrc/thrift/PlanNodes.thrift | 8 +- .../suites/nereids_tpch_p0/tpch/topn-filter.groovy | 8 +- 36 files changed, 695 insertions(+), 430 deletions(-) diff --git a/be/src/olap/iterators.h b/be/src/olap/iterators.h index deb14ff554f..330aa9e3475 100644 --- a/be/src/olap/iterators.h +++ b/be/src/olap/iterators.h @@ -99,9 +99,8 @@ public: TabletSchemaSPtr tablet_schema = nullptr; bool enable_unique_key_merge_on_write = false; bool record_rowids = false; - // flag for enable topn opt - bool use_topn_opt = false; std::vector<int> topn_filter_source_node_ids; + int topn_filter_target_node_id = -1; // used for special optimization for query : ORDER BY key DESC LIMIT n bool read_orderby_key_reverse = false; // columns for orderby keys diff --git a/be/src/olap/rowset/beta_rowset_reader.cpp b/be/src/olap/rowset/beta_rowset_reader.cpp index d8cced3b00a..e113d20ed88 100644 --- a/be/src/olap/rowset/beta_rowset_reader.cpp +++ b/be/src/olap/rowset/beta_rowset_reader.cpp @@ -220,8 +220,8 @@ Status BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context _read_options.enable_unique_key_merge_on_write = _read_context->enable_unique_key_merge_on_write; _read_options.record_rowids = _read_context->record_rowids; - _read_options.use_topn_opt = _read_context->use_topn_opt; _read_options.topn_filter_source_node_ids = _read_context->topn_filter_source_node_ids; + _read_options.topn_filter_target_node_id = _read_context->topn_filter_target_node_id; _read_options.read_orderby_key_reverse = _read_context->read_orderby_key_reverse; _read_options.read_orderby_key_columns = _read_context->read_orderby_key_columns; _read_options.io_ctx.reader_type = _read_context->reader_type; diff --git a/be/src/olap/rowset/rowset_reader_context.h b/be/src/olap/rowset/rowset_reader_context.h index 44cf8556412..59abf85fb72 100644 --- a/be/src/olap/rowset/rowset_reader_context.h +++ b/be/src/olap/rowset/rowset_reader_context.h @@ -36,9 +36,8 @@ struct RowsetReaderContext { ReaderType reader_type = ReaderType::READER_QUERY; Version version {-1, -1}; TabletSchemaSPtr tablet_schema = nullptr; - // flag for enable topn opt - bool use_topn_opt = false; std::vector<int> topn_filter_source_node_ids; + int topn_filter_target_node_id = -1; // whether rowset should return ordered rows. bool need_ordered_result = true; // used for special optimization for query : ORDER BY key DESC LIMIT n diff --git a/be/src/olap/rowset/segment_v2/segment.cpp b/be/src/olap/rowset/segment_v2/segment.cpp index 8a898a5db4d..555c676ef31 100644 --- a/be/src/olap/rowset/segment_v2/segment.cpp +++ b/be/src/olap/rowset/segment_v2/segment.cpp @@ -164,13 +164,12 @@ Status Segment::new_iterator(SchemaSPtr schema, const StorageReadOptions& read_o return Status::OK(); } } - if (read_options.use_topn_opt) { + + if (!read_options.topn_filter_source_node_ids.empty()) { auto* query_ctx = read_options.runtime_state->get_query_ctx(); for (int id : read_options.topn_filter_source_node_ids) { - if (!query_ctx->get_runtime_predicate(id).need_update()) { - continue; - } - auto runtime_predicate = query_ctx->get_runtime_predicate(id).get_predicate(); + auto runtime_predicate = query_ctx->get_runtime_predicate(id).get_predicate( + read_options.topn_filter_target_node_id); int32_t uid = read_options.tablet_schema->column(runtime_predicate->column_id()).unique_id(); diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index 7c6853e9fb3..06f71c3cc11 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -523,7 +523,7 @@ Status SegmentIterator::_get_row_ranges_by_column_conditions() { RETURN_IF_ERROR(_apply_inverted_index()); if (!_row_bitmap.isEmpty() && - (_opts.use_topn_opt || !_opts.col_id_to_predicates.empty() || + (!_opts.topn_filter_source_node_ids.empty() || !_opts.col_id_to_predicates.empty() || _opts.delete_condition_predicates->num_of_column_predicate() > 0)) { RowRanges condition_row_ranges = RowRanges::create_single(_segment->num_rows()); RETURN_IF_ERROR(_get_row_ranges_from_conditions(&condition_row_ranges)); @@ -572,7 +572,7 @@ Status SegmentIterator::_get_row_ranges_from_conditions(RowRanges* condition_row SCOPED_RAW_TIMER(&_opts.stats->block_conditions_filtered_zonemap_ns); RowRanges zone_map_row_ranges = RowRanges::create_single(num_rows()); // second filter data by zone map - for (auto& cid : cids) { + for (const auto& cid : cids) { DCHECK(_opts.col_id_to_predicates.count(cid) > 0); if (!_segment->can_apply_predicate_safely(cid, _opts.col_id_to_predicates.at(cid).get(), *_schema, _opts.io_ctx.reader_type)) { @@ -600,16 +600,12 @@ Status SegmentIterator::_get_row_ranges_from_conditions(RowRanges* condition_row RowRanges::ranges_intersection(*condition_row_ranges, zone_map_row_ranges, condition_row_ranges); - if (_opts.use_topn_opt) { - SCOPED_RAW_TIMER(&_opts.stats->block_conditions_filtered_zonemap_ns); + if (!_opts.topn_filter_source_node_ids.empty()) { auto* query_ctx = _opts.runtime_state->get_query_ctx(); for (int id : _opts.topn_filter_source_node_ids) { - if (!query_ctx->get_runtime_predicate(id).need_update()) { - continue; - } - std::shared_ptr<doris::ColumnPredicate> runtime_predicate = - query_ctx->get_runtime_predicate(id).get_predicate(); + query_ctx->get_runtime_predicate(id).get_predicate( + _opts.topn_filter_target_node_id); if (_segment->can_apply_predicate_safely(runtime_predicate->column_id(), runtime_predicate.get(), *_schema, _opts.io_ctx.reader_type)) { @@ -1579,7 +1575,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() { std::set<const ColumnPredicate*> delete_predicate_set {}; _opts.delete_condition_predicates->get_all_column_predicate(delete_predicate_set); - for (const auto predicate : delete_predicate_set) { + for (const auto* const predicate : delete_predicate_set) { if (PredicateTypeTraits::is_range(predicate->type())) { _delete_range_column_ids.push_back(predicate->column_id()); } else if (PredicateTypeTraits::is_bloom_filter(predicate->type())) { @@ -1593,16 +1589,13 @@ Status SegmentIterator::_vec_init_lazy_materialization() { // but runtime predicate will filter some rows and read more than N rows. // should add add for order by none-key column, since none-key column is not sorted and // all rows should be read, so runtime predicate will reduce rows for topn node - if (_opts.use_topn_opt && + if (!_opts.topn_filter_source_node_ids.empty() && (_opts.read_orderby_key_columns == nullptr || _opts.read_orderby_key_columns->empty())) { for (int id : _opts.topn_filter_source_node_ids) { - if (!_opts.runtime_state->get_query_ctx()->get_runtime_predicate(id).need_update()) { - continue; - } - auto& runtime_predicate = _opts.runtime_state->get_query_ctx()->get_runtime_predicate(id); - _col_predicates.push_back(runtime_predicate.get_predicate().get()); + _col_predicates.push_back( + runtime_predicate.get_predicate(_opts.topn_filter_target_node_id).get()); } } diff --git a/be/src/olap/tablet_reader.cpp b/be/src/olap/tablet_reader.cpp index ae304f6f3e9..728e6e3d6d7 100644 --- a/be/src/olap/tablet_reader.cpp +++ b/be/src/olap/tablet_reader.cpp @@ -233,14 +233,14 @@ Status TabletReader::_capture_rs_readers(const ReaderParams& read_params) { _reader_context.version = read_params.version; _reader_context.tablet_schema = _tablet_schema; _reader_context.need_ordered_result = need_ordered_result; - _reader_context.use_topn_opt = read_params.use_topn_opt; _reader_context.topn_filter_source_node_ids = read_params.topn_filter_source_node_ids; + _reader_context.topn_filter_target_node_id = read_params.topn_filter_target_node_id; _reader_context.read_orderby_key_reverse = read_params.read_orderby_key_reverse; _reader_context.read_orderby_key_limit = read_params.read_orderby_key_limit; _reader_context.filter_block_conjuncts = read_params.filter_block_conjuncts; _reader_context.return_columns = &_return_columns; _reader_context.read_orderby_key_columns = - _orderby_key_columns.size() > 0 ? &_orderby_key_columns : nullptr; + !_orderby_key_columns.empty() ? &_orderby_key_columns : nullptr; _reader_context.predicates = &_col_predicates; _reader_context.predicates_except_leafnode_of_andnode = &_col_preds_except_leafnode_of_andnode; _reader_context.value_predicates = &_value_col_predicates; @@ -575,12 +575,11 @@ Status TabletReader::_init_conditions_param_except_leafnode_of_andnode( } } - if (read_params.use_topn_opt) { - for (int id : read_params.topn_filter_source_node_ids) { - auto& runtime_predicate = - read_params.runtime_state->get_query_ctx()->get_runtime_predicate(id); - RETURN_IF_ERROR(runtime_predicate.set_tablet_schema(_tablet_schema)); - } + for (int id : read_params.topn_filter_source_node_ids) { + auto& runtime_predicate = + read_params.runtime_state->get_query_ctx()->get_runtime_predicate(id); + RETURN_IF_ERROR(runtime_predicate.set_tablet_schema(read_params.topn_filter_target_node_id, + _tablet_schema)); } return Status::OK(); } diff --git a/be/src/olap/tablet_reader.h b/be/src/olap/tablet_reader.h index 3bf83ec296c..a3cd3bd4a49 100644 --- a/be/src/olap/tablet_reader.h +++ b/be/src/olap/tablet_reader.h @@ -159,9 +159,8 @@ public: // used for compaction to record row ids bool record_rowids = false; - // flag for enable topn opt - bool use_topn_opt = false; std::vector<int> topn_filter_source_node_ids; + int topn_filter_target_node_id = -1; // used for special optimization for query : ORDER BY key LIMIT n bool read_orderby_key = false; // used for special optimization for query : ORDER BY key DESC LIMIT n diff --git a/be/src/pipeline/exec/file_scan_operator.cpp b/be/src/pipeline/exec/file_scan_operator.cpp index 2484737af89..ae7a4f06d53 100644 --- a/be/src/pipeline/exec/file_scan_operator.cpp +++ b/be/src/pipeline/exec/file_scan_operator.cpp @@ -47,7 +47,7 @@ Status FileScanLocalState::_init_scanners(std::list<vectorized::VScannerSPtr>* s state(), this, p._limit_per_scanner, _split_source, _scanner_profile.get(), _kv_cache.get()); RETURN_IF_ERROR( - scanner->prepare(_conjuncts, &_colname_to_value_range, &_colname_to_slot_id)); + scanner->prepare(_conjuncts, &_colname_to_value_range, &p._colname_to_slot_id)); scanners->push_back(std::move(scanner)); } return Status::OK(); diff --git a/be/src/pipeline/exec/olap_scan_operator.h b/be/src/pipeline/exec/olap_scan_operator.h index daff2167f7f..061e514b040 100644 --- a/be/src/pipeline/exec/olap_scan_operator.h +++ b/be/src/pipeline/exec/olap_scan_operator.h @@ -81,10 +81,10 @@ private: bool _storage_no_merge() override; bool _push_down_topn(const vectorized::RuntimePredicate& predicate) override { - if (!predicate.target_is_slot()) { + if (!predicate.target_is_slot(_parent->node_id())) { return false; } - return _is_key_column(predicate.get_col_name()) || _storage_no_merge(); + return _is_key_column(predicate.get_col_name(_parent->node_id())) || _storage_no_merge(); } Status _init_scanners(std::list<vectorized::VScannerSPtr>* scanners) override; diff --git a/be/src/pipeline/exec/scan_operator.cpp b/be/src/pipeline/exec/scan_operator.cpp index d0e83937fd9..be7af24b684 100644 --- a/be/src/pipeline/exec/scan_operator.cpp +++ b/be/src/pipeline/exec/scan_operator.cpp @@ -171,23 +171,20 @@ Status ScanLocalState<Derived>::_normalize_conjuncts(RuntimeState* state) { } }; - for (int slot_idx = 0; slot_idx < slots.size(); ++slot_idx) { - _colname_to_slot_id[slots[slot_idx]->col_name()] = slots[slot_idx]->id(); - _slot_id_to_slot_desc[slots[slot_idx]->id()] = slots[slot_idx]; - - auto type = slots[slot_idx]->type().type; - if (slots[slot_idx]->type().type == TYPE_ARRAY) { - type = slots[slot_idx]->type().children[0].type; + for (auto& slot : slots) { + auto type = slot->type().type; + if (slot->type().type == TYPE_ARRAY) { + type = slot->type().children[0].type; if (type == TYPE_ARRAY) { continue; } } - init_value_range(slots[slot_idx], slots[slot_idx]->type().type); + init_value_range(slot, slot->type().type); } get_cast_types_for_variants(); for (const auto& [colname, type] : _cast_types_for_variants) { - init_value_range(_slot_id_to_slot_desc[_colname_to_slot_id[colname]], type); + init_value_range(p._slot_id_to_slot_desc[p._colname_to_slot_id[colname]], type); } RETURN_IF_ERROR(_get_topn_filters(state)); @@ -1268,12 +1265,12 @@ Status ScanLocalState<Derived>::_init_profile() { template <typename Derived> Status ScanLocalState<Derived>::_get_topn_filters(RuntimeState* state) { + auto& p = _parent->cast<typename Derived::Parent>(); for (auto id : get_topn_filter_source_node_ids(state, false)) { const auto& pred = state->get_query_ctx()->get_runtime_predicate(id); - SlotDescriptor* slot_desc = _slot_id_to_slot_desc[_colname_to_slot_id[pred.get_col_name()]]; - vectorized::VExprSPtr topn_pred; - RETURN_IF_ERROR(vectorized::VTopNPred::create_vtopn_pred(slot_desc, id, topn_pred)); + RETURN_IF_ERROR(vectorized::VTopNPred::create_vtopn_pred(pred.get_texpr(p.node_id()), id, + topn_pred)); vectorized::VExprContextSPtr conjunct = vectorized::VExprContext::create_shared(topn_pred); RETURN_IF_ERROR(conjunct->prepare( @@ -1288,6 +1285,7 @@ template <typename Derived> void ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant( const vectorized::VExpr* expr, phmap::flat_hash_map<std::string, std::vector<PrimitiveType>>& colname_to_cast_types) { + auto& p = _parent->cast<typename Derived::Parent>(); const auto* cast_expr = dynamic_cast<const vectorized::VCastExpr*>(expr); if (cast_expr != nullptr) { const auto* src_slot = @@ -1298,7 +1296,7 @@ void ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant( return; } std::vector<SlotDescriptor*> slots = output_tuple_desc()->slots(); - SlotDescriptor* src_slot_desc = _slot_id_to_slot_desc[src_slot->slot_id()]; + SlotDescriptor* src_slot_desc = p._slot_id_to_slot_desc[src_slot->slot_id()]; PrimitiveType cast_dst_type = cast_expr->get_target_type()->get_type_as_type_descriptor().type; if (src_slot_desc->type().is_variant_type()) { @@ -1388,6 +1386,21 @@ Status ScanOperatorX<LocalStateType>::open(RuntimeState* state) { _output_tuple_desc = state->desc_tbl().get_tuple_descriptor(_output_tuple_id); RETURN_IF_ERROR(OperatorX<LocalStateType>::open(state)); + const auto slots = _output_tuple_desc->slots(); + for (auto* slot : slots) { + _colname_to_slot_id[slot->col_name()] = slot->id(); + _slot_id_to_slot_desc[slot->id()] = slot; + } + for (auto id : topn_filter_source_node_ids) { + if (!state->get_query_ctx()->has_runtime_predicate(id)) { + // compatible with older versions fe + continue; + } + + state->get_query_ctx()->get_runtime_predicate(id).init_target(node_id(), + _slot_id_to_slot_desc); + } + RETURN_IF_CANCELLED(state); return Status::OK(); } diff --git a/be/src/pipeline/exec/scan_operator.h b/be/src/pipeline/exec/scan_operator.h index 25a33de3f66..a1b85aac30a 100644 --- a/be/src/pipeline/exec/scan_operator.h +++ b/be/src/pipeline/exec/scan_operator.h @@ -161,8 +161,13 @@ class ScanLocalState : public ScanLocalStateBase { std::vector<int> get_topn_filter_source_node_ids(RuntimeState* state, bool push_down) { std::vector<int> result; for (int id : _parent->cast<typename Derived::Parent>().topn_filter_source_node_ids) { + if (!state->get_query_ctx()->has_runtime_predicate(id)) { + // compatible with older versions fe + continue; + } + const auto& pred = state->get_query_ctx()->get_runtime_predicate(id); - if (!pred.inited()) { + if (!pred.enable()) { continue; } if (_push_down_topn(pred) == push_down) { @@ -334,9 +339,6 @@ protected: // colname -> cast dst type std::map<std::string, PrimitiveType> _cast_types_for_variants; - // slot id -> SlotDescriptor - phmap::flat_hash_map<int, SlotDescriptor*> _slot_id_to_slot_desc; - // slot id -> ColumnValueRange // Parsed from conjuncts phmap::flat_hash_map<int, std::pair<SlotDescriptor*, ColumnValueRangeType>> @@ -345,7 +347,6 @@ protected: // We use _colname_to_value_range to store a column and its conresponding value ranges. std::unordered_map<std::string, ColumnValueRangeType> _colname_to_value_range; - std::unordered_map<std::string, int> _colname_to_slot_id; /** * _colname_to_value_range only store the leaf of and in the conjunct expr tree, * we use _compound_value_ranges to store conresponding value ranges @@ -426,6 +427,9 @@ protected: const TupleDescriptor* _input_tuple_desc = nullptr; const TupleDescriptor* _output_tuple_desc = nullptr; + phmap::flat_hash_map<int, SlotDescriptor*> _slot_id_to_slot_desc; + std::unordered_map<std::string, int> _colname_to_slot_id; + // These two values are from query_options int _max_scan_key_num; int _max_pushdown_conditions_per_column; diff --git a/be/src/pipeline/exec/sort_sink_operator.cpp b/be/src/pipeline/exec/sort_sink_operator.cpp index 7c9f40d1b59..201c31d353b 100644 --- a/be/src/pipeline/exec/sort_sink_operator.cpp +++ b/be/src/pipeline/exec/sort_sink_operator.cpp @@ -79,7 +79,6 @@ SortSinkOperatorX::SortSinkOperatorX(ObjectPool* pool, int operator_id, const TP _pool(pool), _reuse_mem(true), _limit(tnode.limit), - _use_topn_opt(tnode.sort_node.use_topn_opt), _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples), _use_two_phase_read(tnode.sort_node.sort_info.use_two_phase_read), _merge_by_exchange(tnode.sort_node.merge_by_exchange), @@ -96,29 +95,10 @@ Status SortSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { _is_asc_order = tnode.sort_node.sort_info.is_asc_order; _nulls_first = tnode.sort_node.sort_info.nulls_first; + auto* query_ctx = state->get_query_ctx(); // init runtime predicate - if (_use_topn_opt) { - auto* query_ctx = state->get_query_ctx(); - auto first_sort_expr_node = tnode.sort_node.sort_info.ordering_exprs[0].nodes[0]; - if (first_sort_expr_node.node_type == TExprNodeType::SLOT_REF) { - auto first_sort_slot = first_sort_expr_node.slot_ref; - for (auto* tuple_desc : _row_descriptor.tuple_descriptors()) { - if (tuple_desc->id() != first_sort_slot.tuple_id) { - continue; - } - for (auto* slot : tuple_desc->slots()) { - if (slot->id() == first_sort_slot.slot_id) { - RETURN_IF_ERROR(query_ctx->get_runtime_predicate(_node_id).init( - slot->type().type, _nulls_first[0], _is_asc_order[0], - slot->col_name())); - break; - } - } - } - } - if (!query_ctx->get_runtime_predicate(_node_id).inited()) { - return Status::InternalError("runtime predicate is not properly initialized"); - } + if (query_ctx->has_runtime_predicate(_node_id)) { + query_ctx->get_runtime_predicate(_node_id).set_detected_source(); } return Status::OK(); } @@ -132,7 +112,8 @@ Status SortSinkOperatorX::prepare(RuntimeState* state) { // exclude cases which incoming blocks has string column which is sensitive to operations like // `filter` and `memcpy` if (_limit > 0 && _limit + _offset < vectorized::HeapSorter::HEAP_SORT_THRESHOLD && - (_use_two_phase_read || _use_topn_opt || !row_desc.has_varlen_slots())) { + (_use_two_phase_read || state->get_query_ctx()->has_runtime_predicate(_node_id) || + !row_desc.has_varlen_slots())) { _algorithm = SortAlgorithm::HEAP_SORT; _reuse_mem = false; } else if (_limit > 0 && row_desc.has_varlen_slots() && @@ -158,9 +139,9 @@ Status SortSinkOperatorX::sink(doris::RuntimeState* state, vectorized::Block* in local_state._mem_tracker->set_consumption(local_state._shared_state->sorter->data_size()); RETURN_IF_CANCELLED(state); - if (_use_topn_opt) { + if (state->get_query_ctx()->has_runtime_predicate(_node_id)) { auto& predicate = state->get_query_ctx()->get_runtime_predicate(_node_id); - if (predicate.inited()) { + if (predicate.enable()) { vectorized::Field new_top = local_state._shared_state->sorter->get_top_value(); if (!new_top.is_null() && new_top != local_state.old_top) { auto* query_ctx = state->get_query_ctx(); diff --git a/be/src/pipeline/exec/sort_sink_operator.h b/be/src/pipeline/exec/sort_sink_operator.h index 6ade3ada8bd..ba279a4aac4 100644 --- a/be/src/pipeline/exec/sort_sink_operator.h +++ b/be/src/pipeline/exec/sort_sink_operator.h @@ -100,7 +100,6 @@ private: bool _reuse_mem; const int64_t _limit; - const bool _use_topn_opt; SortAlgorithm _algorithm; const RowDescriptor _row_descriptor; diff --git a/be/src/pipeline/pipeline_fragment_context.cpp b/be/src/pipeline/pipeline_fragment_context.cpp index 8db4b1be587..fd4c903d5aa 100644 --- a/be/src/pipeline/pipeline_fragment_context.cpp +++ b/be/src/pipeline/pipeline_fragment_context.cpp @@ -291,10 +291,8 @@ Status PipelineFragmentContext::prepare(const doris::TPipelineFragmentParams& re _query_ctx->runtime_filter_mgr()->set_runtime_filter_params( local_params.runtime_filter_params); } - if (local_params.__isset.topn_filter_source_node_ids) { - _query_ctx->init_runtime_predicates(local_params.topn_filter_source_node_ids); - } else { - _query_ctx->init_runtime_predicates({0}); + if (local_params.__isset.topn_filter_descs) { + _query_ctx->init_runtime_predicates(local_params.topn_filter_descs); } _need_local_merge = request.__isset.parallel_instances; diff --git a/be/src/runtime/query_context.h b/be/src/runtime/query_context.h index 6fd35e00fd2..ff5d1a549ea 100644 --- a/be/src/runtime/query_context.h +++ b/be/src/runtime/query_context.h @@ -135,17 +135,18 @@ public: return _shared_scanner_controller; } + bool has_runtime_predicate(int source_node_id) { + return _runtime_predicates.contains(source_node_id); + } + vectorized::RuntimePredicate& get_runtime_predicate(int source_node_id) { - DCHECK(_runtime_predicates.contains(source_node_id) || _runtime_predicates.contains(0)); - if (_runtime_predicates.contains(source_node_id)) { - return _runtime_predicates[source_node_id]; - } - return _runtime_predicates[0]; + DCHECK(has_runtime_predicate(source_node_id)); + return _runtime_predicates.find(source_node_id)->second; } - void init_runtime_predicates(std::vector<int> source_node_ids) { - for (int id : source_node_ids) { - _runtime_predicates.try_emplace(id); + void init_runtime_predicates(const std::vector<TTopnFilterDesc>& topn_filter_descs) { + for (auto desc : topn_filter_descs) { + _runtime_predicates.try_emplace(desc.source_node_id, desc); } } diff --git a/be/src/runtime/runtime_predicate.cpp b/be/src/runtime/runtime_predicate.cpp index 2655ff86680..f90a5743fdd 100644 --- a/be/src/runtime/runtime_predicate.cpp +++ b/be/src/runtime/runtime_predicate.cpp @@ -17,35 +17,99 @@ #include "runtime/runtime_predicate.h" -#include <stdint.h> - #include <memory> #include "common/compiler_util.h" // IWYU pragma: keep +#include "common/exception.h" +#include "common/status.h" #include "olap/accept_null_predicate.h" #include "olap/column_predicate.h" #include "olap/predicate_creator.h" namespace doris::vectorized { -Status RuntimePredicate::init(PrimitiveType type, bool nulls_first, bool is_asc, - const std::string& col_name) { - std::unique_lock<std::shared_mutex> wlock(_rwlock); +RuntimePredicate::RuntimePredicate(const TTopnFilterDesc& desc) + : _nulls_first(desc.null_first), _is_asc(desc.is_asc) { + DCHECK(!desc.target_node_id_to_target_expr.empty()); + for (auto p : desc.target_node_id_to_target_expr) { + _contexts[p.first].expr = p.second; + } - if (_inited) { - return Status::OK(); + PrimitiveType type = thrift_to_type(desc.target_node_id_to_target_expr.begin() + ->second.nodes[0] + .type.types[0] + .scalar_type.type); + if (!_init(type)) { + std::stringstream ss; + desc.target_node_id_to_target_expr.begin()->second.nodes[0].printTo(ss); + throw Exception(ErrorCode::INTERNAL_ERROR, "meet invalid type, type={}, expr={}", int(type), + ss.str()); } - _nulls_first = nulls_first; - _is_asc = is_asc; // For ASC sort, create runtime predicate col_name <= max_top_value // since values that > min_top_value are large than any value in current topn values // For DESC sort, create runtime predicate col_name >= min_top_value // since values that < min_top_value are less than any value in current topn values - _pred_constructor = is_asc ? create_comparison_predicate<PredicateType::LE> - : create_comparison_predicate<PredicateType::GE>; - _col_name = col_name; + _pred_constructor = _is_asc ? create_comparison_predicate<PredicateType::LE> + : create_comparison_predicate<PredicateType::GE>; +} + +void RuntimePredicate::init_target( + int32_t target_node_id, phmap::flat_hash_map<int, SlotDescriptor*> slot_id_to_slot_desc) { + std::unique_lock<std::shared_mutex> wlock(_rwlock); + check_target_node_id(target_node_id); + if (target_is_slot(target_node_id)) { + _contexts[target_node_id].col_name = + slot_id_to_slot_desc[get_texpr(target_node_id).nodes[0].slot_ref.slot_id] + ->col_name(); + } + _detected_target = true; +} + +template <PrimitiveType type> +std::string get_normal_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits<type>::CppType; + return cast_to_string<type, ValueType>(field.get<ValueType>(), 0); +} + +std::string get_date_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits<TYPE_DATE>::CppType; + ValueType value; + Int64 v = field.get<Int64>(); + auto* p = (VecDateTimeValue*)&v; + value.from_olap_date(p->to_olap_date()); + value.cast_to_date(); + return cast_to_string<TYPE_DATE, ValueType>(value, 0); +} +std::string get_datetime_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIME>::CppType; + ValueType value; + Int64 v = field.get<Int64>(); + auto* p = (VecDateTimeValue*)&v; + value.from_olap_datetime(p->to_olap_datetime()); + value.to_datetime(); + return cast_to_string<TYPE_DATETIME, ValueType>(value, 0); +} + +std::string get_decimalv2_value(const Field& field) { + // can NOT use PrimitiveTypeTraits<TYPE_DECIMALV2>::CppType since + // it is DecimalV2Value and Decimal128V2 can not convert to it implicitly + using ValueType = Decimal128V2::NativeType; + auto v = field.get<DecimalField<Decimal128V2>>(); + // use TYPE_DECIMAL128I instead of TYPE_DECIMALV2 since v.get_scale() + // is always 9 for DECIMALV2 + return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(), v.get_scale()); +} + +template <PrimitiveType type> +std::string get_decimal_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits<type>::CppType; + auto v = field.get<DecimalField<ValueType>>(); + return cast_to_string<type, ValueType>(v.get_value(), v.get_scale()); +} + +bool RuntimePredicate::_init(PrimitiveType type) { // set get value function switch (type) { case PrimitiveType::TYPE_BOOLEAN: { @@ -123,17 +187,16 @@ Status RuntimePredicate::init(PrimitiveType type, bool nulls_first, bool is_asc, break; } default: - return Status::InvalidArgument("unsupported runtime predicate type {}", type); + return false; } - _inited = true; - return Status::OK(); + return true; } Status RuntimePredicate::update(const Field& value) { std::unique_lock<std::shared_mutex> wlock(_rwlock); // skip null value - if (value.is_null() || !_inited) { + if (value.is_null()) { return Status::OK(); } @@ -151,22 +214,28 @@ Status RuntimePredicate::update(const Field& value) { _has_value = true; - if (!updated || !_tablet_schema) { + if (!updated) { return Status::OK(); } - std::unique_ptr<ColumnPredicate> pred { - _pred_constructor(_tablet_schema->column(_col_name), _predicate->column_id(), - _get_value_fn(_orderby_extrem), false, &_predicate_arena)}; - // For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL - // since ORDER BY ASC/DESC should get NULL first but pred returns NULL - // and NULL in where predicate will be treated as FALSE - if (_nulls_first) { - pred = AcceptNullPredicate::create_unique(pred.release()); - } - - ((SharedPredicate*)_predicate.get())->set_nested(pred.release()); + for (auto p : _contexts) { + auto ctx = p.second; + if (!ctx.tablet_schema) { + continue; + } + std::unique_ptr<ColumnPredicate> pred {_pred_constructor( + ctx.tablet_schema->column(ctx.col_name), ctx.predicate->column_id(), + _get_value_fn(_orderby_extrem), false, &_predicate_arena)}; + + // For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL + // since ORDER BY ASC/DESC should get NULL first but pred returns NULL + // and NULL in where predicate will be treated as FALSE + if (_nulls_first) { + pred = AcceptNullPredicate::create_unique(pred.release()); + } + ((SharedPredicate*)ctx.predicate.get())->set_nested(pred.release()); + } return Status::OK(); } diff --git a/be/src/runtime/runtime_predicate.h b/be/src/runtime/runtime_predicate.h index 0305994e0fc..73ed657c0bb 100644 --- a/be/src/runtime/runtime_predicate.h +++ b/be/src/runtime/runtime_predicate.h @@ -42,36 +42,39 @@ namespace vectorized { class RuntimePredicate { public: - RuntimePredicate() = default; + RuntimePredicate(const TTopnFilterDesc& desc); - Status init(PrimitiveType type, bool nulls_first, bool is_asc, const std::string& col_name); + void init_target(int32_t target_node_id, + phmap::flat_hash_map<int, SlotDescriptor*> slot_id_to_slot_desc); - bool inited() const { - // when sort node and scan node are not in the same fragment, predicate will not be initialized + bool enable() const { + // when sort node and scan node are not in the same fragment, predicate will be disabled std::shared_lock<std::shared_mutex> rlock(_rwlock); - return _inited; + return _detected_source && _detected_target; } - bool need_update() const { - std::shared_lock<std::shared_mutex> rlock(_rwlock); - return _inited && _tablet_schema; + void set_detected_source() { + std::unique_lock<std::shared_mutex> wlock(_rwlock); + _detected_source = true; } - Status set_tablet_schema(TabletSchemaSPtr tablet_schema) { + Status set_tablet_schema(int32_t target_node_id, TabletSchemaSPtr tablet_schema) { std::unique_lock<std::shared_mutex> wlock(_rwlock); - if (_tablet_schema || !_inited) { + check_target_node_id(target_node_id); + if (_contexts[target_node_id].tablet_schema) { return Status::OK(); } - RETURN_IF_ERROR(tablet_schema->have_column(_col_name)); - _tablet_schema = tablet_schema; - _predicate = SharedPredicate::create_shared( - _tablet_schema->field_index(_tablet_schema->column(_col_name).unique_id())); + RETURN_IF_ERROR(tablet_schema->have_column(_contexts[target_node_id].col_name)); + _contexts[target_node_id].tablet_schema = tablet_schema; + _contexts[target_node_id].predicate = + SharedPredicate::create_shared(_contexts[target_node_id].get_field_index()); return Status::OK(); } - std::shared_ptr<ColumnPredicate> get_predicate() { + std::shared_ptr<ColumnPredicate> get_predicate(int32_t target_node_id) { std::shared_lock<std::shared_mutex> rlock(_rwlock); - return _predicate; + check_target_node_id(target_node_id); + return _contexts.find(target_node_id)->second.predicate; } Status update(const Field& value); @@ -86,72 +89,72 @@ public: return _orderby_extrem; } - std::string get_col_name() const { return _col_name; } + std::string get_col_name(int32_t target_node_id) const { + check_target_node_id(target_node_id); + return _contexts.find(target_node_id)->second.col_name; + } bool is_asc() const { return _is_asc; } bool nulls_first() const { return _nulls_first; } - bool target_is_slot() const { return true; } + bool target_is_slot(int32_t target_node_id) const { + check_target_node_id(target_node_id); + return _contexts.find(target_node_id)->second.target_is_slot(); + } + + const TExpr& get_texpr(int32_t target_node_id) const { + check_target_node_id(target_node_id); + return _contexts.find(target_node_id)->second.expr; + } private: + void check_target_node_id(int32_t target_node_id) const { + if (!_contexts.contains(target_node_id)) { + std::string msg = "context target node ids: ["; + bool first = true; + for (auto p : _contexts) { + if (first) { + first = false; + } else { + msg += ','; + } + msg += std::to_string(p.first); + } + msg += "], input target node is: " + std::to_string(target_node_id); + DCHECK(false) << msg; + } + } + struct TargetContext { + TExpr expr; + std::string col_name; + TabletSchemaSPtr tablet_schema; + std::shared_ptr<ColumnPredicate> predicate; + + int32_t get_field_index() { + return tablet_schema->field_index(tablet_schema->column(col_name).unique_id()); + } + + bool target_is_slot() const { return expr.nodes[0].node_type == TExprNodeType::SLOT_REF; } + }; + + bool _init(PrimitiveType type); + mutable std::shared_mutex _rwlock; + + bool _nulls_first; + bool _is_asc; + std::map<int32_t, TargetContext> _contexts; + Field _orderby_extrem {Field::Types::Null}; - std::shared_ptr<ColumnPredicate> _predicate; - TabletSchemaSPtr _tablet_schema = nullptr; Arena _predicate_arena; std::function<std::string(const Field&)> _get_value_fn; - bool _nulls_first = true; - bool _is_asc; std::function<ColumnPredicate*(const TabletColumn&, int, const std::string&, bool, vectorized::Arena*)> _pred_constructor; - bool _inited = false; - std::string _col_name; + bool _detected_source = false; + bool _detected_target = false; bool _has_value = false; - - template <PrimitiveType type> - static std::string get_normal_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<type>::CppType; - return cast_to_string<type, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_date_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DATE>::CppType; - ValueType value; - Int64 v = field.get<Int64>(); - auto* p = (VecDateTimeValue*)&v; - value.from_olap_date(p->to_olap_date()); - value.cast_to_date(); - return cast_to_string<TYPE_DATE, ValueType>(value, 0); - } - - static std::string get_datetime_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIME>::CppType; - ValueType value; - Int64 v = field.get<Int64>(); - auto* p = (VecDateTimeValue*)&v; - value.from_olap_datetime(p->to_olap_datetime()); - value.to_datetime(); - return cast_to_string<TYPE_DATETIME, ValueType>(value, 0); - } - - static std::string get_decimalv2_value(const Field& field) { - // can NOT use PrimitiveTypeTraits<TYPE_DECIMALV2>::CppType since - // it is DecimalV2Value and Decimal128V2 can not convert to it implicitly - using ValueType = Decimal128V2::NativeType; - auto v = field.get<DecimalField<Decimal128V2>>(); - // use TYPE_DECIMAL128I instead of TYPE_DECIMALV2 since v.get_scale() - // is always 9 for DECIMALV2 - return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(), v.get_scale()); - } - - template <PrimitiveType type> - static std::string get_decimal_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<type>::CppType; - auto v = field.get<DecimalField<ValueType>>(); - return cast_to_string<type, ValueType>(v.get_value(), v.get_scale()); - } }; } // namespace vectorized diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp index ac560c2c7e1..504f6deabf3 100644 --- a/be/src/runtime/runtime_state.cpp +++ b/be/src/runtime/runtime_state.cpp @@ -93,14 +93,6 @@ RuntimeState::RuntimeState(const TPlanFragmentExecParams& fragment_exec_params, _query_ctx->runtime_filter_mgr()->set_runtime_filter_params( fragment_exec_params.runtime_filter_params); } - - if (_query_ctx) { - if (fragment_exec_params.__isset.topn_filter_source_node_ids) { - _query_ctx->init_runtime_predicates(fragment_exec_params.topn_filter_source_node_ids); - } else { - _query_ctx->init_runtime_predicates({0}); - } - } } RuntimeState::RuntimeState(const TUniqueId& instance_id, const TUniqueId& query_id, diff --git a/be/src/vec/exec/format/orc/vorc_reader.cpp b/be/src/vec/exec/format/orc/vorc_reader.cpp index f6a410cd81f..5c6a015d05c 100644 --- a/be/src/vec/exec/format/orc/vorc_reader.cpp +++ b/be/src/vec/exec/format/orc/vorc_reader.cpp @@ -2186,7 +2186,6 @@ Status OrcReader::_rewrite_dict_conjuncts(std::vector<int32_t>& dict_codes, int texpr_node.__set_node_type(TExprNodeType::BINARY_PRED); texpr_node.__set_opcode(TExprOpcode::EQ); texpr_node.__set_fn(fn); - texpr_node.__set_child_type(TPrimitiveType::INT); texpr_node.__set_num_children(2); texpr_node.__set_is_nullable(is_nullable); root = VectorizedFnCall::create_shared(texpr_node); diff --git a/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp b/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp index 807f016cb43..426810ccbfc 100644 --- a/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp +++ b/be/src/vec/exec/format/parquet/vparquet_group_reader.cpp @@ -911,7 +911,6 @@ Status RowGroupReader::_rewrite_dict_conjuncts(std::vector<int32_t>& dict_codes, texpr_node.__set_node_type(TExprNodeType::BINARY_PRED); texpr_node.__set_opcode(TExprOpcode::EQ); texpr_node.__set_fn(fn); - texpr_node.__set_child_type(TPrimitiveType::INT); texpr_node.__set_num_children(2); texpr_node.__set_is_nullable(is_nullable); root = VectorizedFnCall::create_shared(texpr_node); diff --git a/be/src/vec/exec/scan/new_olap_scanner.cpp b/be/src/vec/exec/scan/new_olap_scanner.cpp index 70fafade3b7..4507afc2cb2 100644 --- a/be/src/vec/exec/scan/new_olap_scanner.cpp +++ b/be/src/vec/exec/scan/new_olap_scanner.cpp @@ -406,8 +406,10 @@ Status NewOlapScanner::_init_tablet_reader_params( _tablet_reader_params.topn_filter_source_node_ids = ((pipeline::OlapScanLocalState*)_local_state) ->get_topn_filter_source_node_ids(_state, true); - _tablet_reader_params.use_topn_opt = - !_tablet_reader_params.topn_filter_source_node_ids.empty(); + if (!_tablet_reader_params.topn_filter_source_node_ids.empty()) { + _tablet_reader_params.topn_filter_target_node_id = + ((pipeline::OlapScanLocalState*)_local_state)->parent()->node_id(); + } } } diff --git a/be/src/vec/exec/vsort_node.cpp b/be/src/vec/exec/vsort_node.cpp index 160690f7737..bd86776d50e 100644 --- a/be/src/vec/exec/vsort_node.cpp +++ b/be/src/vec/exec/vsort_node.cpp @@ -64,8 +64,7 @@ Status VSortNode::init(const TPlanNode& tnode, RuntimeState* state) { // exclude cases which incoming blocks has string column which is sensitive to operations like // `filter` and `memcpy` if (_limit > 0 && _limit + _offset < HeapSorter::HEAP_SORT_THRESHOLD && - (tnode.sort_node.sort_info.use_two_phase_read || tnode.sort_node.use_topn_opt || - !row_desc.has_varlen_slots())) { + (tnode.sort_node.sort_info.use_two_phase_read || !row_desc.has_varlen_slots())) { _sorter = HeapSorter::create_unique(_vsort_exec_exprs, _limit, _offset, _pool, _is_asc_order, _nulls_first, row_desc); _reuse_mem = false; @@ -79,31 +78,6 @@ Status VSortNode::init(const TPlanNode& tnode, RuntimeState* state) { FullSorter::create_unique(_vsort_exec_exprs, _limit, _offset, _pool, _is_asc_order, _nulls_first, row_desc, state, _runtime_profile.get()); } - // init runtime predicate - _use_topn_opt = tnode.sort_node.use_topn_opt; - if (_use_topn_opt) { - auto* query_ctx = state->get_query_ctx(); - auto first_sort_expr_node = tnode.sort_node.sort_info.ordering_exprs[0].nodes[0]; - if (first_sort_expr_node.node_type == TExprNodeType::SLOT_REF) { - auto first_sort_slot = first_sort_expr_node.slot_ref; - for (auto* tuple_desc : this->intermediate_row_desc().tuple_descriptors()) { - if (tuple_desc->id() != first_sort_slot.tuple_id) { - continue; - } - for (auto* slot : tuple_desc->slots()) { - if (slot->id() == first_sort_slot.slot_id) { - RETURN_IF_ERROR(query_ctx->get_runtime_predicate(_id).init( - slot->type().type, _nulls_first[0], _is_asc_order[0], - slot->col_name())); - break; - } - } - } - } - if (!query_ctx->get_runtime_predicate(_id).inited()) { - return Status::InternalError("runtime predicate is not properly initialized"); - } - } _sorter->init_profile(_runtime_profile.get()); return Status::OK(); @@ -140,18 +114,6 @@ Status VSortNode::sink(RuntimeState* state, vectorized::Block* input_block, bool if (input_block->rows() > 0) { RETURN_IF_ERROR(_sorter->append_block(input_block)); RETURN_IF_CANCELLED(state); - - if (_use_topn_opt) { - auto& predicate = state->get_query_ctx()->get_runtime_predicate(_id); - if (predicate.need_update()) { - vectorized::Field new_top = _sorter->get_top_value(); - if (!new_top.is_null() && new_top != old_top) { - auto* query_ctx = state->get_query_ctx(); - RETURN_IF_ERROR(query_ctx->get_runtime_predicate(_id).update(new_top)); - old_top = std::move(new_top); - } - } - } if (!_reuse_mem) { input_block->clear(); } diff --git a/be/src/vec/exec/vsort_node.h b/be/src/vec/exec/vsort_node.h index 5b13eb8fa93..4f009fc5c0d 100644 --- a/be/src/vec/exec/vsort_node.h +++ b/be/src/vec/exec/vsort_node.h @@ -89,7 +89,6 @@ private: RuntimeProfile::Counter* _memory_usage_counter = nullptr; RuntimeProfile::Counter* _sort_blocks_memory_usage = nullptr; - bool _use_topn_opt = false; // topn top value Field old_top {Field::Types::Null}; diff --git a/be/src/vec/exprs/vtopn_pred.h b/be/src/vec/exprs/vtopn_pred.h index 326ceaf0f2c..675c8fb293c 100644 --- a/be/src/vec/exprs/vtopn_pred.h +++ b/be/src/vec/exprs/vtopn_pred.h @@ -19,6 +19,8 @@ #include <gen_cpp/types.pb.h> +#include <utility> + #include "runtime/query_context.h" #include "runtime/runtime_predicate.h" #include "runtime/runtime_state.h" @@ -36,21 +38,24 @@ class VTopNPred : public VExpr { ENABLE_FACTORY_CREATOR(VTopNPred); public: - VTopNPred(const TExprNode& node, int source_node_id) + VTopNPred(const TExprNode& node, int source_node_id, VExprContextSPtr target_ctx) : VExpr(node), _source_node_id(source_node_id), - _expr_name(fmt::format("VTopNPred(source_node_id={})", _source_node_id)) {} + _expr_name(fmt::format("VTopNPred(source_node_id={})", _source_node_id)), + _target_ctx(std::move(target_ctx)) {} - // TODO: support general expr - static Status create_vtopn_pred(SlotDescriptor* slot_desc, int source_node_id, + static Status create_vtopn_pred(const TExpr& target_expr, int source_node_id, vectorized::VExprSPtr& expr) { + vectorized::VExprContextSPtr target_ctx; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(target_expr, target_ctx)); + TExprNode node; node.__set_node_type(TExprNodeType::FUNCTION_CALL); node.__set_type(create_type_desc(PrimitiveType::TYPE_BOOLEAN)); - node.__set_is_nullable(slot_desc->is_nullable()); - expr = vectorized::VTopNPred::create_shared(node, source_node_id); + node.__set_is_nullable(target_ctx->root()->is_nullable()); + expr = vectorized::VTopNPred::create_shared(node, source_node_id, target_ctx); - expr->add_child(VSlotRef::create_shared(slot_desc)); + expr->add_child(target_ctx->root()); return Status::OK(); } @@ -112,5 +117,6 @@ private: std::string _expr_name; RuntimePredicate* _predicate = nullptr; FunctionBasePtr _function; + VExprContextSPtr _target_ctx; }; } // namespace doris::vectorized diff --git a/be/src/vec/olap/vcollect_iterator.cpp b/be/src/vec/olap/vcollect_iterator.cpp index 3ce1869546c..61050979b84 100644 --- a/be/src/vec/olap/vcollect_iterator.cpp +++ b/be/src/vec/olap/vcollect_iterator.cpp @@ -414,7 +414,7 @@ Status VCollectIterator::_topn_next(Block* block) { } // update runtime_predicate - if (_reader->_reader_context.use_topn_opt && changed && + if (!_reader->_reader_context.topn_filter_source_node_ids.empty() && changed && sorted_row_pos.size() >= _topn_limit) { // get field value from column size_t last_sorted_row = *sorted_row_pos.rbegin(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index ec54cc4cc4c..546e73e990d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -60,6 +60,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOneRowRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.trees.plans.physical.PhysicalResultSink; import org.apache.doris.nereids.trees.plans.physical.PhysicalSqlCache; +import org.apache.doris.nereids.trees.plans.physical.TopnFilter; import org.apache.doris.planner.PlanFragment; import org.apache.doris.planner.Planner; import org.apache.doris.planner.RuntimeFilter; @@ -716,4 +717,9 @@ public class NereidsPlanner extends Planner { task.run(); } } + + @Override + public List<TopnFilter> getTopnFilters() { + return cascadesContext.getTopnFilterContext().getTopnFilters(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index e40eead5340..ca1d2a53328 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -215,7 +215,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; @@ -620,10 +619,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla expr -> runtimeFilterGenerator.translateRuntimeFilterTarget(expr, finalScanNode, context) ) ); - if (context.getTopnFilterContext().isTopnFilterTarget(fileScan)) { - context.getTopnFilterContext().addLegacyTarget(fileScan, scanNode); - } - + context.getTopnFilterContext().translateTarget(fileScan, scanNode, context); Utils.execWithUncheckedException(scanNode::finalizeForNereids); // Create PlanFragment DataPartition dataPartition = DataPartition.RANDOM; @@ -668,9 +664,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla expr -> runtimeFilterGenerator.translateRuntimeFilterTarget(expr, esScanNode, context) ) ); - if (context.getTopnFilterContext().isTopnFilterTarget(esScan)) { - context.getTopnFilterContext().addLegacyTarget(esScan, esScanNode); - } + context.getTopnFilterContext().translateTarget(esScan, esScanNode, context); Utils.execWithUncheckedException(esScanNode::finalizeForNereids); DataPartition dataPartition = DataPartition.RANDOM; PlanFragment planFragment = new PlanFragment(context.nextFragmentId(), esScanNode, dataPartition); @@ -695,9 +689,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla expr -> runtimeFilterGenerator.translateRuntimeFilterTarget(expr, jdbcScanNode, context) ) ); - if (context.getTopnFilterContext().isTopnFilterTarget(jdbcScan)) { - context.getTopnFilterContext().addLegacyTarget(jdbcScan, jdbcScanNode); - } + context.getTopnFilterContext().translateTarget(jdbcScan, jdbcScanNode, context); Utils.execWithUncheckedException(jdbcScanNode::finalizeForNereids); DataPartition dataPartition = DataPartition.RANDOM; PlanFragment planFragment = new PlanFragment(context.nextFragmentId(), jdbcScanNode, dataPartition); @@ -722,10 +714,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla expr -> runtimeFilterGenerator.translateRuntimeFilterTarget(expr, odbcScanNode, context) ) ); - if (context.getTopnFilterContext().isTopnFilterTarget(odbcScan)) { - context.getTopnFilterContext().addLegacyTarget(odbcScan, odbcScanNode); - } + context.getTopnFilterContext().translateTarget(odbcScan, odbcScanNode, context); Utils.execWithUncheckedException(odbcScanNode::finalizeForNereids); + context.getTopnFilterContext().translateTarget(odbcScan, odbcScanNode, context); DataPartition dataPartition = DataPartition.RANDOM; PlanFragment planFragment = new PlanFragment(context.nextFragmentId(), odbcScanNode, dataPartition); context.addPlanFragment(planFragment); @@ -801,10 +792,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla expr, olapScanNode, context) ) ); + context.getTopnFilterContext().translateTarget(olapScan, olapScanNode, context); olapScanNode.setPushDownAggNoGrouping(context.getRelationPushAggOp(olapScan.getRelationId())); - if (context.getTopnFilterContext().isTopnFilterTarget(olapScan)) { - context.getTopnFilterContext().addLegacyTarget(olapScan, olapScanNode); - } // TODO: we need to remove all finalizeForNereids olapScanNode.finalizeForNereids(); // Create PlanFragment @@ -828,9 +817,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla PhysicalDeferMaterializeOlapScan deferMaterializeOlapScan, PlanTranslatorContext context) { PlanFragment planFragment = visitPhysicalOlapScan(deferMaterializeOlapScan.getPhysicalOlapScan(), context); OlapScanNode olapScanNode = (OlapScanNode) planFragment.getPlanRoot(); - if (context.getTopnFilterContext().isTopnFilterTarget(deferMaterializeOlapScan)) { - context.getTopnFilterContext().addLegacyTarget(deferMaterializeOlapScan, olapScanNode); - } TupleDescriptor tupleDescriptor = context.getTupleDesc(olapScanNode.getTupleId()); for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) { if (deferMaterializeOlapScan.getDeferMaterializeSlotIds() @@ -839,6 +825,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla } } context.createSlotDesc(tupleDescriptor, deferMaterializeOlapScan.getColumnIdSlot()); + context.getTopnFilterContext().translateTarget(deferMaterializeOlapScan, olapScanNode, context); return planFragment; } @@ -2149,19 +2136,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla sortNode.setOffset(topN.getOffset()); sortNode.setLimit(topN.getLimit()); if (context.getTopnFilterContext().isTopnFilterSource(topN)) { - sortNode.setUseTopnOpt(true); - context.getTopnFilterContext().getTargets(topN).forEach( - relation -> { - Optional<ScanNode> legacyScan = - context.getTopnFilterContext().getLegacyScanNode(relation); - Preconditions.checkState(legacyScan.isPresent(), - "cannot find ScanNode for topn filter:\n" - + "relation: %s \n%s", - relation.toString(), - context.getTopnFilterContext().toString()); - legacyScan.get().addTopnFilterSortNode(sortNode); - } - ); + context.getTopnFilterContext().translateSource(topN, sortNode); } // push sort to scan opt if (sortNode.getChild(0) instanceof OlapScanNode) { @@ -2212,16 +2187,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla sortNode.setUseTwoPhaseReadOpt(true); sortNode.getSortInfo().setUseTwoPhaseRead(); if (context.getTopnFilterContext().isTopnFilterSource(topN)) { - sortNode.setUseTopnOpt(true); - context.getTopnFilterContext().getTargets(topN).forEach( - relation -> { - Optional<ScanNode> legacyScan = - context.getTopnFilterContext().getLegacyScanNode(relation); - Preconditions.checkState(legacyScan.isPresent(), - "cannot find ScanNode for topn filter"); - legacyScan.get().addTopnFilterSortNode(sortNode); - } - ); + context.getTopnFilterContext().translateSource(topN, sortNode); } TupleDescriptor tupleDescriptor = sortNode.getSortInfo().getSortTupleDescriptor(); for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java index ec1e52d6426..fd3c794317d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java @@ -18,11 +18,11 @@ package org.apache.doris.nereids.processor.post; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.processor.post.TopnFilterPushDownVisitor.PushDownContext; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.SortPhase; -import org.apache.doris.nereids.trees.plans.algebra.Join; +import org.apache.doris.nereids.trees.plans.algebra.TopN; import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN; import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan; @@ -32,10 +32,8 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOdbcScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; -import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow; import org.apache.doris.qe.ConnectContext; -import java.util.Optional; /** * topN opt * refer to: @@ -48,33 +46,39 @@ import java.util.Optional; public class TopNScanOpt extends PlanPostProcessor { @Override public PhysicalTopN<? extends Plan> visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) { - Optional<PhysicalRelation> scanOpt = findScanForTopnFilter(topN); - scanOpt.ifPresent(scan -> ctx.getTopnFilterContext().addTopnFilter(topN, scan)); topN.child().accept(this, ctx); + if (checkTopN(topN)) { + TopnFilterPushDownVisitor pusher = new TopnFilterPushDownVisitor(ctx.getTopnFilterContext()); + TopnFilterPushDownVisitor.PushDownContext pushdownContext = new PushDownContext(topN, + topN.getOrderKeys().get(0).getExpr(), + topN.getOrderKeys().get(0).isNullFirst()); + topN.accept(pusher, pushdownContext); + } return topN; } - @Override - public Plan visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan> topN, - CascadesContext context) { - Optional<PhysicalRelation> scanOpt = findScanForTopnFilter(topN.getPhysicalTopN()); - scanOpt.ifPresent(scan -> context.getTopnFilterContext().addTopnFilter(topN, scan)); - topN.child().accept(this, context); - return topN; - } - - private Optional<PhysicalRelation> findScanForTopnFilter(PhysicalTopN<? extends Plan> topN) { - if (topN.getSortPhase() != SortPhase.LOCAL_SORT) { - return Optional.empty(); + boolean checkTopN(TopN topN) { + if (!(topN instanceof PhysicalTopN) && !(topN instanceof PhysicalDeferMaterializeTopN)) { + return false; + } + if (topN instanceof PhysicalTopN + && ((PhysicalTopN) topN).getSortPhase() != SortPhase.LOCAL_SORT) { + return false; + } else { + if (topN instanceof PhysicalDeferMaterializeTopN + && ((PhysicalDeferMaterializeTopN) topN).getSortPhase() != SortPhase.LOCAL_SORT) { + return false; + } } + if (topN.getOrderKeys().isEmpty()) { - return Optional.empty(); + return false; } // topn opt long topNOptLimitThreshold = getTopNOptLimitThreshold(); if (topNOptLimitThreshold == -1 || topN.getLimit() > topNOptLimitThreshold) { - return Optional.empty(); + return false; } // if firstKey's column is not present, it means the firstKey is not an original column from scan node // for example: "select cast(k1 as INT) as id from tbl1 order by id limit 2;" the firstKey "id" is @@ -84,64 +88,28 @@ public class TopNScanOpt extends PlanPostProcessor { // see Alias::toSlot() method to get how column info is passed around by alias of slotReference Expression firstKey = topN.getOrderKeys().get(0).getExpr(); if (!firstKey.isColumnFromTable()) { - return Optional.empty(); + return false; } + if (firstKey.getDataType().isFloatType() || firstKey.getDataType().isDoubleType()) { - return Optional.empty(); + return false; } - - if (! (firstKey instanceof SlotReference)) { - return Optional.empty(); - } - - boolean nullsFirst = topN.getOrderKeys().get(0).isNullFirst(); - return findScanNodeBySlotReference(topN, (SlotReference) firstKey, nullsFirst); + return true; } - private Optional<PhysicalRelation> findScanNodeBySlotReference(Plan root, SlotReference slot, boolean nullsFirst) { - if (root instanceof PhysicalWindow) { - return Optional.empty(); - } - - if (root instanceof PhysicalRelation) { - if (root.getOutputSet().contains(slot) && supportPhysicalRelations((PhysicalRelation) root)) { - return Optional.of((PhysicalRelation) root); - } else { - return Optional.empty(); - } - } - - Optional<PhysicalRelation> target; - if (root instanceof Join) { - Join join = (Join) root; - if (nullsFirst && join.getJoinType().isOuterJoin()) { - // in fact, topn-filter can be pushed down to the left child of leftOuterJoin - // and to the right child of rightOuterJoin. - // but we have rule to push topn down to the left/right side. and topn-filter - // will be generated according to the inferred topn node. - return Optional.empty(); - } - // try to push to both left and right child - if (root.child(0).getOutputSet().contains(slot)) { - target = findScanNodeBySlotReference(root.child(0), slot, nullsFirst); - } else { - target = findScanNodeBySlotReference(root.child(1), slot, nullsFirst); - } - return target; - } - - if (!root.children().isEmpty()) { - // TODO for set operator, topn-filter can be pushed down to all of its children. - Plan child = root.child(0); - if (child.getOutputSet().contains(slot)) { - target = findScanNodeBySlotReference(child, slot, nullsFirst); - if (target.isPresent()) { - return target; - } - } + @Override + public Plan visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan> topN, + CascadesContext ctx) { + topN.child().accept(this, ctx); + if (checkTopN(topN)) { + TopnFilterPushDownVisitor pusher = new TopnFilterPushDownVisitor(ctx.getTopnFilterContext()); + TopnFilterPushDownVisitor.PushDownContext pushdownContext = new PushDownContext(topN, + topN.getOrderKeys().get(0).getExpr(), + topN.getOrderKeys().get(0).isNullFirst()); + topN.accept(pusher, pushdownContext); } - return Optional.empty(); + return topN; } private long getTopNOptLimitThreshold() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java index 6a4fe3123df..fceec21ee7e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java @@ -17,78 +17,76 @@ package org.apache.doris.nereids.processor.post; +import org.apache.doris.analysis.Expr; +import org.apache.doris.nereids.glue.translator.ExpressionTranslator; +import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.algebra.TopN; import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; +import org.apache.doris.nereids.trees.plans.physical.TopnFilter; import org.apache.doris.planner.ScanNode; import org.apache.doris.planner.SortNode; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.google.common.collect.Sets; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; /** * topN runtime filter context */ public class TopnFilterContext { - private final Map<TopN, List<PhysicalRelation>> filters = Maps.newHashMap(); - private final Set<TopN> sources = Sets.newHashSet(); - private final Set<PhysicalRelation> targets = Sets.newHashSet(); - private final Map<PhysicalRelation, ScanNode> legacyTargetsMap = Maps.newHashMap(); - private final Map<TopN, SortNode> legacySourceMap = Maps.newHashMap(); + private final Map<TopN, TopnFilter> filters = Maps.newHashMap(); /** * add topN filter */ - public void addTopnFilter(TopN topn, PhysicalRelation scan) { - targets.add(scan); - sources.add(topn); - - List<PhysicalRelation> targets = filters.get(topn); - if (targets == null) { - filters.put(topn, Lists.newArrayList(scan)); + public void addTopnFilter(TopN topn, PhysicalRelation scan, Expression expr) { + TopnFilter filter = filters.get(topn); + if (filter == null) { + filters.put(topn, new TopnFilter(topn, scan, expr)); } else { - targets.add(scan); + filter.addTarget(scan, expr); } } - /** - * find the corresponding sortNode for topn filter - */ - public Optional<ScanNode> getLegacyScanNode(PhysicalRelation scan) { - return legacyTargetsMap.containsKey(scan) - ? Optional.of(legacyTargetsMap.get(scan)) - : Optional.empty(); - } - - public Optional<SortNode> getLegacySortNode(TopN topn) { - return legacyTargetsMap.containsKey(topn) - ? Optional.of(legacySourceMap.get(topn)) - : Optional.empty(); - } - public boolean isTopnFilterSource(TopN topn) { - return sources.contains(topn); - } - - public boolean isTopnFilterTarget(PhysicalRelation relation) { - return targets.contains(relation); + return filters.containsKey(topn); } - public void addLegacySource(TopN topn, SortNode sort) { - legacySourceMap.put(topn, sort); + public List<TopnFilter> getTopnFilters() { + return Lists.newArrayList(filters.values()); } - public void addLegacyTarget(PhysicalRelation relation, ScanNode legacy) { - legacyTargetsMap.put(relation, legacy); + /** + * translate topn-filter + */ + public void translateTarget(PhysicalRelation relation, ScanNode legacyScan, + PlanTranslatorContext translatorContext) { + for (TopnFilter filter : filters.values()) { + if (filter.hasTargetRelation(relation)) { + Expr expr = ExpressionTranslator.translate(filter.targets.get(relation), translatorContext); + filter.legacyTargets.put(legacyScan, expr); + } + } } - public List<PhysicalRelation> getTargets(TopN topn) { - return filters.get(topn); + /** + * translate topn-filter + */ + public void translateSource(TopN topn, SortNode sortNode) { + TopnFilter filter = filters.get(topn); + if (filter == null) { + return; + } + filter.legacySortNode = sortNode; + sortNode.setUseTopnOpt(true); + Preconditions.checkArgument(!filter.legacyTargets.isEmpty(), "missing targets: " + filter); + for (ScanNode scan : filter.legacyTargets.keySet()) { + scan.addTopnFilterSortNode(sortNode); + } } /** @@ -100,10 +98,8 @@ public class TopnFilterContext { String arrow = " -> "; builder.append("filters:\n"); for (TopN topn : filters.keySet()) { - builder.append(indent).append(topn.toString()).append("\n"); builder.append(indent).append(arrow).append(filters.get(topn)).append("\n"); } return builder.toString(); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterPushDownVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterPushDownVisitor.java new file mode 100644 index 00000000000..6962f3fddf2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterPushDownVisitor.java @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.processor.post.TopnFilterPushDownVisitor.PushDownContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.TopN; +import org.apache.doris.nereids.trees.plans.algebra.Union; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer; +import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOdbcScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; +import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation; +import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; +import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; + +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; + +/** + * push down topn filter + */ +public class TopnFilterPushDownVisitor extends PlanVisitor<Boolean, PushDownContext> { + private TopnFilterContext topnFilterContext; + + public TopnFilterPushDownVisitor(TopnFilterContext topnFilterContext) { + this.topnFilterContext = topnFilterContext; + } + + /** + * topn filter push-down context + */ + public static class PushDownContext { + final Expression probeExpr; + final TopN topn; + final boolean nullsFirst; + + public PushDownContext(TopN topn, Expression probeExpr, boolean nullsFirst) { + this.topn = topn; + this.probeExpr = probeExpr; + this.nullsFirst = nullsFirst; + } + + public PushDownContext withNewProbeExpression(Expression newProbe) { + return new PushDownContext(topn, newProbe, nullsFirst); + } + + } + + @Override + public Boolean visit(Plan plan, PushDownContext ctx) { + boolean pushed = false; + for (Plan child : plan.children()) { + pushed |= child.accept(this, ctx); + } + return pushed; + } + + @Override + public Boolean visitPhysicalProject( + PhysicalProject<? extends Plan> project, PushDownContext ctx) { + // project ( A+1 as x) + // probeExpr: abs(x) => abs(A+1) + PushDownContext ctxProjectProbeExpr = ctx; + Map<Expression, Expression> replaceMap = Maps.newHashMap(); + for (NamedExpression ne : project.getProjects()) { + if (ne instanceof Alias && ctx.probeExpr.getInputSlots().contains(ne.toSlot())) { + replaceMap.put(ctx.probeExpr.getInputSlots().iterator().next(), ((Alias) ne).child()); + } + } + if (! replaceMap.isEmpty()) { + Expression newProbeExpr = ctx.probeExpr.accept(ExpressionVisitors.EXPRESSION_MAP_REPLACER, replaceMap); + ctxProjectProbeExpr = ctx.withNewProbeExpression(newProbeExpr); + } + return project.child().accept(this, ctxProjectProbeExpr); + } + + @Override + public Boolean visitPhysicalSetOperation( + PhysicalSetOperation setOperation, PushDownContext ctx) { + boolean pushedDown = pushDownFilterToSetOperatorChild(setOperation, ctx, 0); + + if (setOperation instanceof Union) { + for (int i = 1; i < setOperation.children().size(); i++) { + // push down to the other children + pushedDown |= pushDownFilterToSetOperatorChild(setOperation, ctx, i); + } + } + return pushedDown; + } + + private Boolean pushDownFilterToSetOperatorChild(PhysicalSetOperation setOperation, + PushDownContext ctx, int childIdx) { + Plan child = setOperation.child(childIdx); + Map<Expression, Expression> replaceMap = Maps.newHashMap(); + + List<NamedExpression> setOutputs = setOperation.getOutputs(); + for (int i = 0; i < setOutputs.size(); i++) { + replaceMap.put(setOutputs.get(i).toSlot(), + setOperation.getRegularChildrenOutputs().get(childIdx).get(i)); + } + if (!replaceMap.isEmpty()) { + Expression newProbeExpr = ctx.probeExpr.accept(ExpressionVisitors.EXPRESSION_MAP_REPLACER, + replaceMap); + PushDownContext childPushDownContext = ctx.withNewProbeExpression(newProbeExpr); + return child.accept(this, childPushDownContext); + } + return false; + } + + @Override + public Boolean visitPhysicalCTEAnchor(PhysicalCTEAnchor<? extends Plan, ? extends Plan> anchor, + PushDownContext ctx) { + return false; + } + + @Override + public Boolean visitPhysicalCTEProducer(PhysicalCTEProducer<? extends Plan> anchor, + PushDownContext ctx) { + return false; + } + + @Override + public Boolean visitPhysicalTopN(PhysicalTopN<? extends Plan> topn, PushDownContext ctx) { + if (topn.equals(ctx.topn)) { + return topn.child().accept(this, ctx); + } + return false; + } + + @Override + public Boolean visitPhysicalWindow(PhysicalWindow<? extends Plan> window, PushDownContext ctx) { + return false; + } + + @Override + public Boolean visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> join, + PushDownContext ctx) { + if (ctx.nullsFirst && join.getJoinType().isOuterJoin()) { + // topn-filter can be pushed down to the left child of leftOuterJoin + // and to the right child of rightOuterJoin, + // but PushDownTopNThroughJoin rule already pushes topn to the left and right side. + // the topn-filter will be generated according to the inferred topn node. + return false; + } + if (join.left().getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + return join.left().accept(this, ctx); + } + if (join.right().getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + // expand expr to the other side of hash join condition: + // T1 join T2 on T1.x = T2.y order by T2.y limit 10 + // we rewrite probeExpr from T2.y to T1.x and try to push T1.x to left side + for (Expression conj : join.getHashJoinConjuncts()) { + if (ctx.probeExpr.equals(conj.child(1))) { + // push to left child. right child is blocking operator, do not need topn-filter + PushDownContext newCtx = ctx.withNewProbeExpression(conj.child(0)); + return join.left().accept(this, newCtx); + } + } + } + // topn key is combination of left and right + // select * from (select T1.A+T2.B as x from T1 join T2) T3 order by x limit 1; + return false; + } + + @Override + public Boolean visitPhysicalNestedLoopJoin(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> join, + PushDownContext ctx) { + if (ctx.nullsFirst && join.getJoinType().isOuterJoin()) { + // topn-filter can be pushed down to the left child of leftOuterJoin + // and to the right child of rightOuterJoin, + // but PushDownTopNThroughJoin rule already pushes topn to the left and right side. + // the topn-filter will be generated according to the inferred topn node. + return false; + } + // push to left child. right child is blocking operator, do not need topn-filter + if (join.left().getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + return join.left().accept(this, ctx); + } + return false; + } + + @Override + public Boolean visitPhysicalRelation(PhysicalRelation relation, PushDownContext ctx) { + if (supportPhysicalRelations(relation) + && relation.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + if (relation.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + topnFilterContext.addTopnFilter(ctx.topn, relation, ctx.probeExpr); + return true; + } + } + return false; + } + + private boolean supportPhysicalRelations(PhysicalRelation relation) { + return relation instanceof PhysicalOlapScan + || relation instanceof PhysicalOdbcScan + || relation instanceof PhysicalEsScan + || relation instanceof PhysicalFileScan + || relation instanceof PhysicalJdbcScan + || relation instanceof PhysicalDeferMaterializeOlapScan; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/TopnFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/TopnFilter.java new file mode 100644 index 00000000000..6c8e2187b48 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/TopnFilter.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.physical; + +import org.apache.doris.analysis.Expr; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.plans.algebra.TopN; +import org.apache.doris.planner.ScanNode; +import org.apache.doris.planner.SortNode; +import org.apache.doris.thrift.TTopnFilterDesc; + +import com.google.common.collect.Maps; + +import java.util.Map; + +/** + * topn filter + */ +public class TopnFilter { + public TopN topn; + public SortNode legacySortNode; + public Map<PhysicalRelation, Expression> targets = Maps.newHashMap(); + public Map<ScanNode, Expr> legacyTargets = Maps.newHashMap(); + + public TopnFilter(TopN topn, PhysicalRelation rel, Expression expr) { + this.topn = topn; + targets.put(rel, expr); + } + + public void addTarget(PhysicalRelation rel, Expression expr) { + targets.put(rel, expr); + } + + public boolean hasTargetRelation(PhysicalRelation rel) { + return targets.containsKey(rel); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append(topn).append("->[ "); + for (PhysicalRelation rel : targets.keySet()) { + builder.append("(").append(rel).append(":").append(targets.get(rel)).append(") "); + } + builder.append("]"); + return builder.toString(); + } + + /** + * to thrift + */ + public TTopnFilterDesc toThrift() { + TTopnFilterDesc tFilter = new TTopnFilterDesc(); + tFilter.setSourceNodeId(legacySortNode.getId().asInt()); + tFilter.setIsAsc(topn.getOrderKeys().get(0).isAsc()); + tFilter.setNullFirst(topn.getOrderKeys().get(0).isNullFirst()); + for (ScanNode scan : legacyTargets.keySet()) { + tFilter.putToTargetNodeIdToTargetExpr(scan.getId().asInt(), + legacyTargets.get(scan).treeToThrift()); + } + return tFilter; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java index e64a0c85818..ce47ab5ac67 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/Planner.java @@ -24,6 +24,7 @@ import org.apache.doris.common.UserException; import org.apache.doris.common.profile.PlanTreeBuilder; import org.apache.doris.common.profile.PlanTreePrinter; import org.apache.doris.nereids.PlannerHook; +import org.apache.doris.nereids.trees.plans.physical.TopnFilter; import org.apache.doris.qe.ResultSet; import org.apache.doris.thrift.TQueryOptions; @@ -131,4 +132,8 @@ public abstract class Planner { public abstract Optional<ResultSet> handleQueryInFe(StatementBase parsedStmt); public abstract void addHook(PlannerHook hook); + + public List<TopnFilter> getTopnFilters() { + return Lists.newArrayList(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java index bab1c1decbb..ef071e78d39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java @@ -42,6 +42,7 @@ import org.apache.doris.metric.MetricRepo; import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.stats.StatsErrorEstimator; +import org.apache.doris.nereids.trees.plans.physical.TopnFilter; import org.apache.doris.planner.DataPartition; import org.apache.doris.planner.DataSink; import org.apache.doris.planner.DataStreamSink; @@ -63,7 +64,6 @@ import org.apache.doris.planner.RuntimeFilter; import org.apache.doris.planner.RuntimeFilterId; import org.apache.doris.planner.ScanNode; import org.apache.doris.planner.SetOperationNode; -import org.apache.doris.planner.SortNode; import org.apache.doris.planner.UnionNode; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.InternalService.PExecPlanFragmentResult; @@ -117,6 +117,7 @@ import org.apache.doris.thrift.TScanRangeLocations; import org.apache.doris.thrift.TScanRangeParams; import org.apache.doris.thrift.TStatusCode; import org.apache.doris.thrift.TTabletCommitInfo; +import org.apache.doris.thrift.TTopnFilterDesc; import org.apache.doris.thrift.TUniqueId; import com.google.common.base.Preconditions; @@ -276,6 +277,7 @@ public class Coordinator implements CoordInterface { public Map<RuntimeFilterId, List<FRuntimeFilterTargetParam>> ridToTargetParam = Maps.newHashMap(); // The runtime filter that expects the instance to be used public List<RuntimeFilter> assignedRuntimeFilters = new ArrayList<>(); + public List<TopnFilter> topnFilters = new ArrayList<>(); // Runtime filter ID to the builder instance number public Map<RuntimeFilterId, Integer> ridToBuilderNum = Maps.newHashMap(); private ConnectContext context; @@ -356,6 +358,7 @@ public class Coordinator implements CoordInterface { nextInstanceId.setHi(queryId.hi); nextInstanceId.setLo(queryId.lo + 1); this.assignedRuntimeFilters = planner.getRuntimeFilters(); + this.topnFilters = planner.getTopnFilters(); this.executionProfile = new ExecutionProfile(queryId, fragments); } @@ -3783,12 +3786,6 @@ public class Coordinator implements CoordInterface { int rate = Math.min(Config.query_colocate_join_memory_limit_penalty_factor, instanceExecParams.size()); memLimit = queryOptions.getMemLimit() / rate; } - Set<Integer> topnFilterSources = Sets.newLinkedHashSet(); - for (ScanNode scanNode : scanNodes) { - for (SortNode sortNode : scanNode.getTopnFilterSortNodes()) { - topnFilterSources.add(sortNode.getId().asInt()); - } - } Map<TNetworkAddress, TPipelineFragmentParams> res = new HashMap(); Map<TNetworkAddress, Integer> instanceIdx = new HashMap(); @@ -3859,11 +3856,12 @@ public class Coordinator implements CoordInterface { localParams.setBackendNum(backendNum++); localParams.setRuntimeFilterParams(new TRuntimeFilterParams()); localParams.runtime_filter_params.setRuntimeFilterMergeAddr(runtimeFilterMergeAddr); - if (!topnFilterSources.isEmpty()) { - // topn_filter_source_node_ids is used by nereids not by legacy planner. - // if there is no topnFilterSources, do not set it. - // topn_filter_source_node_ids=null means legacy planner - localParams.topn_filter_source_node_ids = Lists.newArrayList(topnFilterSources); + if (!topnFilters.isEmpty()) { + List<TTopnFilterDesc> filterDescs = new ArrayList<>(); + for (TopnFilter filter : topnFilters) { + filterDescs.add(filter.toThrift()); + } + localParams.setTopnFilterDescs(filterDescs); } if (instanceExecParam.instanceId.equals(runtimeFilterMergeInstanceId)) { Set<Integer> broadCastRf = assignedRuntimeFilters.stream().filter(RuntimeFilter::isBroadcast) diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift index 5d46e4b38db..3ef7c7acb07 100644 --- a/gensrc/thrift/Exprs.thrift +++ b/gensrc/thrift/Exprs.thrift @@ -253,7 +253,7 @@ struct TExprNode { 26: optional Types.TFunction fn // If set, child[vararg_start_idx] is the first vararg child. 27: optional i32 vararg_start_idx - 28: optional Types.TPrimitiveType child_type + 28: optional Types.TPrimitiveType child_type // Deprecated // For vectorized engine 29: optional bool is_nullable diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index d9a24a2b8cc..43fcdf321a1 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -732,12 +732,12 @@ struct TOlapScanNode { 10: optional i64 sort_limit 11: optional bool enable_unique_key_merge_on_write 12: optional TPushAggOp push_down_agg_type_opt //Deprecated - 13: optional bool use_topn_opt + 13: optional bool use_topn_opt // Deprecated 14: optional list<Descriptors.TOlapTableIndex> indexes_desc 15: optional set<i32> output_column_unique_ids 16: optional list<i32> distribute_column_ids 17: optional i32 schema_version - 18: optional list<i32> topn_filter_source_node_ids + 18: optional list<i32> topn_filter_source_node_ids //deprecated, move to TPlanNode.106 } struct TEqJoinCondition { @@ -935,7 +935,7 @@ struct TSortNode { // Indicates whether the imposed limit comes DEFAULT_ORDER_BY_LIMIT. 6: optional bool is_default_limit - 7: optional bool use_topn_opt + 7: optional bool use_topn_opt // Deprecated 8: optional bool merge_by_exchange 9: optional bool is_analytic_sort 10: optional bool is_colocate @@ -1186,7 +1186,7 @@ struct TTopnFilterDesc { 2: required bool is_asc 3: required bool null_first // scan node id -> expr on scan node - 4: required map<Types.TPlanNodeId, Exprs.TExpr> targetNodeId_to_target_expr + 4: required map<Types.TPlanNodeId, Exprs.TExpr> target_node_id_to_target_expr } // Specification of a runtime filter. diff --git a/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy b/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy index ab77d51b6e3..617794e1988 100644 --- a/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy +++ b/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy @@ -79,7 +79,7 @@ suite("topn-filter") { qt_check_result2 "${multi_topn_desc}" - // push down topn-filter to both join children + // push down topn-filter to join left child explain { sql """ select o_orderkey, c_custkey @@ -87,7 +87,7 @@ suite("topn-filter") { join customer on o_custkey = c_custkey order by c_custkey limit 2; """ - contains "TOPN OPT:" + contains "TOPN OPT:4" } // push topn filter down through AGG @@ -107,7 +107,7 @@ suite("topn-filter") { join nation on s_nationkey = n_nationkey order by s_nationkey limit 1; """ - contains "TOPN OPT:" + contains "TOPN OPT:7" } explain { @@ -141,7 +141,7 @@ suite("topn-filter") { // this topn-filter is generated by topn2, not topn1 explain { sql "select * from nation left outer join region on r_regionkey = n_regionkey order by n_regionkey nulls first limit 1; " - contains "TOPN OPT:" + multiContains ("TOPN OPT:", 1) } // TODO: support latter, push topn to right outer join --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
