This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch fe_local_shuffle
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 261947f9119085398b3f10fed0c9a9ebd8340fc5
Author: 924060929 <[email protected]>
AuthorDate: Sun Mar 29 10:46:34 2026 +0800

    [fix](local shuffle) restore num_tasks raise + add RQG regression cases + 
optimize test speed
    
    1. BE: After building all pipelines, raise num_tasks to _num_instances for 
any
       pipeline reduced below _num_instances by a serial non-scan operator 
(e.g.,
       UNPARTITIONED Exchange).  Use dynamic_cast<ExchangeSourceOperatorX*> to
       distinguish serial Exchange (must be raised) from serial scan (pooling 
scan,
       keep num_tasks=1).  This fixes "must set shared state" errors for
       AGGREGATION_OPERATOR, SORT_OPERATOR, UNION_OPERATOR, INTERSECT_OPERATOR,
       and EXCEPT_OPERATOR in RQG tests.
    
    2. Test: Add 3 new RQG-inspired regression cases covering pooling scan with
       NLJ+Agg, GROUPING SETS, and window+GROUPING SETS.
    
    3. Test speed: Restructure from serial (execute+wait per case) to two-phase
       (batch-execute all queries, then batch-fetch profiles).
    
    Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
---
 be/src/exec/pipeline/pipeline_fragment_context.cpp |  53 +++---
 .../test_local_shuffle_fe_be_consistency.groovy    | 197 ++++++++++++++-------
 2 files changed, 159 insertions(+), 91 deletions(-)

diff --git a/be/src/exec/pipeline/pipeline_fragment_context.cpp 
b/be/src/exec/pipeline/pipeline_fragment_context.cpp
index 148422b61ef..e59e255f688 100644
--- a/be/src/exec/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/exec/pipeline/pipeline_fragment_context.cpp
@@ -264,6 +264,34 @@ Status 
PipelineFragmentContext::_build_and_prepare_full_pipeline(ThreadPool* thr
         // Create deferred local exchangers now that all pipelines have final 
num_tasks.
         RETURN_IF_ERROR(_create_deferred_local_exchangers());
 
+        // Raise num_tasks for pipelines whose serial non-scan operators (e.g.,
+        // UNPARTITIONED Exchange) reduced num_tasks below _num_instances.
+        // Without this, fragment instances 1+ have no task for these pipelines
+        // and downstream operators fail with "must set shared state".
+        //
+        // This applies to ALL pipelines (not just deferred exchanger 
upstreams):
+        // fragments with UNION/INTERSECT/EXCEPT + serial Exchange in child
+        // pipelines also need the raise, even without FE-planned local 
exchange.
+        //
+        // Exception: serial scan sources (pooling scan) keep num_tasks=1 — the
+        // PassthroughExchanger(1, N) handles the fan-out correctly.
+        for (auto& pipeline : _pipelines) {
+            if (pipeline->num_tasks() < _num_instances) {
+                // Only skip the raise for pipelines whose source is a serial
+                // SCAN operator (pooling scan).  Serial Exchange sources must
+                // still be raised — they are NOT scans but receive remote 
data,
+                // and downstream operators need _num_instances tasks for 
shared
+                // state injection.
+                auto* source_op = pipeline->operators().front().get();
+                bool is_serial_scan_source =
+                        source_op->is_serial_operator() && 
source_op->is_source() &&
+                        !dynamic_cast<ExchangeSourceOperatorX*>(source_op);
+                if (!is_serial_scan_source) {
+                    pipeline->set_num_tasks(_num_instances);
+                }
+            }
+        }
+
         // 3. Create sink operator
         if (!_params.fragment.__isset.output_sink) {
             return Status::InternalError("No output sink in this fragment!");
@@ -668,29 +696,8 @@ Status 
PipelineFragmentContext::_build_pipelines(ObjectPool* pool, const Descrip
 
 Status PipelineFragmentContext::_create_deferred_local_exchangers() {
     for (auto& info : _deferred_exchangers) {
-        // Raise upstream pipeline's num_tasks to _num_instances when a serial
-        // non-scan operator (e.g., UNPARTITIONED Exchange) reduced it below
-        // _num_instances.  This is needed because downstream pipelines
-        // (join build, agg, sort, union, etc.) need _num_instances tasks so
-        // every fragment instance can create/inject shared state.  Without the
-        // raise, instances 1+ have no upstream task and operators fail with
-        // "must set shared state".
-        //
-        // Exception: do NOT raise when the serial operator is a scan source
-        // (pooling scan).  For pooling scan, 1 sender is correct —
-        // PassthroughExchanger(1, N) handles the 1→N fan-out properly.
-        if (info.upstream_pipe->num_tasks() < _num_instances) {
-            bool has_serial_scan = false;
-            for (auto& op : info.upstream_pipe->operators()) {
-                if (op->is_serial_operator() && op->is_source()) {
-                    has_serial_scan = true;
-                    break;
-                }
-            }
-            if (!has_serial_scan) {
-                info.upstream_pipe->set_num_tasks(_num_instances);
-            }
-        }
+        // num_tasks raise is handled globally in 
_build_and_prepare_full_pipeline
+        // after this function returns.  No per-exchanger adjustment needed 
here.
         const int sender_count = info.upstream_pipe->num_tasks();
         switch (info.partition_type) {
         case TLocalPartitionType::LOCAL_EXECUTION_HASH_SHUFFLE:
diff --git 
a/regression-test/suites/nereids_p0/local_shuffle/test_local_shuffle_fe_be_consistency.groovy
 
b/regression-test/suites/nereids_p0/local_shuffle/test_local_shuffle_fe_be_consistency.groovy
index 7fa918652bd..2ca0d09f880 100644
--- 
a/regression-test/suites/nereids_p0/local_shuffle/test_local_shuffle_fe_be_consistency.groovy
+++ 
b/regression-test/suites/nereids_p0/local_shuffle/test_local_shuffle_fe_be_consistency.groovy
@@ -84,12 +84,11 @@ suite("test_local_shuffle_fe_be_consistency") {
     // This field is written by SummaryProfile.queryFinished() after 
waitForFragmentsDone(),
     // guaranteeing all BE operator metrics have been merged into the profile.
     def waitForProfile = { String queryId ->
-        // Wait for BE to report profile back to FE before polling.
-        // Without this initial delay, early polls may get incomplete profiles
-        // (missing LOCAL_EXCHANGE_SINK entries), causing flaky MISMATCH 
results.
-        Thread.sleep(2000)
+        // Poll until the profile is fully collected.  The "Is Profile 
Collection
+        // Completed: true" marker is written after all BE fragments have 
reported,
+        // so once it appears the LOCAL_EXCHANGE_SINK counts are final.
         def maxAttempts = 60
-        def sleepMs = 500
+        def sleepMs = 200
         for (int i = 0; i < maxAttempts; i++) {
             Thread.sleep(sleepMs)
             try {
@@ -102,90 +101,109 @@ suite("test_local_shuffle_fe_be_consistency") {
         return getProfile(queryId)
     }
 
-    def runAndGetSinkCounts = { String testSql, boolean enableFePlanner ->
+    // Execute SQL and return query_id (no profile wait — fast)
+    def runAndGetQueryId = { String testSql, boolean enableFePlanner ->
         sql "set enable_profile=true"
         sql "set enable_local_shuffle_planner=${enableFePlanner}"
         sql "set enable_sql_cache=false"
-
-        // Use GetQueryIdAction to reliably get the query_id of the test SQL,
-        // avoiding timing issues with last_query_id() after SET statements.
-        def result = sql "${testSql}"
-
+        sql "${testSql}"
         def queryIdResult = sql "select last_query_id()"
-        def queryId = queryIdResult[0][0].toString()
-
-        // Wait a moment for profile to be reported back from BE
-        Thread.sleep(1000)
-        def profileText = waitForProfile(queryId)
-        def counts = extractSinkExchangeCounts(profileText)
-        logger.info("enable_local_shuffle_planner=${enableFePlanner}, 
query_id=${queryId}, LE sink counts=${counts}")
-        return [queryId: queryId, counts: counts, profile: profileText]
+        return queryIdResult[0][0].toString()
     }
 
     // ============================================================
-    //  Helper: check FE vs BE consistency and result equivalence
-    //  FE mode:  enable_local_shuffle_planner=true  (FE plans exchanges via 
AddLocalExchange)
-    //  BE mode:  enable_local_shuffle_planner=false (BE plans exchanges 
natively in pipeline)
-    //  knownDiff: if true, log mismatch as INFO (expected design difference, 
not a bug)
+    //  Two-phase approach for speed:
+    //  Phase 1: Execute all queries (BE + FE) sequentially, collect query IDs
+    //  Phase 2: Fetch all profiles in parallel (profiles collected in 
background)
+    //  This overlaps profile collection with query execution, avoiding serial 
waits.
     // ============================================================
-    def mismatches = []
+    def mismatches = Collections.synchronizedList([])
+    def pendingChecks = Collections.synchronizedList([])
 
     def setVarBase = 
"disable_join_reorder=true,disable_colocate_plan=true,ignore_storage_data_distribution=false,parallel_pipeline_task_num=4,auto_broadcast_join_threshold=-1,broadcast_row_count_limit=0"
 
     def checkConsistencyWithSql = { String tag, String testSql, boolean 
knownDiff = false ->
-        logger.info("=== Checking: ${tag} ===")
+        // Phase 1: execute queries and collect query IDs (fast, no profile 
wait)
+        String beQueryId = null
+        String feQueryId = null
+        boolean beFailed = false
+        boolean feFailed = false
+        String beError = null
+        String feError = null
 
-        // Run with BE-native planning (enable_local_shuffle_planner=false)
-        def beResult
         try {
-            beResult = runAndGetSinkCounts(testSql, false)
+            beQueryId = runAndGetQueryId(testSql, false)
         } catch (Throwable t) {
-            def errMsg = "[${tag}] BE run FAILED: ${t.message}"
-            logger.warn(errMsg)
-            mismatches << errMsg
-            return [be: [:], fe: [:], match: false]
+            beFailed = true
+            beError = t.message
         }
-        // Run with FE-planned exchanges (enable_local_shuffle_planner=true)
-        def feResult
         try {
-            feResult = runAndGetSinkCounts(testSql, true)
+            feQueryId = runAndGetQueryId(testSql, true)
         } catch (Throwable t) {
-            def errMsg = "[${tag}] FE run FAILED: ${t.message}"
-            logger.warn(errMsg)
-            mismatches << errMsg
-            return [be: beResult.counts, fe: [:], match: false]
-        }
-
-        boolean match = (beResult.counts == feResult.counts)
-        if (match) {
-            logger.info("[${tag}] MATCH: ${beResult.counts}")
-        } else {
-            def msg = "[${tag}] ${knownDiff ? 'KNOWN-DIFF' : 'MISMATCH'}: 
BE=${beResult.counts}, FE=${feResult.counts}"
-            if (knownDiff) {
-                logger.info(msg)
-            } else {
-                logger.warn(msg)
-                mismatches << msg
-            }
+            feFailed = true
+            feError = t.message
         }
 
-        // Verify result correctness: both modes must return identical rows
+        // Build SET_VAR versions for check_sql_equal
         def sqlOn  = 
testSql.replaceFirst(/(?i)\/\*\+SET_VAR\(([^)]*)\)\s*\*\//, 
"/*+SET_VAR(enable_local_shuffle_planner=true,\$1)*/")
         def sqlOff = 
testSql.replaceFirst(/(?i)\/\*\+SET_VAR\(([^)]*)\)\s*\*\//, 
"/*+SET_VAR(enable_local_shuffle_planner=false,\$1)*/")
         if (!testSql.contains("/*+SET_VAR")) {
-            // inject SET_VAR after first SELECT keyword
             sqlOn  = testSql.replaceFirst(/(?i)^\s*(SELECT)\s+/, "SELECT 
/*+SET_VAR(enable_local_shuffle_planner=true,${setVarBase})*/ ")
             sqlOff = testSql.replaceFirst(/(?i)^\s*(SELECT)\s+/, "SELECT 
/*+SET_VAR(enable_local_shuffle_planner=false,${setVarBase})*/ ")
         }
-        try {
-            check_sql_equal(sqlOn, sqlOff)
-        } catch (Throwable t) {
-            def errMsg = "[${tag}] check_sql_equal FAILED: ${t.message}"
-            logger.warn(errMsg)
-            mismatches << errMsg
-        }
 
-        return [be: beResult.counts, fe: feResult.counts, match: match]
+        // Save for phase 2 (deferred profile fetch + comparison)
+        pendingChecks << [tag: tag, beQueryId: beQueryId, feQueryId: feQueryId,
+                          beFailed: beFailed, feFailed: feFailed,
+                          beError: beError, feError: feError,
+                          knownDiff: knownDiff, sqlOn: sqlOn, sqlOff: sqlOff]
+    }
+
+    // Phase 2: fetch profiles and compare (called after all queries are 
submitted)
+    def resolveAllChecks = {
+        for (def check : pendingChecks) {
+            def tag = check.tag
+            if (check.beFailed) {
+                def errMsg = "[${tag}] BE run FAILED: ${check.beError}"
+                logger.warn(errMsg)
+                if (!check.knownDiff) { mismatches << errMsg }
+                continue
+            }
+            if (check.feFailed) {
+                def errMsg = "[${tag}] FE run FAILED: ${check.feError}"
+                logger.warn(errMsg)
+                if (!check.knownDiff) { mismatches << errMsg }
+                continue
+            }
+
+            // Fetch profiles (likely already collected by now)
+            def beProfile = waitForProfile(check.beQueryId)
+            def feProfile = waitForProfile(check.feQueryId)
+            def beCounts = extractSinkExchangeCounts(beProfile)
+            def feCounts = extractSinkExchangeCounts(feProfile)
+
+            boolean match = (beCounts == feCounts)
+            if (match) {
+                logger.info("[${tag}] MATCH: ${beCounts}")
+            } else {
+                def msg = "[${tag}] ${check.knownDiff ? 'KNOWN-DIFF' : 
'MISMATCH'}: BE=${beCounts}, FE=${feCounts}"
+                if (check.knownDiff) {
+                    logger.info(msg)
+                } else {
+                    logger.warn(msg)
+                    mismatches << msg
+                }
+            }
+
+            // Verify result correctness
+            try {
+                check_sql_equal(check.sqlOn, check.sqlOff)
+            } catch (Throwable t) {
+                def errMsg = "[${tag}] check_sql_equal FAILED: ${t.message}"
+                logger.warn(errMsg)
+                mismatches << errMsg
+            }
+        }
     }
 
     // ============================================================
@@ -281,12 +299,6 @@ suite("test_local_shuffle_fe_be_consistency") {
             (1, 10, 2), (2, 20, 4), (3, 30, 5), (4, 40, 6)
     """
 
-    // Wait for table creation and data loading to fully settle (tablet 
reports,
-    // replica sync, etc.) before running profile-based comparisons.  Without
-    // this, early queries may hit incomplete tablets or stale metadata, 
causing
-    // profile collection to return empty/partial results (flaky MISMATCH).
-    Thread.sleep(10000)
-
     // SET_VAR prefix used in most test SQLs (disables plan reorder/colocate 
for deterministic plans)
     def sv = 
"/*+SET_VAR(disable_join_reorder=true,disable_colocate_plan=true,ignore_storage_data_distribution=false,parallel_pipeline_task_num=4,auto_broadcast_join_threshold=-1,broadcast_row_count_limit=0)*/"
     // Same as sv but forces serial source path (default in many environments)
@@ -808,6 +820,55 @@ suite("test_local_shuffle_fe_be_consistency") {
            GROUP BY a.v1
            ORDER BY cnt, a.v1""")
 
+    // ================================================================
+    // Section 13: Pooling scan + operators requiring shared state
+    // Regression cases from RQG build 183677 — serial Exchange on build
+    // side of various operators (Agg, Sort, Union/Repeat) reduced pipeline
+    // num_tasks, causing "must set shared state" errors.
+    // Fixed by restoring the num_tasks raise in 
_create_deferred_local_exchangers
+    // for non-scan serial operators.
+    // ================================================================
+
+    // 13-1: NLJ + AGG with pooling scan.
+    //       NLJ creates pipeline boundary; serial Exchange on build side
+    //       needs raise to _num_instances for AGG shared state injection.
+    //       knownDiff=true: pooling scan + NLJ has FE/BE exchange count
+    //       differences (same root cause as nested_nlj_pooling_scan).
+    checkConsistencyWithSql("agg_after_nlj_pooling_scan",
+        """SELECT ${svSerialSource} a.v1, MAX(a.k1) AS mx
+           FROM ls_serial a LEFT JOIN ls_serial b ON b.k2 < b.k2
+           WHERE a.k1 IS NOT NULL
+           GROUP BY a.v1
+           ORDER BY a.v1, mx""", true)
+
+    // 13-2: GROUPING SETS with pooling scan — generates REPEAT (union-like)
+    //       operator internally.  Serial Exchange reduces num_tasks, causing
+    //       "must set shared state, in UNION_OPERATOR / SORT_OPERATOR".
+    checkConsistencyWithSql("grouping_sets_pooling_scan",
+        """SELECT ${svSerialSource} k1, k2, SUM(v1) AS sv
+           FROM ls_serial
+           GROUP BY GROUPING SETS ((k1, k2), (k1), ())
+           ORDER BY k1, k2, sv""")
+
+    // 13-3: Window function + GROUPING SETS with pooling scan.
+    //       Combines analytic (Sort shared state) and GROUPING SETS 
(Repeat/Union)
+    //       — both need correct num_tasks for shared state injection.
+    checkConsistencyWithSql("window_grouping_sets_pooling_scan",
+        """SELECT ${svSerialSource} k1, SUM(v1),
+                  ROW_NUMBER() OVER (ORDER BY k1) AS rn
+           FROM ls_serial
+           GROUP BY GROUPING SETS ((k1), ())
+           ORDER BY k1, rn""")
+
+    // ================================================================
+    //  Phase 2: Fetch all profiles and compare results
+    //  By now, all queries have been executed and profiles are being
+    //  collected in the background.  Fetching them here overlaps the
+    //  collection time with query execution, significantly reducing
+    //  total test duration.
+    // ================================================================
+    resolveAllChecks()
+
     // ================================================================
     //  Summary
     // ================================================================


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to