This is an automated email from the ASF dual-hosted git repository.

zclllyybb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 892eae27bba [Fix](ai_agg) isolate AI_AGG query_ctx per aggregate state 
(#63080)
892eae27bba is described below

commit 892eae27bbad64703065f981787c3302391c8df0
Author: linrrarity <[email protected]>
AuthorDate: Mon May 11 10:52:55 2026 +0800

    [Fix](ai_agg) isolate AI_AGG query_ctx per aggregate state (#63080)
    
    Problem Summary:
    
    `AI_AGG` previously stored `QueryContext` in a process-level static
    pointer. Concurrent AI_AGG queries could overwrite that pointer, causing
    one query to read another query's AI resource metadata, query options,
    or timeout state.
    
    The aggregate function instance now receives the query context from
    `AggFnEvaluator` and binds it to each AggregateFunctionAIAggData state
    when the state is created, reset, or deserialized.
---
 .../exprs/aggregate/aggregate_function_ai_agg.cpp  |  1 -
 be/src/exprs/aggregate/aggregate_function_ai_agg.h | 22 ++++++---
 be/test/ai/aggregate_function_ai_agg_test.cpp      | 52 +++++++++++++++++++++-
 3 files changed, 68 insertions(+), 7 deletions(-)

diff --git a/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp 
b/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp
index 44cbff4301b..5b7e9efcb67 100644
--- a/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp
+++ b/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp
@@ -21,7 +21,6 @@
 #include "exprs/aggregate/helpers.h"
 
 namespace doris {
-QueryContext* AggregateFunctionAIAggData::_ctx = nullptr;
 
 void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& 
factory) {
     factory.register_function_both("ai_agg",
diff --git a/be/src/exprs/aggregate/aggregate_function_ai_agg.h 
b/be/src/exprs/aggregate/aggregate_function_ai_agg.h
index ae58216b451..fd532c49d74 100644
--- a/be/src/exprs/aggregate/aggregate_function_ai_agg.h
+++ b/be/src/exprs/aggregate/aggregate_function_ai_agg.h
@@ -146,7 +146,7 @@ public:
         }
     }
 
-    static void set_query_context(QueryContext* context) { _ctx = context; }
+    void set_query_context(QueryContext* context) { _ctx = context; }
 
     const std::string& get_task() const { return _task; }
 
@@ -197,7 +197,7 @@ private:
         process_current_context();
     }
 
-    static size_t get_ai_context_window_size() {
+    size_t get_ai_context_window_size() const {
         DORIS_CHECK(_ctx);
 
         return 
static_cast<size_t>(_ctx->query_options().ai_context_window_size);
@@ -247,7 +247,7 @@ private:
         inited = !data.empty();
     }
 
-    static QueryContext* _ctx;
+    QueryContext* _ctx = nullptr;
     AIResource _ai_config;
     std::shared_ptr<AIAdapter> _ai_adapter;
     std::string _task;
@@ -264,7 +264,7 @@ public:
 
     void set_query_context(QueryContext* context) override {
         if (context) {
-            AggregateFunctionAIAggData::set_query_context(context);
+            _ctx = context;
         }
     }
 
@@ -274,6 +274,11 @@ public:
 
     bool is_blockable() const override { return true; }
 
+    void create(AggregateDataPtr __restrict place) const override {
+        new (place) AggregateFunctionAIAggData;
+        data(place).set_query_context(_ctx);
+    }
+
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
              Arena&) const override {
         data(place).prepare(
@@ -303,7 +308,10 @@ public:
         }
     }
 
-    void reset(AggregateDataPtr place) const override { data(place).reset(); }
+    void reset(AggregateDataPtr place) const override {
+        data(place).reset();
+        data(place).set_query_context(_ctx);
+    }
 
     void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
                Arena&) const override {
@@ -317,6 +325,7 @@ public:
     void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
                      Arena&) const override {
         data(place).read(buf);
+        data(place).set_query_context(_ctx);
     }
 
     void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
@@ -324,6 +333,9 @@ public:
         DCHECK(!result.empty()) << "AI returns an empty result";
         assert_cast<ColumnString&>(to).insert_data(result.data(), 
result.size());
     }
+
+private:
+    QueryContext* _ctx = nullptr;
 };
 
 } // namespace doris
diff --git a/be/test/ai/aggregate_function_ai_agg_test.cpp 
b/be/test/ai/aggregate_function_ai_agg_test.cpp
index a5ebbd8fb79..4fb9c7969d0 100644
--- a/be/test/ai/aggregate_function_ai_agg_test.cpp
+++ b/be/test/ai/aggregate_function_ai_agg_test.cpp
@@ -66,7 +66,7 @@ public:
         _agg_function->set_query_context(_query_ctx.get());
     }
 
-    void TearDown() override { AggregateFunctionAIAggData::_ctx = nullptr; }
+    void TearDown() override {}
 
 protected:
     std::unique_ptr<MockRuntimeState> _runtime_state;
@@ -424,6 +424,56 @@ TEST_F(AggregateFunctionAIAggTest, 
ai_context_window_size_session_variable_test)
     _agg_function->destroy(place);
 }
 
+TEST_F(AggregateFunctionAIAggTest, 
query_context_is_isolated_between_function_instances_test) {
+    TQueryOptions first_query_options = create_fake_query_options();
+    first_query_options.__set_ai_context_window_size(8);
+    auto first_query_ctx =
+            MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), 
first_query_options);
+    first_query_ctx->set_mock_ai_resource();
+
+    TQueryOptions second_query_options = create_fake_query_options();
+    second_query_options.__set_ai_context_window_size(1024);
+    auto second_query_ctx =
+            MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), 
second_query_options);
+    second_query_ctx->set_mock_ai_resource();
+
+    AggregateFunctionSimpleFactory factory;
+    register_aggregate_function_ai_agg(factory);
+    auto first_agg_function = factory.get("ai_agg", _data_types, nullptr, 
false, -1);
+    auto second_agg_function = factory.get("ai_agg", _data_types, nullptr, 
false, -1);
+    ASSERT_TRUE(first_agg_function != nullptr);
+    ASSERT_TRUE(second_agg_function != nullptr);
+
+    first_agg_function->set_query_context(first_query_ctx.get());
+    second_agg_function->set_query_context(second_query_ctx.get());
+
+    auto resource_col = ColumnString::create();
+    auto text_col = ColumnString::create();
+    auto task_col = ColumnString::create();
+
+    resource_col->insert_data("mock_resource", 13);
+    text_col->insert_data("abcd", 4);
+    task_col->insert_data("summarize", 9);
+
+    resource_col->insert_data("mock_resource", 13);
+    text_col->insert_data("efgh", 4);
+    task_col->insert_data("summarize", 9);
+
+    std::unique_ptr<char[]> memory(new 
char[first_agg_function->size_of_data()]);
+    AggregateDataPtr place = memory.get();
+    first_agg_function->create(place);
+
+    const IColumn* columns[3] = {resource_col.get(), text_col.get(), 
task_col.get()};
+    first_agg_function->add(place, columns, 0, _arena);
+    first_agg_function->add(place, columns, 1, _arena);
+
+    const auto& data = *reinterpret_cast<const 
AggregateFunctionAIAggData*>(place);
+    std::string actual(reinterpret_cast<const char*>(data.data.data()), 
data.data.size());
+    EXPECT_EQ(actual, "this is a mock response\nefgh");
+
+    first_agg_function->destroy(place);
+}
+
 TEST_F(AggregateFunctionAIAggTest, 
gemini_endpoint_normalize_to_generate_content_test) {
     AIResource resource;
     resource.provider_type = "GEMINI";


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to