This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 20d68fee949 [fix](ann-index) Fix ANN IVF/PQ recall, avoid init-time 
large ANN build-buffer reservation, and skip ANN index build for segments with 
insufficient rows. (#64082)
20d68fee949 is described below

commit 20d68fee949efbb260b0df83cd5fe4d00f29727c
Author: Qi Chen <[email protected]>
AuthorDate: Mon Jun 8 10:18:27 2026 +0800

    [fix](ann-index) Fix ANN IVF/PQ recall, avoid init-time large ANN 
build-buffer reservation, and skip ANN index build for segments with 
insufficient rows. (#64082)
---
 be/src/common/config.cpp                           |  10 +-
 be/src/common/config.h                             |   4 +-
 be/src/storage/index/ann/ann_index_writer.cpp      | 119 ++--
 be/src/storage/index/ann/ann_index_writer.h        |  20 +-
 be/src/storage/index/ann/faiss_ann_index.cpp       |   3 +-
 .../storage/index/ann/ann_index_writer_test.cpp    | 677 ++++++++-------------
 .../data/ann_index_p0/ivf_on_disk_index_test.out   |   4 +
 .../ivf_pq_full_buffer_train_recall.out            |   4 +
 .../data/ann_index_p0/ivf_pq_recall.out            |   9 +
 .../ann_index_build_min_segment_rows.groovy        |  66 ++
 .../ann_index_p0/ivf_on_disk_index_test.groovy     |  16 +-
 .../ivf_pq_full_buffer_train_recall.groovy         |  68 +++
 .../suites/ann_index_p0/ivf_pq_recall.groovy       |  85 +++
 13 files changed, 571 insertions(+), 514 deletions(-)

diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp
index 42be005ac9c..0ad6ce6d28f 100644
--- a/be/src/common/config.cpp
+++ b/be/src/common/config.cpp
@@ -1746,12 +1746,10 @@ 
DEFINE_mBool(enable_prefill_all_dbm_agg_cache_after_compaction, "true");
 DEFINE_String(ann_index_ivf_list_cache_limit, "70%");
 // Stale sweep time for ANN index IVF list cache in seconds. 3600s is 1 hour.
 DEFINE_mInt32(ann_index_ivf_list_cache_stale_sweep_time_sec, "3600");
-
-// Chunk size for ANN/vector index building per training/adding batch
-// 1M By default.
-DEFINE_mInt64(ann_index_build_chunk_size, "1000000");
-DEFINE_Validator(ann_index_build_chunk_size,
-                 [](const int64_t config) -> bool { return config > 0; });
+// Minimum segment rows required to persist an ANN index. 0 keeps the default 
behavior.
+DEFINE_mInt64(ann_index_build_min_segment_rows, "0");
+DEFINE_Validator(ann_index_build_min_segment_rows,
+                 [](const int64_t config) -> bool { return config >= 0; });
 
 DEFINE_mBool(enable_wal_tde, "false");
 
diff --git a/be/src/common/config.h b/be/src/common/config.h
index 91172721d45..c13e5494c6b 100644
--- a/be/src/common/config.h
+++ b/be/src/common/config.h
@@ -1802,8 +1802,8 @@ DECLARE_mInt32(max_segment_partial_column_cache_size);
 DECLARE_String(ann_index_ivf_list_cache_limit);
 // Stale sweep time for ANN index IVF list cache in seconds.
 DECLARE_mInt32(ann_index_ivf_list_cache_stale_sweep_time_sec);
-// Chunk size for ANN/vector index building per training/adding batch
-DECLARE_mInt64(ann_index_build_chunk_size);
+// Minimum segment rows required to persist an ANN index.
+DECLARE_mInt64(ann_index_build_min_segment_rows);
 
 DECLARE_mBool(enable_prefill_output_dbm_agg_cache_after_compaction);
 DECLARE_mBool(enable_prefill_all_dbm_agg_cache_after_compaction);
diff --git a/be/src/storage/index/ann/ann_index_writer.cpp 
b/be/src/storage/index/ann/ann_index_writer.cpp
index 28d348cc319..21911417c4f 100644
--- a/be/src/storage/index/ann/ann_index_writer.cpp
+++ b/be/src/storage/index/ann/ann_index_writer.cpp
@@ -17,11 +17,13 @@
 
 #include "storage/index/ann/ann_index_writer.h"
 
+#include <algorithm>
 #include <cstddef>
 #include <memory>
 #include <string>
 
 #include "common/cast_set.h"
+#include "common/config.h"
 #include "storage/index/ann/faiss_ann_index.h"
 #include "storage/index/inverted/inverted_index_fs_directory.h"
 
@@ -39,7 +41,7 @@ AnnIndexColumnWriter::AnnIndexColumnWriter(IndexFileWriter* 
index_file_writer,
                                            const TabletIndex* index_meta)
         : _index_file_writer(index_file_writer), _index_meta(index_meta) {}
 
-AnnIndexColumnWriter::~AnnIndexColumnWriter() {}
+AnnIndexColumnWriter::~AnnIndexColumnWriter() = default;
 
 Status AnnIndexColumnWriter::init() {
     Result<std::shared_ptr<DorisFSDirectory>> compound_dir = 
_index_file_writer->open(_index_meta);
@@ -77,9 +79,6 @@ Status AnnIndexColumnWriter::init() {
             index_type, build_parameter.dim, metric_type, 
build_parameter.max_degree,
             build_parameter.ef_construction, quantizer);
 
-    size_t block_size = AnnIndexColumnWriter::chunk_size() * 
build_parameter.dim;
-    _float_array.reserve(block_size);
-
     return Status::OK();
 }
 
@@ -87,7 +86,10 @@ Status AnnIndexColumnWriter::add_values(const std::string 
fn, const void* values
     return Status::OK();
 }
 
-void AnnIndexColumnWriter::close_on_error() {}
+void AnnIndexColumnWriter::close_on_error() {
+    PODArray<float> empty_buffered_vectors;
+    _buffered_vectors.swap(empty_buffered_vectors);
+}
 
 Status AnnIndexColumnWriter::add_array_values(size_t field_size, const void* 
value_ptr,
                                               const uint8_t* null_map, const 
uint8_t* offsets_ptr,
@@ -109,26 +111,10 @@ Status AnnIndexColumnWriter::add_array_values(size_t 
field_size, const void* val
 
     const float* p = reinterpret_cast<const float*>(value_ptr);
 
-    const size_t full_elements = AnnIndexColumnWriter::chunk_size() * dim;
-    size_t remaining_elements = num_rows * dim;
-    size_t src_offset = 0;
-    while (remaining_elements > 0) {
-        size_t available_space = full_elements - _float_array.size();
-        size_t elements_to_add = std::min(remaining_elements, available_space);
-
-        _float_array.insert(_float_array.end(), p + src_offset, p + src_offset 
+ elements_to_add);
-        src_offset += elements_to_add;
-        remaining_elements -= elements_to_add;
-
-        if (_float_array.size() == full_elements) {
-            RETURN_IF_ERROR(
-                    _vector_index->train(AnnIndexColumnWriter::chunk_size(), 
_float_array.data()));
-            RETURN_IF_ERROR(
-                    _vector_index->add(AnnIndexColumnWriter::chunk_size(), 
_float_array.data()));
-            _float_array.clear();
-            _need_save_index = true;
-        }
-    }
+    // The offsets check above guarantees every array row matches the ANN 
index dimension.
+    DCHECK(p != nullptr);
+    _buffered_vectors.insert(_buffered_vectors.end(), p, p + num_rows * dim);
+    _total_rows += cast_set<int64_t>(num_rows);
 
     return Status::OK();
 }
@@ -146,54 +132,41 @@ int64_t AnnIndexColumnWriter::size() const {
 }
 
 Status AnnIndexColumnWriter::finish() {
-    Int64 min_train_rows = _vector_index->get_min_train_rows();
-
-    // Check if we have enough rows to train the index
-    // train/add the remaining data
-    if (_float_array.empty()) {
-        if (_need_save_index) {
-            return _vector_index->save(_dir.get());
-        } else {
-            // No data was added at all. This can happen if the segment has 0 
rows
-            // or all rows were filtered out. We need to delete the directory 
entry
-            // to avoid writing an empty/invalid index file.
-            LOG_INFO("No data to train/add for ANN index. Skipping index 
building.");
-            return _index_file_writer->delete_index(_index_meta);
-        }
-    } else {
-        DCHECK(_float_array.size() % _vector_index->get_dimension() == 0);
-
-        Int64 num_rows = _float_array.size() / _vector_index->get_dimension();
-
-        if (num_rows >= min_train_rows) {
-            RETURN_IF_ERROR(_vector_index->train(num_rows, 
_float_array.data()));
-            RETURN_IF_ERROR(_vector_index->add(num_rows, _float_array.data()));
-            _float_array.clear();
-            return _vector_index->save(_dir.get());
-        } else {
-            // It happens to have not enough data to train.
-            // If we have data to add before, we still need to save the index.
-            if (_need_save_index) {
-                // For IVF indexes, adding remaining vectors without training 
is acceptable
-                // because the quantizer was already trained on previous 
batches. These vectors
-                // are simply added to the nearest clusters without retraining.
-                RETURN_IF_ERROR(_vector_index->add(num_rows, 
_float_array.data()));
-                _float_array.clear();
-                return _vector_index->save(_dir.get());
-            } else {
-                // Not enough data to train and no data added before.
-                // Means this is a very small segment, we can skip the index 
building.
-                // We need to delete the directory entry from 
index_file_writer to avoid
-                // writing an empty/invalid index file which causes 
"IndexInput read past EOF" error.
-                LOG_INFO(
-                        "Remaining data size {} is less than minimum {} rows 
required for ANN "
-                        "index "
-                        "training. Skipping index building for this segment.",
-                        num_rows, min_train_rows);
-                _float_array.clear();
-                return _index_file_writer->delete_index(_index_meta);
-            }
-        }
+    if (_total_rows == 0) {
+        LOG_INFO("No data to train/add for ANN index. Skipping index 
building.");
+        return _index_file_writer->delete_index(_index_meta);
+    }
+
+    const Int64 min_train_rows = _vector_index->get_min_train_rows();
+    const Int64 effective_min_rows =
+            std::max(min_train_rows, 
cast_set<Int64>(config::ann_index_build_min_segment_rows));
+    if (_total_rows < effective_min_rows) {
+        LOG_INFO(
+                "Total data size {} is less than minimum {} rows required for 
ANN index build. "
+                "Skipping index building for this segment.",
+                _total_rows, effective_min_rows);
+        PODArray<float> empty_buffered_vectors;
+        _buffered_vectors.swap(empty_buffered_vectors);
+        return _index_file_writer->delete_index(_index_meta);
+    }
+
+    return _build_and_save(min_train_rows, effective_min_rows);
+}
+
+Status AnnIndexColumnWriter::_build_and_save(Int64 min_train_rows, Int64 
effective_min_rows) {
+    const size_t dim = _vector_index->get_dimension();
+    DCHECK(_buffered_vectors.size() % dim == 0);
+    const Int64 train_rows = cast_set<Int64>(_buffered_vectors.size() / dim);
+    DORIS_CHECK(train_rows == _total_rows);
+    DORIS_CHECK(train_rows >= effective_min_rows);
+    if (min_train_rows > 0) {
+        RETURN_IF_ERROR(_vector_index->train(train_rows, 
_buffered_vectors.data()));
     }
+    RETURN_IF_ERROR(_vector_index->add(train_rows, _buffered_vectors.data()));
+    // PODArray::clear() keeps the allocated capacity. Swap with an empty 
array so the
+    // full-segment build buffer is released before saving the index.
+    PODArray<float> empty_buffered_vectors;
+    _buffered_vectors.swap(empty_buffered_vectors);
+    return _vector_index->save(_dir.get());
 }
 } // namespace doris::segment_v2
diff --git a/be/src/storage/index/ann/ann_index_writer.h 
b/be/src/storage/index/ann/ann_index_writer.h
index 7b7e63f8574..67061bef921 100644
--- a/be/src/storage/index/ann/ann_index_writer.h
+++ b/be/src/storage/index/ann/ann_index_writer.h
@@ -27,7 +27,6 @@
 #include <roaring/roaring.hh>
 #include <string>
 
-#include "common/config.h"
 #include "core/pod_array.h"
 #include "storage/index/ann/ann_index.h"
 #include "storage/index/index_file_writer.h"
@@ -38,13 +37,6 @@
 namespace doris::segment_v2 {
 class AnnIndexColumnWriter : public IndexColumnWriter {
 public:
-    static inline int64_t chunk_size() {
-#ifdef BE_TEST
-        return 10;
-#else
-        return config::ann_index_build_chunk_size;
-#endif
-    }
     static constexpr const char* INDEX_TYPE = "index_type";
     static constexpr const char* METRIC_TYPE = "metric_type";
     static constexpr const char* DIM = "dim";
@@ -71,16 +63,20 @@ public:
     Status finish() override;
 
 private:
+    Status _build_and_save(Int64 min_train_rows, Int64 effective_min_rows);
+
+#ifdef BE_TEST
+    friend class TestAnnIndexColumnWriter;
+#endif
+
     // VectorIndex shoule be managed by some cache.
     // VectorIndex should be weak shared by AnnIndexWriter and 
VectorIndexReader
     // This should be a weak_ptr
     std::shared_ptr<VectorIndex> _vector_index;
-    // _float_array is used to buffer the float data before training/adding to 
vector index
-    // if we dont do this, the performance(recall) will be very poor when 
adding small number of vectors one by one
-    PODArray<float> _float_array;
+    PODArray<float> _buffered_vectors;
+    int64_t _total_rows = 0;
     IndexFileWriter* _index_file_writer;
     const TabletIndex* _index_meta;
     std::shared_ptr<DorisFSDirectory> _dir;
-    bool _need_save_index = false;
 };
 } // namespace doris::segment_v2
diff --git a/be/src/storage/index/ann/faiss_ann_index.cpp 
b/be/src/storage/index/ann/faiss_ann_index.cpp
index f933f3c683f..68b06db2b90 100644
--- a/be/src/storage/index/ann/faiss_ann_index.cpp
+++ b/be/src/storage/index/ann/faiss_ann_index.cpp
@@ -501,7 +501,8 @@ Int64 FaissVectorIndex::get_min_train_rows() const {
     // For IVF indexes, the minimum number of training points should be at 
least
     // equal to the number of clusters (nlist). FAISS requires this for 
k-means clustering.
     Int64 ivf_min = 0;
-    if (_params.index_type == FaissBuildParameter::IndexType::IVF) {
+    if (_params.index_type == FaissBuildParameter::IndexType::IVF ||
+        _params.index_type == FaissBuildParameter::IndexType::IVF_ON_DISK) {
         ivf_min = _params.ivf_nlist;
     }
 
diff --git a/be/test/storage/index/ann/ann_index_writer_test.cpp 
b/be/test/storage/index/ann/ann_index_writer_test.cpp
index bb30f9e1979..20107c90779 100644
--- a/be/test/storage/index/ann/ann_index_writer_test.cpp
+++ b/be/test/storage/index/ann/ann_index_writer_test.cpp
@@ -26,10 +26,13 @@
 #include <string>
 #include <vector>
 
+#include "common/config.h"
+#include "storage/index/ann/faiss_ann_index.h"
 #include "storage/index/ann/vector_search_utils.h"
 #include "storage/index/index_file_writer.h"
 #include "storage/index/inverted/inverted_index_fs_directory.h"
 #include "storage/tablet/tablet_schema.h"
+#include "util/defer_op.h"
 
 using namespace doris::vector_search_utils;
 
@@ -60,7 +63,8 @@ public:
             : AnnIndexColumnWriter(index_file_writer, index_meta) {}
 
     void set_vector_index(std::shared_ptr<VectorIndex> index) { _vector_index 
= index; }
-    void set_need_save_index(bool value) { _need_save_index = value; }
+    size_t buffered_vector_capacity() const { return 
_buffered_vectors.capacity(); }
+    size_t buffered_vector_rows(size_t dim) const { return 
_buffered_vectors.size() / dim; }
 };
 
 class AnnIndexWriterTest : public ::testing::Test {
@@ -165,6 +169,18 @@ TEST_F(AnnIndexWriterTest, 
TestInitWithDifferentProperties) {
     }
 }
 
+TEST_F(AnnIndexWriterTest, TestInitDoesNotPreallocateBuildBuffer) {
+    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
+                                                             
_tablet_index.get());
+
+    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
+    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
+    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
+
+    ASSERT_TRUE(writer->init().ok());
+    EXPECT_EQ(writer->buffered_vector_capacity(), 0);
+}
+
 TEST_F(AnnIndexWriterTest, TestAddArrayValuesSuccess) {
     auto writer =
             std::make_unique<AnnIndexColumnWriter>(_index_file_writer.get(), 
_tablet_index.get());
@@ -415,7 +431,7 @@ TEST_F(AnnIndexWriterTest, TestInvalidMetricType) {
     EXPECT_THROW(writer->init(), doris::Exception);
 }
 
-TEST_F(AnnIndexWriterTest, TestAddMoreThanChunkSize) {
+TEST_F(AnnIndexWriterTest, TestNoTrainIndexAddsAtFinish) {
     auto mock_index = std::make_shared<MockVectorIndex>();
     auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
                                                              
_tablet_index.get());
@@ -427,55 +443,133 @@ TEST_F(AnnIndexWriterTest, TestAddMoreThanChunkSize) {
     ASSERT_TRUE(writer->init().ok());
     writer->set_vector_index(mock_index);
 
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(1)
-            .WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, train(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(0));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
 
-    // CHUNK_SIZE = 10
     const size_t dim = 4;
-
-    {
-        const size_t num_rows = 6;
-        std::vector<float> vectors = {
-                1.0f,  2.0f,  3.0f,  4.0f,  // Row 0
-                5.0f,  6.0f,  7.0f,  8.0f,  // Row 1
-                9.0f,  10.0f, 11.0f, 12.0f, // Row 2
-                13.0f, 14.0f, 15.0f, 16.0f, // Row 3
-                17.0f, 18.0f, 19.0f, 20.0f, // Row 4
-                21.0f, 22.0f, 23.0f, 24.0f  // Row 5
-        };
-        std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
+    constexpr size_t batch_rows = 6;
+    for (int batch = 0; batch < 2; ++batch) {
+        std::vector<float> vectors(batch_rows * dim);
+        for (size_t i = 0; i < vectors.size(); ++i) {
+            vectors[i] = static_cast<float>(batch * vectors.size() + i);
+        }
+        std::vector<size_t> offsets;
+        for (size_t row = 0; row <= batch_rows; ++row) {
+            offsets.push_back(row * dim);
+        }
 
         Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
                                                  reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
+                                                 batch_rows);
         EXPECT_TRUE(status.ok());
     }
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 2 * batch_rows);
 
-    {
-        const size_t num_rows = 6;
-        std::vector<float> vectors = {
-                25.0f, 26.0f, 27.0f, 28.0f, // Row 6
-                29.0f, 30.0f, 31.0f, 32.0f, // Row 7
-                33.0f, 34.0f, 35.0f, 36.0f, // Row 8
-                37.0f, 38.0f, 39.0f, 40.0f, // Row 9
-                41.0f, 42.0f, 43.0f, 44.0f, // Row 10
-                45.0f, 46.0f, 47.0f, 48.0f  // Row 11
-        };
-        std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
+    EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(mock_index.get()));
 
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(0));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    {
+        testing::InSequence sequence;
+        EXPECT_CALL(*mock_index, add(12, testing::_))
+                .Times(1)
+                .WillOnce(testing::Return(Status::OK()));
+        EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
     }
 
     Status status = writer->finish();
     EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 0);
+}
+
+TEST_F(AnnIndexWriterTest, 
TestNoTrainIndexSkipsWhenRowsLessThanMinSegmentRows) {
+    const int64_t old_min_segment_rows = 
config::ann_index_build_min_segment_rows;
+    config::ann_index_build_min_segment_rows = 5;
+    doris::Defer restore_config {
+            [&] { config::ann_index_build_min_segment_rows = 
old_min_segment_rows; }};
+
+    auto mock_index = std::make_shared<MockVectorIndex>();
+    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
+                                                             
_tablet_index.get());
+
+    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
+    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
+    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
+
+    ASSERT_TRUE(writer->init().ok());
+    writer->set_vector_index(mock_index);
+
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(0));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
+
+    const size_t dim = 4;
+    const size_t num_rows = 3;
+    std::vector<float> vectors(num_rows * dim);
+    for (size_t i = 0; i < vectors.size(); ++i) {
+        vectors[i] = static_cast<float>(i);
+    }
+    std::vector<size_t> offsets;
+    for (size_t row = 0; row <= num_rows; ++row) {
+        offsets.push_back(row * dim);
+    }
+
+    Status status =
+            writer->add_array_values(sizeof(float), vectors.data(), nullptr,
+                                     reinterpret_cast<const 
uint8_t*>(offsets.data()), num_rows);
+    EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), num_rows);
+
+    status = writer->finish();
+    EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 0);
+}
+
+TEST_F(AnnIndexWriterTest, TestTrainRequiredIndexUsesEffectiveMinSegmentRows) {
+    const int64_t old_min_segment_rows = 
config::ann_index_build_min_segment_rows;
+    config::ann_index_build_min_segment_rows = 10;
+    doris::Defer restore_config {
+            [&] { config::ann_index_build_min_segment_rows = 
old_min_segment_rows; }};
+
+    auto mock_index = std::make_shared<MockVectorIndex>();
+    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
+                                                             
_tablet_index.get());
+
+    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
+    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
+    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
+
+    ASSERT_TRUE(writer->init().ok());
+    writer->set_vector_index(mock_index);
+
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
+
+    const size_t dim = 4;
+    const size_t num_rows = 6;
+    std::vector<float> vectors(num_rows * dim);
+    for (size_t i = 0; i < vectors.size(); ++i) {
+        vectors[i] = static_cast<float>(i);
+    }
+    std::vector<size_t> offsets;
+    for (size_t row = 0; row <= num_rows; ++row) {
+        offsets.push_back(row * dim);
+    }
+
+    Status status =
+            writer->add_array_values(sizeof(float), vectors.data(), nullptr,
+                                     reinterpret_cast<const 
uint8_t*>(offsets.data()), num_rows);
+    EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), num_rows);
+
+    status = writer->finish();
+    EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 0);
 }
 
 TEST_F(AnnIndexWriterTest, TestCreateFromIndexColumnWriter) {
@@ -566,7 +660,7 @@ TEST_F(AnnIndexWriterTest, TestAddArrayValuesIVF) {
     EXPECT_TRUE(status.ok());
 }
 
-TEST_F(AnnIndexWriterTest, TestAddMoreThanChunkSizeIVF) {
+TEST_F(AnnIndexWriterTest, TestSmallTrainRequiredIndexUsesMemoryBuffer) {
     auto mock_index = std::make_shared<MockVectorIndex>();
     auto properties = _properties;
     properties["index_type"] = "ivf";
@@ -587,133 +681,51 @@ TEST_F(AnnIndexWriterTest, TestAddMoreThanChunkSizeIVF) {
     ASSERT_TRUE(writer->init().ok());
     writer->set_vector_index(mock_index);
 
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(1)
-            .WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, train(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
 
-    // CHUNK_SIZE = 10
     const size_t dim = 4;
-
-    {
-        const size_t num_rows = 6;
-        std::vector<float> vectors = {
-                1.0f,  2.0f,  3.0f,  4.0f,  // Row 0
-                5.0f,  6.0f,  7.0f,  8.0f,  // Row 1
-                9.0f,  10.0f, 11.0f, 12.0f, // Row 2
-                13.0f, 14.0f, 15.0f, 16.0f, // Row 3
-                17.0f, 18.0f, 19.0f, 20.0f, // Row 4
-                21.0f, 22.0f, 23.0f, 24.0f  // Row 5
-        };
-        std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+    const size_t num_rows = 4;
+    std::vector<float> vectors(num_rows * dim);
+    for (size_t i = 0; i < vectors.size(); ++i) {
+        vectors[i] = static_cast<float>(i);
     }
-
-    {
-        const size_t num_rows = 6;
-        std::vector<float> vectors = {
-                25.0f, 26.0f, 27.0f, 28.0f, // Row 6
-                29.0f, 30.0f, 31.0f, 32.0f, // Row 7
-                33.0f, 34.0f, 35.0f, 36.0f, // Row 8
-                37.0f, 38.0f, 39.0f, 40.0f, // Row 9
-                41.0f, 42.0f, 43.0f, 44.0f, // Row 10
-                45.0f, 46.0f, 47.0f, 48.0f  // Row 11
-        };
-        std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+    std::vector<size_t> offsets;
+    for (size_t row = 0; row <= num_rows; ++row) {
+        offsets.push_back(row * dim);
     }
 
-    Status status = writer->finish();
+    Status status =
+            writer->add_array_values(sizeof(float), vectors.data(), nullptr,
+                                     reinterpret_cast<const 
uint8_t*>(offsets.data()), num_rows);
     EXPECT_TRUE(status.ok());
-}
-
-TEST_F(AnnIndexWriterTest, TestSkipTrainWhenRemainderLessThanNlist) {
-    auto mock_index = std::make_shared<MockVectorIndex>();
-    auto properties = _properties;
-    properties["index_type"] = "ivf";
-    properties["nlist"] = "5"; // Set nlist to 5
-    properties["quantizer"] = "flat";
-
-    auto tablet_index = std::make_unique<TabletIndex>();
-    tablet_index->_properties = properties;
-    tablet_index->_index_id = 1;
-
-    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
-                                                             
tablet_index.get());
-
-    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
-    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
-    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
+    EXPECT_EQ(writer->buffered_vector_rows(dim), num_rows);
+    EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(mock_index.get()));
 
-    ASSERT_TRUE(writer->init().ok());
-    writer->set_vector_index(mock_index);
-
-    // CHUNK_SIZE = 10, nlist = 5
-    // Add 12 rows: first 10 will be trained/added in one batch, remaining 2 < 
5
-    // Since we have trained data before (_need_save_index = true), we should 
add the remaining 2 rows and save
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(5));
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(1)
-            .WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-
-    const size_t dim = 4;
-
-    // Add 12 rows total
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
     {
-        const size_t num_rows = 10;
-        std::vector<float> vectors(10 * 4);
-        for (size_t i = 0; i < 10 * 4; ++i) {
-            vectors[i] = static_cast<float>(i);
-        }
-        std::vector<size_t> offsets;
-        for (size_t i = 0; i <= num_rows; ++i) {
-            offsets.push_back(i * 4);
-        }
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
-    }
-
-    // Add 2 more rows
-    {
-        const size_t num_rows = 2;
-        std::vector<float> vectors = {
-                40.0f, 41.0f, 42.0f, 43.0f, // Row 10
-                44.0f, 45.0f, 46.0f, 47.0f  // Row 11
-        };
-        std::vector<size_t> offsets = {0, 4, 8};
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+        testing::InSequence sequence;
+        EXPECT_CALL(*mock_index, train(4, testing::_))
+                .Times(1)
+                .WillOnce(testing::Return(Status::OK()));
+        EXPECT_CALL(*mock_index, add(4, testing::_))
+                .Times(1)
+                .WillOnce(testing::Return(Status::OK()));
+        EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
     }
 
-    Status status = writer->finish();
+    status = writer->finish();
     EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 0);
 }
 
-TEST_F(AnnIndexWriterTest, TestLargeDataVolumeWithRemainderSkip) {
+TEST_F(AnnIndexWriterTest, TestTrainRequiredIndexTrainsOnceAndAddsAllRows) {
     auto mock_index = std::make_shared<MockVectorIndex>();
     auto properties = _properties;
     properties["index_type"] = "ivf";
-    properties["nlist"] = "3"; // Set nlist to 3
+    properties["nlist"] = "2";
     properties["quantizer"] = "flat";
 
     auto tablet_index = std::make_unique<TabletIndex>();
@@ -730,32 +742,24 @@ TEST_F(AnnIndexWriterTest, 
TestLargeDataVolumeWithRemainderSkip) {
     ASSERT_TRUE(writer->init().ok());
     writer->set_vector_index(mock_index);
 
-    // CHUNK_SIZE = 10, nlist = 3
-    // Add 23 rows: 2 full chunks of 10, remaining 3 == nlist, so train 
remaining
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(3));
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(2)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, testing::_))
-            .Times(2)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, train(3, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(3, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
 
     const size_t dim = 4;
 
-    // Add 3 batches: 10 + 10 + 3 = 23 rows
-    for (int batch = 0; batch < 2; ++batch) {
-        const size_t num_rows = 10;
-        std::vector<float> vectors(10 * 4);
-        for (size_t i = 0; i < 10 * 4; ++i) {
-            vectors[i] = static_cast<float>(batch * 40 + i);
-        }
-        std::vector<size_t> offsets;
-        for (size_t i = 0; i <= num_rows; ++i) {
-            offsets.push_back(i * 4);
-        }
+    {
+        const size_t num_rows = 6;
+        std::vector<float> vectors = {
+                1.0f,  2.0f,  3.0f,  4.0f,  // Row 0
+                5.0f,  6.0f,  7.0f,  8.0f,  // Row 1
+                9.0f,  10.0f, 11.0f, 12.0f, // Row 2
+                13.0f, 14.0f, 15.0f, 16.0f, // Row 3
+                17.0f, 18.0f, 19.0f, 20.0f, // Row 4
+                21.0f, 22.0f, 23.0f, 24.0f  // Row 5
+        };
+        std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
 
         Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
                                                  reinterpret_cast<const 
uint8_t*>(offsets.data()),
@@ -763,15 +767,17 @@ TEST_F(AnnIndexWriterTest, 
TestLargeDataVolumeWithRemainderSkip) {
         EXPECT_TRUE(status.ok());
     }
 
-    // Add remaining 3 rows
     {
-        const size_t num_rows = 3;
+        const size_t num_rows = 6;
         std::vector<float> vectors = {
-                80.0f, 81.0f, 82.0f, 83.0f, // Row 20
-                84.0f, 85.0f, 86.0f, 87.0f, // Row 21
-                88.0f, 89.0f, 90.0f, 91.0f  // Row 22
+                25.0f, 26.0f, 27.0f, 28.0f, // Row 6
+                29.0f, 30.0f, 31.0f, 32.0f, // Row 7
+                33.0f, 34.0f, 35.0f, 36.0f, // Row 8
+                37.0f, 38.0f, 39.0f, 40.0f, // Row 9
+                41.0f, 42.0f, 43.0f, 44.0f, // Row 10
+                45.0f, 46.0f, 47.0f, 48.0f  // Row 11
         };
-        std::vector<size_t> offsets = {0, 4, 8, 12};
+        std::vector<size_t> offsets = {0, 4, 8, 12, 16, 20, 24};
 
         Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
                                                  reinterpret_cast<const 
uint8_t*>(offsets.data()),
@@ -779,96 +785,29 @@ TEST_F(AnnIndexWriterTest, 
TestLargeDataVolumeWithRemainderSkip) {
         EXPECT_TRUE(status.ok());
     }
 
-    Status status = writer->finish();
-    EXPECT_TRUE(status.ok());
-}
-
-TEST_F(AnnIndexWriterTest, TestLargeDataVolumeSkipRemainder) {
-    auto mock_index = std::make_shared<MockVectorIndex>();
-    auto properties = _properties;
-    properties["index_type"] = "ivf";
-    properties["nlist"] = "4"; // Set nlist to 4
-    properties["quantizer"] = "flat";
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 12);
+    EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(mock_index.get()));
 
-    auto tablet_index = std::make_unique<TabletIndex>();
-    tablet_index->_properties = properties;
-    tablet_index->_index_id = 1;
-
-    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
-                                                             
tablet_index.get());
-
-    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
-    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
-    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
-
-    ASSERT_TRUE(writer->init().ok());
-    writer->set_vector_index(mock_index);
-
-    // CHUNK_SIZE = 10, nlist = 4
-    // Add 22 rows: 2 full chunks of 10, remaining 2 < 4
-    // Since we have trained data before (_need_save_index = true), we should 
add the remaining 2 rows and save
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(4));
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(2)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, testing::_))
-            .Times(2)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-
-    const size_t dim = 4;
-
-    // Add 2 batches of 10 rows
-    for (int batch = 0; batch < 2; ++batch) {
-        const size_t num_rows = 10;
-        std::vector<float> vectors(10 * 4);
-        for (size_t i = 0; i < 10 * 4; ++i) {
-            vectors[i] = static_cast<float>(batch * 40 + i);
-        }
-        std::vector<size_t> offsets;
-        for (size_t i = 0; i <= num_rows; ++i) {
-            offsets.push_back(i * 4);
-        }
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
-    }
-
-    // Add remaining 2 rows
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
     {
-        const size_t num_rows = 2;
-        std::vector<float> vectors = {
-                80.0f, 81.0f, 82.0f, 83.0f, // Row 20
-                84.0f, 85.0f, 86.0f, 87.0f  // Row 21
-        };
-        std::vector<size_t> offsets = {0, 4, 8};
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+        testing::InSequence sequence;
+        EXPECT_CALL(*mock_index, train(12, testing::_))
+                .Times(1)
+                .WillOnce(testing::Return(Status::OK()));
+        EXPECT_CALL(*mock_index, add(12, testing::_))
+                .Times(1)
+                .WillOnce(testing::Return(Status::OK()));
+        EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
     }
 
     Status status = writer->finish();
     EXPECT_TRUE(status.ok());
 }
 
-TEST_F(AnnIndexWriterTest, TestSkipIndexWhenTotalRowsLessThanNlist) {
+TEST_F(AnnIndexWriterTest, TestTrainRequiredIndexTrainsWithAllBufferedRows) {
     auto mock_index = std::make_shared<MockVectorIndex>();
-    auto properties = _properties;
-    properties["index_type"] = "ivf";
-    properties["nlist"] = "5"; // Set nlist to 5
-    properties["quantizer"] = "flat";
-
-    auto tablet_index = std::make_unique<TabletIndex>();
-    tablet_index->_properties = properties;
-    tablet_index->_index_id = 1;
-
     auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
-                                                             
tablet_index.get());
+                                                             
_tablet_index.get());
 
     auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
     fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
@@ -876,95 +815,56 @@ TEST_F(AnnIndexWriterTest, 
TestSkipIndexWhenTotalRowsLessThanNlist) {
 
     ASSERT_TRUE(writer->init().ok());
     writer->set_vector_index(mock_index);
-    writer->set_need_save_index(false); // No previous training, so should 
skip entirely
 
-    // Add only 3 rows, which is less than nlist (5)
-    // Since no data was trained before (_need_save_index = false), we should 
skip index building entirely
-    // No train, add, or save should be called
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(5));
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
     EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
     EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
     EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
 
     const size_t dim = 4;
-
-    // Add 3 rows
-    {
-        const size_t num_rows = 3;
-        std::vector<float> vectors = {
-                1.0f, 2.0f,  3.0f,  4.0f, // Row 0
-                5.0f, 6.0f,  7.0f,  8.0f, // Row 1
-                9.0f, 10.0f, 11.0f, 12.0f // Row 2
-        };
-        std::vector<size_t> offsets = {0, 4, 8, 12};
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+    const size_t num_rows = 20;
+    std::vector<float> vectors(num_rows * dim);
+    for (size_t row = 0; row < num_rows; ++row) {
+        for (size_t col = 0; col < dim; ++col) {
+            vectors[row * dim + col] = static_cast<float>(row);
+        }
+    }
+    std::vector<size_t> offsets;
+    for (size_t row = 0; row <= num_rows; ++row) {
+        offsets.push_back(row * dim);
     }
 
-    Status status = writer->finish();
+    Status status =
+            writer->add_array_values(sizeof(float), vectors.data(), nullptr,
+                                     reinterpret_cast<const 
uint8_t*>(offsets.data()), num_rows);
     EXPECT_TRUE(status.ok());
-}
-
-TEST_F(AnnIndexWriterTest, TestPQMinTrainRows) {
-    // Test writer behavior under a large mocked min_train_rows threshold.
-
-    auto mock_index = std::make_shared<MockVectorIndex>();
-    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
-                                                             
_tablet_index.get());
-
-    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
-    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
-    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
-
-    ASSERT_TRUE(writer->init().ok());
-    writer->set_vector_index(mock_index);
-
-    // Set up expectations: mock a very large min_train_rows threshold.
-    // Since we only provide 1000 vectors, which is less than 131072, training 
will happen in batches
-    // but finish() will skip saving since remaining data is insufficient
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(131072));
-    // 1000 vectors will be processed in 100 batches of 10 vectors each
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(100)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, testing::_))
-            .Times(100)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    // Since we have trained data in batches, the index will be saved even 
though total data is insufficient
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-
-    const size_t dim = 4;
+    EXPECT_EQ(writer->buffered_vector_rows(dim), num_rows);
+    EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(mock_index.get()));
 
-    // Add only 1000 rows, which is less than the required 131072
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(2));
     {
-        const size_t num_rows = 1000;
-        std::vector<float> vectors(num_rows * dim);
-        for (size_t i = 0; i < num_rows * dim; ++i) {
-            vectors[i] = static_cast<float>(i % 100);
-        }
-        std::vector<size_t> offsets;
-        for (size_t i = 0; i <= num_rows; ++i) {
-            offsets.push_back(i * dim);
-        }
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+        testing::InSequence sequence;
+        EXPECT_CALL(*mock_index, train(20, testing::_))
+                .Times(1)
+                .WillOnce(testing::Invoke([&](Int64 n, const float* vec) {
+                    EXPECT_EQ(n, num_rows);
+                    for (size_t row = 0; row < static_cast<size_t>(n); ++row) {
+                        const auto row_id = static_cast<size_t>(vec[row * 
dim]);
+                        EXPECT_EQ(row_id, row);
+                    }
+                    return Status::OK();
+                }));
+        EXPECT_CALL(*mock_index, add(20, testing::_))
+                .Times(1)
+                .WillOnce(testing::Return(Status::OK()));
+        EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
     }
 
-    // Finish should skip index building due to insufficient training data
-    Status status = writer->finish();
+    status = writer->finish();
     EXPECT_TRUE(status.ok());
 }
 
-TEST_F(AnnIndexWriterTest, TestSQMinTrainRows) {
-    // Test that SQ quantizer requires sufficient training data
-    // SQ requires at least nlist * 2 = 10 * 2 = 20 training vectors
-
+TEST_F(AnnIndexWriterTest, TestSkipIndexWhenTotalRowsLessThanMinTrainRows) {
     auto mock_index = std::make_shared<MockVectorIndex>();
     auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
                                                              
_tablet_index.get());
@@ -976,94 +876,49 @@ TEST_F(AnnIndexWriterTest, TestSQMinTrainRows) {
     ASSERT_TRUE(writer->init().ok());
     writer->set_vector_index(mock_index);
 
-    // Set up expectations: SQ should require at least 20 training vectors
-    // Since we only provide 15 vectors, training will happen in batches but 
finish() will skip saving
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(20));
-    // 15 vectors will be processed in 1 batch of 10 vectors and remaining 5 
vectors
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(1)
-            .WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(5, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    // Since we have trained data, the index will be saved even though total 
data is insufficient
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(5));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
 
     const size_t dim = 4;
-
-    // Add only 15 rows, which is less than the required 20
-    {
-        const size_t num_rows = 15;
-        std::vector<float> vectors(num_rows * dim);
-        for (size_t i = 0; i < num_rows * dim; ++i) {
-            vectors[i] = static_cast<float>(i % 100);
-        }
-        std::vector<size_t> offsets;
-        for (size_t i = 0; i <= num_rows; ++i) {
-            offsets.push_back(i * dim);
-        }
-
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
+    const size_t num_rows = 3;
+    std::vector<float> vectors(num_rows * dim);
+    for (size_t i = 0; i < vectors.size(); ++i) {
+        vectors[i] = static_cast<float>(i);
+    }
+    std::vector<size_t> offsets;
+    for (size_t row = 0; row <= num_rows; ++row) {
+        offsets.push_back(row * dim);
     }
 
-    // Finish should skip index building due to insufficient training data
-    Status status = writer->finish();
+    Status status =
+            writer->add_array_values(sizeof(float), vectors.data(), nullptr,
+                                     reinterpret_cast<const 
uint8_t*>(offsets.data()), num_rows);
     EXPECT_TRUE(status.ok());
-}
-
-TEST_F(AnnIndexWriterTest, TestPQWithSufficientData) {
-    // Test that PQ works when sufficient training data is provided
-
-    auto mock_index = std::make_shared<MockVectorIndex>();
-    auto writer = 
std::make_unique<TestAnnIndexColumnWriter>(_index_file_writer.get(),
-                                                             
_tablet_index.get());
-
-    auto fs_dir = std::make_shared<DorisRAMFSDirectory>();
-    fs_dir->init(doris::io::global_local_filesystem(), 
"./ut_dir/tmp_vector_search", nullptr);
-    EXPECT_CALL(*_index_file_writer, 
open(testing::_)).WillOnce(testing::Return(fs_dir));
-
-    ASSERT_TRUE(writer->init().ok());
-    writer->set_vector_index(mock_index);
-
-    // Mock min_train_rows to 131072 and provide exactly that amount.
-    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(131072));
-    // Since we provide exactly 131072 vectors, they will be trained and added 
in chunks
-    // Each chunk is 10 vectors, so we expect 13107 train calls and 13107 add 
calls for full chunks
-    EXPECT_CALL(*mock_index, train(10, testing::_))
-            .Times(13107)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, add(10, testing::_))
-            .Times(13107)
-            .WillRepeatedly(testing::Return(Status::OK()));
-    // The remaining 2 vectors will be added without training since 
min_train_rows > 2
-    EXPECT_CALL(*mock_index, add(2, 
testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
-    EXPECT_CALL(*mock_index, 
save(testing::_)).Times(1).WillOnce(testing::Return(Status::OK()));
 
-    const size_t dim = 4;
+    EXPECT_EQ(writer->buffered_vector_rows(dim), num_rows);
+    EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(mock_index.get()));
+    EXPECT_CALL(*mock_index, 
get_min_train_rows()).WillRepeatedly(testing::Return(5));
+    EXPECT_CALL(*mock_index, train(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, add(testing::_, testing::_)).Times(0);
+    EXPECT_CALL(*mock_index, save(testing::_)).Times(0);
 
-    // Add exactly 131072 rows
-    {
-        const size_t num_rows = 131072;
-        std::vector<float> vectors(num_rows * dim);
-        for (size_t i = 0; i < num_rows * dim; ++i) {
-            vectors[i] = static_cast<float>(i % 100);
-        }
-        std::vector<size_t> offsets;
-        for (size_t i = 0; i <= num_rows; ++i) {
-            offsets.push_back(i * dim);
-        }
+    status = writer->finish();
+    EXPECT_TRUE(status.ok());
+    EXPECT_EQ(writer->buffered_vector_rows(dim), 0);
+}
 
-        Status status = writer->add_array_values(sizeof(float), 
vectors.data(), nullptr,
-                                                 reinterpret_cast<const 
uint8_t*>(offsets.data()),
-                                                 num_rows);
-        EXPECT_TRUE(status.ok());
-    }
+TEST_F(AnnIndexWriterTest, TestIVFOnDiskMinTrainRows) {
+    FaissVectorIndex index;
+    FaissBuildParameter params;
+    params.index_type = FaissBuildParameter::IndexType::IVF_ON_DISK;
+    params.quantizer = FaissBuildParameter::Quantizer::FLAT;
+    params.dim = 4;
+    params.ivf_nlist = 7;
 
-    // Finish should successfully build the index
-    Status status = writer->finish();
-    EXPECT_TRUE(status.ok());
+    index.build(params);
+    EXPECT_EQ(index.get_min_train_rows(), 7);
 }
 
 } // namespace doris::segment_v2
diff --git a/regression-test/data/ann_index_p0/ivf_on_disk_index_test.out 
b/regression-test/data/ann_index_p0/ivf_on_disk_index_test.out
index bcd94f4ac52..239c104321c 100644
--- a/regression-test/data/ann_index_p0/ivf_on_disk_index_test.out
+++ b/regression-test/data/ann_index_p0/ivf_on_disk_index_test.out
@@ -11,6 +11,10 @@
 1
 2
 
+-- !sql_l2_insufficient_train_rows --
+1
+2
+
 -- !sql --
 1      [1, 2, 3]
 2      [0.5, 2.1, 2.9]
diff --git 
a/regression-test/data/ann_index_p0/ivf_pq_full_buffer_train_recall.out 
b/regression-test/data/ann_index_p0/ivf_pq_full_buffer_train_recall.out
new file mode 100644
index 00000000000..6c3458e2735
--- /dev/null
+++ b/regression-test/data/ann_index_p0/ivf_pq_full_buffer_train_recall.out
@@ -0,0 +1,4 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !target_in_top20 --
+1
+
diff --git a/regression-test/data/ann_index_p0/ivf_pq_recall.out 
b/regression-test/data/ann_index_p0/ivf_pq_recall.out
new file mode 100644
index 00000000000..14aab16eedc
--- /dev/null
+++ b/regression-test/data/ann_index_p0/ivf_pq_recall.out
@@ -0,0 +1,9 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !row_count --
+800
+
+-- !first_cluster_recall --
+20
+
+-- !second_cluster_recall --
+20
diff --git 
a/regression-test/suites/ann_index_p0/ann_index_build_min_segment_rows.groovy 
b/regression-test/suites/ann_index_p0/ann_index_build_min_segment_rows.groovy
new file mode 100644
index 00000000000..01393444830
--- /dev/null
+++ 
b/regression-test/suites/ann_index_p0/ann_index_build_min_segment_rows.groovy
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("ann_index_build_min_segment_rows", "nonConcurrent") {
+    sql "unset variable all;"
+    sql "set enable_common_expr_pushdown=true;"
+    sql "set experimental_enable_virtual_slot_for_cse=true;"
+    sql "set enable_no_need_read_data_opt=true;"
+    sql "set parallel_pipeline_task_num=1;"
+    sql "set enable_sql_cache=false;"
+    sql "set enable_condition_cache=false;"
+
+    setBeConfigTemporary([ann_index_build_min_segment_rows: 100]) {
+        sql "drop table if exists ann_index_build_min_segment_rows"
+        sql """
+            create table ann_index_build_min_segment_rows (
+                id int not null,
+                embedding array<float> not null,
+                index idx_embedding(`embedding`) using ann properties(
+                    "index_type"="hnsw",
+                    "metric_type"="l2_distance",
+                    "dim"="3"
+                )
+            ) duplicate key(id)
+            distributed by hash(id) buckets 1
+            properties("replication_num"="1");
+        """
+
+        sql """
+            insert into ann_index_build_min_segment_rows values
+            (1, [0.0, 0.0, 0.0]),
+            (2, [0.1, 0.0, 0.0]),
+            (3, [0.2, 0.0, 0.0]);
+        """
+
+        try {
+            GetDebugPoint().enableDebugPointForAllBEs(
+                    "segment_iterator._read_columns_by_index", [column_name: 
"embedding"])
+            test {
+                sql """
+                    select id
+                    from ann_index_build_min_segment_rows
+                    where l2_distance_approximate(embedding, [0.0, 0.0, 0.0]) 
< 1.0
+                    order by id;
+                """
+                exception "does not need to read data"
+            }
+        } finally {
+            
GetDebugPoint().disableDebugPointForAllBEs("segment_iterator._read_columns_by_index")
+        }
+    }
+}
diff --git a/regression-test/suites/ann_index_p0/ivf_on_disk_index_test.groovy 
b/regression-test/suites/ann_index_p0/ivf_on_disk_index_test.groovy
index a9eed51d7a4..63fab34d072 100644
--- a/regression-test/suites/ann_index_p0/ivf_on_disk_index_test.groovy
+++ b/regression-test/suites/ann_index_p0/ivf_on_disk_index_test.groovy
@@ -68,7 +68,7 @@ suite ("ivf_on_disk_index_test") {
         exception """nlist of ann index must be specified for ivf/ivf_on_disk 
type"""
     }
 
-    // ========== Error: not enough training points ==========
+    // Not enough training points: should not throw exception anymore, just 
skip index building.
     sql """
     CREATE TABLE tbl_ivf_on_disk_l2 (
         id INT NOT NULL,
@@ -84,14 +84,12 @@ suite ("ivf_on_disk_index_test") {
     DISTRIBUTED BY HASH(id) BUCKETS 1
     PROPERTIES ("replication_num" = "1");
     """
-    test {
-        sql """
-        INSERT INTO tbl_ivf_on_disk_l2 VALUES
-        (1, [1.0, 2.0, 3.0]),
-        (2, [0.5, 2.1, 2.9]);
-        """
-        exception """exception occurred during training"""
-    }
+    sql """
+    INSERT INTO tbl_ivf_on_disk_l2 VALUES
+    (1, [1.0, 2.0, 3.0]),
+    (2, [0.5, 2.1, 2.9]);
+    """
+    qt_sql_l2_insufficient_train_rows "select id from tbl_ivf_on_disk_l2 order 
by l2_distance_approximate(embedding, [1.0,2.0,3.0]) limit 2;"
 
     // ========== IVF_ON_DISK with inner product ==========
     sql "drop table if exists tbl_ivf_on_disk_ip"
diff --git 
a/regression-test/suites/ann_index_p0/ivf_pq_full_buffer_train_recall.groovy 
b/regression-test/suites/ann_index_p0/ivf_pq_full_buffer_train_recall.groovy
new file mode 100644
index 00000000000..20cccfb28b6
--- /dev/null
+++ b/regression-test/suites/ann_index_p0/ivf_pq_full_buffer_train_recall.groovy
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("ivf_pq_full_buffer_train_recall", "nonConcurrent") {
+    sql "set enable_common_expr_pushdown=true;"
+    sql "set enable_ann_index_result_cache=false;"
+    sql "set ivf_nprobe=8;"
+
+    sql "drop table if exists tbl_ivf_pq_full_buffer_train_recall"
+    sql """
+    CREATE TABLE tbl_ivf_pq_full_buffer_train_recall (
+        id INT NOT NULL,
+        embedding ARRAY<FLOAT> NOT NULL,
+        INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+                "index_type"="ivf",
+                "metric_type"="l2_distance",
+                "nlist"="8",
+                "dim"="4",
+                "quantizer"="pq",
+                "pq_m"="2",
+                "pq_nbits"="1"
+        )
+    ) ENGINE=OLAP
+    DUPLICATE KEY(id)
+    DISTRIBUTED BY HASH(id) BUCKETS 1
+    PROPERTIES ("replication_num" = "1");
+    """
+
+    def insertData = []
+    for (int i = 1; i <= 400; i++) {
+        if (i == 250) {
+            insertData.add("(${i}, [0.0, 0.0, 0.0, 0.0])")
+        } else if (i <= 200) {
+            insertData.add("(${i}, [1000.0, ${i}.0, ${(i % 17)}.0, ${(i % 
19)}.0])")
+        } else {
+            insertData.add(
+                    "(${i}, [${(i - 250) / 50.0}, ${(250 - i) / 50.0}, "
+                            + "${(i % 7 - 3) / 10.0}, ${(i % 5 - 2) / 10.0}])")
+        }
+    }
+    sql "INSERT INTO tbl_ivf_pq_full_buffer_train_recall VALUES 
${insertData.join(', ')};"
+    sql "sync"
+
+    qt_target_in_top20 """
+        select count(*)
+        from (
+            select id
+            from tbl_ivf_pq_full_buffer_train_recall
+            order by l2_distance_approximate(embedding, [0.0, 0.0, 0.0, 0.0]), 
id
+            limit 20
+        ) t
+        where id = 250;
+    """
+}
diff --git a/regression-test/suites/ann_index_p0/ivf_pq_recall.groovy 
b/regression-test/suites/ann_index_p0/ivf_pq_recall.groovy
new file mode 100644
index 00000000000..c1c6a7b7651
--- /dev/null
+++ b/regression-test/suites/ann_index_p0/ivf_pq_recall.groovy
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("ivf_pq_recall", "nonConcurrent") {
+    sql "set enable_common_expr_pushdown=true;"
+    sql "set enable_ann_index_result_cache=false;"
+    sql "set ivf_nprobe=8;"
+
+    sql "drop table if exists ivf_pq_recall"
+    sql """
+        create table ivf_pq_recall (
+            id int not null,
+            embedding array<float> not null,
+            index idx_embedding (`embedding`) using ann properties(
+                "index_type" = "ivf",
+                "metric_type" = "l2_distance",
+                "nlist" = "8",
+                "dim" = "4",
+                "quantizer" = "pq",
+                "pq_m" = "2",
+                "pq_nbits" = "2"
+            )
+        ) engine=olap
+        duplicate key(id)
+        distributed by hash(id) buckets 1
+        properties(
+            "replication_num" = "1",
+            "disable_auto_compaction" = "true"
+        );
+    """
+
+    def formatFloat = { double value ->
+        String.format(java.util.Locale.ROOT, "%.3f", value)
+    }
+    def vector = { double x ->
+        "[${formatFloat(x)}, ${formatFloat(x * 2)}, ${formatFloat(x * 3)}, 
${formatFloat(x * 4)}]"
+    }
+    def rows = []
+    for (int i = 1; i <= 400; i++) {
+        double x = (i - 1) / 1000.0
+        rows.add("(${i}, ${vector(x)})")
+    }
+    for (int i = 401; i <= 800; i++) {
+        double x = 1000.0 + (i - 401) / 1000.0
+        rows.add("(${i}, ${vector(x)})")
+    }
+    sql "insert into ivf_pq_recall values ${rows.join(',')};"
+    sql "sync"
+
+    qt_row_count "select count(*) from ivf_pq_recall;"
+
+    qt_first_cluster_recall """
+        select count(*) from (
+            select id
+            from ivf_pq_recall
+            order by l2_distance_approximate(embedding, [0.0, 0.0, 0.0, 0.0])
+            limit 20
+        ) t
+        where id between 1 and 400;
+    """
+
+    qt_second_cluster_recall """
+        select count(*) from (
+            select id
+            from ivf_pq_recall
+            order by l2_distance_approximate(embedding, [1000.0, 2000.0, 
3000.0, 4000.0])
+            limit 20
+        ) t
+        where id between 401 and 800;
+    """
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to