This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch vector-index-dev in repository https://gitbox.apache.org/repos/asf/doris.git
commit 305f9587bdd2fe1cd0b27a6916731eba2e9246d6 Author: hezhiqiang <hezhiqi...@selectdb.com> AuthorDate: Mon Jun 9 10:13:45 2025 +0800 Fix multi-threads range search --- .../ann_search_params.h} | 65 +++++++-------- .../ann_index/range_search_runtime_info.cpp | 26 ++++++ .../ann_index/range_search_runtime_info.h | 96 ++++++++++++++++++++++ .../olap/rowset/segment_v2/ann_index_iterator.cpp | 8 +- be/src/olap/rowset/segment_v2/ann_index_iterator.h | 38 ++------- be/src/olap/rowset/segment_v2/ann_index_reader.cpp | 21 ++--- be/src/olap/rowset/segment_v2/ann_index_reader.h | 21 +++-- be/src/olap/rowset/segment_v2/column_writer.cpp | 30 +++---- be/src/olap/rowset/segment_v2/column_writer.h | 6 +- be/src/olap/rowset/segment_v2/index_iterator.h | 7 +- .../segment_v2/inverted_index/util/term_iterator.h | 6 +- be/src/olap/rowset/segment_v2/segment_iterator.cpp | 28 +++++-- be/src/olap/rowset/segment_v2/segment_iterator.h | 8 ++ .../rowset/segment_v2/virtual_column_iterator.cpp | 16 ++-- .../rowset/segment_v2/virtual_column_iterator.h | 2 - be/src/olap/tablet_reader.cpp | 2 +- be/src/olap/tablet_reader.h | 2 +- be/src/pipeline/exec/olap_scan_operator.cpp | 8 +- be/src/pipeline/exec/olap_scan_operator.h | 2 +- be/src/vec/exec/scan/olap_scanner.cpp | 16 ++-- be/src/vec/exec/scan/scanner.cpp | 1 + be/src/vec/exprs/ann_range_search_params.h | 57 ------------- be/src/vec/exprs/ann_topn_runtime.cpp | 9 +- be/src/vec/exprs/ann_topn_runtime.h | 5 +- be/src/vec/exprs/vectorized_fn_call.cpp | 62 ++++++++------ be/src/vec/exprs/vectorized_fn_call.h | 9 +- be/src/vec/exprs/vexpr.cpp | 14 +++- be/src/vec/exprs/vexpr.h | 12 ++- be/src/vec/exprs/vexpr_context.cpp | 22 ++++- be/src/vec/exprs/vexpr_context.h | 15 ++++ be/src/vec/exprs/virtual_slot_ref.cpp | 4 +- be/src/vec/exprs/virtual_slot_ref.h | 1 + .../array/function_array_distance_approximate.h | 9 +- be/src/vec/runtime/vector_search_user_params.h | 2 + be/src/vector/CMakeLists.txt | 1 + be/src/vector/faiss_vector_index.cpp | 30 ++++--- be/src/vector/faiss_vector_index.h | 13 ++- .../metric.cpp} | 36 +++++--- .../metric.h} | 18 ++-- be/src/vector/vector_index.h | 62 +++----------- .../rewrite/PushDownVectorTopNIntoOlapScan.java | 3 + .../trees/copier/LogicalPlanDeepCopier.java | 11 +-- 42 files changed, 453 insertions(+), 351 deletions(-) diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.h b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h similarity index 57% copy from be/src/olap/rowset/segment_v2/ann_index_iterator.h copy to be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h index 82a4113cacb..04b3f5ddc82 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_iterator.h +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h @@ -17,16 +17,14 @@ #pragma once -#include <cstdint> -#include <memory> +#include <gen_cpp/Opcodes_types.h> -#include "gutil/integral_types.h" -#include "olap/rowset/segment_v2/ann_index_reader.h" -#include "olap/rowset/segment_v2/index_iterator.h" -#include "runtime/runtime_state.h" +#include <roaring/roaring.hh> +#include <string> -namespace doris::segment_v2 { +#include "runtime/runtime_state.h" +namespace doris::vectorized { struct AnnIndexParam { const float* query_value; const size_t query_value_size; @@ -56,35 +54,30 @@ struct RangeSearchResult { std::unique_ptr<float[]> distance; }; -// IndexIterator 与 IndexReader 的角色似乎有点重复,未来可以重构后删除一层概念 -class AnnIndexIterator : public IndexIterator { -public: - AnnIndexIterator(const io::IOContext& io_ctx, OlapReaderStatistics* stats, - RuntimeState* runtime_state, const IndexReaderPtr& reader); - ~AnnIndexIterator() override = default; - - IndexType type() override { return IndexType::ANN; } - - IndexReaderPtr get_reader() override { - return std::static_pointer_cast<IndexReader>(_ann_reader); - } - - MOCK_FUNCTION Status read_from_index(const IndexParam& param) override; - - Status read_null_bitmap(InvertedIndexQueryCacheHandle* cache_handle) override { - return Status::OK(); - } - - bool has_null() override { return true; } - - MOCK_FUNCTION Status range_search(const RangeSearchParams& params, - const VectorSearchUserParams& custom_params, - RangeSearchResult* result); - -private: - std::shared_ptr<AnnIndexReader> _ann_reader; +/* +This struct is used to wrap the search result of a vector index. +roaring is a bitmap that contains the row ids that satisfy the search condition. +row_ids is a vector of row ids that are returned by the search, it could be used by virtual_column_iterator to do column filter. +distances is a vector of distances that are returned by the search. +For range search, is condition is not le_or_lt, the row_ids and distances will be nullptr. +*/ +struct IndexSearchResult { + IndexSearchResult() = default; + + std::unique_ptr<float[]> distances = nullptr; + std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr; + std::shared_ptr<roaring::Roaring> roaring = nullptr; +}; - ENABLE_FACTORY_CREATOR(AnnIndexIterator); +struct IndexSearchParameters { + roaring::Roaring* roaring = nullptr; + bool is_le_or_lt = true; + virtual ~IndexSearchParameters() = default; }; -} // namespace doris::segment_v2 \ No newline at end of file +struct HNSWSearchParameters : public IndexSearchParameters { + int ef_search = 16; + bool check_relative_distance = true; + bool bounded_queue = true; +}; +} // namespace doris::vectorized diff --git a/be/src/olap/rowset/segment_v2/ann_index/range_search_runtime_info.cpp b/be/src/olap/rowset/segment_v2/ann_index/range_search_runtime_info.cpp new file mode 100644 index 00000000000..0dcdb754c6a --- /dev/null +++ b/be/src/olap/rowset/segment_v2/ann_index/range_search_runtime_info.cpp @@ -0,0 +1,26 @@ +#include "olap/rowset/segment_v2/ann_index/range_search_runtime_info.h" + +#include <fmt/format.h> + +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" + +namespace doris::vectorized { +RangeSearchParams RangeSearchRuntimeInfo::to_range_search_params() const { + RangeSearchParams params; + params.query_value = query_value.get(); + params.radius = static_cast<float>(radius); + params.roaring = nullptr; + params.is_le_or_lt = is_le_or_lt; + return params; +} + +std::string RangeSearchRuntimeInfo::to_string() const { + return fmt::format( + "is_ann_range_search: {}, is_le_or_lt: {}, src_col_idx: {}, " + "dst_col_idx: {}, metric_type {}, radius: {}, user params: {}, query_vector is null: " + "{}", + is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx, + segment_v2::metric_to_string(metric_type), radius, user_params.to_string(), + query_value == nullptr); +} +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/ann_index/range_search_runtime_info.h b/be/src/olap/rowset/segment_v2/ann_index/range_search_runtime_info.h new file mode 100644 index 00000000000..de145434b1f --- /dev/null +++ b/be/src/olap/rowset/segment_v2/ann_index/range_search_runtime_info.h @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <gen_cpp/Opcodes_types.h> + +#include <string> + +#include "vec/runtime/vector_search_user_params.h" +#include "vector/metric.h" + +namespace doris::vectorized { +struct RangeSearchParams; + +struct RangeSearchRuntimeInfo { + // DefaultConstructor + RangeSearchRuntimeInfo() + : is_ann_range_search(false), + is_le_or_lt(true), + src_col_idx(0), + dst_col_idx(-1), + radius(0.0), + metric_type(segment_v2::Metric::UNKNOWN) { + query_value = nullptr; + } + + // CopyConstructor + RangeSearchRuntimeInfo(const RangeSearchRuntimeInfo& other) + : is_ann_range_search(other.is_ann_range_search), + is_le_or_lt(other.is_le_or_lt), + src_col_idx(other.src_col_idx), + dim(other.dim), + dst_col_idx(other.dst_col_idx), + radius(other.radius), + metric_type(other.metric_type), + user_params(other.user_params) { + // Do deep copy to query_value. + if (other.query_value) { + query_value = std::make_unique<float[]>(other.dim); + std::copy(other.query_value.get(), other.query_value.get() + other.dim, + query_value.get()); + } else { + query_value = nullptr; + } + } + + RangeSearchRuntimeInfo& operator=(const RangeSearchRuntimeInfo& other) { + is_ann_range_search = other.is_ann_range_search; + is_le_or_lt = other.is_le_or_lt; + src_col_idx = other.src_col_idx; + dst_col_idx = other.dst_col_idx; + radius = other.radius; + metric_type = other.metric_type; + user_params = other.user_params; + dim = other.dim; + // Do deep copy to query_value. + if (other.query_value) { + query_value = std::make_unique<float[]>(other.dim); + std::copy(other.query_value.get(), other.query_value.get() + other.dim, + query_value.get()); + } else { + query_value = nullptr; + } + return *this; + } + + RangeSearchParams to_range_search_params() const; + + std::string to_string() const; + + bool is_ann_range_search = false; + bool is_le_or_lt = true; + size_t src_col_idx = 0; + size_t dim = 0; + int64_t dst_col_idx = -1; + double radius = 0.0; + segment_v2::Metric metric_type; + doris::VectorSearchUserParams user_params; + std::unique_ptr<float[]> query_value; +}; +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp b/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp index 3b37e3cabcb..fce222745dd 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp @@ -19,6 +19,8 @@ #include <memory> +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" + namespace doris::segment_v2 { AnnIndexIterator::AnnIndexIterator(const io::IOContext& io_ctx, OlapReaderStatistics* stats, @@ -28,7 +30,7 @@ AnnIndexIterator::AnnIndexIterator(const io::IOContext& io_ctx, OlapReaderStatis } Status AnnIndexIterator::read_from_index(const IndexParam& param) { - auto* a_param = std::get<AnnIndexParam*>(param); + auto* a_param = std::get<vectorized::AnnIndexParam*>(param); if (a_param == nullptr) { return Status::Error<ErrorCode::INDEX_INVALID_PARAMETERS>("a_param is null"); } @@ -36,9 +38,9 @@ Status AnnIndexIterator::read_from_index(const IndexParam& param) { return _ann_reader->query(&_io_ctx, a_param); } -Status AnnIndexIterator::range_search(const RangeSearchParams& params, +Status AnnIndexIterator::range_search(const vectorized::RangeSearchParams& params, const VectorSearchUserParams& custom_params, - RangeSearchResult* result) { + vectorized::RangeSearchResult* result) { if (_ann_reader == nullptr) { return Status::Error<ErrorCode::INDEX_INVALID_PARAMETERS>("_ann_reader is null"); } diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.h b/be/src/olap/rowset/segment_v2/ann_index_iterator.h index 82a4113cacb..7bff4c334e9 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_iterator.h +++ b/be/src/olap/rowset/segment_v2/ann_index_iterator.h @@ -25,36 +25,12 @@ #include "olap/rowset/segment_v2/index_iterator.h" #include "runtime/runtime_state.h" -namespace doris::segment_v2 { - -struct AnnIndexParam { - const float* query_value; - const size_t query_value_size; - size_t limit; - doris::VectorSearchUserParams _user_params; - roaring::Roaring* roaring; - std::unique_ptr<std::vector<float>> distance = nullptr; - std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr; -}; - -struct RangeSearchParams { - bool is_le_or_lt = true; - float* query_value = nullptr; - float radius = -1; - roaring::Roaring* roaring; // roaring from segment_iterator - std::string to_string() const { - DCHECK(roaring != nullptr); - return fmt::format("is_le_or_lt: {}, radius: {}, input rows {}", is_le_or_lt, radius, - roaring->cardinality()); - } - virtual ~RangeSearchParams() = default; -}; +namespace doris::vectorized { +struct RangeSearchParams; +struct RangeSearchResult; +} // namespace doris::vectorized -struct RangeSearchResult { - std::shared_ptr<roaring::Roaring> roaring; - std::unique_ptr<std::vector<uint64_t>> row_ids; - std::unique_ptr<float[]> distance; -}; +namespace doris::segment_v2 { // IndexIterator 与 IndexReader 的角色似乎有点重复,未来可以重构后删除一层概念 class AnnIndexIterator : public IndexIterator { @@ -77,9 +53,9 @@ public: bool has_null() override { return true; } - MOCK_FUNCTION Status range_search(const RangeSearchParams& params, + MOCK_FUNCTION Status range_search(const vectorized::RangeSearchParams& params, const VectorSearchUserParams& custom_params, - RangeSearchResult* result); + vectorized::RangeSearchResult* result); private: std::shared_ptr<AnnIndexReader> _ann_reader; diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp index d9d90dfbf04..b08528b2b78 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp @@ -23,6 +23,7 @@ #include "ann_index_iterator.h" #include "common/config.h" #include "io/io_common.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/rowset/segment_v2/index_file_reader.h" #include "olap/rowset/segment_v2/inverted_index_compound_reader.h" #include "runtime/runtime_state.h" @@ -32,7 +33,7 @@ namespace doris::segment_v2 { -void AnnIndexReader::update_result(const IndexSearchResult& search_result, +void AnnIndexReader::update_result(const vectorized::IndexSearchResult& search_result, std::vector<float>& distance, roaring::Roaring& roaring) { DCHECK(search_result.distances != nullptr); DCHECK(search_result.roaring != nullptr); @@ -54,7 +55,7 @@ AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta, _index_type = it->second; it = index_properties.find("metric_type"); DCHECK(it != index_properties.end()); - _metric_type = VectorIndex::string_to_metric(it->second); + _metric_type = string_to_metric(it->second); } Status AnnIndexReader::new_iterator(const io::IOContext& io_ctx, OlapReaderStatistics* stats, @@ -81,16 +82,16 @@ Status AnnIndexReader::load_index(io::IOContext* io_ctx) { }); } -Status AnnIndexReader::query(io::IOContext* io_ctx, AnnIndexParam* param) { +Status AnnIndexReader::query(io::IOContext* io_ctx, vectorized::AnnIndexParam* param) { #ifndef BE_TEST RETURN_IF_ERROR(load_index(io_ctx)); #endif DCHECK(_vector_index != nullptr); const float* query_vec = param->query_value; const int limit = param->limit; - IndexSearchResult index_search_result; + vectorized::IndexSearchResult index_search_result; if (_index_type == "hnsw") { - HNSWSearchParameters hnsw_search_params; + vectorized::HNSWSearchParameters hnsw_search_params; hnsw_search_params.roaring = param->roaring; hnsw_search_params.ef_search = param->_user_params.hnsw_ef_search; hnsw_search_params.check_relative_distance = @@ -112,18 +113,18 @@ Status AnnIndexReader::query(io::IOContext* io_ctx, AnnIndexParam* param) { return Status::OK(); } -Status AnnIndexReader::range_search(const RangeSearchParams& params, +Status AnnIndexReader::range_search(const vectorized::RangeSearchParams& params, const VectorSearchUserParams& custom_params, - RangeSearchResult* result, io::IOContext* io_ctx) { + vectorized::RangeSearchResult* result, io::IOContext* io_ctx) { #ifndef BE_TEST RETURN_IF_ERROR(load_index(io_ctx)); #endif DCHECK(_vector_index != nullptr); - IndexSearchResult search_result; - std::unique_ptr<IndexSearchParameters> search_param = nullptr; + vectorized::IndexSearchResult search_result; + std::unique_ptr<vectorized::IndexSearchParameters> search_param = nullptr; if (_index_type == "hnsw") { - auto hnsw_param = std::make_unique<HNSWSearchParameters>(); + auto hnsw_param = std::make_unique<vectorized::HNSWSearchParameters>(); hnsw_param->ef_search = custom_params.hnsw_ef_search; hnsw_param->check_relative_distance = custom_params.hnsw_check_relative_distance; hnsw_param->bounded_queue = custom_params.hnsw_bounded_queue; diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.h b/be/src/olap/rowset/segment_v2/ann_index_reader.h index e557d9c96b3..6d326a748f2 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_reader.h +++ b/be/src/olap/rowset/segment_v2/ann_index_reader.h @@ -23,11 +23,14 @@ #include "util/once.h" #include "vector/vector_index.h" -namespace doris::segment_v2 { - +namespace doris::vectorized { struct AnnIndexParam; struct RangeSearchParams; struct RangeSearchResult; +struct IndexSearchResult; +} // namespace doris::vectorized + +namespace doris::segment_v2 { class IndexFileReader; class IndexIterator; @@ -38,16 +41,16 @@ public: std::shared_ptr<IndexFileReader> index_file_reader); ~AnnIndexReader() override = default; - static void update_result(const IndexSearchResult&, std::vector<float>& distance, + static void update_result(const vectorized::IndexSearchResult&, std::vector<float>& distance, roaring::Roaring& row_id); Status load_index(io::IOContext* io_ctx); - Status query(io::IOContext* io_ctx, AnnIndexParam* param); + Status query(io::IOContext* io_ctx, vectorized::AnnIndexParam* param); - Status range_search(const RangeSearchParams& params, - const VectorSearchUserParams& custom_params, RangeSearchResult* result, - io::IOContext* io_ctx = nullptr); + Status range_search(const vectorized::RangeSearchParams& params, + const VectorSearchUserParams& custom_params, + vectorized::RangeSearchResult* result, io::IOContext* io_ctx = nullptr); uint64_t get_index_id() const override { return _index_meta.index_id(); } @@ -55,7 +58,7 @@ public: RuntimeState* runtime_state, std::unique_ptr<IndexIterator>* iterator) override; - VectorIndex::Metric get_metric_type() const { return _metric_type; } + Metric get_metric_type() const { return _metric_type; } private: TabletIndex _index_meta; @@ -63,7 +66,7 @@ private: std::unique_ptr<VectorIndex> _vector_index; // TODO: Use integer. std::string _index_type; - VectorIndex::Metric _metric_type; + Metric _metric_type; DorisCallOnce<Status> _load_index_once; }; diff --git a/be/src/olap/rowset/segment_v2/column_writer.cpp b/be/src/olap/rowset/segment_v2/column_writer.cpp index 3d7d4244bcc..2220924ffb4 100644 --- a/be/src/olap/rowset/segment_v2/column_writer.cpp +++ b/be/src/olap/rowset/segment_v2/column_writer.cpp @@ -517,9 +517,8 @@ Status ScalarColumnWriter::append_nulls(size_t num_rows) { return Status::OK(); } -// append data to page builder. this function will make sure that -// num_rows must be written before return. And ptr will be modified -// to next data should be written +// Appends data to the page builder, ensuring all num_rows are written. +// Advances ptr to point to the next data to be written after completion. Status ScalarColumnWriter::append_data(const uint8_t** ptr, size_t num_rows) { size_t remaining = num_rows; while (remaining > 0) { @@ -535,14 +534,6 @@ Status ScalarColumnWriter::append_data(const uint8_t** ptr, size_t num_rows) { return Status::OK(); } -Status ScalarColumnWriter::append_data_in_current_page(const uint8_t** data, size_t* num_written) { - RETURN_IF_CATCH_EXCEPTION( - { return _internal_append_data_in_current_page(*data, num_written); }); - - *data += get_field()->size() * (*num_written); - return Status::OK(); -} - Status ScalarColumnWriter::_internal_append_data_in_current_page(const uint8_t* data, size_t* num_written) { RETURN_IF_ERROR(_page_builder->add(data, num_written)); @@ -933,18 +924,19 @@ Status ArrayColumnWriter::append_data(const uint8_t** ptr, size_t num_rows) { // [size, offset_ptr, item_data_ptr, item_nullmap_ptr] auto data_ptr = reinterpret_cast<const uint64_t*>(*ptr); // total number length - size_t element_cnt = (*data_ptr); + size_t element_cnt = size_t((unsigned long)(*data_ptr)); auto offset_data = *(data_ptr + 1); const uint8_t* offsets_ptr = (const uint8_t*)offset_data; auto data = *(data_ptr + 2); auto nested_null_map = *(data_ptr + 3); + LOG_INFO("ArrayColumnWriter, element_cnt{}", element_cnt); if (element_cnt > 0) { RETURN_IF_ERROR(_item_writer->append(reinterpret_cast<const uint8_t*>(nested_null_map), reinterpret_cast<const void*>(data), element_cnt)); } if (_opts.need_inverted_index) { auto* writer = dynamic_cast<ScalarColumnWriter*>(_item_writer.get()); - // now only support nested type is scala + // Only support scalar as nested type if (writer != nullptr) { //NOTE: use array field name as index field, but item_writer size should be used when moving item_data_ptr RETURN_IF_ERROR(_inverted_index_builder->add_array_values( @@ -955,12 +947,16 @@ Status ArrayColumnWriter::append_data(const uint8_t** ptr, size_t num_rows) { if (_opts.need_ann_index) { auto* writer = dynamic_cast<ScalarColumnWriter*>(_item_writer.get()); - // now only support nested type is scala + // Only support scalar as nested type if (writer != nullptr) { //NOTE: use array field name as index field, but item_writer size should be used when moving item_data_ptr RETURN_IF_ERROR(_ann_index_builder->add_array_values( _item_writer->get_field()->size(), reinterpret_cast<const void*>(data), reinterpret_cast<const uint8_t*>(nested_null_map), offsets_ptr, num_rows)); + } else { + return Status::NotSupported( + "Ann index can only be build on array with scalar type. but got {} as nested", + _item_writer->get_field()->type()); } } @@ -968,6 +964,12 @@ Status ArrayColumnWriter::append_data(const uint8_t** ptr, size_t num_rows) { return Status::OK(); } +Status ScalarColumnWriter::append_data_in_current_page(const uint8_t** data, size_t* num_written) { + RETURN_IF_ERROR(append_data_in_current_page(*data, num_written)); + *data += get_field()->size() * (*num_written); + return Status::OK(); +} + uint64_t ArrayColumnWriter::estimate_buffer_size() { return _offset_writer->estimate_buffer_size() + (is_nullable() ? _null_writer->estimate_buffer_size() : 0) + diff --git a/be/src/olap/rowset/segment_v2/column_writer.h b/be/src/olap/rowset/segment_v2/column_writer.h index d88c9bfcb3c..6e5bec9f59b 100644 --- a/be/src/olap/rowset/segment_v2/column_writer.h +++ b/be/src/olap/rowset/segment_v2/column_writer.h @@ -165,7 +165,7 @@ public: virtual ordinal_t get_next_rowid() const = 0; - // used for append not null data. + // Append non-null data. virtual Status append_data(const uint8_t** ptr, size_t num_rows) = 0; bool is_nullable() const { return _is_nullable; } @@ -221,6 +221,10 @@ public: // used for append not null data. When page is full, will append data not reach num_rows. Status append_data_in_current_page(const uint8_t** ptr, size_t* num_written); + Status append_data_in_current_page(const uint8_t* ptr, size_t* num_written) { + RETURN_IF_CATCH_EXCEPTION( + { return _internal_append_data_in_current_page(ptr, num_written); }); + } friend class ArrayColumnWriter; friend class OffsetColumnWriter; diff --git a/be/src/olap/rowset/segment_v2/index_iterator.h b/be/src/olap/rowset/segment_v2/index_iterator.h index 0fef498340e..13c68eef161 100644 --- a/be/src/olap/rowset/segment_v2/index_iterator.h +++ b/be/src/olap/rowset/segment_v2/index_iterator.h @@ -27,13 +27,16 @@ #include "olap/rowset/segment_v2/inverted_index_query_type.h" #include "runtime/runtime_state.h" +namespace doris::vectorized { +struct AnnIndexParam; +} + namespace doris::segment_v2 { class InvertedIndexQueryCacheHandle; struct InvertedIndexParam; -struct AnnIndexParam; -using IndexParam = std::variant<InvertedIndexParam*, AnnIndexParam*>; +using IndexParam = std::variant<InvertedIndexParam*, vectorized::AnnIndexParam*>; class IndexIterator { public: diff --git a/be/src/olap/rowset/segment_v2/inverted_index/util/term_iterator.h b/be/src/olap/rowset/segment_v2/inverted_index/util/term_iterator.h index f79dd6492b8..e3ab9335350 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/util/term_iterator.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/util/term_iterator.h @@ -64,13 +64,15 @@ public: bool read_range(DocRange* docRange) const { return term_docs_->readRange(docRange); } - static TermDocs* ensure_term_doc(const io::IOContext* io_ctx, lucene::index::IndexReader* reader, + static TermDocs* ensure_term_doc(const io::IOContext* io_ctx, + lucene::index::IndexReader* reader, const std::wstring& field_name, const std::string& term) { std::wstring ws_term = StringUtil::string_to_wstring(term); return ensure_term_doc(io_ctx, reader, field_name, ws_term); } - static TermDocs* ensure_term_doc(const io::IOContext* io_ctx, lucene::index::IndexReader* reader, + static TermDocs* ensure_term_doc(const io::IOContext* io_ctx, + lucene::index::IndexReader* reader, const std::wstring& field_name, const std::wstring& ws_term) { auto* t = _CLNEW Term(field_name.c_str(), ws_term.c_str()); auto* term_pos = reader->termDocs(t, io_ctx); diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index e4b59f28875..36866f9704d 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -51,6 +51,7 @@ #include "olap/match_predicate.h" #include "olap/olap_common.h" #include "olap/primary_key_index.h" +#include "olap/rowset/segment_v2/ann_index_reader.h" #include "olap/rowset/segment_v2/bitmap_index_reader.h" #include "olap/rowset/segment_v2/column_reader.h" #include "olap/rowset/segment_v2/index_file_reader.h" @@ -566,9 +567,11 @@ Status SegmentIterator::_get_row_ranges_by_column_conditions() { } } _opts.stats->rows_inverted_index_filtered += (input_rows - _row_bitmap.cardinality()); + for (auto cid : _schema->column_ids()) { bool result_true = _check_all_conditions_passed_inverted_index_for_column(cid); - + LOG_INFO("Check all conditions passed in inverted index for column {}, result: {}", + cid, result_true); if (result_true) { _need_read_data_indices[cid] = false; } @@ -639,7 +642,7 @@ Status SegmentIterator::_apply_ann_topn_predicate() { auto index_reader = ann_index_iterator->get_reader(); auto ann_index_reader = dynamic_cast<AnnIndexReader*>(index_reader.get()); DCHECK(ann_index_reader != nullptr); - if (ann_index_reader->get_metric_type() == VectorIndex::Metric::INNER_PRODUCT) { + if (ann_index_reader->get_metric_type() == Metric::INNER_PRODUCT) { if (_ann_topn_runtime->is_asc()) { LOG_INFO("Asc topn for inner product can not be evaluated by ann index"); return Status::OK(); @@ -655,8 +658,8 @@ Status SegmentIterator::_apply_ann_topn_predicate() { LOG_INFO( "Ann topn metric type {} not match index metric type {}, can not be evaluated by " "ann index", - VectorIndex::metric_to_string(_ann_topn_runtime->get_metric_type()), - VectorIndex::metric_to_string(ann_index_reader->get_metric_type())); + metric_to_string(_ann_topn_runtime->get_metric_type()), + metric_to_string(ann_index_reader->get_metric_type())); return Status::OK(); } @@ -688,6 +691,7 @@ Status SegmentIterator::_apply_ann_topn_predicate() { result_row_ids->size()); virtual_column_iter->prepare_materialization(std::move(result_column), std::move(result_row_ids)); + return Status::OK(); } @@ -965,8 +969,8 @@ Status SegmentIterator::_apply_index_expr() { } for (const auto& expr_ctx : _common_expr_ctxs_push_down) { - RETURN_IF_ERROR(expr_ctx->root()->evaluate_ann_range_search( - _index_iterators, idx_to_cids, _column_iterators, _row_bitmap)); + RETURN_IF_ERROR(expr_ctx->evaluate_ann_range_search(_index_iterators, idx_to_cids, + _column_iterators, _row_bitmap)); } for (auto it = _common_expr_ctxs_push_down.begin(); it != _common_expr_ctxs_push_down.end();) { @@ -1081,14 +1085,16 @@ bool SegmentIterator::_need_read_data(ColumnId cid) { _opts.enable_unique_key_merge_on_write)))) { return true; } + // this is a virtual column, we always need to read data if (this->_vir_cid_to_idx_in_block.contains(cid)) { return true; } - // if there is delete predicate, we always need to read data + // if there is a delete predicate, we always need to read data if (_has_delete_predicate(cid)) { return true; } + if (_output_columns.count(-1)) { // if _output_columns contains -1, it means that the light // weight schema change may not be enabled or other reasons @@ -1104,6 +1110,7 @@ bool SegmentIterator::_need_read_data(ColumnId cid) { // If any of the above conditions are met, log a debug message indicating that there's no need to read data for the indexed column. // Then, return false. int32_t unique_id = _opts.tablet_schema->column(cid).unique_id(); + LOG_INFO("Output columns contains {} is {}", cid, _output_columns.contains(unique_id)); if ((_need_read_data_indices.contains(cid) && !_need_read_data_indices[cid] && !_output_columns.contains(unique_id)) || (_need_read_data_indices.contains(cid) && !_need_read_data_indices[cid] && @@ -1916,9 +1923,11 @@ Status SegmentIterator::_read_columns_by_index(uint32_t nrows_read_limit, uint32 auto& column = _current_return_columns[cid]; if (!_virtual_column_exprs.contains(cid)) { if (_no_need_read_key_data(cid, column, nrows_read)) { + LOG_INFO("Column {} no need to read.", cid); continue; } if (_prune_column(cid, column, true, nrows_read)) { + LOG_INFO("Column {} is pruned. No need to read data.", cid); continue; } @@ -2409,8 +2418,8 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_column_predicate)); RETURN_IF_ERROR(_convert_to_expected_type(_cols_not_included_by_any_predicates)); LOG_INFO( - "No need to evaluate any predicates or filter, output non-predicate columns, " - "block rows {}, selected size {}", + "No need to evaluate any predicates or filter block rows {}, " + "_current_batch_rows_read {}", block->rows(), _current_batch_rows_read); _output_non_pred_columns(block); } else { @@ -2772,6 +2781,7 @@ Status SegmentIterator::_construct_compound_expr_context() { _common_expr_inverted_index_status); for (const auto& expr_ctx : _opts.common_expr_ctxs_push_down) { vectorized::VExprContextSPtr context; + // _ann_range_search_runtime will do deep copy. RETURN_IF_ERROR(expr_ctx->clone(_opts.runtime_state, context)); context->set_inverted_index_context(inverted_index_context); _common_expr_ctxs_push_down.emplace_back(context); diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h b/be/src/olap/rowset/segment_v2/segment_iterator.h index 8d9336cbf18..a638605d18a 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.h +++ b/be/src/olap/rowset/segment_v2/segment_iterator.h @@ -470,9 +470,17 @@ private: std::vector<uint8_t> _ret_flags; + /* + * column and column_predicates on it. + * a boolean value to indicate whether the column has been read by the index. + */ std::unordered_map<ColumnId, std::unordered_map<ColumnPredicate*, bool>> _column_predicate_inverted_index_status; + /* + * column and common expr on it. + * a boolean value to indicate whether the column has been read by the index. + */ std::unordered_map<ColumnId, std::unordered_map<const vectorized::VExpr*, bool>> _common_expr_inverted_index_status; diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp index 624cc484717..c0f1bf078c7 100644 --- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp @@ -35,6 +35,7 @@ Status VirtualColumnIterator::init(const ColumnIteratorOptions& opts) { return Status::OK(); } +// TODO(zhiqiang): What if input is empty? void VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr column, std::unique_ptr<std::vector<uint64_t>> labels) { DCHECK(labels->size() == column->size()) << "labels size: " << labels->size() @@ -45,6 +46,11 @@ void VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr col const std::vector<uint64_t>& labels_ref = *labels; const size_t n = labels_ref.size(); LOG_INFO("Input labels {}", fmt::join(labels_ref, ", ")); + if (n == 0) { + _size = 0; + _max_ordinal = 0; + return; + } std::vector<size_t> order(n); // global_row_id_to_idx: // {5:0, 4:1, 1:2, 10:3, 7:4, 2:5} @@ -84,16 +90,6 @@ void VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr col _filter = doris::vectorized::IColumn::Filter(_size, 0); } -Status VirtualColumnIterator::seek_to_first() { - if (_size < 0) { - // _materialized_column is not set. do nothing. - return Status::OK(); - } - _current_ordinal = 0; - - return Status::OK(); -} - Status VirtualColumnIterator::seek_to_ordinal(ordinal_t ord_idx) { if (_size < 0 || vectorized::check_and_get_column<vectorized::ColumnNothing>(*_materialized_column_ptr)) { diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h index cfdd59745d8..41f4e76fd24 100644 --- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h +++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h @@ -38,8 +38,6 @@ public: Status init(const ColumnIteratorOptions& opts) override; - Status seek_to_first() override; - Status seek_to_ordinal(ordinal_t ord_idx) override; Status next_batch(size_t* n, vectorized::MutableColumnPtr& dst, bool* has_null) override; diff --git a/be/src/olap/tablet_reader.cpp b/be/src/olap/tablet_reader.cpp index 2d427c63ad1..2542b2a8c8a 100644 --- a/be/src/olap/tablet_reader.cpp +++ b/be/src/olap/tablet_reader.cpp @@ -264,7 +264,7 @@ Status TabletReader::_capture_rs_readers(const ReaderParams& read_params) { _reader_context.ann_topn_runtime = read_params.ann_topn_runtime; _reader_context.vir_cid_to_idx_in_block = read_params.vir_cid_to_idx_in_block; _reader_context.vir_col_idx_to_type = read_params.vir_col_idx_to_type; - _reader_context.output_columns = &read_params.output_columns; + _reader_context.output_columns = &read_params.output_column_unique_ids; _reader_context.push_down_agg_type_opt = read_params.push_down_agg_type_opt; _reader_context.ttl_seconds = _tablet->ttl_seconds(); diff --git a/be/src/olap/tablet_reader.h b/be/src/olap/tablet_reader.h index 4faa1b67fb4..83a84c2e701 100644 --- a/be/src/olap/tablet_reader.h +++ b/be/src/olap/tablet_reader.h @@ -153,7 +153,7 @@ public: // return_columns is init from query schema std::vector<ColumnId> return_columns; // output_columns only contain columns in OrderByExprs and outputExprs - std::set<int32_t> output_columns; + std::set<int32_t> output_column_unique_ids; RuntimeProfile* profile = nullptr; RuntimeState* runtime_state = nullptr; diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp b/be/src/pipeline/exec/olap_scan_operator.cpp index 721a340f52d..b42ef0c9fc1 100644 --- a/be/src/pipeline/exec/olap_scan_operator.cpp +++ b/be/src/pipeline/exec/olap_scan_operator.cpp @@ -317,10 +317,6 @@ bool OlapScanLocalState::_storage_no_merge() { } Status OlapScanLocalState::_init_scanners(std::list<vectorized::ScannerSPtr>* scanners) { - // auto& p = _parent->cast<OlapScanOperatorX>(); - // if (p._olap_scan_node.keyType != TKeysType::DUP_KEYS) { - // return Status::NotSupported("Now only dup keys table is supported"); - // } if (_scan_ranges.empty()) { _eos = true; _scan_dependency->set_ready(); @@ -343,9 +339,11 @@ Status OlapScanLocalState::_init_scanners(std::list<vectorized::ScannerSPtr>* sc auto& p = _parent->cast<OlapScanOperatorX>(); for (auto uid : p._olap_scan_node.output_column_unique_ids) { - _maybe_read_column_ids.emplace(uid); + _output_column_unique_ids.emplace(uid); } + LOG_INFO("Output column unique ids: {}", fmt::join(_output_column_unique_ids, ", ")); + // ranges constructed from scan keys RETURN_IF_ERROR(_scan_keys.get_key_range(&_cond_ranges)); // if we can't get ranges from conditions, we give it a total range diff --git a/be/src/pipeline/exec/olap_scan_operator.h b/be/src/pipeline/exec/olap_scan_operator.h index 7ef19f19879..77b91c0f625 100644 --- a/be/src/pipeline/exec/olap_scan_operator.h +++ b/be/src/pipeline/exec/olap_scan_operator.h @@ -97,7 +97,7 @@ private: OlapScanKeys _scan_keys; std::vector<FilterOlapParam<TCondition>> _olap_filters; // If column id in this set, indicate that we need to read data after index filtering - std::set<int32_t> _maybe_read_column_ids; + std::set<int32_t> _output_column_unique_ids; std::unique_ptr<RuntimeProfile> _segment_profile; std::unique_ptr<RuntimeProfile> _index_filter_profile; diff --git a/be/src/vec/exec/scan/olap_scanner.cpp b/be/src/vec/exec/scan/olap_scanner.cpp index 77657a45206..dc077197539 100644 --- a/be/src/vec/exec/scan/olap_scanner.cpp +++ b/be/src/vec/exec/scan/olap_scanner.cpp @@ -92,7 +92,7 @@ OlapScanner::OlapScanner(pipeline::ScanLocalStateBase* parent, OlapScanner::Para .target_cast_type_for_variants {}, .rs_splits {}, .return_columns {}, - .output_columns {}, + .output_column_unique_ids {}, .remaining_conjunct_roots {}, .common_expr_ctxs_push_down {}, .topn_filter_source_node_ids {}, @@ -133,14 +133,6 @@ static std::string read_columns_to_string(TabletSchemaSPtr tablet_schema, } Status OlapScanner::init() { - const TOlapScanNode& olap_scan_node = - _local_state->cast<pipeline::OlapScanLocalState>().olap_scan_node(); - if (olap_scan_node.__isset.keyType) { - if (olap_scan_node.keyType != TKeysType::DUP_KEYS) { - return Status::InternalError<false>("Currently only support DUP_KEYS, but got {}", - olap_scan_node.keyType); - } - } _is_init = true; auto* local_state = static_cast<pipeline::OlapScanLocalState*>(_local_state); auto& tablet = _tablet_reader_params.tablet; @@ -150,7 +142,9 @@ Status OlapScanner::init() { VExprContextSPtr context; RETURN_IF_ERROR(ctx->clone(_state, context)); _common_expr_ctxs_push_down.emplace_back(context); + LOG_INFO("Prepare ann range search."); RETURN_IF_ERROR(context->prepare_ann_range_search(_vector_search_params)); + LOG_INFO("Finish prepare ann range search, query_id={}", print_id(_state->query_id())); } for (auto pair : local_state->_slot_id_to_virtual_column_expr) { @@ -325,8 +319,8 @@ Status OlapScanner::_init_tablet_reader_params( _tablet_reader_params.ann_topn_runtime = _ann_topn_runtime; _tablet_reader_params.vir_cid_to_idx_in_block = _vir_cid_to_idx_in_block; _tablet_reader_params.vir_col_idx_to_type = _vir_col_idx_to_type; - _tablet_reader_params.output_columns = - ((pipeline::OlapScanLocalState*)_local_state)->_maybe_read_column_ids; + _tablet_reader_params.output_column_unique_ids = + ((pipeline::OlapScanLocalState*)_local_state)->_output_column_unique_ids; for (const auto& ele : ((pipeline::OlapScanLocalState*)_local_state)->_cast_types_for_variants) { _tablet_reader_params.target_cast_type_for_variants[ele.first] = diff --git a/be/src/vec/exec/scan/scanner.cpp b/be/src/vec/exec/scan/scanner.cpp index 8235be56ef5..8748461b6c0 100644 --- a/be/src/vec/exec/scan/scanner.cpp +++ b/be/src/vec/exec/scan/scanner.cpp @@ -184,6 +184,7 @@ Status Scanner::_do_projections(vectorized::Block* origin_block, vectorized::Blo for (int i = 0; i < projections.size(); i++) { RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); } + input_block.shuffle_columns(result_column_ids); } diff --git a/be/src/vec/exprs/ann_range_search_params.h b/be/src/vec/exprs/ann_range_search_params.h deleted file mode 100644 index 410c4dc14c4..00000000000 --- a/be/src/vec/exprs/ann_range_search_params.h +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <gen_cpp/Opcodes_types.h> - -#include <string> - -#include "olap/rowset/segment_v2/ann_index_iterator.h" -#include "runtime/runtime_state.h" -#include "vector/vector_index.h" - -namespace doris::vectorized { -struct RangeSearchRuntimeInfo { - bool is_ann_range_search = false; - bool is_le_or_lt = true; - size_t src_col_idx = 0; - int64_t dst_col_idx = -1; - double radius = 0.0; - segment_v2::VectorIndex::Metric metric_type; - doris::VectorSearchUserParams user_params; - std::unique_ptr<float[]> query_value; - - segment_v2::RangeSearchParams to_range_search_params() { - segment_v2::RangeSearchParams params; - params.query_value = query_value.get(); - params.radius = static_cast<float>(radius); - params.roaring = nullptr; - params.is_le_or_lt = is_le_or_lt; - return params; - } - - std::string to_string() const { - return fmt::format( - "is_ann_range_search: {}, is_le_or_lt: {}, src_col_idx: {}, " - "dst_col_idx: {}, metric_type {}, radius: {}, user params: {}", - is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx, - segment_v2::VectorIndex::metric_to_string(metric_type), radius, - user_params.to_string()); - } -}; -} // namespace doris::vectorized diff --git a/be/src/vec/exprs/ann_topn_runtime.cpp b/be/src/vec/exprs/ann_topn_runtime.cpp index 5f86a2a9241..8b229f0f859 100644 --- a/be/src/vec/exprs/ann_topn_runtime.cpp +++ b/be/src/vec/exprs/ann_topn_runtime.cpp @@ -22,13 +22,13 @@ #include <string> #include "common/logging.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/rowset/segment_v2/ann_index_iterator.h" #include "runtime/runtime_state.h" #include "vec/columns/column.h" #include "vec/columns/column_array.h" #include "vec/columns/column_const.h" #include "vec/columns/column_nullable.h" -#include "vec/columns/columns_number.h" #include "vec/common/assert_cast.h" #include "vec/exprs/varray_literal.h" #include "vec/exprs/vexpr_context.h" @@ -114,7 +114,7 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des // Strip the "_approximate" suffix metric_name = metric_name.substr(0, metric_name.size() - 12); - _metric_type = segment_v2::VectorIndex::string_to_metric(metric_name); + _metric_type = segment_v2::string_to_metric(metric_name); VLOG_DEBUG << "AnnTopNRuntime: {}" << this->debug_string(); return Status::OK(); @@ -147,7 +147,7 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann query_value_f32[i] = static_cast<float>(query_value[i]); } - segment_v2::AnnIndexParam ann_query_params { + vectorized::AnnIndexParam ann_query_params { .query_value = query_value_f32.get(), .query_value_size = query_value_size, .limit = _limit, @@ -178,7 +178,6 @@ std::string AnnTopNRuntime::debug_string() const { "AnnTopNRuntime: limit={}, src_col_idx={}, dest_col_idx={}, asc={}, user_params={}, " "metric_type={}, order_by_expr={}", _limit, _src_column_idx, _dest_column_idx, _asc, _user_params.to_string(), - segment_v2::VectorIndex::metric_to_string(_metric_type), - _order_by_expr_ctx->root()->debug_string()); + segment_v2::metric_to_string(_metric_type), _order_by_expr_ctx->root()->debug_string()); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/exprs/ann_topn_runtime.h b/be/src/vec/exprs/ann_topn_runtime.h index f270799cad2..bd3dd494907 100644 --- a/be/src/vec/exprs/ann_topn_runtime.h +++ b/be/src/vec/exprs/ann_topn_runtime.h @@ -26,6 +26,7 @@ #include "vec/exprs/vexpr_context.h" #include "vec/exprs/vexpr_fwd.h" #include "vec/exprs/vslot_ref.h" +#include "vector/metric.h" namespace doris::vectorized { @@ -44,7 +45,7 @@ public: roaring::Roaring& row_bitmap, vectorized::IColumn::MutablePtr& result_column, std::unique_ptr<std::vector<uint64_t>>& row_ids); - segment_v2::VectorIndex::Metric get_metric_type() const { return _metric_type; } + segment_v2::Metric get_metric_type() const { return _metric_type; } std::string debug_string() const; size_t get_src_column_idx() const { return _src_column_idx; } @@ -62,7 +63,7 @@ private: std::string _name = "ann_topn_runtime"; size_t _src_column_idx = -1; size_t _dest_column_idx = -1; - segment_v2::VectorIndex::Metric _metric_type; + segment_v2::Metric _metric_type; IColumn::Ptr _query_array; doris::VectorSearchUserParams _user_params; }; diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index 49bfef7bc55..34f47aafb2b 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -29,6 +29,7 @@ #include "common/logging.h" #include "common/status.h" #include "common/utils.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/rowset/segment_v2/ann_index_iterator.h" #include "olap/rowset/segment_v2/column_reader.h" #include "olap/rowset/segment_v2/index_reader.h" @@ -39,7 +40,6 @@ #include "vec/columns/column.h" #include "vec/columns/column_array.h" #include "vec/columns/column_nullable.h" -#include "vec/columns/columns_number.h" #include "vec/core/block.h" #include "vec/core/column_numbers.h" #include "vec/core/types.h" @@ -333,16 +333,21 @@ bool VectorizedFnCall::equals(const VExpr& other) { SlotRef */ -Status VectorizedFnCall::prepare_ann_range_search( - const doris::VectorSearchUserParams& user_params) { +Status VectorizedFnCall::prepare_ann_range_search(const doris::VectorSearchUserParams& user_params, + RangeSearchRuntimeInfo& range_search_runtime, + bool& suitable_for_ann_index) { + if (!suitable_for_ann_index) { + return Status::OK(); + } std::set<TExprOpcode::type> ops = {TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; if (ops.find(this->op()) == ops.end()) { + suitable_for_ann_index = false; LOG_INFO("Not a range search function."); return Status::OK(); } - _ann_range_search_params.is_le_or_lt = + range_search_runtime.is_le_or_lt = (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT); DCHECK(_children.size() == 2); @@ -353,6 +358,7 @@ Status VectorizedFnCall::prepare_ann_range_search( // Return type of L2Distance is always double. auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child); if (right_literal == nullptr) { + suitable_for_ann_index = false; LOG_INFO("Right child is not a literal."); return Status::OK(); } @@ -360,12 +366,13 @@ Status VectorizedFnCall::prepare_ann_range_search( auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const(); auto right_type = right_literal->get_data_type(); if (right_type->get_primitive_type() != PrimitiveType::TYPE_DOUBLE) { + suitable_for_ann_index = false; LOG_INFO("Right child is not a Float64Literal."); return Status::OK(); } const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get()); - _ann_range_search_params.radius = cf64_right->get_data()[0]; + range_search_runtime.radius = cf64_right->get_data()[0]; std::shared_ptr<VectorizedFnCall> function_call; auto vir_slot_ref = std::dynamic_pointer_cast<VirtualSlotRef>(left_child); @@ -378,6 +385,7 @@ Status VectorizedFnCall::prepare_ann_range_search( } if (function_call == nullptr) { + suitable_for_ann_index = false; LOG_INFO("Left child is not a function call."); return Status::OK(); } @@ -388,16 +396,17 @@ Status VectorizedFnCall::prepare_ann_range_search( if (distance_functions.find(function_call->_function_name) == distance_functions.end()) { LOG_INFO("Left child is not a approximate distance function. Got {}", function_call->_function_name); + suitable_for_ann_index = false; return Status::OK(); } else { // Strip the _approximate suffix. std::string metric_name = function_call->_function_name; metric_name = metric_name.substr(0, metric_name.size() - 12); - _ann_range_search_params.metric_type = - segment_v2::VectorIndex::string_to_metric(metric_name); + range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); } if (function_call->get_num_children() != 2) { + suitable_for_ann_index = false; return Status::OK(); } @@ -418,6 +427,7 @@ Status VectorizedFnCall::prepare_ann_range_search( function_call->get_child(idx_of_array_literal)); if (cast_to_array_expr == nullptr || array_literal == nullptr) { + suitable_for_ann_index = false; LOG_INFO("Cast to array expr or array literal is null."); return Status::OK(); } @@ -426,44 +436,47 @@ Status VectorizedFnCall::prepare_ann_range_search( std::shared_ptr<VSlotRef> slot_ref = std::dynamic_pointer_cast<VSlotRef>(cast_to_array_expr->get_child(0)); if (slot_ref == nullptr) { + suitable_for_ann_index = false; LOG_INFO("Cast to array expr's child is not a slot ref."); return Status::OK(); } - _ann_range_search_params.src_col_idx = slot_ref->column_id(); - _ann_range_search_params.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); + range_search_runtime.src_col_idx = slot_ref->column_id(); + range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); auto col_const = array_literal->get_column_ptr(); auto col_array = col_const->convert_to_full_column_if_const(); const ColumnArray* array_col = assert_cast<const ColumnArray*>(col_array.get()); DCHECK(array_col->size() == 1); size_t dim = array_col->get_offsets()[0]; - _ann_range_search_params.query_value = std::make_unique<float[]>(dim); + range_search_runtime.dim = dim; + range_search_runtime.query_value = std::make_unique<float[]>(dim); const ColumnNullable* cn = assert_cast<const ColumnNullable*>(array_col->get_data_ptr().get()); const ColumnFloat64* cf64 = assert_cast<const ColumnFloat64*>(cn->get_nested_column_ptr().get()); for (size_t i = 0; i < dim; ++i) { - _ann_range_search_params.query_value[i] = static_cast<Float32>(cf64->get_data()[i]); + range_search_runtime.query_value[i] = static_cast<Float32>(cf64->get_data()[i]); } - _ann_range_search_params.is_ann_range_search = true; - _ann_range_search_params.user_params = user_params; - LOG_INFO("Ann range search params: {}", _ann_range_search_params.to_string()); + range_search_runtime.is_ann_range_search = true; + range_search_runtime.user_params = user_params; + LOG_INFO("Ann range search params: {}", range_search_runtime.to_string()); return Status::OK(); } Status VectorizedFnCall::evaluate_ann_range_search( + const RangeSearchRuntimeInfo& range_search_runtime, const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, const std::vector<ColumnId>& idx_to_cid, const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, roaring::Roaring& row_bitmap) { - if (_ann_range_search_params.is_ann_range_search == false) { + if (range_search_runtime.is_ann_range_search == false) { return Status::OK(); } LOG_INFO("Try apply ann range search. Local search params: {}", - _ann_range_search_params.to_string()); + range_search_runtime.to_string()); size_t origin_num = row_bitmap.cardinality(); - int idx_in_block = static_cast<int>(_ann_range_search_params.src_col_idx); + int idx_in_block = static_cast<int>(range_search_runtime.src_col_idx); DCHECK(idx_in_block < idx_to_cid.size()) << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size(); @@ -488,21 +501,22 @@ Status VectorizedFnCall::evaluate_ann_range_search( DCHECK(ann_index_reader != nullptr) << "Ann index reader should not be null. Column cid: " << src_col_cid; // Check if metrics type is match. - if (ann_index_reader->get_metric_type() != _ann_range_search_params.metric_type) { + if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) { LOG_INFO("Metric type not match, can not execute range search by index."); return Status::OK(); } - RangeSearchParams params = _ann_range_search_params.to_range_search_params(); + RangeSearchParams params = range_search_runtime.to_range_search_params(); params.roaring = &row_bitmap; DCHECK(params.roaring != nullptr); + DCHECK(params.query_value != nullptr); RangeSearchResult result; - RETURN_IF_ERROR(ann_index_iterator->range_search(params, _ann_range_search_params.user_params, - &result)); + RETURN_IF_ERROR( + ann_index_iterator->range_search(params, range_search_runtime.user_params, &result)); #ifndef NDEBUG - if (this->_ann_range_search_params.is_le_or_lt == false) { + if (range_search_runtime.is_le_or_lt == false) { DCHECK(result.distance == nullptr) << "Should not have distance"; } #endif @@ -516,12 +530,12 @@ Status VectorizedFnCall::evaluate_ann_range_search( } // Process virtual column - if (_ann_range_search_params.dst_col_idx >= 0) { + if (range_search_runtime.dst_col_idx >= 0) { // Prepare materialization if we can use result from index. // Typical situation: range search and operator is LE or LT. if (result.distance != nullptr) { DCHECK(result.row_ids != nullptr); - ColumnId dst_col_cid = idx_to_cid[_ann_range_search_params.dst_col_idx]; + ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx]; DCHECK(dst_col_cid < column_iterators.size()); DCHECK(column_iterators[dst_col_cid] != nullptr); segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get(); diff --git a/be/src/vec/exprs/vectorized_fn_call.h b/be/src/vec/exprs/vectorized_fn_call.h index 14d86964ac9..59826dda047 100644 --- a/be/src/vec/exprs/vectorized_fn_call.h +++ b/be/src/vec/exprs/vectorized_fn_call.h @@ -22,11 +22,12 @@ #include <vector> #include "common/status.h" +#include "olap/rowset/segment_v2/ann_index/range_search_runtime_info.h" #include "runtime/runtime_state.h" #include "udf/udf.h" #include "vec/core/column_numbers.h" -#include "vec/exprs/ann_range_search_params.h" #include "vec/exprs/vexpr.h" +#include "vec/exprs/vexpr_context.h" #include "vec/exprs/vliteral.h" #include "vec/exprs/vslot_ref.h" #include "vec/functions/function.h" @@ -79,18 +80,20 @@ public: size_t estimate_memory(const size_t rows) override; Status evaluate_ann_range_search( + const RangeSearchRuntimeInfo& runtime, const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, const std::vector<ColumnId>& idx_to_cid, const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, roaring::Roaring& row_bitmap) override; - Status prepare_ann_range_search(const doris::VectorSearchUserParams& params) override; + Status prepare_ann_range_search(const doris::VectorSearchUserParams& params, + RangeSearchRuntimeInfo& runtime, + bool& suitable_for_ann_index) override; protected: FunctionBasePtr _function; std::string _expr_name; std::string _function_name; - RangeSearchRuntimeInfo _ann_range_search_params; private: Status _do_execute(doris::vectorized::VExprContext* context, doris::vectorized::Block* block, diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp index 5d7004140ff..ab6a9e1eac0 100644 --- a/be/src/vec/exprs/vexpr.cpp +++ b/be/src/vec/exprs/vexpr.cpp @@ -804,6 +804,7 @@ bool VExpr::equals(const VExpr& other) { } Status VExpr::evaluate_ann_range_search( + const RangeSearchRuntimeInfo& runtime, const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& index_iterators, const std::vector<ColumnId>& idx_to_cid, const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, @@ -811,9 +812,18 @@ Status VExpr::evaluate_ann_range_search( return Status::OK(); } -Status VExpr::prepare_ann_range_search(const doris::VectorSearchUserParams& params) { +Status VExpr::prepare_ann_range_search(const doris::VectorSearchUserParams& params, + RangeSearchRuntimeInfo& range_search_runtime, + bool& suitable_for_ann_index) { + if (!suitable_for_ann_index) { + return Status::OK(); + } for (auto& child : _children) { - RETURN_IF_ERROR(child->prepare_ann_range_search(params)); + RETURN_IF_ERROR(child->prepare_ann_range_search(params, range_search_runtime, + suitable_for_ann_index)); + if (!suitable_for_ann_index) { + return Status::OK(); + } } return Status::OK(); } diff --git a/be/src/vec/exprs/vexpr.h b/be/src/vec/exprs/vexpr.h index cb02ccb6e83..ce2d3aba083 100644 --- a/be/src/vec/exprs/vexpr.h +++ b/be/src/vec/exprs/vexpr.h @@ -47,6 +47,7 @@ #include "vec/core/wide_integer.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_ipv6.h" +#include "vec/exprs/vexpr_context.h" #include "vec/exprs/vexpr_fwd.h" #include "vec/functions/function.h" @@ -65,7 +66,7 @@ class ColumnIterator; namespace vectorized { #include "common/compile_check_begin.h" - +struct RangeSearchRuntimeInfo; #define RETURN_IF_ERROR_OR_PREPARED(stmt) \ if (_prepared) { \ return Status::OK(); \ @@ -279,12 +280,19 @@ public: void set_node_type(TExprNodeType::type node_type) { _node_type = node_type; } #endif virtual Status evaluate_ann_range_search( + const RangeSearchRuntimeInfo& runtime, const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, const std::vector<ColumnId>& idx_to_cid, const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, roaring::Roaring& row_bitmap); - virtual Status prepare_ann_range_search(const doris::VectorSearchUserParams& params); + // Prepare the runtime for ANN range search. + // RangeSearchRuntimeInfo is used to store the runtime information of ann range search. + // suitable_for_ann_index is used to indicate whether the current expr can be used for ANN range search. + // If suitable_for_ann_index is false, the we will do exhausted search. + virtual Status prepare_ann_range_search(const doris::VectorSearchUserParams& params, + RangeSearchRuntimeInfo& range_search_runtime, + bool& suitable_for_ann_index); bool has_been_executed(); diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp index 886dea256c5..fc027538494 100644 --- a/be/src/vec/exprs/vexpr_context.cpp +++ b/be/src/vec/exprs/vexpr_context.cpp @@ -117,6 +117,9 @@ Status VExprContext::clone(RuntimeState* state, VExprContextSPtr& new_ctx) { new_ctx->_is_clone = true; new_ctx->_prepared = true; new_ctx->_opened = true; + // RangeSearchRuntimeInfo should be cloned as well. + // The object of RangeSearchRuntimeInfo is not shared by threads. + new_ctx->_ann_range_search_runtime = this->_ann_range_search_runtime; return _root->open(state, new_ctx.get(), FunctionContext::THREAD_LOCAL); } @@ -436,7 +439,24 @@ Status VExprContext::prepare_ann_range_search(const doris::VectorSearchUserParam if (_root == nullptr) { return Status::OK(); } - return _root->prepare_ann_range_search(params); + + RETURN_IF_ERROR(_root->prepare_ann_range_search(params, _ann_range_search_runtime, + _suitable_for_ann_index)); + LOG_INFO("Prepare ann range search result {}, _suitable_for_ann_index {}", + this->_ann_range_search_runtime.to_string(), this->_suitable_for_ann_index); + return Status::OK(); +} + +Status VExprContext::evaluate_ann_range_search( + const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, + const std::vector<ColumnId>& idx_to_cid, + const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, + roaring::Roaring& row_bitmap) { + if (_root != nullptr) { + return _root->evaluate_ann_range_search(_ann_range_search_runtime, cid_to_index_iterators, + idx_to_cid, column_iterators, row_bitmap); + } + return Status::OK(); } #include "common/compile_check_end.h" diff --git a/be/src/vec/exprs/vexpr_context.h b/be/src/vec/exprs/vexpr_context.h index d43012353e5..abd1eb62d18 100644 --- a/be/src/vec/exprs/vexpr_context.h +++ b/be/src/vec/exprs/vexpr_context.h @@ -27,6 +27,8 @@ #include "common/factory_creator.h" #include "common/status.h" +#include "olap/rowset/segment_v2/ann_index/range_search_runtime_info.h" +#include "olap/rowset/segment_v2/column_reader.h" #include "olap/rowset/segment_v2/inverted_index_reader.h" #include "runtime/runtime_state.h" #include "runtime/types.h" @@ -39,6 +41,10 @@ class RowDescriptor; class RuntimeState; } // namespace doris +namespace doris::segment_v2 { +class ColumnIterator; +} // namespace doris::segment_v2 + namespace doris::vectorized { class InvertedIndexContext { @@ -282,6 +288,12 @@ public: Status prepare_ann_range_search(const doris::VectorSearchUserParams& params); + Status evaluate_ann_range_search( + const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, + const std::vector<ColumnId>& idx_to_cid, + const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, + roaring::Roaring& row_bitmap); + private: // Close method is called in vexpr context dector, not need call expicility void close(); @@ -315,5 +327,8 @@ private: std::shared_ptr<InvertedIndexContext> _inverted_index_context; size_t _memory_usage = 0; + + RangeSearchRuntimeInfo _ann_range_search_runtime; + bool _suitable_for_ann_index = true; }; } // namespace doris::vectorized diff --git a/be/src/vec/exprs/virtual_slot_ref.cpp b/be/src/vec/exprs/virtual_slot_ref.cpp index 844da622f31..6d320b2b96f 100644 --- a/be/src/vec/exprs/virtual_slot_ref.cpp +++ b/be/src/vec/exprs/virtual_slot_ref.cpp @@ -207,12 +207,14 @@ bool VirtualSlotRef::equals(const VExpr& other) { } Status VirtualSlotRef::evaluate_ann_range_search( + const RangeSearchRuntimeInfo& range_search_runtime, const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, const std::vector<ColumnId>& idx_to_cid, const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, roaring::Roaring& row_bitmap) { if (_virtual_column_expr != nullptr) { - return _virtual_column_expr->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, + return _virtual_column_expr->evaluate_ann_range_search(range_search_runtime, + cid_to_index_iterators, idx_to_cid, column_iterators, row_bitmap); } return Status::OK(); diff --git a/be/src/vec/exprs/virtual_slot_ref.h b/be/src/vec/exprs/virtual_slot_ref.h index 5a45267082b..57f601d8c18 100644 --- a/be/src/vec/exprs/virtual_slot_ref.h +++ b/be/src/vec/exprs/virtual_slot_ref.h @@ -74,6 +74,7 @@ public: SlotRef ArrayLiteral */ Status evaluate_ann_range_search( + const RangeSearchRuntimeInfo& range_search_runtime, const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, const std::vector<ColumnId>& idx_to_cid, const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, diff --git a/be/src/vec/functions/array/function_array_distance_approximate.h b/be/src/vec/functions/array/function_array_distance_approximate.h index 5e4415f0243..85b2b8d4fc8 100644 --- a/be/src/vec/functions/array/function_array_distance_approximate.h +++ b/be/src/vec/functions/array/function_array_distance_approximate.h @@ -19,7 +19,6 @@ #include "vec/columns/column.h" #include "vec/columns/column_array.h" -#include "vec/columns/columns_number.h" #include "vec/common/assert_cast.h" #include "vec/core/types.h" #include "vec/data_types/data_type.h" @@ -37,13 +36,9 @@ public: static constexpr auto name = "l2_distance_approximate"; struct State { double sum = 0; - size_t count = 0; }; - static void accumulate(State& state, double x, double y) { - state.sum += (x - y) * (x - y); - state.count++; - } - static double finalize(const State& state) { return sqrt(state.sum / state.count); } + static void accumulate(State& state, double x, double y) { state.sum += (x - y) * (x - y); } + static double finalize(const State& state) { return sqrt(state.sum); } }; class InnerProductApproximate { diff --git a/be/src/vec/runtime/vector_search_user_params.h b/be/src/vec/runtime/vector_search_user_params.h index 5f886405e06..d07cd90458d 100644 --- a/be/src/vec/runtime/vector_search_user_params.h +++ b/be/src/vec/runtime/vector_search_user_params.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include <string> namespace doris { diff --git a/be/src/vector/CMakeLists.txt b/be/src/vector/CMakeLists.txt index 646cc874d9a..dae87fedf7a 100644 --- a/be/src/vector/CMakeLists.txt +++ b/be/src/vector/CMakeLists.txt @@ -29,6 +29,7 @@ if (BUILD_FAISS) list(APPEND VECTOR_LIB_SRC faiss_vector_index.h faiss_vector_index.cpp + metric.cpp ) list(APPEND VECTOR_LIB_DEPENDENCIES faiss) endif() diff --git a/be/src/vector/faiss_vector_index.cpp b/be/src/vector/faiss_vector_index.cpp index f48c71334c5..d9eec3d78d1 100644 --- a/be/src/vector/faiss_vector_index.cpp +++ b/be/src/vector/faiss_vector_index.cpp @@ -31,6 +31,7 @@ #include "common/status.h" #include "faiss/IndexHNSW.h" #include "faiss/impl/io.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "vector/vector_index.h" namespace doris::segment_v2 { @@ -179,8 +180,8 @@ void FaissVectorIndex::set_build_params(const FaissBuildParameter& params) { // TODO: Support batch search doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k, - const IndexSearchParameters& params, - IndexSearchResult& result) { + const vectorized::IndexSearchParameters& params, + vectorized::IndexSearchResult& result) { std::unique_ptr<float[]> distances_ptr = std::make_unique<float[]>(k); float* distances = distances_ptr.get(); @@ -197,8 +198,8 @@ doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k, std::unique_ptr<faiss::IDSelector> id_sel = nullptr; id_sel = roaring_to_faiss_selector(*params.roaring); faiss::SearchParametersHNSW param; - const HNSWSearchParameters* hnsw_params = - dynamic_cast<const HNSWSearchParameters*>(¶ms); + const vectorized::HNSWSearchParameters* hnsw_params = + dynamic_cast<const vectorized::HNSWSearchParameters*>(¶ms); if (hnsw_params == nullptr) { return doris::Status::InvalidArgument( "HNSW search parameters should not be null for HNSW index"); @@ -213,13 +214,13 @@ doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k, result.roaring = std::make_shared<roaring::Roaring>(); update_roaring(labels, k, *result.roaring); - result.distances = std::move(distances_ptr); - + size_t roaring_cardinality = result.roaring->cardinality(); + result.distances = std::make_unique<float[]>(roaring_cardinality); result.row_ids = std::make_unique<std::vector<uint64_t>>(); - for (size_t i = 0; i < k; ++i) { - if (labels[i] >= 0) { - result.row_ids->push_back(labels[i]); - } + + for (size_t i = 0; i < roaring_cardinality; ++i) { + result.row_ids->push_back(labels[i]); + result.distances[i] = std::sqrt(distances[i]); // Convert squared distance to actual distance } DCHECK(result.row_ids->size() == result.roaring->cardinality()) @@ -229,15 +230,17 @@ doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k, } doris::Status FaissVectorIndex::range_search(const float* query_vec, const float& radius, - const IndexSearchParameters& params, - IndexSearchResult& result) { + const vectorized::IndexSearchParameters& params, + vectorized::IndexSearchResult& result) { DCHECK(_index != nullptr); + DCHECK(query_vec != nullptr); std::unique_ptr<faiss::IDSelector> sel = nullptr; if (params.roaring != nullptr) { sel = roaring_to_faiss_selector(*params.roaring); } faiss::RangeSearchResult native_search_result(1, true); - const HNSWSearchParameters* hnsw_params = dynamic_cast<const HNSWSearchParameters*>(¶ms); + const vectorized::HNSWSearchParameters* hnsw_params = + dynamic_cast<const vectorized::HNSWSearchParameters*>(¶ms); if (hnsw_params != nullptr) { faiss::SearchParametersHNSW param; param.efSearch = hnsw_params->ef_search; @@ -263,6 +266,7 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float for (size_t i = begin; i < end; ++i) { (*row_ids)[i] = native_search_result.labels[i]; roaring->add(native_search_result.labels[i]); + // TODO: l2_distance and inner_product is different. distances[i] = sqrt(native_search_result.distances[i]); } diff --git a/be/src/vector/faiss_vector_index.h b/be/src/vector/faiss_vector_index.h index e7251cf52c3..7bf2831a912 100644 --- a/be/src/vector/faiss_vector_index.h +++ b/be/src/vector/faiss_vector_index.h @@ -28,6 +28,11 @@ #include "common/status.h" #include "vector_index.h" +namespace doris::vectorized { +struct IndexSearchParameters; +struct IndexSearchResult; +} // namespace doris::vectorized + namespace doris::segment_v2 { struct FaissBuildParameter { enum class IndexType { BruteForce, IVF, HNSW }; @@ -93,12 +98,12 @@ public: void set_build_params(const FaissBuildParameter& params); doris::Status ann_topn_search(const float* query_vec, int k, - const IndexSearchParameters& params, - IndexSearchResult& result) override; + const vectorized::IndexSearchParameters& params, + vectorized::IndexSearchResult& result) override; doris::Status range_search(const float* query_vec, const float& radius, - const IndexSearchParameters& params, - IndexSearchResult& result) override; + const vectorized::IndexSearchParameters& params, + vectorized::IndexSearchResult& result) override; doris::Status save(lucene::store::Directory*) override; diff --git a/be/src/vec/runtime/vector_search_user_params.h b/be/src/vector/metric.cpp similarity index 54% copy from be/src/vec/runtime/vector_search_user_params.h copy to be/src/vector/metric.cpp index 5f886405e06..9d6a415111e 100644 --- a/be/src/vec/runtime/vector_search_user_params.h +++ b/be/src/vector/metric.cpp @@ -15,17 +15,33 @@ // specific language governing permissions and limitations // under the License. +#include "vector/metric.h" + #include <string> -namespace doris { -// Constructed from session variables. -struct VectorSearchUserParams { - int hnsw_ef_search = 16; - bool hnsw_check_relative_distance = true; - bool hnsw_bounded_queue = true; +#include "vec/functions/array/function_array_distance.h" + +namespace doris::segment_v2 { + +std::string metric_to_string(Metric metric) { + switch (metric) { + case Metric::L2: + return vectorized::L2Distance::name; + case Metric::INNER_PRODUCT: + return vectorized::InnerProduct::name; + default: + return "UNKNOWN"; + } +} - bool operator==(const VectorSearchUserParams& other) const; +Metric string_to_metric(const std::string& metric) { + if (metric == vectorized::L2Distance::name) { + return Metric::L2; + } else if (metric == vectorized::InnerProduct::name) { + return Metric::INNER_PRODUCT; + } else { + return Metric::UNKNOWN; + } +} - std::string to_string() const; -}; -} // namespace doris \ No newline at end of file +} // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/vec/runtime/vector_search_user_params.h b/be/src/vector/metric.h similarity index 71% copy from be/src/vec/runtime/vector_search_user_params.h copy to be/src/vector/metric.h index 5f886405e06..b6c95bed675 100644 --- a/be/src/vec/runtime/vector_search_user_params.h +++ b/be/src/vector/metric.h @@ -15,17 +15,15 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include <string> -namespace doris { -// Constructed from session variables. -struct VectorSearchUserParams { - int hnsw_ef_search = 16; - bool hnsw_check_relative_distance = true; - bool hnsw_bounded_queue = true; +namespace doris::segment_v2 { +enum class Metric { L2, INNER_PRODUCT, UNKNOWN }; + +std::string metric_to_string(Metric metric); - bool operator==(const VectorSearchUserParams& other) const; +Metric string_to_metric(const std::string& metric); - std::string to_string() const; -}; -} // namespace doris \ No newline at end of file +} // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/vector/vector_index.h b/be/src/vector/vector_index.h index cc2a2c3b565..00e0904fdeb 100644 --- a/be/src/vector/vector_index.h +++ b/be/src/vector/vector_index.h @@ -22,62 +22,20 @@ #include "common/status.h" #include "vec/functions/array/function_array_distance.h" +#include "vector/metric.h" namespace lucene::store { class Directory; } -namespace doris::segment_v2 { -/* -This struct is used to wrap the search result of a vector index. -roaring is a bitmap that contains the row ids that satisfy the search condition. -row_ids is a vector of row ids that are returned by the search, it could be used by virtual_column_iterator to do column filter. -distances is a vector of distances that are returned by the search. -For range search, is condition is not le_or_lt, the row_ids and distances will be nullptr. -*/ -struct IndexSearchResult { - IndexSearchResult() = default; - - std::unique_ptr<float[]> distances = nullptr; - std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr; - std::shared_ptr<roaring::Roaring> roaring = nullptr; -}; - -struct IndexSearchParameters { - roaring::Roaring* roaring = nullptr; - bool is_le_or_lt = true; - virtual ~IndexSearchParameters() = default; -}; -struct HNSWSearchParameters : public IndexSearchParameters { - int ef_search = 16; - bool check_relative_distance = true; - bool bounded_queue = true; -}; +namespace doris::vectorized { +struct IndexSearchParameters; +struct IndexSearchResult; +} // namespace doris::vectorized +namespace doris::segment_v2 { class VectorIndex { public: - enum class Metric { L2, INNER_PRODUCT, UNKNOWN }; - - static std::string metric_to_string(Metric metric) { - switch (metric) { - case Metric::L2: - return vectorized::L2Distance::name; - case Metric::INNER_PRODUCT: - return vectorized::InnerProduct::name; - default: - return "UNKNOWN"; - } - } - static Metric string_to_metric(const std::string& metric) { - if (metric == vectorized::L2Distance::name) { - return Metric::L2; - } else if (metric == vectorized::InnerProduct::name) { - return Metric::INNER_PRODUCT; - } else { - return Metric::UNKNOWN; - } - } - virtual ~VectorIndex() = default; /** Add n vectors of dimension d vectors to the index. @@ -99,8 +57,8 @@ public: * @return status of the operation */ virtual doris::Status ann_topn_search(const float* query_vec, int k, - const IndexSearchParameters& params, - IndexSearchResult& result) = 0; + const vectorized::IndexSearchParameters& params, + vectorized::IndexSearchResult& result) = 0; /** * Search for the nearest neighbors of a query vector within a given radius. * @param query_vec input vector, size d @@ -109,8 +67,8 @@ public: * @return status of the operation */ virtual doris::Status range_search(const float* query_vec, const float& radius, - const IndexSearchParameters& params, - IndexSearchResult& result) = 0; + const vectorized::IndexSearchParameters& params, + vectorized::IndexSearchResult& result) = 0; virtual doris::Status save(lucene::store::Directory*) = 0; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java index 7573eb8ade2..adb7f0b5f2b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java @@ -71,13 +71,16 @@ public class PushDownVectorTopNIntoOlapScan implements RewriteRuleFactory { LogicalProject<?> project, LogicalOlapScan scan, Optional<LogicalFilter<?>> optionalFilter) { + // Retrives the expression used for ordering in the TopN. Expression orderKey = topN.getOrderKeys().get(0).getExpr(); + // The order key must be a SlotReference corresponding to an expr. if (!(orderKey instanceof SlotReference)) { return null; } SlotReference keySlot = (SlotReference) orderKey; Expression orderKeyExpr = null; Alias orderKeyAlias = null; + // Find the corresponding expression in the project that matches the keySlot. for (NamedExpression projection : project.getProjects()) { if (projection.toSlot().equals(keySlot) && projection instanceof Alias) { orderKeyExpr = ((Alias) projection).child(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index d74851dec3c..d203134e7c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -124,16 +124,7 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter<DeepCopierContext .collect(ImmutableList.toImmutableList()); newRelation = newRelation.withVirtualColumns(virtualColumns); context.putRelation(catalogRelation.getRelationId(), newRelation); - return newRelation; - } - - @Override - public Plan visitLogicalCatalogRelation(LogicalCatalogRelation relation, DeepCopierContext context) { - if (context.getRelationReplaceMap().containsKey(relation.getRelationId())) { - return context.getRelationReplaceMap().get(relation.getRelationId()); - } - LogicalCatalogRelation newRelation = (LogicalCatalogRelation) visitLogicalRelation(relation, context); - return updateOperativeSlots(relation, newRelation); + return updateOperativeSlots(catalogRelation, newRelation); } @Override --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org