https://github.com/skatrak updated 
https://github.com/llvm/llvm-project/pull/150925

>From 5e10fa89545160f8574e2f5d04fde3f26470f987 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Thu, 3 Jul 2025 16:47:51 +0100
Subject: [PATCH 1/6] [OpenMP][OMPIRBuilder] Use device shared memory for arg
 structures

Argument structures are created when sections of the LLVM IR corresponding to
an OpenMP construct are outlined into their own function. For this, stack
allocations are used.

This patch modifies this behavior when compiling for a target device and
outlining `parallel`-related IR, so that it uses device shared memory instead
of private stack space. This is needed in order for threads to have access to
these arguments.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  6 ++
 .../llvm/Transforms/Utils/CodeExtractor.h     | 34 ++++++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 98 +++++++++++++++++--
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   | 73 ++++++++++----
 .../LLVMIR/omptarget-parallel-llvm.mlir       | 10 +-
 5 files changed, 187 insertions(+), 34 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 1ee8498551030..09a1fe7cebf4b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2440,7 +2440,13 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
+    CustomArgAllocatorCBTy CustomArgAllocatorCB;
+    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
     SetVector<Value *> Inputs, Outputs;
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h 
b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 3e2c69b47bc48..dae74c412ba66 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/Support/Compiler.h"
 #include <limits>
 
@@ -24,7 +25,6 @@ namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
 class AllocaInst;
-class BasicBlock;
 class BlockFrequency;
 class BlockFrequencyInfo;
 class BranchProbabilityInfo;
@@ -85,6 +85,10 @@ class CodeExtractorAnalysisCache {
   /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
   ///    as arguments, and inserting stores to the arguments for any scalars.
   class CodeExtractor {
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     using ValueSet = SetVector<Value *>;
 
     // Various bits of state computed on construction.
@@ -133,6 +137,25 @@ class CodeExtractorAnalysisCache {
     // space.
     bool ArgsInZeroAddressSpace;
 
+    // If set, this callback will be used to allocate the arguments in the
+    // caller before passing it to the outlined function holding the extracted
+    // piece of code.
+    CustomArgAllocatorCBTy *CustomArgAllocatorCB;
+
+    // A block outside of the extraction set where previously introduced
+    // intermediate allocations can be deallocated. This is only used when an
+    // custom deallocator is specified.
+    BasicBlock *DeallocationBlock;
+
+    // If set, this callback will be used to deallocate the arguments in the
+    // caller after running the outlined function holding the extracted piece 
of
+    // code. It will not be called if a custom allocator isn't also present.
+    //
+    // By default, this will be done at the end of the basic block containing
+    // the call to the outlined function, except if a deallocation block is
+    // specified. In that case, that will take precedence.
+    CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
+
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -149,7 +172,9 @@ class CodeExtractorAnalysisCache {
     /// the function from which the code is being extracted.
     /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
     /// param pointer of the outlined function is declared in zero address
-    /// space.
+    /// space. If a CustomArgAllocatorCB callback is specified, it will be used
+    /// to allocate any structures or variable copies needed to pass arguments
+    /// to the outlined function, rather than using regular allocas.
     LLVM_ABI
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = 
nullptr,
@@ -157,7 +182,10 @@ class CodeExtractorAnalysisCache {
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "", bool ArgsInZeroAddressSpace = 
false);
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
+                  CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
+                  BasicBlock *DeallocationBlock = nullptr,
+                  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
 
     /// Perform the extraction, returning the new function.
     ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index e0db4108ec508..53a32981fefae 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -289,6 +289,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool 
HasChunks,
   return Result;
 }
 
+/// Given a function, if it represents the entry point of a target kernel, this
+/// returns the execution mode flags associated to that kernel.
+static std::optional<omp::OMPTgtExecModeFlags>
+getTargetKernelExecMode(Function &Kernel) {
+  CallInst *TargetInitCall = nullptr;
+  for (Instruction &Inst : Kernel.getEntryBlock()) {
+    if (auto *Call = dyn_cast<CallInst>(&Inst)) {
+      if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
+        TargetInitCall = Call;
+        break;
+      }
+    }
+  }
+
+  if (!TargetInitCall)
+    return std::nullopt;
+
+  // Get the kernel mode information from the global variable associated to the
+  // first argument to the call to __kmpc_target_init. Refer to
+  // createTargetInit() to see how this is initialized.
+  Value *InitOperand = TargetInitCall->getArgOperand(0);
+  GlobalVariable *KernelEnv = nullptr;
+  if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
+    KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
+  else
+    KernelEnv = cast<GlobalVariable>(InitOperand);
+  auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
+  auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
+  auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
+  return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -812,15 +844,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
-                            /* AggregateArgs */ true,
-                            /* BlockFrequencyInfo */ nullptr,
-                            /* BranchProbabilityInfo */ nullptr,
-                            /* AssumptionCache */ nullptr,
-                            /* AllowVarArgs */ true,
-                            /* AllowAlloca */ true,
-                            /* AllocaBlock*/ OI.OuterAllocaBB,
-                            /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
+    CodeExtractor Extractor(
+        Blocks, /* DominatorTree */ nullptr,
+        /* AggregateArgs */ true,
+        /* BlockFrequencyInfo */ nullptr,
+        /* BranchProbabilityInfo */ nullptr,
+        /* AssumptionCache */ nullptr,
+        /* AllowVarArgs */ true,
+        /* AllowAlloca */ true,
+        /* AllocaBlock*/ OI.OuterAllocaBB,
+        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
+        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
+        /* DeallocationBlock */ OI.ExitBB,
+        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1726,6 +1762,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
                              ThreadID, ToBeDeletedVec);
     };
+
+    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+        getTargetKernelExecMode(*OuterFn);
+
+    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+    // the CodeExtractor to follow its default behavior. Otherwise, we need to
+    // use device shared memory to allocate argument structures.
+    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+      OI.CustomArgAllocatorCB = [this,
+                                 EntryBB](BasicBlock *, BasicBlock::iterator,
+                                          Type *ArgTy, const Twine &Name) {
+        // Instead of using the insertion point provided by the CodeExtractor,
+        // here we need to use the block that eventually calls the outlined
+        // function for the `parallel` construct.
+        //
+        // The reason is that the explicit deallocation call will be inserted
+        // within the outlined function, whereas the alloca insertion point
+        // might actually be located somewhere else in the caller. This becomes
+        // a problem when e.g. `parallel` is inside of a `distribute` 
construct,
+        // because the deallocation would be executed multiple times and the
+        // allocation just once (outside of the loop).
+        //
+        // TODO: Ideally, we'd want to do the allocation and deallocation
+        // outside of the `parallel` outlined function, hence using here the
+        // insertion point provided by the CodeExtractor. We can't do this at
+        // the moment because there is currently no way of passing an eligible
+        // insertion point for the explicit deallocation to the CodeExtractor,
+        // as that block is created (at least when nested inside of
+        // `distribute`) sometime after createParallel() completed, so it can't
+        // be stored in the OutlineInfo structure here.
+        //
+        // The current approach results in an explicit allocation and
+        // deallocation pair for each `distribute` loop iteration in that case,
+        // which is suboptimal.
+        return createOMPAllocShared(
+            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
+            Name);
+      };
+      OI.CustomArgDeallocatorCB =
+          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
+                 Type *ArgTy) -> Instruction * {
+        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
+      };
+    }
     OI.FixUpNonEntryAllocas = true;
   } else {
     // Generate OpenMP host runtime call
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp 
b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index a6ba42f5bec2a..b54876a6c7cf7 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -25,7 +25,6 @@
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
-#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -264,12 +263,18 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, 
DominatorTree *DT,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
                              BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace)
+                             bool ArgsInZeroAddressSpace,
+                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
+                             BasicBlock *DeallocationBlock,
+                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
       AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
+      CustomArgAllocatorCB(CustomArgAllocatorCB),
+      DeallocationBlock(DeallocationBlock),
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -1851,24 +1856,38 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    AllocaInst *alloca = new AllocaInst(
-        output->getType(), DL.getAllocaAddrSpace(), nullptr,
-        output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
-    params.push_back(alloca);
-    ReloadOutputs.push_back(alloca);
+    Value *OutAlloc;
+    if (CustomArgAllocatorCB)
+      OutAlloc = (*CustomArgAllocatorCB)(
+          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
+          output->getName() + ".loc");
+    else
+      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
+                                nullptr, output->getName() + ".loc",
+                                AllocaBlock->getFirstInsertionPt());
+
+    params.push_back(OutAlloc);
+    ReloadOutputs.push_back(OutAlloc);
   }
 
-  AllocaInst *Struct = nullptr;
+  Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                            "structArg", AllocaBlock->getFirstInsertionPt());
-    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-      auto *StructSpaceCast = new AddrSpaceCastInst(
-          Struct, PointerType ::get(Context, 0), "structArg.ascast");
-      StructSpaceCast->insertAfter(Struct->getIterator());
-      params.push_back(StructSpaceCast);
-    } else {
+    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
+    if (CustomArgAllocatorCB) {
+      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
+                                       "structArg");
       params.push_back(Struct);
+    } else {
+      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
+                              "structArg", StructArgIP);
+      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+        auto *StructSpaceCast = new AddrSpaceCastInst(
+            Struct, PointerType ::get(Context, 0), "structArg.ascast");
+        StructSpaceCast->insertAfter(Struct->getIterator());
+        params.push_back(StructSpaceCast);
+      } else {
+        params.push_back(Struct);
+      }
     }
 
     unsigned AggIdx = 0;
@@ -2012,6 +2031,26 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), 
LifetimesStart,
                                        {}, call);
 
+  // Deallocate variables that used a custom allocator.
+  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
+    BasicBlock *DeallocBlock = codeReplacer;
+    BasicBlock::iterator DeallocIP = codeReplacer->end();
+    if (DeallocationBlock) {
+      DeallocBlock = DeallocationBlock;
+      DeallocIP = DeallocationBlock->getFirstInsertionPt();
+    }
+
+    int Index = 0;
+    for (Value *Output : outputs) {
+      if (!StructValues.contains(Output))
+        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
+                                  ReloadOutputs[Index++], Output->getType());
+    }
+
+    if (Struct)
+      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  }
+
   return call;
 }
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir 
b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index cdb8dbbbc946c..6476632c58587 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -56,8 +56,6 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
 // CHECK:         %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to 
ptr
-// CHECK:         %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
-// CHECK:         %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) 
%[[STRUCTARG]] to ptr
 // CHECK:         %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
 // CHECK:         %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to 
ptr
 // CHECK:         store ptr %[[TMP0]], ptr %[[TMP4]], align 8
@@ -65,12 +63,14 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], 
label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
+// CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr 
@__kmpc_alloc_shared(i64 8)
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 
@__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] 
to ptr))
-// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) 
%[[STRUCTARG]], i32 0, i32 0
-// CHECK:         store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
+// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], 
i32 0, i32 0
+// CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr 
%[[TMP2]], i64 0, i64 0
-// CHECK:         store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
+// CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr 
addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, 
i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](

>From 532d20b958a52801ef8226d9baec754ba58fe377 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Tue, 5 Aug 2025 15:34:28 +0100
Subject: [PATCH 2/6] Address intermittent ICE triggered from the
 `OpenMPIRBuilder::finalize` method due to an invalid builder insertion point

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp    | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 016be6ca09667..716c054c700ca 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6769,6 +6769,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
 
 static LogicalResult
 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
+                         llvm::OpenMPIRBuilder *ompBuilder,
                          LLVM::ModuleTranslation &moduleTranslation) {
   // Amend omp.declare_target by deleting the IR of the outlined functions
   // created for target regions. They cannot be filtered out from MLIR earlier
@@ -6791,6 +6792,11 @@ convertDeclareTargetAttr(Operation *op, 
mlir::omp::DeclareTargetAttr attribute,
             moduleTranslation.lookupFunction(funcOp.getName());
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
+
+        // Invalidate the builder's current insertion point, as it now points 
to
+        // a deleted block.
+        ompBuilder->Builder.ClearInsertionPoint();
+        ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
       }
     }
     return success();
@@ -7334,9 +7340,12 @@ LogicalResult 
OpenMPDialectLLVMIRTranslationInterface::amendOperation(
       .Case("omp.declare_target",
             [&](Attribute attr) {
               if (auto declareTargetAttr =
-                      dyn_cast<omp::DeclareTargetAttr>(attr))
+                      dyn_cast<omp::DeclareTargetAttr>(attr)) {
+                llvm::OpenMPIRBuilder *ompBuilder =
+                    moduleTranslation.getOpenMPBuilder();
                 return convertDeclareTargetAttr(op, declareTargetAttr,
-                                                moduleTranslation);
+                                                ompBuilder, moduleTranslation);
+              }
               return failure();
             })
       .Case("omp.requires",

>From 1e82909bbc0b8fc9af51199d87f4553041d7be15 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Thu, 14 Aug 2025 12:40:09 +0100
Subject: [PATCH 3/6] Address nits

---
 .../llvm/Transforms/Utils/CodeExtractor.h     | 46 +++++++++----------
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  2 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   |  5 +-
 3 files changed, 28 insertions(+), 25 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h 
b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index dae74c412ba66..72e1c296c2fd5 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -98,15 +98,15 @@ class CodeExtractorAnalysisCache {
     BranchProbabilityInfo *BPI;
     AssumptionCache *AC;
 
-    // A block outside of the extraction set where any intermediate
-    // allocations will be placed inside. If this is null, allocations
-    // will be placed in the entry block of the function.
+    /// A block outside of the extraction set where any intermediate
+    /// allocations will be placed inside. If this is null, allocations
+    /// will be placed in the entry block of the function.
     BasicBlock *AllocationBlock;
 
-    // If true, varargs functions can be extracted.
+    /// If true, varargs functions can be extracted.
     bool AllowVarArgs;
 
-    // Bits of intermediate state computed at various phases of extraction.
+    /// Bits of intermediate state computed at various phases of extraction.
     SetVector<BasicBlock *> Blocks;
 
     /// Lists of blocks that are branched from the code region to be extracted,
@@ -128,32 +128,32 @@ class CodeExtractorAnalysisCache {
     /// returns 1, etc.
     SmallVector<BasicBlock *> ExtractedFuncRetVals;
 
-    // Suffix to use when creating extracted function (appended to the original
-    // function name + "."). If empty, the default is to use the entry block
-    // label, if non-empty, otherwise "extracted".
+    /// Suffix to use when creating extracted function (appended to the 
original
+    /// function name + "."). If empty, the default is to use the entry block
+    /// label, if non-empty, otherwise "extracted".
     std::string Suffix;
 
-    // If true, the outlined function has aggregate argument in zero address
-    // space.
+    /// If true, the outlined function has aggregate argument in zero address
+    /// space.
     bool ArgsInZeroAddressSpace;
 
-    // If set, this callback will be used to allocate the arguments in the
-    // caller before passing it to the outlined function holding the extracted
-    // piece of code.
+    /// If set, this callback will be used to allocate the arguments in the
+    /// caller before passing it to the outlined function holding the extracted
+    /// piece of code.
     CustomArgAllocatorCBTy *CustomArgAllocatorCB;
 
-    // A block outside of the extraction set where previously introduced
-    // intermediate allocations can be deallocated. This is only used when an
-    // custom deallocator is specified.
+    /// A block outside of the extraction set where previously introduced
+    /// intermediate allocations can be deallocated. This is only used when a
+    /// custom deallocator is specified.
     BasicBlock *DeallocationBlock;
 
-    // If set, this callback will be used to deallocate the arguments in the
-    // caller after running the outlined function holding the extracted piece 
of
-    // code. It will not be called if a custom allocator isn't also present.
-    //
-    // By default, this will be done at the end of the basic block containing
-    // the call to the outlined function, except if a deallocation block is
-    // specified. In that case, that will take precedence.
+    /// If set, this callback will be used to deallocate the arguments in the
+    /// caller after running the outlined function holding the extracted piece
+    /// of code. It will not be called if a custom allocator isn't also 
present.
+    ///
+    /// By default, this will be done at the end of the basic block containing
+    /// the call to the outlined function, except if a deallocation block is
+    /// specified. In that case, that will take precedence.
     CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
 
   public:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 53a32981fefae..ca58127e556c9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -290,7 +290,7 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool 
HasChunks,
 }
 
 /// Given a function, if it represents the entry point of a target kernel, this
-/// returns the execution mode flags associated to that kernel.
+/// returns the execution mode flags associated with that kernel.
 static std::optional<omp::OMPTgtExecModeFlags>
 getTargetKernelExecMode(Function &Kernel) {
   CallInst *TargetInitCall = nullptr;
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp 
b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index b54876a6c7cf7..c01c3d25952ec 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -274,7 +274,10 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, 
DominatorTree *DT,
       Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
       CustomArgAllocatorCB(CustomArgAllocatorCB),
       DeallocationBlock(DeallocationBlock),
-      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {
+  assert((!CustomArgDeallocatorCB || CustomArgAllocatorCB) &&
+         "custom deallocator only allowed if a custom allocator is provided");
+}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.

>From 447c073b9f89232d6b9aa1e5e7e74231ed3cb928 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Fri, 15 Aug 2025 15:45:14 +0100
Subject: [PATCH 4/6] Replace CodeExtractor callbacks with subclasses and
 simplify their creation based on OutlineInfo structures

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  23 +-
 .../llvm/Transforms/Utils/CodeExtractor.h     |  59 ++--
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 325 ++++++++++++------
 llvm/lib/Transforms/IPO/HotColdSplitting.cpp  |   1 +
 llvm/lib/Transforms/IPO/IROutliner.cpp        |   4 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   | 107 +++---
 .../Transforms/Utils/CodeExtractorTest.cpp    |   3 +-
 7 files changed, 312 insertions(+), 210 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 09a1fe7cebf4b..237155066eb5e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -33,6 +33,7 @@
 
 namespace llvm {
 class CanonicalLoopInfo;
+class CodeExtractor;
 class ScanInfo;
 struct TargetRegionEntryInfo;
 class OffloadEntriesInfoManager;
@@ -2440,30 +2441,34 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
-    using CustomArgAllocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
-    using CustomArgDeallocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
-    CustomArgAllocatorCBTy CustomArgAllocatorCB;
-    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
     SetVector<Value *> Inputs, Outputs;
     // TODO: this should be safe to enable by default
     bool FixUpNonEntryAllocas = false;
 
+    LLVM_ABI virtual ~OutlineInfo() = default;
+
     /// Collect all blocks in between EntryBB and ExitBB in both the given
     /// vector and set.
     LLVM_ABI void collectBlocks(SmallPtrSetImpl<BasicBlock *> &BlockSet,
                                 SmallVectorImpl<BasicBlock *> &BlockVector);
 
+    /// Create a CodeExtractor instance based on the information stored in this
+    /// structure, the list of collected blocks from a previous call to
+    /// \c collectBlocks and a flag stating whether arguments must be passed in
+    /// address space 0.
+    LLVM_ABI virtual std::unique_ptr<CodeExtractor>
+    createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                        bool ArgsInZeroAddressSpace, Twine Suffix = Twine(""));
+
     /// Return the function that contains the region to be outlined.
     Function *getFunction() const { return EntryBB->getParent(); }
   };
 
   /// Collection of regions that need to be outlined during finalization.
-  SmallVector<OutlineInfo, 16> OutlineInfos;
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> OutlineInfos;
 
   /// A collection of candidate target functions that's constant allocas will
   /// attempt to be raised on a call of finalize after all currently enqueued
@@ -2478,7 +2483,9 @@ class OpenMPIRBuilder {
   std::forward_list<ScanInfo> ScanInfos;
 
   /// Add a new region that will be outlined later.
-  void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); }
+  void addOutlineInfo(std::unique_ptr<OutlineInfo> &&OI) {
+    OutlineInfos.emplace_back(std::move(OI));
+  }
 
   /// An ordered map of auto-generated variables to their unique names.
   /// It stores variables with the following names: 1) ".gomp_critical_user_" +
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h 
b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 72e1c296c2fd5..70132e0f31cad 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -24,6 +24,7 @@
 namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
+class AddrSpaceCastInst;
 class AllocaInst;
 class BlockFrequency;
 class BlockFrequencyInfo;
@@ -85,10 +86,6 @@ class CodeExtractorAnalysisCache {
   /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
   ///    as arguments, and inserting stores to the arguments for any scalars.
   class CodeExtractor {
-    using CustomArgAllocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
-    using CustomArgDeallocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     using ValueSet = SetVector<Value *>;
 
     // Various bits of state computed on construction.
@@ -103,6 +100,14 @@ class CodeExtractorAnalysisCache {
     /// will be placed in the entry block of the function.
     BasicBlock *AllocationBlock;
 
+    /// A block outside of the extraction set where deallocations for
+    /// intermediate allocations can be placed inside. Not used for
+    /// automatically deallocated memory (e.g. `alloca`), which is the default.
+    ///
+    /// If it is null and needed, the end of the replacement basic block will 
be
+    /// used to place deallocations.
+    BasicBlock *DeallocationBlock;
+
     /// If true, varargs functions can be extracted.
     bool AllowVarArgs;
 
@@ -137,25 +142,6 @@ class CodeExtractorAnalysisCache {
     /// space.
     bool ArgsInZeroAddressSpace;
 
-    /// If set, this callback will be used to allocate the arguments in the
-    /// caller before passing it to the outlined function holding the extracted
-    /// piece of code.
-    CustomArgAllocatorCBTy *CustomArgAllocatorCB;
-
-    /// A block outside of the extraction set where previously introduced
-    /// intermediate allocations can be deallocated. This is only used when a
-    /// custom deallocator is specified.
-    BasicBlock *DeallocationBlock;
-
-    /// If set, this callback will be used to deallocate the arguments in the
-    /// caller after running the outlined function holding the extracted piece
-    /// of code. It will not be called if a custom allocator isn't also 
present.
-    ///
-    /// By default, this will be done at the end of the basic block containing
-    /// the call to the outlined function, except if a deallocation block is
-    /// specified. In that case, that will take precedence.
-    CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
-
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -169,12 +155,12 @@ class CodeExtractorAnalysisCache {
     /// however code extractor won't validate whether extraction is legal.
     /// Any new allocations will be placed in the AllocationBlock, unless
     /// it is null, in which case it will be placed in the entry block of
-    /// the function from which the code is being extracted.
+    /// the function from which the code is being extracted. Explicit
+    /// deallocations for the aforementioned allocations will be placed in the
+    /// DeallocationBlock or the end of the replacement block, if needed.
     /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
     /// param pointer of the outlined function is declared in zero address
-    /// space. If a CustomArgAllocatorCB callback is specified, it will be used
-    /// to allocate any structures or variable copies needed to pass arguments
-    /// to the outlined function, rather than using regular allocas.
+    /// space.
     LLVM_ABI
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = 
nullptr,
@@ -182,10 +168,10 @@ class CodeExtractorAnalysisCache {
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
-                  CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
                   BasicBlock *DeallocationBlock = nullptr,
-                  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = 
false);
+
+    LLVM_ABI virtual ~CodeExtractor() = default;
 
     /// Perform the extraction, returning the new function.
     ///
@@ -271,6 +257,19 @@ class CodeExtractorAnalysisCache {
     /// region, passing it instead as a scalar.
     LLVM_ABI void excludeArgFromAggregate(Value *Arg);
 
+  protected:
+    /// Allocate an intermediate variable at the specified point.
+    LLVM_ABI virtual Instruction *
+    allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP, Type *VarType,
+                const Twine &Name = Twine(""),
+                AddrSpaceCastInst **CastedAlloc = nullptr);
+
+    /// Deallocate a previously-allocated intermediate variable at the 
specified
+    /// point.
+    LLVM_ABI virtual Instruction *deallocateVar(BasicBlock *BB,
+                                                BasicBlock::iterator DeallocIP,
+                                                Value *Var, Type *VarType);
+
   private:
     struct LifetimeMarkerInfo {
       bool SinkLifeStart = false;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ca58127e556c9..abfba9ebe2302 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -488,6 +488,88 @@ enum OpenMPOffloadingRequiresDirFlags {
   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
 };
 
+class OMPCodeExtractor : public CodeExtractor {
+public:
+  OMPCodeExtractor(OpenMPIRBuilder &OMPBuilder, ArrayRef<BasicBlock *> BBs,
+                   DominatorTree *DT = nullptr, bool AggregateArgs = false,
+                   BlockFrequencyInfo *BFI = nullptr,
+                   BranchProbabilityInfo *BPI = nullptr,
+                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
+                   bool AllowAlloca = false,
+                   BasicBlock *AllocationBlock = nullptr,
+                   BasicBlock *DeallocationBlock = nullptr,
+                   std::string Suffix = "", bool ArgsInZeroAddressSpace = 
false)
+      : CodeExtractor(BBs, DT, AggregateArgs, BFI, BPI, AC, AllowVarArgs,
+                      AllowAlloca, AllocationBlock, DeallocationBlock, Suffix,
+                      ArgsInZeroAddressSpace),
+        OMPBuilder(OMPBuilder) {}
+
+  virtual ~OMPCodeExtractor() = default;
+
+protected:
+  OpenMPIRBuilder &OMPBuilder;
+};
+
+class DeviceSharedMemCodeExtractor : public OMPCodeExtractor {
+public:
+  DeviceSharedMemCodeExtractor(
+      OpenMPIRBuilder &OMPBuilder, BasicBlock *AllocBlockOverride,
+      ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
+      bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+      BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr,
+      bool AllowVarArgs = false, bool AllowAlloca = false,
+      BasicBlock *AllocationBlock = nullptr,
+      BasicBlock *DeallocationBlock = nullptr, std::string Suffix = "",
+      bool ArgsInZeroAddressSpace = false)
+      : OMPCodeExtractor(OMPBuilder, BBs, DT, AggregateArgs, BFI, BPI, AC,
+                         AllowVarArgs, AllowAlloca, AllocationBlock,
+                         DeallocationBlock, Suffix, ArgsInZeroAddressSpace),
+        AllocBlockOverride(AllocBlockOverride) {}
+  virtual ~DeviceSharedMemCodeExtractor() = default;
+
+protected:
+  virtual Instruction *
+  allocateVar(BasicBlock *, BasicBlock::iterator, Type *VarType,
+              const Twine &Name = Twine(""),
+              AddrSpaceCastInst **CastedAlloc = nullptr) override {
+    // Ignore the CastedAlloc pointer, if requested, because shared memory
+    // should not be casted to address space 0 to be passed around.
+    return OMPBuilder.createOMPAllocShared(
+        OpenMPIRBuilder::InsertPointTy(
+            AllocBlockOverride, AllocBlockOverride->getFirstInsertionPt()),
+        VarType, Name);
+  }
+
+  virtual Instruction *deallocateVar(BasicBlock *BB,
+                                     BasicBlock::iterator DeallocIP, Value 
*Var,
+                                     Type *VarType) override {
+    return OMPBuilder.createOMPFreeShared(
+        OpenMPIRBuilder::InsertPointTy(BB, DeallocIP), Var, VarType);
+  }
+
+private:
+  // TODO: Remove the need for this override and instead get the CodeExtractor
+  // to provide a valid insert point for explicit deallocations by correctly
+  // populating its DeallocationBlock.
+  BasicBlock *AllocBlockOverride;
+};
+
+/// Helper storing information about regions to outline using device shared
+/// memory for intermediate allocations.
+struct DeviceSharedMemOutlineInfo : public OpenMPIRBuilder::OutlineInfo {
+  OpenMPIRBuilder &OMPBuilder;
+  BasicBlock *AllocBlockOverride = nullptr;
+
+  DeviceSharedMemOutlineInfo(OpenMPIRBuilder &OMPBuilder)
+      : OMPBuilder(OMPBuilder) {}
+  virtual ~DeviceSharedMemOutlineInfo() = default;
+
+  virtual std::unique_ptr<CodeExtractor>
+  createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                      bool ArgsInZeroAddressSpace,
+                      Twine Suffix = Twine("")) override;
+};
+
 } // anonymous namespace
 
 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
@@ -822,20 +904,20 @@ static void 
hoistNonEntryAllocasToEntryBlock(llvm::BasicBlock &Block) {
 void OpenMPIRBuilder::finalize(Function *Fn) {
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  SmallVector<OutlineInfo, 16> DeferredOutlines;
-  for (OutlineInfo &OI : OutlineInfos) {
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> DeferredOutlines;
+  for (std::unique_ptr<OutlineInfo> &OI : OutlineInfos) {
     // Skip functions that have not finalized yet; may happen with nested
     // function generation.
-    if (Fn && OI.getFunction() != Fn) {
-      DeferredOutlines.push_back(OI);
+    if (Fn && OI->getFunction() != Fn) {
+      DeferredOutlines.push_back(std::move(OI));
       continue;
     }
 
     ParallelRegionBlockSet.clear();
     Blocks.clear();
-    OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+    OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
-    Function *OuterFn = OI.getFunction();
+    Function *OuterFn = OI->getFunction();
     CodeExtractorAnalysisCache CEAC(*OuterFn);
     // If we generate code for the target device, we need to allocate
     // struct for aggregate params in the device default alloca address space.
@@ -844,31 +926,20 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(
-        Blocks, /* DominatorTree */ nullptr,
-        /* AggregateArgs */ true,
-        /* BlockFrequencyInfo */ nullptr,
-        /* BranchProbabilityInfo */ nullptr,
-        /* AssumptionCache */ nullptr,
-        /* AllowVarArgs */ true,
-        /* AllowAlloca */ true,
-        /* AllocaBlock*/ OI.OuterAllocaBB,
-        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
-        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
-        /* DeallocationBlock */ OI.ExitBB,
-        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
+    std::unique_ptr<CodeExtractor> Extractor =
+        OI->createCodeExtractor(Blocks, ArgsInZeroAddressSpace, ".omp_par");
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
-    LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
-                      << " Exit: " << OI.ExitBB->getName() << "\n");
-    assert(Extractor.isEligible() &&
+    LLVM_DEBUG(dbgs() << "Entry " << OI->EntryBB->getName()
+                      << " Exit: " << OI->ExitBB->getName() << "\n");
+    assert(Extractor->isEligible() &&
            "Expected OpenMP outlining to be possible!");
 
-    for (auto *V : OI.ExcludeArgsFromAggregate)
-      Extractor.excludeArgFromAggregate(V);
+    for (auto *V : OI->ExcludeArgsFromAggregate)
+      Extractor->excludeArgFromAggregate(V);
 
     Function *OutlinedFn =
-        Extractor.extractCodeRegion(CEAC, OI.Inputs, OI.Outputs);
+        Extractor->extractCodeRegion(CEAC, OI->Inputs, OI->Outputs);
 
     // Forward target-cpu, target-features attributes to the outlined function.
     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
@@ -893,8 +964,8 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // made our own entry block after all.
     {
       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
-      assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
-      assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
+      assert(ArtificialEntry.getUniqueSuccessor() == OI->EntryBB);
+      assert(OI->EntryBB->getUniquePredecessor() == &ArtificialEntry);
       // Move instructions from the to-be-deleted ArtificialEntry to the entry
       // basic block of the parallel region. CodeExtractor generates
       // instructions to unwrap the aggregate argument and may sink
@@ -910,26 +981,27 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
 
         if (I.isTerminator()) {
           // Absorb any debug value that terminator may have
-          if (OI.EntryBB->getTerminator())
-            OI.EntryBB->getTerminator()->adoptDbgRecords(
+          if (OI->EntryBB->getTerminator())
+            OI->EntryBB->getTerminator()->adoptDbgRecords(
                 &ArtificialEntry, I.getIterator(), false);
           continue;
         }
 
-        I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
+        I.moveBeforePreserving(*OI->EntryBB,
+                               OI->EntryBB->getFirstInsertionPt());
       }
 
-      OI.EntryBB->moveBefore(&ArtificialEntry);
+      OI->EntryBB->moveBefore(&ArtificialEntry);
       ArtificialEntry.eraseFromParent();
     }
-    assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
+    assert(&OutlinedFn->getEntryBlock() == OI->EntryBB);
     assert(OutlinedFn && OutlinedFn->hasNUses(1));
 
     // Run a user callback, e.g. to add attributes.
-    if (OI.PostOutlineCB)
-      OI.PostOutlineCB(*OutlinedFn);
+    if (OI->PostOutlineCB)
+      OI->PostOutlineCB(*OutlinedFn);
 
-    if (OI.FixUpNonEntryAllocas) {
+    if (OI->FixUpNonEntryAllocas) {
       PostDominatorTree PostDomTree(*OutlinedFn);
       for (llvm::BasicBlock &BB : *OutlinedFn)
         if (PostDomTree.properlyDominates(&BB, &OutlinedFn->getEntryBlock()))
@@ -1753,26 +1825,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
 
   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
 
-  OutlineInfo OI;
-  if (Config.isTargetDevice()) {
-    // Generate OpenMP target specific runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
-      targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, 
Ident,
-                             IfCondition, NumThreads, PrivTID, PrivTIDAddr,
-                             ThreadID, ToBeDeletedVec);
-    };
+  auto OI = [&]() -> std::unique_ptr<OutlineInfo> {
+    if (Config.isTargetDevice()) {
+      std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+          getTargetKernelExecMode(*OuterFn);
 
-    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
-        getTargetKernelExecMode(*OuterFn);
+      // If OuterFn is not a Generic kernel, skip custom allocation. This 
causes
+      // the CodeExtractor to follow its default behavior. Otherwise, we need 
to
+      // use device shared memory to allocate argument structures.
+      if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+        auto Info = std::make_unique<DeviceSharedMemOutlineInfo>(*this);
 
-    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
-    // the CodeExtractor to follow its default behavior. Otherwise, we need to
-    // use device shared memory to allocate argument structures.
-    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
-      OI.CustomArgAllocatorCB = [this,
-                                 EntryBB](BasicBlock *, BasicBlock::iterator,
-                                          Type *ArgTy, const Twine &Name) {
         // Instead of using the insertion point provided by the CodeExtractor,
         // here we need to use the block that eventually calls the outlined
         // function for the `parallel` construct.
@@ -1796,34 +1859,38 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
         // The current approach results in an explicit allocation and
         // deallocation pair for each `distribute` loop iteration in that case,
         // which is suboptimal.
-        return createOMPAllocShared(
-            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
-            Name);
-      };
-      OI.CustomArgDeallocatorCB =
-          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
-                 Type *ArgTy) -> Instruction * {
-        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
-      };
+        Info->AllocBlockOverride = EntryBB;
+        return Info;
+      }
     }
-    OI.FixUpNonEntryAllocas = true;
+    return std::make_unique<OutlineInfo>();
+  }();
+
+  if (Config.isTargetDevice()) {
+    // Generate OpenMP target specific runtime call
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
+      targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, 
Ident,
+                             IfCondition, NumThreads, PrivTID, PrivTIDAddr,
+                             ThreadID, ToBeDeletedVec);
+    };
   } else {
     // Generate OpenMP host runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
     };
-    OI.FixUpNonEntryAllocas = true;
   }
-
-  OI.OuterAllocaBB = OuterAllocaBlock;
-  OI.EntryBB = PRegEntryBB;
-  OI.ExitBB = PRegExitBB;
+  
+  OI->FixUpNonEntryAllocas = true;
+  OI->OuterAllocaBB = OuterAllocaBlock;
+  OI->EntryBB = PRegEntryBB;
+  OI->ExitBB = PRegExitBB;
 
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
@@ -1834,6 +1901,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ OuterAllocaBlock,
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
 
   // Find inputs to, outputs from the code region.
@@ -1858,7 +1926,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
 
   auto PrivHelper = [&](Value &V) -> Error {
     if (&V == TIDAddr || &V == ZeroAddr) {
-      OI.ExcludeArgsFromAggregate.push_back(&V);
+      OI->ExcludeArgsFromAggregate.push_back(&V);
       return Error::success();
     }
 
@@ -2522,19 +2590,19 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createTask(
   if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = TaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
 
-  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
-                      Mergeable, Priority, EventHandle, TaskAllocaBB,
-                      ToBeDeleted](Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+                       Mergeable, Priority, EventHandle, TaskAllocaBB,
+                       ToBeDeleted](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
@@ -5900,19 +5968,19 @@ OpenMPIRBuilder::InsertPointTy 
OpenMPIRBuilder::applyWorkshareLoopTarget(
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
 
-  OutlineInfo OI;
-  OI.OuterAllocaBB = CLI->getPreheader();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->OuterAllocaBB = CLI->getPreheader();
   Function *OuterFn = CLI->getPreheader()->getParent();
 
   // Instructions which need to be deleted at the end of code generation
   SmallVector<Instruction *, 4> ToBeDeleted;
 
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Mark the body loop as region which needs to be extracted
-  OI.EntryBB = CLI->getBody();
-  OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
-                                               "omp.prelatch", true);
+  OI->EntryBB = CLI->getBody();
+  OI->ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
+                                                "omp.prelatch", true);
 
   // Prepare loop body for extraction
   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
@@ -5932,7 +6000,7 @@ OpenMPIRBuilder::InsertPointTy 
OpenMPIRBuilder::applyWorkshareLoopTarget(
   // loop body region.
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks,
@@ -5944,6 +6012,7 @@ OpenMPIRBuilder::InsertPointTy 
OpenMPIRBuilder::applyWorkshareLoopTarget(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ CLI->getPreheader(),
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_wsloop",
                           /* AggrArgsIn0AddrSpace */ true);
 
@@ -5968,15 +6037,15 @@ OpenMPIRBuilder::InsertPointTy 
OpenMPIRBuilder::applyWorkshareLoopTarget(
   }
   // Make sure that loop counter variable is not merged into loop body
   // function argument structure and it is passed as separate variable
-  OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
+  OI->ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
 
   // PostOutline CB is invoked when loop body function is outlined and
   // loop body is replaced by call to outlined function. We need to add
   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
   // function will handle loop control logic.
   //
-  OI.PostOutlineCB = [=, ToBeDeletedVec =
-                             std::move(ToBeDeleted)](Function &OutlinedFn) {
+  OI->PostOutlineCB = [=, ToBeDeletedVec =
+                              std::move(ToBeDeleted)](Function &OutlinedFn) {
     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec,
                                 LoopType, NoLoop);
   };
@@ -8801,13 +8870,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::emitTargetTask(
                                    TargetTaskAllocaBB->begin());
   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
 
-  OutlineInfo OI;
-  OI.EntryBB = TargetTaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TargetTaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", 
false));
 
   // Generate the task body which will subsequently be outlined.
@@ -8825,8 +8894,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::emitTargetTask(
   // OI.ExitBlock is set to the single task body block and will get left out of
   // the outlining process. So, simply create a new empty block to which we
   // uncoditionally branch from where TaskBodyCB left off
-  OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
-  emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(),
+  OI->ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
+  emitBlock(OI->ExitBB, Builder.GetInsertBlock()->getParent(),
             /*IsFinished=*/true);
 
   SmallVector<Value *, 2> OffloadingArraysToPrivatize;
@@ -8838,13 +8907,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::emitTargetTask(
           RTArgs.SizesArray}) {
       if (V && !isa<ConstantPointerNull, GlobalVariable>(V)) {
         OffloadingArraysToPrivatize.push_back(V);
-        OI.ExcludeArgsFromAggregate.push_back(V);
+        OI->ExcludeArgsFromAggregate.push_back(V);
       }
     }
   }
-  OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
-                      DeviceID, OffloadingArraysToPrivatize](
-                         Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
+                       DeviceID, OffloadingArraysToPrivatize](
+                          Function &OutlinedFn) mutable {
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
 
@@ -10824,17 +10893,17 @@ OpenMPIRBuilder::createTeams(const 
LocationDescription &Loc,
   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = AllocaBB;
-  OI.ExitBB = ExitBB;
-  OI.OuterAllocaBB = &OuterAllocaBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = AllocaBB;
+  OI->ExitBB = ExitBB;
+  OI->OuterAllocaBB = &OuterAllocaBB;
 
   // Insert fake values for global tid and bound tid.
   SmallVector<Instruction *, 8> ToBeDeleted;
   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
 
   auto HostPostOutlineCB = [this, Ident,
@@ -10875,7 +10944,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription 
&Loc,
   };
 
   if (!Config.isTargetDevice())
-    OI.PostOutlineCB = HostPostOutlineCB;
+    OI->PostOutlineCB = HostPostOutlineCB;
 
   addOutlineInfo(std::move(OI));
 
@@ -10914,11 +10983,10 @@ OpenMPIRBuilder::createDistribute(const 
LocationDescription &Loc,
   // When using target we use different runtime functions which require a
   // callback.
   if (Config.isTargetDevice()) {
-    OutlineInfo OI;
-    OI.OuterAllocaBB = OuterAllocaIP.getBlock();
-    OI.EntryBB = AllocaBB;
-    OI.ExitBB = ExitBB;
-
+    auto OI = std::make_unique<OutlineInfo>();
+    OI->OuterAllocaBB = OuterAllocaIP.getBlock();
+    OI->EntryBB = AllocaBB;
+    OI->ExitBB = ExitBB;
     addOutlineInfo(std::move(OI));
   }
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
@@ -10980,6 +11048,39 @@ void OpenMPIRBuilder::OutlineInfo::collectBlocks(
   }
 }
 
+std::unique_ptr<CodeExtractor>
+OpenMPIRBuilder::OutlineInfo::createCodeExtractor(ArrayRef<BasicBlock *> 
Blocks,
+                                                  bool ArgsInZeroAddressSpace,
+                                                  Twine Suffix) {
+  return std::make_unique<CodeExtractor>(Blocks, /* DominatorTree */ nullptr,
+                                         /* AggregateArgs */ true,
+                                         /* BlockFrequencyInfo */ nullptr,
+                                         /* BranchProbabilityInfo */ nullptr,
+                                         /* AssumptionCache */ nullptr,
+                                         /* AllowVarArgs */ true,
+                                         /* AllowAlloca */ true,
+                                         /* AllocationBlock*/ OuterAllocaBB,
+                                         /* DeallocationBlock */ nullptr,
+                                         /* Suffix */ Suffix.str(),
+                                         ArgsInZeroAddressSpace);
+}
+
+std::unique_ptr<CodeExtractor> DeviceSharedMemOutlineInfo::createCodeExtractor(
+    ArrayRef<BasicBlock *> Blocks, bool ArgsInZeroAddressSpace, Twine Suffix) {
+  // TODO: Initialize the DeallocationBlock with a proper pair to 
OuterAllocaBB.
+  return std::make_unique<DeviceSharedMemCodeExtractor>(
+      OMPBuilder, AllocBlockOverride, Blocks, /* DominatorTree */ nullptr,
+      /* AggregateArgs */ true,
+      /* BlockFrequencyInfo */ nullptr,
+      /* BranchProbabilityInfo */ nullptr,
+      /* AssumptionCache */ nullptr,
+      /* AllowVarArgs */ true,
+      /* AllowAlloca */ true,
+      /* AllocationBlock*/ OuterAllocaBB,
+      /* DeallocationBlock */ ExitBB,
+      /* Suffix */ Suffix.str(), ArgsInZeroAddressSpace);
+}
+
 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
                                          uint64_t Size, int32_t Flags,
                                          GlobalValue::LinkageTypes,
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp 
b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 3d8b7cbb59630..57809017a75a4 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -721,6 +721,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool 
HasProfileSummary) {
             SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr,
             /* BPI */ nullptr, AC, /* AllowVarArgs */ false,
             /* AllowAlloca */ false, /* AllocaBlock */ nullptr,
+            /* DeallocationBlock */ nullptr,
             /* Suffix */ "cold." + std::to_string(OutlinedFunctionID));
 
         if (CE.isEligible() && isSplittingBeneficial(CE, SubRegion, TTI) &&
diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp 
b/llvm/lib/Transforms/IPO/IROutliner.cpp
index 6e1ca9c4cd2d6..97ad82a627861 100644
--- a/llvm/lib/Transforms/IPO/IROutliner.cpp
+++ b/llvm/lib/Transforms/IPO/IROutliner.cpp
@@ -2825,7 +2825,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       findAddInputsOutputs(M, *OS, NotSame);
       if (!OS->IgnoreRegion)
         OutlinedRegions.push_back(OS);
@@ -2936,7 +2936,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       bool FunctionOutlined = extractSection(*OS);
       if (FunctionOutlined) {
         unsigned StartIdx = OS->Candidate->getStartIdx();
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp 
b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index c01c3d25952ec..08f8cfc1c55b2 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -262,22 +262,14 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, 
DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
-                             BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace,
-                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
-                             BasicBlock *DeallocationBlock,
-                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
+                             BasicBlock *AllocationBlock,
+                             BasicBlock *DeallocationBlock, std::string Suffix,
+                             bool ArgsInZeroAddressSpace)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
-      AllowVarArgs(AllowVarArgs),
+      DeallocationBlock(DeallocationBlock), AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
-      CustomArgAllocatorCB(CustomArgAllocatorCB),
-      DeallocationBlock(DeallocationBlock),
-      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {
-  assert((!CustomArgDeallocatorCB || CustomArgAllocatorCB) &&
-         "custom deallocator only allowed if a custom allocator is provided");
-}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -451,6 +443,27 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock 
*CommonExitBlock) {
   return CommonExitBlock;
 }
 
+Instruction *CodeExtractor::allocateVar(BasicBlock *BB,
+                                        BasicBlock::iterator AllocIP,
+                                        Type *VarType, const Twine &Name,
+                                        AddrSpaceCastInst **CastedAlloc) {
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  Instruction *Alloca =
+      new AllocaInst(VarType, DL.getAllocaAddrSpace(), nullptr, Name, AllocIP);
+
+  if (CastedAlloc && ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+    *CastedAlloc = new AddrSpaceCastInst(
+        Alloca, PointerType::get(BB->getContext(), 0), Name + ".ascast");
+    (*CastedAlloc)->insertAfter(Alloca->getIterator());
+  }
+  return Alloca;
+}
+
+Instruction *CodeExtractor::deallocateVar(BasicBlock *, BasicBlock::iterator,
+                                          Value *, Type *) {
+  return nullptr;
+}
+
 // Find the pair of life time markers for address 'Addr' that are either
 // defined inside the outline region or can legally be shrinkwrapped into the
 // outline region. If there are not other untracked uses of the address, return
@@ -1828,7 +1841,6 @@ CallInst *CodeExtractor::emitReplacerCall(
     std::vector<Value *> &Reloads) {
   LLVMContext &Context = oldFunction->getContext();
   Module *M = oldFunction->getParent();
-  const DataLayout &DL = M->getDataLayout();
 
   // This takes place of the original loop
   BasicBlock *codeReplacer =
@@ -1859,39 +1871,22 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    Value *OutAlloc;
-    if (CustomArgAllocatorCB)
-      OutAlloc = (*CustomArgAllocatorCB)(
-          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
-          output->getName() + ".loc");
-    else
-      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
-                                nullptr, output->getName() + ".loc",
-                                AllocaBlock->getFirstInsertionPt());
-
+    Value *OutAlloc =
+        allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                    output->getType(), output->getName() + ".loc");
     params.push_back(OutAlloc);
     ReloadOutputs.push_back(OutAlloc);
   }
 
   Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
-    if (CustomArgAllocatorCB) {
-      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
-                                       "structArg");
+    AddrSpaceCastInst *StructSpaceCast = nullptr;
+    Struct = allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                         StructArgTy, "structArg", &StructSpaceCast);
+    if (StructSpaceCast)
+      params.push_back(StructSpaceCast);
+    else
       params.push_back(Struct);
-    } else {
-      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                              "structArg", StructArgIP);
-      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-        auto *StructSpaceCast = new AddrSpaceCastInst(
-            Struct, PointerType ::get(Context, 0), "structArg.ascast");
-        StructSpaceCast->insertAfter(Struct->getIterator());
-        params.push_back(StructSpaceCast);
-      } else {
-        params.push_back(Struct);
-      }
-    }
 
     unsigned AggIdx = 0;
     for (Value *input : inputs) {
@@ -2034,26 +2029,24 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), 
LifetimesStart,
                                        {}, call);
 
-  // Deallocate variables that used a custom allocator.
-  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
-    BasicBlock *DeallocBlock = codeReplacer;
-    BasicBlock::iterator DeallocIP = codeReplacer->end();
-    if (DeallocationBlock) {
-      DeallocBlock = DeallocationBlock;
-      DeallocIP = DeallocationBlock->getFirstInsertionPt();
-    }
-
-    int Index = 0;
-    for (Value *Output : outputs) {
-      if (!StructValues.contains(Output))
-        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
-                                  ReloadOutputs[Index++], Output->getType());
-    }
+  // Deallocate intermediate variables if they need explicit deallocation.
+  BasicBlock *DeallocBlock = codeReplacer;
+  BasicBlock::iterator DeallocIP = codeReplacer->end();
+  if (DeallocationBlock) {
+    DeallocBlock = DeallocationBlock;
+    DeallocIP = DeallocationBlock->getFirstInsertionPt();
+  }
 
-    if (Struct)
-      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  int Index = 0;
+  for (Value *Output : outputs) {
+    if (!StructValues.contains(Output))
+      deallocateVar(DeallocBlock, DeallocIP, ReloadOutputs[Index++],
+                    Output->getType());
   }
 
+  if (Struct)
+    deallocateVar(DeallocBlock, DeallocIP, Struct, StructArgTy);
+
   return call;
 }
 
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp 
b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 90f06204ec9b3..8da41318dabf0 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -711,7 +711,8 @@ TEST(CodeExtractor, OpenMPAggregateArgs) {
                    /* AssumptionCache */ nullptr,
                    /* AllowVarArgs */ true,
                    /* AllowAlloca */ true,
-                   /* AllocaBlock*/ &Func->getEntryBlock(),
+                   /* AllocationBlock*/ &Func->getEntryBlock(),
+                   /* DeallocationBlock */ nullptr,
                    /* Suffix */ ".outlined",
                    /* ArgsInZeroAddressSpace */ true);
 

>From 978cfd5db0468336264c4f737ad042cfd0832585 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Thu, 22 Jan 2026 16:17:28 +0000
Subject: [PATCH 5/6] Update after rebase

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 28 +++++++++++------------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index abfba9ebe2302..76ac4d0483416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2276,15 +2276,15 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createTaskloop(
   }
 
   llvm::CanonicalLoopInfo *CLI = result.get();
-  OutlineInfo OI;
-  OI.EntryBB = TaskloopAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskloopExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskloopAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskloopExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *> ToBeDeleted;
   // dummy instruction to be used as a fake argument
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskloopAllocaIP, "global.tid", false));
   Value *FakeLB = createFakeIntVal(Builder, AllocaIP, ToBeDeleted,
                                    TaskloopAllocaIP, "lb", false, true);
@@ -2294,11 +2294,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createTaskloop(
                                      TaskloopAllocaIP, "step", false, true);
   // For Taskloop, we want to force the bounds being the first 3 inputs in the
   // aggregate struct
-  OI.Inputs.insert(FakeLB);
-  OI.Inputs.insert(FakeUB);
-  OI.Inputs.insert(FakeStep);
+  OI->Inputs.insert(FakeLB);
+  OI->Inputs.insert(FakeUB);
+  OI->Inputs.insert(FakeStep);
   if (TaskContextStructPtrVal)
-    OI.Inputs.insert(TaskContextStructPtrVal);
+    OI->Inputs.insert(TaskContextStructPtrVal);
   assert(
       (TaskContextStructPtrVal && DupCB) ||
       (!TaskContextStructPtrVal && !DupCB) &&
@@ -2321,11 +2321,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createTaskloop(
   }
   Value *TaskDupFn = *TaskDupFnOrErr;
 
-  OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
-                      TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
-                      IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
-                      FakeStep, Final, Mergeable,
-                      Priority](Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
+                       TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
+                       IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
+                       FakeStep, Final, Mergeable,
+                       Priority](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");

>From 5628add9d9b52e70b52590d1271152af20e1287f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Fri, 23 Jan 2026 14:23:54 +0000
Subject: [PATCH 6/6] Address formatting and ABI issues

---
 .../llvm/Transforms/Utils/CodeExtractor.h     | 47 +++++++++----------
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  2 +-
 2 files changed, 23 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h 
b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 70132e0f31cad..a44305e081588 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -85,7 +85,7 @@ class CodeExtractorAnalysisCache {
   ///    function to arguments.
   /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
   ///    as arguments, and inserting stores to the arguments for any scalars.
-  class CodeExtractor {
+  class LLVM_ABI CodeExtractor {
     using ValueSet = SetVector<Value *>;
 
     // Various bits of state computed on construction.
@@ -161,7 +161,6 @@ class CodeExtractorAnalysisCache {
     /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
     /// param pointer of the outlined function is declared in zero address
     /// space.
-    LLVM_ABI
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = 
nullptr,
                   BranchProbabilityInfo *BPI = nullptr,
@@ -171,14 +170,13 @@ class CodeExtractorAnalysisCache {
                   BasicBlock *DeallocationBlock = nullptr,
                   std::string Suffix = "", bool ArgsInZeroAddressSpace = 
false);
 
-    LLVM_ABI virtual ~CodeExtractor() = default;
+    virtual ~CodeExtractor() = default;
 
     /// Perform the extraction, returning the new function.
     ///
     /// Returns zero when called on a CodeExtractor instance where isEligible
     /// returns false.
-    LLVM_ABI Function *
-    extractCodeRegion(const CodeExtractorAnalysisCache &CEAC);
+    Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC);
 
     /// Perform the extraction, returning the new function and providing an
     /// interface to see what was categorized as inputs and outputs.
@@ -191,15 +189,15 @@ class CodeExtractorAnalysisCache {
     /// newly outlined function.
     /// \returns zero when called on a CodeExtractor instance where isEligible
     /// returns false.
-    LLVM_ABI Function *extractCodeRegion(const CodeExtractorAnalysisCache 
&CEAC,
-                                         ValueSet &Inputs, ValueSet &Outputs);
+    Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
+                                ValueSet &Inputs, ValueSet &Outputs);
 
     /// Verify that assumption cache isn't stale after a region is extracted.
     /// Returns true when verifier finds errors. AssumptionCache is passed as
     /// parameter to make this function stateless.
-    LLVM_ABI static bool verifyAssumptionCache(const Function &OldFunc,
-                                               const Function &NewFunc,
-                                               AssumptionCache *AC);
+    static bool verifyAssumptionCache(const Function &OldFunc,
+                                      const Function &NewFunc,
+                                      AssumptionCache *AC);
 
     /// Test whether this code extractor is eligible.
     ///
@@ -208,7 +206,7 @@ class CodeExtractorAnalysisCache {
     ///
     /// Checks that varargs handling (with vastart and vaend) is only done in
     /// the outlined blocks.
-    LLVM_ABI bool isEligible() const;
+    bool isEligible() const;
 
     /// Compute the set of input values and output values for the code.
     ///
@@ -218,15 +216,15 @@ class CodeExtractorAnalysisCache {
     /// a code sequence, that sequence is modified, including changing these
     /// sets, before extraction occurs. These modifications won't have any
     /// significant impact on the cost however.
-    LLVM_ABI void findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
-                                    const ValueSet &Allocas,
-                                    bool CollectGlobalInputs = false) const;
+    void findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
+                           const ValueSet &Allocas,
+                           bool CollectGlobalInputs = false) const;
 
     /// Check if life time marker nodes can be hoisted/sunk into the outline
     /// region.
     ///
     /// Returns true if it is safe to do the code motion.
-    LLVM_ABI bool
+    bool
     isLegalToShrinkwrapLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
                                        Instruction *AllocaAddr) const;
 
@@ -238,9 +236,9 @@ class CodeExtractorAnalysisCache {
     /// are used by the lifetime markers are also candidates for shrink-
     /// wrapping. The instructions that need to be sunk are collected in
     /// 'Allocas'.
-    LLVM_ABI void findAllocas(const CodeExtractorAnalysisCache &CEAC,
-                              ValueSet &SinkCands, ValueSet &HoistCands,
-                              BasicBlock *&ExitBlock) const;
+    void findAllocas(const CodeExtractorAnalysisCache &CEAC,
+                     ValueSet &SinkCands, ValueSet &HoistCands,
+                     BasicBlock *&ExitBlock) const;
 
     /// Find or create a block within the outline region for placing hoisted
     /// code.
@@ -250,25 +248,24 @@ class CodeExtractorAnalysisCache {
     /// inside the region that is the predecessor of CommonExitBlock, that 
block
     /// will be returned. Otherwise CommonExitBlock will be split and the
     /// original block will be added to the outline region.
-    LLVM_ABI BasicBlock *
-    findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
+    BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
 
     /// Exclude a value from aggregate argument passing when extracting a code
     /// region, passing it instead as a scalar.
-    LLVM_ABI void excludeArgFromAggregate(Value *Arg);
+    void excludeArgFromAggregate(Value *Arg);
 
   protected:
     /// Allocate an intermediate variable at the specified point.
-    LLVM_ABI virtual Instruction *
+    virtual Instruction *
     allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP, Type *VarType,
                 const Twine &Name = Twine(""),
                 AddrSpaceCastInst **CastedAlloc = nullptr);
 
     /// Deallocate a previously-allocated intermediate variable at the 
specified
     /// point.
-    LLVM_ABI virtual Instruction *deallocateVar(BasicBlock *BB,
-                                                BasicBlock::iterator DeallocIP,
-                                                Value *Var, Type *VarType);
+    virtual Instruction *deallocateVar(BasicBlock *BB,
+                                       BasicBlock::iterator DeallocIP,
+                                       Value *Var, Type *VarType);
 
   private:
     struct LifetimeMarkerInfo {
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 76ac4d0483416..8bd6d6b105cbb 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1882,7 +1882,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
     };
   }
-  
+
   OI->FixUpNonEntryAllocas = true;
   OI->OuterAllocaBB = OuterAllocaBlock;
   OI->EntryBB = PRegEntryBB;

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to