This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new a266025aaae branch-4.0: [feat](ann index) support ivf index #58130
(#58628)
a266025aaae is described below
commit a266025aaae1a2dce7e10b11ac851cca5d0ea9d0
Author: Jack <[email protected]>
AuthorDate: Wed Dec 3 09:46:51 2025 +0800
branch-4.0: [feat](ann index) support ivf index #58130 (#58628)
cherry pick from #58130 and #58588
---------
Co-authored-by: ivin <[email protected]>
Co-authored-by: Chen768959 <[email protected]>
---
.../olap/rowset/segment_v2/ann_index/ann_index.cpp | 6 +-
.../olap/rowset/segment_v2/ann_index/ann_index.h | 7 +-
.../segment_v2/ann_index/ann_index_reader.cpp | 18 +-
.../segment_v2/ann_index/ann_index_writer.cpp | 1 +
.../rowset/segment_v2/ann_index/ann_index_writer.h | 1 +
.../segment_v2/ann_index/ann_search_params.h | 6 +-
.../segment_v2/ann_index/faiss_ann_index.cpp | 145 ++++++--
.../rowset/segment_v2/ann_index/faiss_ann_index.h | 10 +-
be/src/runtime/runtime_state.h | 2 +-
be/src/vec/runtime/vector_search_user_params.cpp | 8 +-
be/src/vec/runtime/vector_search_user_params.h | 3 +-
.../olap/vector_search/ann_index_reader_test.cpp | 200 ++++++++++-
.../olap/vector_search/ann_index_writer_test.cpp | 179 +++++++++
.../olap/vector_search/faiss_vector_index_test.cpp | 398 +++++++++++++++++----
be/test/olap/vector_search/vector_search_utils.cpp | 8 +-
.../doris/analysis/AnnIndexPropertiesChecker.java | 23 +-
.../java/org/apache/doris/qe/SessionVariable.java | 7 +
gensrc/thrift/PaloInternalService.thrift | 2 +
.../data/ann_index_p0/ivf_index_test.out | 17 +
.../ann_index_p0/create_ann_index_test.groovy | 4 +-
.../create_tbl_with_ann_index_test.groovy | 6 +-
.../suites/ann_index_p0/ivf_index_test.groovy | 126 +++++++
22 files changed, 1062 insertions(+), 115 deletions(-)
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index.cpp
b/be/src/olap/rowset/segment_v2/ann_index/ann_index.cpp
index 19e19a76458..860e0d328bf 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index.cpp
@@ -49,6 +49,8 @@ std::string ann_index_type_to_string(AnnIndexType type) {
return "unknown";
case AnnIndexType::HNSW:
return "hnsw";
+ case AnnIndexType::IVF:
+ return "ivf";
default:
return "unknown";
}
@@ -57,6 +59,8 @@ std::string ann_index_type_to_string(AnnIndexType type) {
AnnIndexType string_to_ann_index_type(const std::string& type) {
if (type == "hnsw") {
return AnnIndexType::HNSW;
+ } else if (type == "ivf") {
+ return AnnIndexType::IVF;
} else {
return AnnIndexType::UNKNOWN;
}
@@ -70,4 +74,4 @@ VectorIndex::~VectorIndex() {
DorisMetrics::instance()->ann_index_in_memory_cnt->increment(-1);
}
-} // namespace doris::segment_v2
\ No newline at end of file
+} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_index.h
index 1718e95d162..7d785e9b3e1 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index.h
@@ -55,7 +55,7 @@ std::string metric_to_string(AnnIndexMetric metric);
AnnIndexMetric string_to_metric(const std::string& metric);
-enum class AnnIndexType { UNKNOWN, HNSW };
+enum class AnnIndexType { UNKNOWN, HNSW, IVF };
std::string ann_index_type_to_string(AnnIndexType type);
@@ -119,10 +119,13 @@ public:
void set_metric(AnnIndexMetric metric) { _metric = metric; }
+ void set_type(AnnIndexType type) { _index_type = type; }
+
protected:
// When adding vectors to the index, use this variable to check the
dimension of the vectors.
size_t _dimension = 0;
- AnnIndexMetric _metric = AnnIndexMetric::L2; // Default metric is L2
distance
+ AnnIndexMetric _metric = AnnIndexMetric::L2; // Default metric is L2
distance
+ AnnIndexType _index_type = AnnIndexType::HNSW; // Default index type is
hnsw
};
#include "common/compile_check_end.h"
} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
index f993804c072..ffdd9bf58e7 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
@@ -84,6 +84,7 @@ Status AnnIndexReader::load_index(io::IOContext* io_ctx) {
}
_vector_index = std::make_unique<FaissVectorIndex>();
_vector_index->set_metric(_metric_type);
+ _vector_index->set_type(_index_type);
RETURN_IF_ERROR(_vector_index->load(compound_dir->get()));
} catch (CLuceneError& err) {
return Status::Error<ErrorCode::INVERTED_INDEX_CLUCENE_ERROR>(
@@ -124,6 +125,17 @@ Status AnnIndexReader::query(io::IOContext* io_ctx,
AnnTopNParam* param, AnnInde
stats->engine_search_ns.update(index_search_result.engine_search_ns);
stats->engine_convert_ns.update(index_search_result.engine_convert_ns);
stats->engine_prepare_ns.update(index_search_result.engine_prepare_ns);
+ } else if (_index_type == AnnIndexType::IVF) {
+ IVFSearchParameters ivf_search_params;
+ ivf_search_params.roaring = param->roaring;
+ ivf_search_params.rows_of_segment = param->rows_of_segment;
+ ivf_search_params.nprobe = param->_user_params.ivf_nprobe;
+ RETURN_IF_ERROR(_vector_index->ann_topn_search(query_vec, limit,
ivf_search_params,
+
index_search_result));
+ // Accumulate detailed engine timings
+
stats->engine_search_ns.update(index_search_result.engine_search_ns);
+
stats->engine_convert_ns.update(index_search_result.engine_convert_ns);
+
stats->engine_prepare_ns.update(index_search_result.engine_prepare_ns);
} else {
throw Exception(Status::NotSupported("Unsupported index type: {}",
ann_index_type_to_string(_index_type)));
@@ -173,6 +185,10 @@ Status AnnIndexReader::range_search(const
AnnRangeSearchParams& params,
hnsw_param->check_relative_distance =
custom_params.hnsw_check_relative_distance;
hnsw_param->bounded_queue = custom_params.hnsw_bounded_queue;
search_param = std::move(hnsw_param);
+ } else if (_index_type == AnnIndexType::IVF) {
+ auto ivf_param =
std::make_unique<segment_v2::IVFSearchParameters>();
+ ivf_param->nprobe = custom_params.ivf_nprobe;
+ search_param = std::move(ivf_param);
} else {
throw Exception(Status::NotSupported("Unsupported index type: {}",
ann_index_type_to_string(_index_type)));
@@ -233,4 +249,4 @@ size_t AnnIndexReader::get_dimension() const {
return _dim;
}
-} // namespace doris::segment_v2
\ No newline at end of file
+} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.cpp
b/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.cpp
index e8a56a891d0..562565d565b 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.cpp
@@ -63,6 +63,7 @@ Status AnnIndexColumnWriter::init() {
build_parameter.max_degree = std::stoi(get_or_default(properties,
MAX_DEGREE, "32"));
build_parameter.metric_type =
FaissBuildParameter::string_to_metric_type(metric_type);
build_parameter.ef_construction = std::stoi(get_or_default(properties,
EF_CONSTRUCTION, "40"));
+ build_parameter.ivf_nlist = std::stoi(get_or_default(properties, NLIST,
"1024"));
build_parameter.quantizer =
FaissBuildParameter::string_to_quantizer(quantizer);
build_parameter.pq_m = std::stoi(get_or_default(properties, PQ_M, "8"));
build_parameter.pq_nbits = std::stoi(get_or_default(properties, PQ_NBITS,
"8"));
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.h
index fbfa6e3c8b2..0fb7ef11706 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_writer.h
@@ -50,6 +50,7 @@ public:
static constexpr const char* DIM = "dim";
static constexpr const char* MAX_DEGREE = "max_degree";
static constexpr const char* EF_CONSTRUCTION = "ef_construction";
+ static constexpr const char* NLIST = "nlist";
static constexpr const char* QUANTIZER = "quantizer";
static constexpr const char* PQ_M = "pq_m";
static constexpr const char* PQ_NBITS = "pq_nbits";
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
index b2d9c758659..671bb6dc4ae 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
@@ -142,5 +142,9 @@ struct HNSWSearchParameters : public IndexSearchParameters {
bool check_relative_distance = true;
bool bounded_queue = true;
};
+
+struct IVFSearchParameters : public IndexSearchParameters {
+ int nprobe = 1;
+};
#include "common/compile_check_end.h"
-} // namespace doris::segment_v2
\ No newline at end of file
+} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp
b/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp
index 5c29384b876..f01a9c23c78 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp
@@ -38,7 +38,12 @@
#include "common/logging.h"
#include "common/status.h"
#include "faiss/Index.h"
+#include "faiss/IndexFlat.h"
#include "faiss/IndexHNSW.h"
+#include "faiss/IndexIVF.h"
+#include "faiss/IndexIVFFlat.h"
+#include "faiss/IndexIVFPQ.h"
+#include "faiss/IndexScalarQuantizer.h"
#include "faiss/MetricType.h"
#include "faiss/impl/FaissException.h"
#include "faiss/impl/IDSelector.h"
@@ -294,6 +299,7 @@ void FaissVectorIndex::build(const FaissBuildParameter&
params) {
}
if (params.index_type == FaissBuildParameter::IndexType::HNSW) {
+ set_type(AnnIndexType::HNSW);
std::unique_ptr<faiss::IndexHNSW> hnsw_index;
if (params.quantizer == FaissBuildParameter::Quantizer::SQ4) {
if (params.metric_type == FaissBuildParameter::MetricType::L2) {
@@ -340,6 +346,62 @@ void FaissVectorIndex::build(const FaissBuildParameter&
params) {
hnsw_index->hnsw.efConstruction = params.ef_construction;
_index = std::move(hnsw_index);
+ } else if (params.index_type == FaissBuildParameter::IndexType::IVF) {
+ set_type(AnnIndexType::IVF);
+ std::unique_ptr<faiss::Index> ivf_index;
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ _quantizer = std::make_unique<faiss::IndexFlat>(params.dim,
faiss::METRIC_L2);
+ } else {
+ _quantizer =
+ std::make_unique<faiss::IndexFlat>(params.dim,
faiss::METRIC_INNER_PRODUCT);
+ }
+
+ if (params.quantizer == FaissBuildParameter::Quantizer::FLAT) {
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ ivf_index = std::make_unique<faiss::IndexIVFFlat>(
+ _quantizer.get(), params.dim, params.ivf_nlist,
faiss::METRIC_L2);
+ } else {
+ ivf_index =
std::make_unique<faiss::IndexIVFFlat>(_quantizer.get(), params.dim,
+
params.ivf_nlist,
+
faiss::METRIC_INNER_PRODUCT);
+ }
+ } else if (params.quantizer == FaissBuildParameter::Quantizer::SQ4) {
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ ivf_index = std::make_unique<faiss::IndexIVFScalarQuantizer>(
+ _quantizer.get(), params.dim, params.ivf_nlist,
+ faiss::ScalarQuantizer::QT_4bit, faiss::METRIC_L2);
+ } else {
+ ivf_index = std::make_unique<faiss::IndexIVFScalarQuantizer>(
+ _quantizer.get(), params.dim, params.ivf_nlist,
+ faiss::ScalarQuantizer::QT_4bit,
faiss::METRIC_INNER_PRODUCT);
+ }
+ } else if (params.quantizer == FaissBuildParameter::Quantizer::SQ8) {
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ ivf_index = std::make_unique<faiss::IndexIVFScalarQuantizer>(
+ _quantizer.get(), params.dim, params.ivf_nlist,
+ faiss::ScalarQuantizer::QT_8bit, faiss::METRIC_L2);
+ } else {
+ ivf_index = std::make_unique<faiss::IndexIVFScalarQuantizer>(
+ _quantizer.get(), params.dim, params.ivf_nlist,
+ faiss::ScalarQuantizer::QT_8bit,
faiss::METRIC_INNER_PRODUCT);
+ }
+ } else if (params.quantizer == FaissBuildParameter::Quantizer::PQ) {
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ ivf_index =
std::make_unique<faiss::IndexIVFPQ>(_quantizer.get(), params.dim,
+
params.ivf_nlist, params.pq_m,
+
params.pq_nbits, faiss::METRIC_L2);
+ } else {
+ ivf_index = std::make_unique<faiss::IndexIVFPQ>(
+ _quantizer.get(), params.dim, params.ivf_nlist,
params.pq_m,
+ params.pq_nbits, faiss::METRIC_INNER_PRODUCT);
+ }
+ } else {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "Unsupported quantizer for IVF: {}",
+ static_cast<int>(params.quantizer));
+ }
+
+ _index = std::move(ivf_index);
} else {
throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
"Unsupported index type: {}",
static_cast<int>(params.index_type));
@@ -365,29 +427,44 @@ doris::Status FaissVectorIndex::ann_topn_search(const
float* query_vec, int k,
DCHECK(params.roaring != nullptr)
<< "Roaring should not be null for topN search, please set roaring
in params";
- faiss::SearchParametersHNSW param;
- const HNSWSearchParameters* hnsw_params = dynamic_cast<const
HNSWSearchParameters*>(¶ms);
- if (hnsw_params == nullptr) {
- return doris::Status::InvalidArgument(
- "HNSW search parameters should not be null for HNSW index");
- }
- param.efSearch = hnsw_params->ef_search;
- param.check_relative_distance = hnsw_params->check_relative_distance;
- param.bounded_queue = hnsw_params->bounded_queue;
- param.sel = nullptr;
+ std::unique_ptr<faiss::SearchParameters> search_param;
std::unique_ptr<faiss::IDSelector> id_sel = nullptr;
// Costs of roaring to faiss selector is very high especially when the
cardinality is very high.
if (params.roaring->cardinality() != params.rows_of_segment) {
SCOPED_RAW_TIMER(&result.engine_prepare_ns);
id_sel = roaring_to_faiss_selector(*params.roaring);
- param.sel = id_sel.get();
+ }
+
+ if (_index_type == AnnIndexType::HNSW) {
+ const HNSWSearchParameters* hnsw_params =
+ dynamic_cast<const HNSWSearchParameters*>(¶ms);
+ if (hnsw_params == nullptr) {
+ return doris::Status::InvalidArgument(
+ "HNSW search parameters should not be null for HNSW
index");
+ }
+ faiss::SearchParametersHNSW* param = new faiss::SearchParametersHNSW();
+ param->efSearch = hnsw_params->ef_search;
+ param->check_relative_distance = hnsw_params->check_relative_distance;
+ param->bounded_queue = hnsw_params->bounded_queue;
+ param->sel = id_sel.get();
+ search_param.reset(param);
+ } else if (_index_type == AnnIndexType::IVF) {
+ const IVFSearchParameters* ivf_params = dynamic_cast<const
IVFSearchParameters*>(¶ms);
+ if (ivf_params == nullptr) {
+ return doris::Status::InvalidArgument(
+ "IVF search parameters should not be null for IVF index");
+ }
+ faiss::SearchParametersIVF* param = new faiss::SearchParametersIVF();
+ param->nprobe = ivf_params->nprobe;
+ param->sel = id_sel.get();
+ search_param.reset(param);
} else {
- param.sel = nullptr;
+ return doris::Status::InvalidArgument("Unsupported index type for
search");
}
{
SCOPED_RAW_TIMER(&result.engine_search_ns);
- _index->search(1, query_vec, k, distances, labels, ¶m);
+ _index->search(1, query_vec, k, distances, labels, search_param.get());
}
{
SCOPED_RAW_TIMER(&result.engine_convert_ns);
@@ -436,14 +513,29 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
DCHECK(query_vec != nullptr);
DCHECK(params.roaring != nullptr)
<< "Roaring should not be null for range search, please set
roaring in params";
- faiss::SearchParametersHNSW param;
+ std::unique_ptr<faiss::SearchParameters> search_param;
const HNSWSearchParameters* hnsw_params = dynamic_cast<const
HNSWSearchParameters*>(¶ms);
- {
- // Engine prepare: set search parameters and bind selector
- SCOPED_RAW_TIMER(&result.engine_prepare_ns);
- param.efSearch = hnsw_params->ef_search;
- param.check_relative_distance = hnsw_params->check_relative_distance;
- param.bounded_queue = hnsw_params->bounded_queue;
+ const IVFSearchParameters* ivf_params = dynamic_cast<const
IVFSearchParameters*>(¶ms);
+ if (hnsw_params != nullptr) {
+ faiss::SearchParametersHNSW* param = new faiss::SearchParametersHNSW();
+ {
+ // Engine prepare: set search parameters and bind selector
+ SCOPED_RAW_TIMER(&result.engine_prepare_ns);
+ param->efSearch = hnsw_params->ef_search;
+ param->check_relative_distance =
hnsw_params->check_relative_distance;
+ param->bounded_queue = hnsw_params->bounded_queue;
+ }
+ search_param.reset(param);
+ } else if (ivf_params != nullptr) {
+ faiss::SearchParametersIVF* param = new faiss::SearchParametersIVF();
+ {
+ // Engine prepare: set search parameters and bind selector
+ SCOPED_RAW_TIMER(&result.engine_prepare_ns);
+ param->nprobe = ivf_params->nprobe;
+ }
+ search_param.reset(param);
+ } else {
+ return doris::Status::InvalidArgument("Unsupported index type for
range search");
}
std::unique_ptr<faiss::IDSelector> sel;
{
@@ -451,26 +543,25 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
SCOPED_RAW_TIMER(&result.engine_prepare_ns);
if (params.roaring->cardinality() != params.rows_of_segment) {
sel = roaring_to_faiss_selector(*params.roaring);
- param.sel = sel.get();
+ search_param->sel = sel.get();
} else {
- param.sel = nullptr;
+ search_param->sel = nullptr;
}
}
faiss::RangeSearchResult native_search_result(1, true);
- // Currently only support HNSW index for range search.
- DCHECK(hnsw_params != nullptr) << "HNSW search parameters should not be
null for HNSW index";
{
// Engine search: FAISS range_search
SCOPED_RAW_TIMER(&result.engine_search_ns);
if (_metric == AnnIndexMetric::L2) {
if (radius <= 0) {
- _index->range_search(1, query_vec, 0.0f,
&native_search_result, ¶m);
+ _index->range_search(1, query_vec, 0.0f,
&native_search_result, search_param.get());
} else {
- _index->range_search(1, query_vec, radius * radius,
&native_search_result, ¶m);
+ _index->range_search(1, query_vec, radius * radius,
&native_search_result,
+ search_param.get());
}
} else if (_metric == AnnIndexMetric::IP) {
- _index->range_search(1, query_vec, radius, &native_search_result,
¶m);
+ _index->range_search(1, query_vec, radius, &native_search_result,
search_param.get());
}
}
diff --git a/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.h
b/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.h
index c06402157ae..96b06de4218 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.h
@@ -49,7 +49,8 @@ struct FaissBuildParameter {
* @brief Supported vector index types.
*/
enum class IndexType {
- HNSW ///< Hierarchical Navigable Small World (HNSW) index for high
performance
+ HNSW, ///< Hierarchical Navigable Small World (HNSW) index for high
performance
+ IVF ///< Inverted File index
};
/**
@@ -76,6 +77,8 @@ struct FaissBuildParameter {
static IndexType string_to_index_type(const std::string& type) {
if (type == "hnsw") {
return IndexType::HNSW;
+ } else if (type == "ivf") {
+ return IndexType::IVF;
} else {
throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
"Unsupported index type: {}",
type);
@@ -124,6 +127,8 @@ struct FaissBuildParameter {
/// PQ specific parameters
int pq_m = 0; ///< Number of sub-quantizers for PQ
int pq_nbits = 8; ///< Number of bits per sub-quantizer for PQ
+ /// IVF specific parameters
+ int ivf_nlist = 1024; ///< Number of clusters for IVF
};
/**
@@ -271,7 +276,8 @@ public:
private:
std::unique_ptr<faiss::Index> _index = nullptr; ///< Underlying FAISS
index instance
- FaissBuildParameter _params; ///< Build parameters for
the index
+ std::unique_ptr<faiss::Index> _quantizer = nullptr;
+ FaissBuildParameter _params; ///< Build parameters for the index
};
#include "common/compile_check_end.h"
} // namespace doris::segment_v2
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index 24993123bf3..d4fed370806 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -705,7 +705,7 @@ public:
VectorSearchUserParams get_vector_search_params() const {
return VectorSearchUserParams(_query_options.hnsw_ef_search,
_query_options.hnsw_check_relative_distance,
- _query_options.hnsw_bounded_queue);
+ _query_options.hnsw_bounded_queue,
_query_options.ivf_nprobe);
}
private:
diff --git a/be/src/vec/runtime/vector_search_user_params.cpp
b/be/src/vec/runtime/vector_search_user_params.cpp
index 19fb4e1f082..9a62a8bf667 100644
--- a/be/src/vec/runtime/vector_search_user_params.cpp
+++ b/be/src/vec/runtime/vector_search_user_params.cpp
@@ -24,13 +24,13 @@ namespace doris {
bool VectorSearchUserParams::operator==(const VectorSearchUserParams& other)
const {
return hnsw_ef_search == other.hnsw_ef_search &&
hnsw_check_relative_distance == other.hnsw_check_relative_distance
&&
- hnsw_bounded_queue == other.hnsw_bounded_queue;
+ hnsw_bounded_queue == other.hnsw_bounded_queue && ivf_nprobe ==
other.ivf_nprobe;
}
std::string VectorSearchUserParams::to_string() const {
return fmt::format(
"hnsw_ef_search: {}, hnsw_check_relative_distance: {}, "
- "hnsw_bounded_queue: {}",
- hnsw_ef_search, hnsw_check_relative_distance, hnsw_bounded_queue);
+ "hnsw_bounded_queue: {}, ivf_nprobe: {}",
+ hnsw_ef_search, hnsw_check_relative_distance, hnsw_bounded_queue,
ivf_nprobe);
}
-} // namespace doris
\ No newline at end of file
+} // namespace doris
diff --git a/be/src/vec/runtime/vector_search_user_params.h
b/be/src/vec/runtime/vector_search_user_params.h
index 600716651c5..eb2a4f439cd 100644
--- a/be/src/vec/runtime/vector_search_user_params.h
+++ b/be/src/vec/runtime/vector_search_user_params.h
@@ -26,10 +26,11 @@ struct VectorSearchUserParams {
int hnsw_ef_search = 32;
bool hnsw_check_relative_distance = true;
bool hnsw_bounded_queue = true;
+ int ivf_nprobe = 1;
bool operator==(const VectorSearchUserParams& other) const;
std::string to_string() const;
};
#include "common/compile_check_end.h"
-} // namespace doris
\ No newline at end of file
+} // namespace doris
diff --git a/be/test/olap/vector_search/ann_index_reader_test.cpp
b/be/test/olap/vector_search/ann_index_reader_test.cpp
index 8af387ab230..b7eea905684 100644
--- a/be/test/olap/vector_search/ann_index_reader_test.cpp
+++ b/be/test/olap/vector_search/ann_index_reader_test.cpp
@@ -25,6 +25,7 @@
#include <memory>
#include <string>
+#include "olap/rowset/segment_v2/ann_index/ann_index.h"
#include "olap/rowset/segment_v2/ann_index/ann_index_iterator.h"
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
#include "olap/rowset/segment_v2/ann_index/faiss_ann_index.h"
@@ -89,6 +90,26 @@ TEST_F(AnnIndexReaderTest,
TestConstructorWithDifferentMetrics) {
EXPECT_EQ(reader->get_index_id(), 2);
}
+TEST_F(AnnIndexReaderTest, TestConstructorWithIVF) {
+ // Test with IVF index type
+ auto properties = _properties;
+ properties["index_type"] = "ivf";
+ properties["nlist"] = "128";
+ properties["quantizer"] = "flat";
+
+ auto tablet_index = std::make_unique<TabletIndex>();
+ tablet_index->_properties = properties;
+ tablet_index->_index_id = 3;
+
+ auto reader =
std::make_unique<segment_v2::AnnIndexReader>(tablet_index.get(),
+
_mock_index_file_reader);
+ reader->_index_type = segment_v2::AnnIndexType::IVF;
+
+ EXPECT_EQ(reader->get_index_id(), 3);
+ EXPECT_EQ(reader->index_type(), IndexType::ANN);
+ EXPECT_EQ(reader->get_metric_type(), segment_v2::AnnIndexMetric::L2);
+}
+
TEST_F(AnnIndexReaderTest, TestNewIterator) {
// TODO: Fix if we using unique_ptr here.
auto reader =
std::make_shared<segment_v2::AnnIndexReader>(_tablet_index.get(),
@@ -115,8 +136,8 @@ TEST_F(AnnIndexReaderTest, TestLoadIndexSuccess) {
// For the open method that returns Result<unique_ptr<...>>, we need to
use a different approach
// since gmock has issues with non-copyable return types
- ON_CALL(*_mock_index_file_reader, open(testing::_, testing::_))
- .WillByDefault(testing::Invoke(
+ EXPECT_CALL(*_mock_index_file_reader, open(testing::_, testing::_))
+ .WillOnce(testing::Invoke(
[](const doris::TabletIndex*, const doris::io::IOContext*)
->
doris::Result<std::unique_ptr<doris::segment_v2::DorisCompoundReader,
doris::segment_v2::DirectoryDeleter>> {
@@ -151,8 +172,8 @@ TEST_F(AnnIndexReaderTest, TestLoadIndexFailureOpen) {
EXPECT_CALL(*_mock_index_file_reader, init(testing::_, testing::_))
.WillOnce(testing::Return(Status::OK()));
- ON_CALL(*_mock_index_file_reader, open(testing::_, testing::_))
- .WillByDefault(testing::Invoke(
+ EXPECT_CALL(*_mock_index_file_reader, open(testing::_, testing::_))
+ .WillOnce(testing::Invoke(
[](const doris::TabletIndex*, const doris::io::IOContext*)
->
doris::Result<std::unique_ptr<doris::segment_v2::DorisCompoundReader,
doris::segment_v2::DirectoryDeleter>> {
@@ -209,6 +230,59 @@ TEST_F(AnnIndexReaderTest, TestQueryWithoutLoadIndex) {
}
}
+TEST_F(AnnIndexReaderTest, TestQueryIVFWithoutLoadIndex) {
+ // Test IVF index query
+ auto properties = _properties;
+ properties["index_type"] = "ivf";
+ properties["nlist"] = "64";
+ properties["quantizer"] = "flat";
+
+ auto tablet_index = std::make_unique<TabletIndex>();
+ tablet_index->_properties = properties;
+
+ auto reader =
std::make_unique<segment_v2::AnnIndexReader>(tablet_index.get(),
+
_mock_index_file_reader);
+
+ // Set up _vector_index manually to bypass load_index for testing
+ auto doris_faiss_vector_index =
std::make_unique<doris::segment_v2::FaissVectorIndex>();
+
doris_faiss_vector_index->set_metric(doris::segment_v2::AnnIndexMetric::L2);
+
+ doris::segment_v2::FaissBuildParameter build_params;
+ build_params.dim = 4;
+ build_params.ivf_nlist = 64;
+ build_params.index_type =
doris::segment_v2::FaissBuildParameter::IndexType::IVF;
+ build_params.metric_type =
doris::segment_v2::FaissBuildParameter::MetricType::L2;
+ build_params.quantizer =
doris::segment_v2::FaissBuildParameter::Quantizer::FLAT;
+ doris_faiss_vector_index->build(build_params);
+
+ reader->_vector_index = std::move(doris_faiss_vector_index);
+ reader->_index_type = segment_v2::AnnIndexType::IVF;
+
+ // Create query parameters
+ const float query_data[] = {1.0f, 2.0f, 3.0f, 4.0f};
+ roaring::Roaring bitmap;
+ bitmap.add(1);
+ bitmap.add(2);
+
+ segment_v2::AnnTopNParam param {.query_value = query_data,
+ .query_value_size = 4,
+ .limit = 5,
+ ._user_params = VectorSearchUserParams {},
+ .roaring = &bitmap};
+
+ segment_v2::AnnIndexStats stats;
+ io::IOContext io_ctx;
+
+ Status status = reader->query(&io_ctx, ¶m, &stats);
+
+ // The query might succeed or fail depending on the internal index state,
+ // but it should not crash and should properly initialize distance and
row_ids
+ if (status.ok()) {
+ EXPECT_NE(param.distance, nullptr);
+ EXPECT_NE(param.row_ids, nullptr);
+ }
+}
+
TEST_F(AnnIndexReaderTest, TestRangeSearchWithoutLoadIndex) {
auto reader =
std::make_unique<segment_v2::AnnIndexReader>(_tablet_index.get(),
_mock_index_file_reader);
@@ -255,6 +329,61 @@ TEST_F(AnnIndexReaderTest,
TestRangeSearchWithoutLoadIndex) {
}
}
+TEST_F(AnnIndexReaderTest, TestRangeSearchIVFWithoutLoadIndex) {
+ // Test IVF index range search
+ auto properties = _properties;
+ properties["index_type"] = "ivf";
+ properties["nlist"] = "64";
+ properties["quantizer"] = "flat";
+
+ auto tablet_index = std::make_unique<TabletIndex>();
+ tablet_index->_properties = properties;
+
+ auto reader =
std::make_unique<segment_v2::AnnIndexReader>(tablet_index.get(),
+
_mock_index_file_reader);
+
+ // Set up _vector_index manually to bypass load_index for testing
+ auto doris_faiss_vector_index =
std::make_unique<doris::segment_v2::FaissVectorIndex>();
+
doris_faiss_vector_index->set_metric(doris::segment_v2::AnnIndexMetric::L2);
+
+ doris::segment_v2::FaissBuildParameter build_params;
+ build_params.dim = 4;
+ build_params.ivf_nlist = 64;
+ build_params.index_type =
doris::segment_v2::FaissBuildParameter::IndexType::IVF;
+ build_params.metric_type =
doris::segment_v2::FaissBuildParameter::MetricType::L2;
+ build_params.quantizer =
doris::segment_v2::FaissBuildParameter::Quantizer::FLAT;
+ doris_faiss_vector_index->build(build_params);
+
+ reader->_vector_index = std::move(doris_faiss_vector_index);
+ reader->_index_type = segment_v2::AnnIndexType::IVF;
+
+ // Create range search parameters
+ float query_data[] = {1.0f, 2.0f, 3.0f, 4.0f};
+ roaring::Roaring bitmap;
+ bitmap.add(1);
+ bitmap.add(2);
+
+ segment_v2::AnnRangeSearchParams params;
+ params.is_le_or_lt = true;
+ params.radius = 5.0f;
+ params.query_value = query_data;
+ params.roaring = &bitmap;
+
+ VectorSearchUserParams user_params;
+ user_params.ivf_nprobe = 2;
+
+ segment_v2::AnnRangeSearchResult result;
+ segment_v2::AnnIndexStats stats;
+
+ Status status = reader->range_search(params, user_params, &result, &stats);
+
+ // The range search might succeed or fail depending on the internal index
state,
+ // but it should not crash
+ if (status.ok()) {
+ EXPECT_NE(result.roaring, nullptr);
+ }
+}
+
TEST_F(AnnIndexReaderTest, TestUpdateResultStatic) {
// Test the static update_result method
segment_v2::IndexSearchResult search_result;
@@ -494,4 +623,65 @@ TEST_F(AnnIndexReaderTest, AnnIndexReaderRangeSearch) {
}
}
-} // namespace doris::vectorized
\ No newline at end of file
+TEST_F(AnnIndexReaderTest, AnnIndexReaderIVFRangeSearch) {
+ // Test IVF index range search functionality
+ std::map<std::string, std::string> index_properties;
+ index_properties["index_type"] = "ivf";
+ index_properties["metric_type"] = "l2_distance";
+ index_properties["dim"] = "32";
+ index_properties["nlist"] = "8"; // Small nlist for testing
+ index_properties["quantizer"] = "flat";
+ std::unique_ptr<doris::TabletIndex> index_meta =
std::make_unique<doris::TabletIndex>();
+ index_meta->_properties = index_properties;
+ auto mock_index_file_reader = std::make_shared<MockIndexFileReader>();
+ auto ann_index_reader =
+ std::make_unique<segment_v2::AnnIndexReader>(index_meta.get(),
mock_index_file_reader);
+ ann_index_reader->_index_type = segment_v2::AnnIndexType::IVF;
+
+ // Create and set up IVF index
+ auto doris_faiss_index =
std::make_unique<doris::segment_v2::FaissVectorIndex>();
+ doris::segment_v2::FaissBuildParameter build_params;
+ build_params.dim = 32;
+ build_params.ivf_nlist = 8;
+ build_params.index_type =
doris::segment_v2::FaissBuildParameter::IndexType::IVF;
+ build_params.metric_type =
doris::segment_v2::FaissBuildParameter::MetricType::L2;
+ build_params.quantizer =
doris::segment_v2::FaissBuildParameter::Quantizer::FLAT;
+ doris_faiss_index->build(build_params);
+
+ const size_t num_vectors = 1000;
+ auto vectors =
doris::vector_search_utils::generate_test_vectors_matrix(num_vectors, 32);
+ for (const auto& vec : vectors) {
+ doris_faiss_index->add(1, vec.data());
+ }
+
+ std::ignore = doris_faiss_index->save(this->_ram_dir.get());
+ std::vector<float> query_value = vectors[0];
+ const float radius =
doris::vector_search_utils::get_radius_from_matrix(query_value.data(), 32,
+
vectors, 0.3F);
+
+ // Make sure all rows are in the roaring
+ auto roaring = std::make_unique<roaring::Roaring>();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ roaring->add(i);
+ }
+
+ doris::segment_v2::AnnRangeSearchParams params;
+ params.radius = radius;
+ params.query_value = query_value.data();
+ params.roaring = roaring.get();
+ doris::VectorSearchUserParams custom_params;
+ doris::segment_v2::AnnRangeSearchResult result;
+ auto stats = std::make_unique<doris::segment_v2::AnnIndexStats>();
+ auto doris_faiss_vector_index =
std::make_unique<doris::segment_v2::FaissVectorIndex>();
+ std::ignore = doris_faiss_vector_index->load(this->_ram_dir.get());
+ ann_index_reader->_vector_index = std::move(doris_faiss_vector_index);
+ Status status = ann_index_reader->range_search(params, custom_params,
&result, stats.get());
+
+ // IVF range search should work without crashing
+ if (status.ok()) {
+ EXPECT_NE(result.roaring, nullptr);
+ EXPECT_GE(result.roaring->cardinality(), 0);
+ }
+}
+
+} // namespace doris::vectorized
diff --git a/be/test/olap/vector_search/ann_index_writer_test.cpp
b/be/test/olap/vector_search/ann_index_writer_test.cpp
index 37d390379c2..ad4b6881c42 100644
--- a/be/test/olap/vector_search/ann_index_writer_test.cpp
+++ b/be/test/olap/vector_search/ann_index_writer_test.cpp
@@ -26,6 +26,7 @@
#include <string>
#include <vector>
+#include "olap/field.h"
#include "olap/rowset/segment_v2/index_file_writer.h"
#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
#include "olap/tablet_schema.h"
@@ -125,8 +126,26 @@ TEST_F(AnnIndexWriterTest,
TestInitWithDifferentProperties) {
{"metric_type", "l2_distance"},
{"dim", "128"},
{"max_degree", "64"}},
+ {{"index_type", "ivf"},
+ {"metric_type", "l2_distance"},
+ {"dim", "8"},
+ {"nlist", "128"},
+ {"quantizer", "flat"}},
+ {{"index_type", "ivf"},
+ {"metric_type", "inner_product"},
+ {"dim", "128"},
+ {"nlist", "512"},
+ {"quantizer", "sq4"}},
+ {{"index_type", "ivf"},
+ {"metric_type", "l2_distance"},
+ {"dim", "64"},
+ {"nlist", "256"},
+ {"quantizer", "pq"},
+ {"pq_m", "4"},
+ {"pq_nbits", "8"}},
// Test with default values (missing properties)
{{"index_type", "hnsw"}},
+ {{"index_type", "ivf"}},
{}};
for (const auto& props : test_properties) {
@@ -473,4 +492,164 @@ TEST_F(AnnIndexWriterTest, TestAddMoreThanChunkSize) {
EXPECT_TRUE(status.ok());
}
+TEST_F(AnnIndexWriterTest, TestCreateFromIndexColumnWriter) {
+ TabletSchemaSPtr tablet_schema = std::make_shared<TabletSchema>();
+ TabletSchemaPB tablet_schema_pb;
+ tablet_schema_pb.set_keys_type(DUP_KEYS);
+ tablet_schema->init_from_pb(tablet_schema_pb);
+
+ TabletColumn array_column;
+ array_column.set_name("arr1");
+ array_column.set_type(FieldType::OLAP_FIELD_TYPE_ARRAY);
+ array_column.set_length(0);
+ array_column.set_index_length(0);
+ array_column.set_is_nullable(false);
+
+ TabletColumn child_column;
+ child_column.set_name("arr_sub_float");
+ child_column.set_type(FieldType::OLAP_FIELD_TYPE_FLOAT);
+ child_column.set_length(INT_MAX);
+ array_column.add_sub_column(child_column);
+ tablet_schema->append_column(array_column);
+
+ // Get field for array column
+ std::unique_ptr<Field> field(FieldFactory::create(array_column));
+ ASSERT_NE(field.get(), nullptr);
+
+ auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
+ fs_dir->init(doris::io::global_local_filesystem(),
"./ut_dir/tmp_vector_search", nullptr);
+ EXPECT_CALL(*_index_file_writer,
open(testing::_)).WillOnce(testing::Return(fs_dir));
+
+ // Create column writer
+ std::unique_ptr<IndexColumnWriter> column_writer;
+ auto status = IndexColumnWriter::create(field.get(), &column_writer,
_index_file_writer.get(),
+ _tablet_index.get());
+ EXPECT_TRUE(status.ok());
+
+ // Prepare test data
+ const size_t num_rows = 3;
+ std::vector<float> vectors = {
+ 1.0f, 2.0f, 3.0f, 4.0f, // Row 0
+ 5.0f, 6.0f, 7.0f, 8.0f, // Row 1
+ 9.0f, 10.0f, 11.0f, 12.0f // Row 2
+ };
+
+ std::vector<size_t> offsets = {0, 4, 8, 12}; // Each row has 4 elements
+
+ status = column_writer->add_array_values(sizeof(float), vectors.data(),
nullptr,
+ reinterpret_cast<const
uint8_t*>(offsets.data()),
+ num_rows);
+ EXPECT_TRUE(status.ok());
+
+ ASSERT_TRUE(column_writer->finish().ok());
+}
+
+TEST_F(AnnIndexWriterTest, TestAddArrayValuesIVF) {
+ auto properties = _properties;
+ properties["index_type"] = "ivf";
+ properties["nlist"] = "3";
+ properties["quantizer"] = "flat";
+
+ auto tablet_index = std::make_unique<TabletIndex>();
+ tablet_index->_properties = properties;
+ tablet_index->_index_id = 1;
+
+ auto writer =
+ std::make_unique<AnnIndexColumnWriter>(_index_file_writer.get(),
tablet_index.get());
+
+ auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
+ fs_dir->init(doris::io::global_local_filesystem(),
"./ut_dir/tmp_vector_search", nullptr);
+ EXPECT_CALL(*_index_file_writer,
open(testing::_)).WillOnce(testing::Return(fs_dir));
+
+ ASSERT_TRUE(writer->init().ok());
+
+ // Prepare test data
+ const size_t dim = 4;
+ const size_t num_rows = 3;
+ std::vector<float> vectors = {
+ 1.0f, 2.0f, 3.0f, 4.0f, // Row 0
+ 5.0f, 6.0f, 7.0f, 8.0f, // Row 1
+ 9.0f, 10.0f, 11.0f, 12.0f // Row 2
+ };
+
+ std::vector<size_t> offsets = {0, 4, 8, 12}; // Each row has 4 elements
+
+ Status status =
+ writer->add_array_values(sizeof(float), vectors.data(), nullptr,
+ reinterpret_cast<const
uint8_t*>(offsets.data()), num_rows);
+ EXPECT_TRUE(status.ok());
+}
+
+TEST_F(AnnIndexWriterTest, TestAddMoreThanChunkSizeIVF) {
+ auto mock_index = std::make_shared<MockVectorIndex>();
+ auto properties = _properties;
+ properties["index_type"] = "ivf";
+ properties["nlist"] = "2";
+ properties["quantizer"] = "flat";
+
+ auto tablet_index = std::make_unique<TabletIndex>();
+ tablet_index->_properties = properties;
+ tablet_index->_index_id = 1;
+
+ auto writer =
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
+
tablet_index.get());
+
+ auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
+ fs_dir->init(doris::io::global_local_filesystem(),
"./ut_dir/tmp_vector_search", nullptr);
+ EXPECT_CALL(*_index_file_writer,
open(testing::_)).WillOnce(testing::Return(fs_dir));
+
+ ASSERT_TRUE(writer->init().ok());
+ writer->set_vector_index(mock_index);
+
+ EXPECT_CALL(*mock_index, train(10, testing::_))
+ .Times(1)
+ .WillOnce(testing::Return(Status::OK()));
+ EXPECT_CALL(*mock_index, add(10,
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+ EXPECT_CALL(*mock_index, train(2,
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+ EXPECT_CALL(*mock_index, add(2,
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+ EXPECT_CALL(*mock_index,
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+
+ // CHUNK_SIZE = 10
+ const size_t dim = 4;
+
+ {
+ const size_t num_rows = 6;
+ std::vector<float> vectors = {
+ 1.0f, 2.0f, 3.0f, 4.0f, // Row 0
+ 5.0f, 6.0f, 7.0f, 8.0f, // Row 1
+ 9.0f, 10.0f, 11.0f, 12.0f, // Row 2
+ 13.0f, 14.0f, 15.0f, 16.0f, // Row 3
+ 17.0f, 18.0f, 19.0f, 20.0f, // Row 4
+ 21.0f, 22.0f, 23.0f, 24.0f // Row 5
+ };
+ std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
+
+ Status status = writer->add_array_values(sizeof(float),
vectors.data(), nullptr,
+ reinterpret_cast<const
uint8_t*>(offsets.data()),
+ num_rows);
+ EXPECT_TRUE(status.ok());
+ }
+
+ {
+ const size_t num_rows = 6;
+ std::vector<float> vectors = {
+ 25.0f, 26.0f, 27.0f, 28.0f, // Row 6
+ 29.0f, 30.0f, 31.0f, 32.0f, // Row 7
+ 33.0f, 34.0f, 35.0f, 36.0f, // Row 8
+ 37.0f, 38.0f, 39.0f, 40.0f, // Row 9
+ 41.0f, 42.0f, 43.0f, 44.0f, // Row 10
+ 45.0f, 46.0f, 47.0f, 48.0f // Row 11
+ };
+ std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
+
+ Status status = writer->add_array_values(sizeof(float),
vectors.data(), nullptr,
+ reinterpret_cast<const
uint8_t*>(offsets.data()),
+ num_rows);
+ EXPECT_TRUE(status.ok());
+ }
+
+ Status status = writer->finish();
+ EXPECT_TRUE(status.ok());
+}
+
} // namespace doris::segment_v2
diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp
b/be/test/olap/vector_search/faiss_vector_index_test.cpp
index 82869837c59..2d6b7ac95b2 100644
--- a/be/test/olap/vector_search/faiss_vector_index_test.cpp
+++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp
@@ -16,6 +16,7 @@
// under the License.
#include <faiss/IndexHNSW.h>
+#include <faiss/IndexIVFFlat.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -39,79 +40,178 @@ namespace doris::vectorized {
// Test saving and loading an index
TEST_F(VectorSearchTest, TestSaveAndLoad) {
- // Step 1: Create first index instance
- auto index1 = std::make_unique<FaissVectorIndex>();
+ std::vector<FaissBuildParameter::Quantizer> quantizers = {
+ FaissBuildParameter::Quantizer::FLAT,
FaissBuildParameter::Quantizer::SQ4,
+ FaissBuildParameter::Quantizer::SQ8,
FaissBuildParameter::Quantizer::PQ};
- // Step 2: Set build parameters
- FaissBuildParameter params;
- params.dim = 128; // Vector dimension
- params.max_degree = 16; // HNSW max connections
- params.index_type = FaissBuildParameter::IndexType::HNSW;
- index1->build(params);
-
- // Step 3: Add vectors to the index
- const int num_vectors = 100;
- std::vector<float> vectors;
- for (int i = 0; i < num_vectors; i++) {
- auto tmp = vector_search_utils::generate_random_vector(params.dim);
- vectors.insert(vectors.end(), tmp.begin(), tmp.end());
- }
+ for (auto quantizer : quantizers) {
+ // Step 1: Create first index instance
+ auto index1 = std::make_unique<FaissVectorIndex>();
- std::ignore = index1->add(num_vectors, vectors.data());
+ // Step 2: Set build parameters
+ FaissBuildParameter params;
+ params.dim = 128; // Vector dimension
+ params.max_degree = 16; // HNSW max connections
+ params.pq_m = 4;
+ params.pq_nbits = 8;
+ params.index_type = FaissBuildParameter::IndexType::HNSW;
+ params.quantizer = quantizer;
+ index1->build(params);
- // Step 4: Save the index
- auto save_status = index1->save(_ram_dir.get());
- ASSERT_TRUE(save_status.ok()) << "Failed to save index: " <<
save_status.to_string();
+ // Step 3: Add vectors to the index
+ const int num_vectors = 100;
+ std::vector<float> vectors;
+ for (int i = 0; i < num_vectors; i++) {
+ auto tmp = vector_search_utils::generate_random_vector(params.dim);
+ vectors.insert(vectors.end(), tmp.begin(), tmp.end());
+ }
- // Step 5: Create a new index instance
- auto index2 = std::make_unique<FaissVectorIndex>();
+ std::ignore = index1->train(num_vectors, vectors.data());
+ std::ignore = index1->add(num_vectors, vectors.data());
- // Step 6: Load the index
- auto load_status = index2->load(_ram_dir.get());
- ASSERT_TRUE(load_status.ok()) << "Failed to load index: " <<
load_status.to_string();
+ // Step 4: Save the index
+ auto save_status = index1->save(_ram_dir.get());
+ ASSERT_TRUE(save_status.ok()) << "Failed to save index: " <<
save_status.to_string();
- // Step 7: Verify the loaded index works by searching
- auto query_vec = vector_search_utils::generate_random_vector(params.dim);
- const int top_k = 10;
+ // Step 5: Create a new index instance
+ auto index2 = std::make_unique<FaissVectorIndex>();
- // TopN search requires a candidate roaring and rows_of_segment now.
- HNSWSearchParameters topn_params;
- auto topn_roaring = std::make_unique<roaring::Roaring>();
- for (int i = 0; i < num_vectors; ++i) topn_roaring->add(i);
- topn_params.roaring = topn_roaring.get();
- topn_params.rows_of_segment = num_vectors;
+ // Step 6: Load the index
+ auto load_status = index2->load(_ram_dir.get());
+ ASSERT_TRUE(load_status.ok()) << "Failed to load index: " <<
load_status.to_string();
- IndexSearchResult search_result1;
- IndexSearchResult search_result2;
+ // Step 7: Verify the loaded index works by searching
+ auto query_vec =
vector_search_utils::generate_random_vector(params.dim);
+ const int top_k = 10;
+
+ // TopN search requires a candidate roaring and rows_of_segment now.
+ HNSWSearchParameters topn_params;
+ auto topn_roaring = std::make_unique<roaring::Roaring>();
+ for (int i = 0; i < num_vectors; ++i) topn_roaring->add(i);
+ topn_params.roaring = topn_roaring.get();
+ topn_params.rows_of_segment = num_vectors;
- std::ignore = index1->ann_topn_search(query_vec.data(), top_k,
topn_params, search_result1);
+ IndexSearchResult search_result1;
+ IndexSearchResult search_result2;
- std::ignore = index2->ann_topn_search(query_vec.data(), top_k,
topn_params, search_result2);
+ std::ignore = index1->ann_topn_search(query_vec.data(), top_k,
topn_params, search_result1);
- // Compare the results
- EXPECT_EQ(search_result1.roaring->cardinality(),
search_result2.roaring->cardinality())
- << "Row ID cardinality mismatch";
- for (size_t i = 0; i < search_result1.roaring->cardinality(); ++i) {
- EXPECT_EQ(search_result1.distances[i], search_result2.distances[i])
- << "Distance mismatch at index " << i;
- }
+ std::ignore = index2->ann_topn_search(query_vec.data(), top_k,
topn_params, search_result2);
- HNSWSearchParameters hnsw_params;
- auto roaring_bitmap = std::make_unique<roaring::Roaring>();
- hnsw_params.roaring = roaring_bitmap.get();
- for (size_t i = 0; i < num_vectors; ++i) {
- hnsw_params.roaring->add(i);
+ // Compare the results
+ EXPECT_EQ(search_result1.roaring->cardinality(),
search_result2.roaring->cardinality())
+ << "Row ID cardinality mismatch";
+ for (size_t i = 0; i < search_result1.roaring->cardinality(); ++i) {
+ EXPECT_EQ(search_result1.distances[i], search_result2.distances[i])
+ << "Distance mismatch at index " << i;
+ }
+
+ HNSWSearchParameters hnsw_params;
+ auto roaring_bitmap = std::make_unique<roaring::Roaring>();
+ hnsw_params.roaring = roaring_bitmap.get();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ hnsw_params.roaring->add(i);
+ }
+ IndexSearchResult range_search_result1;
+ std::ignore = index1->range_search(vectors.data(), 10, hnsw_params,
range_search_result1);
+ IndexSearchResult range_search_result2;
+ std::ignore = index2->range_search(vectors.data(), 10, hnsw_params,
range_search_result2);
+ EXPECT_EQ(range_search_result1.roaring->cardinality(),
+ range_search_result2.roaring->cardinality())
+ << "Row ID cardinality mismatch";
+ for (size_t i = 0; i < range_search_result1.roaring->cardinality();
++i) {
+ EXPECT_EQ(range_search_result1.distances[i],
range_search_result2.distances[i])
+ << "Distance mismatch at index " << i;
+ }
}
- IndexSearchResult range_search_result1;
- std::ignore = index1->range_search(vectors.data(), 10, hnsw_params,
range_search_result1);
- IndexSearchResult range_search_result2;
- std::ignore = index2->range_search(vectors.data(), 10, hnsw_params,
range_search_result2);
- EXPECT_EQ(range_search_result1.roaring->cardinality(),
- range_search_result2.roaring->cardinality())
- << "Row ID cardinality mismatch";
- for (size_t i = 0; i < range_search_result1.roaring->cardinality(); ++i) {
- EXPECT_EQ(range_search_result1.distances[i],
range_search_result2.distances[i])
- << "Distance mismatch at index " << i;
+}
+
+TEST_F(VectorSearchTest, TestSaveAndLoadIVF) {
+ std::vector<FaissBuildParameter::Quantizer> quantizers = {
+ FaissBuildParameter::Quantizer::FLAT,
FaissBuildParameter::Quantizer::SQ4,
+ FaissBuildParameter::Quantizer::SQ8,
FaissBuildParameter::Quantizer::PQ};
+
+ for (auto quantizer : quantizers) {
+ // Step 1: Create first index instance
+ auto index1 = std::make_unique<FaissVectorIndex>();
+
+ // Step 2: Set build parameters
+ FaissBuildParameter params;
+ params.dim = 128; // Vector dimension
+ params.ivf_nlist = 4;
+ params.pq_m = 4;
+ params.pq_nbits = 8;
+ params.quantizer = quantizer;
+ params.index_type = FaissBuildParameter::IndexType::IVF;
+ index1->build(params);
+
+ // Step 3: Add vectors to the index
+ const int num_vectors = 100;
+ std::vector<float> vectors;
+ for (int i = 0; i < num_vectors; i++) {
+ auto tmp = vector_search_utils::generate_random_vector(params.dim);
+ vectors.insert(vectors.end(), tmp.begin(), tmp.end());
+ }
+
+ std::ignore = index1->train(num_vectors, vectors.data());
+ std::ignore = index1->add(num_vectors, vectors.data());
+
+ // Step 4: Save the index
+ auto save_status = index1->save(_ram_dir.get());
+ ASSERT_TRUE(save_status.ok()) << "Failed to save index: " <<
save_status.to_string();
+
+ // Step 5: Create a new index instance
+ auto index2 = std::make_unique<FaissVectorIndex>();
+ index2->set_type(segment_v2::AnnIndexType::IVF);
+
+ // Step 6: Load the index
+ auto load_status = index2->load(_ram_dir.get());
+ ASSERT_TRUE(load_status.ok()) << "Failed to load index: " <<
load_status.to_string();
+
+ // Step 7: Verify the loaded index works by searching
+ auto query_vec =
vector_search_utils::generate_random_vector(params.dim);
+ const int top_k = 10;
+
+ // TopN search requires a candidate roaring and rows_of_segment now.
+ IVFSearchParameters topn_params;
+ auto topn_roaring = std::make_unique<roaring::Roaring>();
+ for (int i = 0; i < num_vectors; ++i) topn_roaring->add(i);
+ topn_params.roaring = topn_roaring.get();
+ topn_params.rows_of_segment = num_vectors;
+ topn_params.nprobe = 4;
+
+ IndexSearchResult search_result1;
+ IndexSearchResult search_result2;
+
+ std::ignore = index1->ann_topn_search(query_vec.data(), top_k,
topn_params, search_result1);
+
+ std::ignore = index2->ann_topn_search(query_vec.data(), top_k,
topn_params, search_result2);
+
+ // Compare the results
+ EXPECT_EQ(search_result1.roaring->cardinality(),
search_result2.roaring->cardinality())
+ << "Row ID cardinality mismatch";
+ for (size_t i = 0; i < search_result1.roaring->cardinality(); ++i) {
+ EXPECT_EQ(search_result1.distances[i], search_result2.distances[i])
+ << "Distance mismatch at index " << i;
+ }
+
+ IVFSearchParameters ivf_params;
+ auto roaring_bitmap = std::make_unique<roaring::Roaring>();
+ ivf_params.roaring = roaring_bitmap.get();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ ivf_params.roaring->add(i);
+ }
+ IndexSearchResult range_search_result1;
+ std::ignore = index1->range_search(vectors.data(), 10, ivf_params,
range_search_result1);
+ IndexSearchResult range_search_result2;
+ std::ignore = index2->range_search(vectors.data(), 10, ivf_params,
range_search_result2);
+ EXPECT_EQ(range_search_result1.roaring->cardinality(),
+ range_search_result2.roaring->cardinality())
+ << "Row ID cardinality mismatch";
+ for (size_t i = 0; i < range_search_result1.roaring->cardinality();
++i) {
+ EXPECT_EQ(range_search_result1.distances[i],
range_search_result2.distances[i])
+ << "Distance mismatch at index " << i;
+ }
}
}
@@ -425,6 +525,117 @@ TEST_F(VectorSearchTest, CompRangeSearch) {
}
}
+TEST_F(VectorSearchTest, CompRangeSearchIVF) {
+ size_t iterations = 5;
+ std::vector<faiss::MetricType> metrics = {faiss::METRIC_L2,
faiss::METRIC_INNER_PRODUCT};
+ for (size_t i = 0; i < iterations; ++i) {
+ for (auto metric : metrics) {
+ // Random parameters for each test iteration
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ size_t random_d = std::uniform_int_distribution<>(1, 512)(gen);
+ size_t random_n = std::uniform_int_distribution<>(10, 200)(gen);
+
+ // Step 1: Create and build index
+ auto doris_index = std::make_unique<FaissVectorIndex>();
+ FaissBuildParameter params;
+ params.dim = random_d;
+ params.ivf_nlist = 4;
+ params.index_type = FaissBuildParameter::IndexType::IVF;
+ if (metric == faiss::METRIC_L2) {
+ params.metric_type = FaissBuildParameter::MetricType::L2;
+ } else if (metric == faiss::METRIC_INNER_PRODUCT) {
+ params.metric_type = FaissBuildParameter::MetricType::IP;
+ } else {
+ throw std::runtime_error(fmt::format("Unsupported metric type:
{}", metric));
+ }
+ doris_index->build(params);
+
+ const int num_vectors = random_n;
+ std::vector<float> flat_vector;
+ std::vector<float> query_vec;
+ for (int i = 0; i < num_vectors; i++) {
+ auto vec =
vector_search_utils::generate_random_vector(params.dim);
+ if (i == 0) {
+ query_vec = vec;
+ }
+ flat_vector.insert(flat_vector.end(), vec.begin(), vec.end());
+ }
+
+ std::unique_ptr<faiss::Index> native_index;
+ std::unique_ptr<faiss::IndexFlat> quantizer = nullptr;
+ if (metric == faiss::METRIC_L2) {
+ quantizer = std::make_unique<faiss::IndexFlat>(params.dim,
faiss::METRIC_L2);
+ native_index = std::make_unique<faiss::IndexIVFFlat>(
+ quantizer.get(), params.dim, params.ivf_nlist,
faiss::METRIC_L2);
+ } else if (metric == faiss::METRIC_INNER_PRODUCT) {
+ quantizer =
+ std::make_unique<faiss::IndexFlat>(params.dim,
faiss::METRIC_INNER_PRODUCT);
+ native_index = std::make_unique<faiss::IndexIVFFlat>(
+ quantizer.get(), params.dim, params.ivf_nlist,
faiss::METRIC_INNER_PRODUCT);
+ } else {
+ throw std::runtime_error(fmt::format("Unsupported metric type:
{}", metric));
+ }
+
+ doris::vector_search_utils::add_vectors_to_indexes_batch_mode(
+ doris_index.get(), native_index.get(), num_vectors,
flat_vector);
+
+ float radius = 0;
+ radius = doris::vector_search_utils::get_radius_from_flatten(
+ query_vec.data(), params.dim, flat_vector, 0.4f);
+
+ IVFSearchParameters ivf_params;
+ ivf_params.nprobe = 4;
+ // Search on all rows;
+ auto roaring = std::make_unique<roaring::Roaring>();
+ ivf_params.roaring = roaring.get();
+ for (size_t i = 0; i < num_vectors; i++) {
+ ivf_params.roaring->add(i);
+ }
+ ivf_params.is_le_or_lt = metric == faiss::METRIC_L2;
+ IndexSearchResult doris_result;
+ std::ignore =
+ doris_index->range_search(query_vec.data(), radius,
ivf_params, doris_result);
+
+ faiss::SearchParametersIVF search_params_native;
+ search_params_native.nprobe = ivf_params.nprobe;
+ faiss::RangeSearchResult search_result_native(1, true);
+ // 对于L2,radius要平方;对于IP,直接用
+ float faiss_radius = (metric == faiss::METRIC_L2) ? radius *
radius : radius;
+ native_index->range_search(1, query_vec.data(), faiss_radius,
&search_result_native,
+ &search_params_native);
+
+ std::vector<std::pair<int, float>> native_results;
+ size_t begin = search_result_native.lims[0];
+ size_t end = search_result_native.lims[1];
+ for (size_t i = begin; i < end; i++) {
+ native_results.push_back(
+ {search_result_native.labels[i],
search_result_native.distances[i]});
+ }
+
+ // Make sure result is same
+ ASSERT_NEAR(doris_result.roaring->cardinality(),
native_results.size(), 1)
+ << fmt::format("\nd: {}, n: {}, metric: {}", random_d,
random_n, metric);
+ ASSERT_EQ(doris_result.distances != nullptr, true);
+ if (doris_result.roaring->cardinality() == native_results.size()) {
+ for (size_t i = 0; i < native_results.size(); i++) {
+ const size_t rowid = native_results[i].first;
+ const float dis = native_results[i].second;
+ ASSERT_EQ(doris_result.roaring->contains(rowid), true)
+ << "Row ID mismatch at rank " << i;
+ if (metric == faiss::METRIC_L2) {
+ ASSERT_FLOAT_EQ(doris_result.distances[i], sqrt(dis))
+ << "Distance mismatch at rank " << i;
+ } else {
+ ASSERT_FLOAT_EQ(doris_result.distances[i], dis)
+ << "Distance mismatch at rank " << i;
+ }
+ }
+ }
+ }
+ }
+}
+
TEST_F(VectorSearchTest, RangeSearchAllRowsAsCandidates) {
size_t iterations = 5;
// Random parameters for each test iteration
@@ -824,6 +1035,71 @@ TEST_F(VectorSearchTest, InnerProductRangeSearchBasic) {
}
}
+TEST_F(VectorSearchTest, InnerProductRangeSearchBasicIVF) {
+ const size_t iterations = 3;
+
+ for (size_t iter = 0; iter < iterations; ++iter) {
+ const int dim = 64;
+ const int n = 500;
+ const int nlist = 4;
+
+ // Create Doris IVF index
+ auto doris_index = std::make_unique<FaissVectorIndex>();
+ FaissBuildParameter params;
+ params.dim = dim;
+ params.index_type = FaissBuildParameter::IndexType::IVF;
+ params.ivf_nlist = nlist;
+ params.quantizer = FaissBuildParameter::Quantizer::FLAT;
+ params.metric_type = FaissBuildParameter::MetricType::IP;
+ doris_index->build(params);
+
+ // Generate vectors
+ std::vector<std::vector<float>> vectors;
+ std::vector<float> flat_vectors;
+ for (int i = 0; i < n; ++i) {
+ auto vec = doris::vector_search_utils::generate_random_vector(dim);
+ vectors.push_back(vec);
+ flat_vectors.insert(flat_vectors.end(), vec.begin(), vec.end());
+ }
+
+ // Add vectors to index
+ doris_index->train(n, flat_vectors.data());
+ doris_index->add(n, flat_vectors.data());
+
+ // Use first vector as query
+ std::vector<float> query_vec = vectors[0];
+
+ // Calculate radius based on inner product distribution
+ float radius = doris::vector_search_utils::get_radius_from_matrix(
+ query_vec.data(), dim, vectors, 0.5f,
faiss::METRIC_INNER_PRODUCT);
+
+ // Perform Doris range search
+ IVFSearchParameters doris_params;
+ doris_params.nprobe = 4; // Search in 4 clusters
+ doris_params.is_le_or_lt = false; // For inner product, we want values
>= radius
+ auto roaring = std::make_unique<roaring::Roaring>();
+ for (int i = 0; i < n; ++i) {
+ roaring->add(i);
+ }
+ doris_params.roaring = roaring.get();
+
+ IndexSearchResult doris_result;
+ auto status =
+ doris_index->range_search(query_vec.data(), radius,
doris_params, doris_result);
+ ASSERT_TRUE(status.ok()) << "Doris IVF range search failed";
+
+ // Verify basic properties
+ ASSERT_GT(doris_result.roaring->cardinality(), 0u) << "Should find
some results";
+ ASSERT_NE(doris_result.distances, nullptr) << "Distances should be
provided";
+
+ // Verify all returned distances are >= radius
+ for (size_t i = 0; i < doris_result.roaring->cardinality(); ++i) {
+ ASSERT_GE(doris_result.distances[i], radius - 1e-6)
+ << "Distance should be >= radius for IVF inner product
range search";
+ }
+ }
+}
+
TEST_F(VectorSearchTest, InnerProductVsL2Comparison) {
const int dim = 32;
const int n = 100;
@@ -1074,4 +1350,4 @@ TEST_F(VectorSearchTest,
InnerProductRangeSearchZeroAndNegativeRadius) {
ASSERT_GE(res_gen.roaring->cardinality(), static_cast<size_t>(n * 0.9));
}
-} // namespace doris::vectorized
\ No newline at end of file
+} // namespace doris::vectorized
diff --git a/be/test/olap/vector_search/vector_search_utils.cpp
b/be/test/olap/vector_search/vector_search_utils.cpp
index cb02b464d6a..9c2b885c724 100644
--- a/be/test/olap/vector_search/vector_search_utils.cpp
+++ b/be/test/olap/vector_search/vector_search_utils.cpp
@@ -131,12 +131,16 @@ void
add_vectors_to_indexes_batch_mode(segment_v2::VectorIndex* doris_index,
faiss::Index* native_index, size_t
num_vectors,
const std::vector<float>&
flatten_vectors) {
if (doris_index) {
- auto status = doris_index->add(num_vectors, flatten_vectors.data());
+ auto status = doris_index->train(num_vectors, flatten_vectors.data());
+ ASSERT_TRUE(status.ok()) << "Failed to train vectors to Doris index: "
+ << status.to_string();
+ status = doris_index->add(num_vectors, flatten_vectors.data());
ASSERT_TRUE(status.ok()) << "Failed to add vectors to Doris index: "
<< status.to_string();
}
if (native_index) {
// Add vectors to native Faiss index
+ native_index->train(num_vectors, flatten_vectors.data());
native_index->add(num_vectors, flatten_vectors.data());
}
}
@@ -282,4 +286,4 @@ create_tmp_ann_index_reader(std::map<std::string,
std::string> properties) {
mock_index_file_reader);
return std::make_pair(std::move(mock_tablet_index), ann_reader);
}
-} // namespace doris::vector_search_utils
\ No newline at end of file
+} // namespace doris::vector_search_utils
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AnnIndexPropertiesChecker.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AnnIndexPropertiesChecker.java
index 2fcb6d307df..f03d04a7098 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AnnIndexPropertiesChecker.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AnnIndexPropertiesChecker.java
@@ -29,12 +29,13 @@ public class AnnIndexPropertiesChecker {
String quantizer = null;
int dimension = 0;
int numSubQuantizers = 0;
+ int nlist = 0;
for (String key : properties.keySet()) {
switch (key) {
case "index_type":
type = properties.get(key);
- if (!type.equals("hnsw")) {
- throw new AnalysisException("only support ann index
with type hnsw, got: " + type);
+ if (!type.equals("hnsw") && !type.equals("ivf")) {
+ throw new AnalysisException("only support ann index
with type hnsw or ivf, got: " + type);
}
break;
case "metric_type":
@@ -120,11 +121,29 @@ public class AnnIndexPropertiesChecker {
"pq_nbits of ann index must be a positive
integer, got: " + pqNbits);
}
break;
+ case "nlist":
+ String nlistStr = properties.get(key);
+ try {
+ nlist = Integer.parseInt(nlistStr);
+ if (nlist <= 0) {
+ throw new AnalysisException(
+ "nlist of ann index must be a positive
integer, got: " + nlistStr);
+ }
+ } catch (NumberFormatException e) {
+ throw new AnalysisException("nlist of ann index must
be a positive integer, got: " + nlistStr);
+ }
+ break;
default:
throw new AnalysisException("unknown ann index property: "
+ key);
}
}
+ if (type != null && type.equals("ivf")) {
+ if (nlist == 0) {
+ throw new AnalysisException("nlist of ann index must be
specified for ivf type");
+ }
+ }
+
if (type == null) {
throw new AnalysisException("index_type of ann index be
specified.");
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 91e18dcb42a..79ee8d1d06a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -880,6 +880,7 @@ public class SessionVariable implements Serializable,
Writable {
public static final String HNSW_EF_SEARCH = "hnsw_ef_search";
public static final String HNSW_CHECK_RELATIVE_DISTANCE =
"hnsw_check_relative_distance";
public static final String HNSW_BOUNDED_QUEUE = "hnsw_bounded_queue";
+ public static final String IVF_NPROBE = "ivf_nprobe";
public static final String DEFAULT_VARIANT_MAX_SUBCOLUMNS_COUNT =
"default_variant_max_subcolumns_count";
@@ -3041,6 +3042,11 @@ public class SessionVariable implements Serializable,
Writable {
"Whether to use a bounded priority queue to optimize HNSW
search performance"})
public boolean hnswBoundedQueue = true;
+ @VariableMgr.VarAttr(name = IVF_NPROBE, needForward = true,
+ description = {"IVF 索引的 nprobe 参数,控制搜索时访问的聚类数量",
+ "IVF index nprobe parameter, controls the number of
clusters to search"})
+ public int ivfNprobe = 1;
+
@VariableMgr.VarAttr(
name = DEFAULT_VARIANT_MAX_SUBCOLUMNS_COUNT,
needForward = true,
@@ -4830,6 +4836,7 @@ public class SessionVariable implements Serializable,
Writable {
tResult.setHnswEfSearch(hnswEFSearch);
tResult.setHnswCheckRelativeDistance(hnswCheckRelativeDistance);
tResult.setHnswBoundedQueue(hnswBoundedQueue);
+ tResult.setIvfNprobe(ivfNprobe);
tResult.setMergeReadSliceSize(mergeReadSliceSizeBytes);
tResult.setEnableExtendedRegex(enableExtendedRegex);
diff --git a/gensrc/thrift/PaloInternalService.thrift
b/gensrc/thrift/PaloInternalService.thrift
index 7d93887a8ca..d78f6e04ff3 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -416,6 +416,8 @@ struct TQueryOptions {
// Default 0 means use config::iceberg_sink_max_file_size
178: optional i64 iceberg_write_target_file_size_bytes = 0;
+ 182: optional i32 ivf_nprobe = 1;
+
// For cloud, to control if the content would be written into file cache
// In write path, to control if the content would be written into file cache.
// In read path, read from file cache or remote storage when execute query.
diff --git a/regression-test/data/ann_index_p0/ivf_index_test.out
b/regression-test/data/ann_index_p0/ivf_index_test.out
new file mode 100644
index 00000000000..91c7483c806
--- /dev/null
+++ b/regression-test/data/ann_index_p0/ivf_index_test.out
@@ -0,0 +1,17 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !sql --
+1 [1, 2, 3]
+2 [0.5, 2.1, 2.9]
+3 [10, 10, 10]
+4 [20, 20, 20]
+5 [50, 20, 20]
+6 [60, 20, 20]
+
+-- !sql --
+1 [1, 2, 3]
+2 [0.5, 2.1, 2.9]
+3 [10, 10, 10]
+4 [20, 20, 20]
+5 [50, 20, 20]
+6 [60, 20, 20]
+
diff --git a/regression-test/suites/ann_index_p0/create_ann_index_test.groovy
b/regression-test/suites/ann_index_p0/create_ann_index_test.groovy
index 7313b452bcc..e2fd6f75a25 100644
--- a/regression-test/suites/ann_index_p0/create_ann_index_test.groovy
+++ b/regression-test/suites/ann_index_p0/create_ann_index_test.groovy
@@ -182,7 +182,7 @@ suite("create_ann_index_test") {
id INT NOT NULL COMMENT "",
embedding ARRAY<FLOAT> NOT NULL COMMENT "",
INDEX idx_test_ann (`embedding`) USING ANN PROPERTIES(
- "index_type"="ivf",
+ "index_type"="unknown",
"metric_type"="l2_distance",
"dim"="1"
)
@@ -193,7 +193,7 @@ suite("create_ann_index_test") {
"replication_num" = "1"
);
"""
- exception "only support ann index with type hnsw"
+ exception "only support ann index with type hnsw or ivf"
}
// metric_type is incorrect
diff --git
a/regression-test/suites/ann_index_p0/create_tbl_with_ann_index_test.groovy
b/regression-test/suites/ann_index_p0/create_tbl_with_ann_index_test.groovy
index 99925c354d8..5677dc69789 100644
--- a/regression-test/suites/ann_index_p0/create_tbl_with_ann_index_test.groovy
+++ b/regression-test/suites/ann_index_p0/create_tbl_with_ann_index_test.groovy
@@ -46,7 +46,8 @@ suite("create_tbl_with_ann_index_test") {
INDEX ann_idx2 (vec) USING ANN PROPERTIES(
"index_type" = "ivf",
"metric_type" = "l2_distance",
- "dim" = "128"
+ "dim" = "128",
+ "nlist" = "128"
)
) ENGINE=OLAP
DUPLICATE KEY(id) COMMENT "OLAP"
@@ -55,7 +56,6 @@ suite("create_tbl_with_ann_index_test") {
"replication_num" = "1"
);
"""
- exception "only support ann index with type hnsw"
}
// metric_type 错误
@@ -294,4 +294,4 @@ suite("create_tbl_with_ann_index_test") {
);
"""
-}
\ No newline at end of file
+}
diff --git a/regression-test/suites/ann_index_p0/ivf_index_test.groovy
b/regression-test/suites/ann_index_p0/ivf_index_test.groovy
new file mode 100644
index 00000000000..231e728068a
--- /dev/null
+++ b/regression-test/suites/ann_index_p0/ivf_index_test.groovy
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite ("ivf_index_test") {
+ sql "set enable_common_expr_pushdown=true;"
+
+ // IVF index
+ sql "drop table if exists tbl_ann_l2"
+ sql """
+ CREATE TABLE tbl_ann_l2 (
+ id INT NOT NULL,
+ embedding ARRAY<FLOAT> NOT NULL,
+ INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+ "index_type"="ivf",
+ "metric_type"="l2_distance",
+ "nlist"="3",
+ "dim"="3"
+ )
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES ("replication_num" = "1");
+ """
+
+ sql """
+ INSERT INTO tbl_ann_l2 VALUES
+ (1, [1.0, 2.0, 3.0]),
+ (2, [0.5, 2.1, 2.9]),
+ (3, [10.0, 10.0, 10.0]),
+ (4, [20.0, 20.0, 20.0]),
+ (5, [50.0, 20.0, 20.0]),
+ (6, [60.0, 20.0, 20.0]);
+ """
+ qt_sql "select * from tbl_ann_l2;"
+ // just approximate search
+ sql "select id, l2_distance_approximate(embedding, [1.0,2.0,3.0]) as dist
from tbl_ann_l2 order by dist limit 2;"
+
+ sql """drop table if exists tbl_ann_l2"""
+ test {
+ // missing nlist
+ sql """
+ CREATE TABLE tbl_ann_l2 (
+ id INT NOT NULL,
+ embedding ARRAY<FLOAT> NOT NULL,
+ INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+ "index_type"="ivf",
+ "metric_type"="l2_distance",
+ "dim"="3"
+ )
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES ("replication_num" = "1");
+ """
+ exception """nlist of ann index must be specified for ivf type"""
+ }
+
+ sql """
+ CREATE TABLE tbl_ann_l2 (
+ id INT NOT NULL,
+ embedding ARRAY<FLOAT> NOT NULL,
+ INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+ "index_type"="ivf",
+ "metric_type"="l2_distance",
+ "nlist"="3",
+ "dim"="3"
+ )
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES ("replication_num" = "1");
+ """
+ test {
+ // not enough training points
+ sql """
+ INSERT INTO tbl_ann_l2 VALUES
+ (1, [1.0, 2.0, 3.0]),
+ (2, [0.5, 2.1, 2.9]);
+ """
+ exception """exception occurred during training"""
+ }
+
+ sql "drop table if exists tbl_ann_ip"
+ sql """
+ CREATE TABLE tbl_ann_ip (
+ id INT NOT NULL,
+ embedding ARRAY<FLOAT> NOT NULL,
+ INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+ "index_type"="ivf",
+ "metric_type"="inner_product",
+ "nlist"="3",
+ "dim"="3"
+ )
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES ("replication_num" = "1");
+ """
+
+ sql """
+ INSERT INTO tbl_ann_ip VALUES
+ (1, [1.0, 2.0, 3.0]),
+ (2, [0.5, 2.1, 2.9]),
+ (3, [10.0, 10.0, 10.0]),
+ (4, [20.0, 20.0, 20.0]),
+ (5, [50.0, 20.0, 20.0]),
+ (6, [60.0, 20.0, 20.0]);
+ """
+ qt_sql "select * from tbl_ann_ip;"
+ // just approximate search
+ sql "select id, inner_product_approximate(embedding, [1.0,2.0,3.0]) as
dist from tbl_ann_ip order by dist desc limit 2;"
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]