https://github.com/tblah created https://github.com/llvm/llvm-project/pull/144898
Idea suggested by @skatrak >From 280e55d4355f100b7d3066fce3c0515b369fecce Mon Sep 17 00:00:00 2001 From: Tom Eccles <tom.ecc...@arm.com> Date: Wed, 18 Jun 2025 21:01:13 +0000 Subject: [PATCH] [flang][OpenMP][NFC] remove globals with mlir::StateStack Idea suggested by @skatrak --- flang/include/flang/Lower/AbstractConverter.h | 3 + flang/lib/Lower/Bridge.cpp | 6 ++ flang/lib/Lower/OpenMP/OpenMP.cpp | 102 ++++++++++++------ mlir/include/mlir/Support/StateStack.h | 11 ++ 4 files changed, 91 insertions(+), 31 deletions(-) diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h index 8ae68e143cd2f..de3e833f60699 100644 --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -26,6 +26,7 @@ namespace mlir { class SymbolTable; +class StateStack; } namespace fir { @@ -361,6 +362,8 @@ class AbstractConverter { /// functions in order to be in sync). virtual mlir::SymbolTable *getMLIRSymbolTable() = 0; + virtual mlir::StateStack &getStateStack() = 0; + private: /// Options controlling lowering behavior. const Fortran::lower::LoweringOptions &loweringOptions; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 64b16b3abe991..462ceb8dff736 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -78,6 +78,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" #include "llvm/Target/TargetMachine.h" +#include "mlir/Support/StateStack.h" #include <optional> #define DEBUG_TYPE "flang-lower-bridge" @@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; } + mlir::StateStack &getStateStack() override { return stateStack; } + /// Add the symbol to the local map and return `true`. If the symbol is /// already in the map and \p forced is `false`, the map is not updated. /// Instead the value `false` is returned. @@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// attribute since mlirSymbolTable must pro-actively be maintained when /// new Symbol operations are created. mlir::SymbolTable mlirSymbolTable; + + /// Used to store context while recursing into regions during lowering. + mlir::StateStack stateStack; }; } // namespace diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 7ad8869597274..bff3321af2814 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -38,6 +38,7 @@ #include "flang/Support/OpenMP-utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -200,9 +201,41 @@ class HostEvalInfo { /// the handling of the outer region by keeping a stack of information /// structures, but it will probably still require some further work to support /// reverse offloading. -static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo; -static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0> - sectionsStack; +class HostEvalInfoStackFrame + : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame) + + HostEvalInfo info; +}; + +static HostEvalInfo * +getHostEvalInfoStackTop(lower::AbstractConverter &converter) { + HostEvalInfoStackFrame *frame = + converter.getStateStack().getStackTop<HostEvalInfoStackFrame>(); + return frame ? &frame->info : nullptr; +} + +/// Stack frame for storing the OpenMPSectionsConstruct currently being +/// processed so that it can be refered to when lowering the construct. +class SectionsConstructStackFrame + : public mlir::StateStackFrameBase<SectionsConstructStackFrame> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame) + + explicit SectionsConstructStackFrame( + const parser::OpenMPSectionsConstruct §ionsConstruct) + : sectionsConstruct{sectionsConstruct} {} + + const parser::OpenMPSectionsConstruct §ionsConstruct; +}; + +static const parser::OpenMPSectionsConstruct * +getSectionsConstructStackTop(lower::AbstractConverter &converter) { + SectionsConstructStackFrame *frame = + converter.getStateStack().getStackTop<SectionsConstructStackFrame>(); + return frame ? &frame->sectionsConstruct : nullptr; +} /// Bind symbols to their corresponding entry block arguments. /// @@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, if (!ompEval) return; - HostEvalInfo &hostInfo = hostEvalInfo.back(); + HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); + assert(hostInfo && "expected HOST_EVAL info structure"); switch (extractOmpDirective(*ompEval)) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_distribute_parallel_do: case OMPD_target_teams_distribute_parallel_do_simd: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_distribute_parallel_do: case OMPD_distribute_parallel_do_simd: - cp.processNumThreads(stmtCtx, hostInfo.ops); + cp.processNumThreads(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_distribute: case OMPD_distribute_simd: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; case OMPD_teams: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); processSingleNestedIf([](Directive nestedDir) { return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir); }); @@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, case OMPD_teams_distribute: case OMPD_teams_distribute_simd: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_distribute: case OMPD_target_teams_distribute_simd: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); + cp.processNumTeams(stmtCtx, hostInfo->ops); break; case OMPD_teams_loop: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_loop: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_loop: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; // Standalone 'target' case. @@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, } }; - assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); - const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); assert(ompEval && llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && @@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp( mlir::Region ®ion = targetOp.getRegion(); mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region); bindEntryBlockArgs(converter, targetOp, args); - if (!hostEvalInfo.empty()) - hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); + if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) + hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter, llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv)) + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv)) cp.processCollapse(loc, eval, clauseOps, iv); clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); @@ -1753,7 +1786,8 @@ static void genParallelClauses( cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) cp.processNumThreads(stmtCtx, clauseOps); cp.processProcBind(clauseOps); @@ -1818,16 +1852,17 @@ static void genTargetClauses( llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); ClauseProcessor cp(converter, semaCtx, clauses); cp.processBare(clauseOps); cp.processDefaultMap(stmtCtx, defaultMaps); cp.processDepend(symTable, stmtCtx, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms); - if (!hostEvalInfo.empty()) { + if (hostEvalInfo) { // Only process host_eval if compiling for the host device. processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); - hostEvalInfo.back().collectValues(clauseOps.hostEvalVars); + hostEvalInfo->collectValues(clauseOps.hostEvalVars); } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); @@ -1963,7 +1998,8 @@ static void genTeamsClauses( cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) { + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) { cp.processNumTeams(stmtCtx, clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); } @@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - assert(!sectionsStack.empty()); + const parser::OpenMPSectionsConstruct *sectionsConstruct = + getSectionsConstructStackTop(converter); + assert(sectionsConstruct); + const auto §ionBlocks = - std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t); - sectionsStack.pop_back(); + std::get<parser::OmpSectionBlocks>(sectionsConstruct->t); + converter.getStateStack().stackPop(); mlir::omp::SectionsOperands clauseOps; llvm::SmallVector<const semantics::Symbol *> reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, @@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // Introduce a new host_eval information structure for this target region. if (!isTargetDevice) - hostEvalInfo.emplace_back(); + converter.getStateStack().stackPush<HostEvalInfoStackFrame>(); mlir::omp::TargetOperands clauseOps; DefaultMapsTy defaultMaps; @@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // Remove the host_eval information structure created for this target region. if (!isTargetDevice) - hostEvalInfo.pop_back(); + converter.getStateStack().stackPop(); return targetOp; } @@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx, eval, source, directive, clauses)}; - sectionsStack.push_back(§ionsConstruct); + converter.getStateStack().stackPush<SectionsConstructStackFrame>( + sectionsConstruct); genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue, queue.begin()); } diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h index aca2375028246..9641a22c47776 100644 --- a/mlir/include/mlir/Support/StateStack.h +++ b/mlir/include/mlir/Support/StateStack.h @@ -83,6 +83,17 @@ class StateStack { return WalkResult::advance(); } + /// Get the top instance of frame type `T` or nullptr if none are found + template <typename T> + T *getStackTop() { + T *top = nullptr; + stackWalk<T>([&](T &frame) -> mlir::WalkResult { + top = &frame; + return mlir::WalkResult::interrupt(); + }); + return top; + } + private: SmallVector<std::unique_ptr<StateStackFrame>> stack; }; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits