Author: Jan Leyonberg
Date: 2026-04-01T12:50:09-04:00
New Revision: 91adaeceb162357a33e2ea6155cb13a4198a981a

URL: 
https://github.com/llvm/llvm-project/commit/91adaeceb162357a33e2ea6155cb13a4198a981a
DIFF: 
https://github.com/llvm/llvm-project/commit/91adaeceb162357a33e2ea6155cb13a4198a981a.diff

LOG: [CIR][MLIR][OpenMP] Enable the MarkDeclareTarget pass for ClangIR (#189420)

This patch enables the MarkDeclareTarget for CIR by adding the pass to
the lowerings and attaching the declare target interface to the
cir::FuncOp. The MarkDeclareTarget is also generalized to work on the
FunctionOpInterface instead of func::Op since it needs to be able to
handle cir::FuncOp as well.

Co-authored-by: Claude Opus 4.6 <[email protected]>

Added: 
    clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h
    clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt
    clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp
    clang/test/CIR/Transforms/omp-mark-declare-target.cir

Modified: 
    clang/lib/CIR/CodeGen/CIRGenerator.cpp
    clang/lib/CIR/CodeGen/CMakeLists.txt
    clang/lib/CIR/Dialect/CMakeLists.txt
    clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt
    clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
    clang/tools/cir-opt/CMakeLists.txt
    clang/tools/cir-opt/cir-opt.cpp
    clang/tools/cir-translate/CMakeLists.txt
    clang/tools/cir-translate/cir-translate.cpp
    mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
    mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h 
b/clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h
new file mode 100644
index 0000000000000..2247025a4433b
--- /dev/null
+++ b/clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h
@@ -0,0 +1,22 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
+#define CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
+
+namespace mlir {
+class DialectRegistry;
+} // namespace mlir
+
+namespace cir::omp {
+
+void registerOpenMPExtensions(mlir::DialectRegistry &registry);
+
+} // namespace cir::omp
+
+#endif // CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H

diff  --git a/clang/lib/CIR/CodeGen/CIRGenerator.cpp 
b/clang/lib/CIR/CodeGen/CIRGenerator.cpp
index 80f85169b73cb..31d40c21ef6e1 100644
--- a/clang/lib/CIR/CodeGen/CIRGenerator.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenerator.cpp
@@ -21,6 +21,7 @@
 #include "clang/CIR/CIRGenerator.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 #include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
+#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
 #include "llvm/IR/DataLayout.h"
 
 using namespace cir;
@@ -56,9 +57,10 @@ void CIRGenerator::Initialize(ASTContext &astContext) {
   mlirContext->getOrLoadDialect<mlir::acc::OpenACCDialect>();
   mlirContext->getOrLoadDialect<mlir::omp::OpenMPDialect>();
 
-  // Register extensions to integrate CIR types with OpenACC.
+  // Register extensions to integrate CIR types with OpenACC and OpenMP.
   mlir::DialectRegistry registry;
   cir::acc::registerOpenACCExtensions(registry);
+  cir::omp::registerOpenMPExtensions(registry);
   mlirContext->appendDialectRegistry(registry);
 
   cgm = std::make_unique<clang::CIRGen::CIRGenModule>(

diff  --git a/clang/lib/CIR/CodeGen/CMakeLists.txt 
b/clang/lib/CIR/CodeGen/CMakeLists.txt
index 0afff8ad7f555..3a2616fcd2526 100644
--- a/clang/lib/CIR/CodeGen/CMakeLists.txt
+++ b/clang/lib/CIR/CodeGen/CMakeLists.txt
@@ -65,6 +65,7 @@ add_clang_library(clangCIR
   clangLex
   ${dialect_libs}
   CIROpenACCSupport
+  CIROpenMPSupport
   MLIRCIR
   MLIRCIRInterfaces
   MLIRTargetLLVMIRImport

diff  --git a/clang/lib/CIR/Dialect/CMakeLists.txt 
b/clang/lib/CIR/Dialect/CMakeLists.txt
index c825a61b2779b..e05c9becebbad 100644
--- a/clang/lib/CIR/Dialect/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(OpenACC)
+add_subdirectory(OpenMP)
 add_subdirectory(Transforms)

diff  --git a/clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt 
b/clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt
new file mode 100644
index 0000000000000..f6f4017f0f1f6
--- /dev/null
+++ b/clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_clang_library(CIROpenMPSupport
+  RegisterOpenMPExtensions.cpp
+
+  DEPENDS
+  MLIRCIROpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRCIR
+  MLIROpenMPDialect
+  )

diff  --git a/clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp 
b/clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp
new file mode 100644
index 0000000000000..b5129202e66c4
--- /dev/null
+++ b/clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp
@@ -0,0 +1,26 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Registration for OpenMP extensions as applied to CIR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
+#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+
+namespace cir::omp {
+
+void registerOpenMPExtensions(mlir::DialectRegistry &registry) {
+  registry.addExtension(+[](mlir::MLIRContext *ctx, cir::CIRDialect *dialect) {
+    cir::FuncOp::attachInterface<
+        mlir::omp::DeclareTargetDefaultModel<cir::FuncOp>>(*ctx);
+  });
+}
+
+} // namespace cir::omp

diff  --git a/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt 
b/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt
index c7467fe40ba30..021397fee992b 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt
@@ -22,6 +22,7 @@ add_clang_library(clangCIRLoweringDirectToLLVM
   MLIRBuiltinToLLVMIRTranslation
   MLIRLLVMToLLVMIRTranslation
   MLIROpenMPToLLVMIRTranslation
+  MLIROpenMPTransforms
   MLIRIR
   )
 

diff  --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index ba89fbe3091bc..149cd90b813ec 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
 #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -4936,6 +4937,7 @@ std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
 
 void populateCIRToLLVMPasses(mlir::OpPassManager &pm) {
   mlir::populateCIRPreLoweringPasses(pm);
+  pm.addPass(mlir::omp::createMarkDeclareTargetPass());
   pm.addPass(createConvertCIRToLLVMPass());
 }
 

diff  --git a/clang/test/CIR/Transforms/omp-mark-declare-target.cir 
b/clang/test/CIR/Transforms/omp-mark-declare-target.cir
new file mode 100644
index 0000000000000..914589ec65bcf
--- /dev/null
+++ b/clang/test/CIR/Transforms/omp-mark-declare-target.cir
@@ -0,0 +1,53 @@
+// RUN: cir-opt --omp-mark-declare-target %s -o - | FileCheck %s
+
+// Test that the MarkDeclareTarget pass propagates the declare_target
+// attribute from explicitly marked functions to functions they call,
+// and from omp.target regions to functions called within them.
+
+!s32i = !cir.int<s, 32>
+
+module {
+  // A helper function with no declare_target attribute initially.
+  // After the pass, it should be marked because @caller calls it.
+  // CHECK-LABEL: cir.func @helper
+  // CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), 
capture_clause = (to)
+  cir.func @helper() {
+    cir.return
+  }
+
+  // Explicitly marked as declare_target; calls @helper.
+  // CHECK-LABEL: cir.func @caller
+  // CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), 
capture_clause = (to)>
+  cir.func @caller() attributes {omp.declare_target = 
#omp.declaretarget<device_type = (host), capture_clause = (to)>} {
+    cir.call @helper() : () -> ()
+    cir.return
+  }
+
+  // Called from within an omp.target region; should be marked as nohost.
+  // CHECK-LABEL: cir.func @device_helper
+  // CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = 
(nohost), capture_clause = (to)
+  cir.func @device_helper() {
+    cir.return
+  }
+
+  // Contains an omp.target region that calls @device_helper.
+  // The function itself should NOT be marked as declare_target.
+  // CHECK-LABEL: cir.func @target_caller
+  // CHECK-NOT: omp.declare_target
+  // CHECK-SAME: {
+  cir.func @target_caller() {
+    omp.target {
+      cir.call @device_helper() : () -> ()
+      omp.terminator
+    }
+    cir.return
+  }
+
+  // Not called by any declare_target function or target region.
+  // CHECK-LABEL: cir.func @unrelated
+  // CHECK-NOT: omp.declare_target
+  // CHECK-SAME: {
+  cir.func @unrelated() {
+    cir.return
+  }
+}

diff  --git a/clang/tools/cir-opt/CMakeLists.txt 
b/clang/tools/cir-opt/CMakeLists.txt
index cae7de6f056a9..4e9553ed8a7e7 100644
--- a/clang/tools/cir-opt/CMakeLists.txt
+++ b/clang/tools/cir-opt/CMakeLists.txt
@@ -23,6 +23,7 @@ clang_target_link_libraries(cir-opt
   PRIVATE
   clangCIR
   clangCIRLoweringDirectToLLVM
+  CIROpenMPSupport
   MLIRCIR
   MLIRCIRTransforms
 )
@@ -35,6 +36,7 @@ target_link_libraries(cir-opt
   MLIRDialect
   MLIRIR
   MLIRMemRefDialect
+  MLIROpenMPTransforms
   MLIROptLib
   MLIRParser
   MLIRPass

diff  --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp
index a24bf5d581af9..05e3b9ec7e964 100644
--- a/clang/tools/cir-opt/cir-opt.cpp
+++ b/clang/tools/cir-opt/cir-opt.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassOptions.h"
@@ -25,6 +26,7 @@
 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
 #include "mlir/Transforms/Passes.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
 #include "clang/CIR/Dialect/Passes.h"
 #include "clang/CIR/Passes.h"
 
@@ -37,6 +39,7 @@ int main(int argc, char **argv) {
   registry.insert<mlir::BuiltinDialect, cir::CIRDialect,
                   mlir::memref::MemRefDialect, mlir::LLVM::LLVMDialect,
                   mlir::DLTIDialect, mlir::omp::OpenMPDialect>();
+  cir::omp::registerOpenMPExtensions(registry);
 
   ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
     return mlir::createCIRCanonicalizePass();
@@ -71,6 +74,7 @@ int main(int argc, char **argv) {
     return mlir::createCXXABILoweringPass();
   });
 
+  mlir::omp::registerOpenMPPasses();
   mlir::registerTransformsPasses();
 
   return mlir::asMainReturnCode(MlirOptMain(

diff  --git a/clang/tools/cir-translate/CMakeLists.txt 
b/clang/tools/cir-translate/CMakeLists.txt
index 21834799ea82f..53e60220b8736 100644
--- a/clang/tools/cir-translate/CMakeLists.txt
+++ b/clang/tools/cir-translate/CMakeLists.txt
@@ -13,6 +13,7 @@ clang_target_link_libraries(cir-translate
   PRIVATE
   clangCIR
   clangCIRLoweringDirectToLLVM
+  CIROpenMPSupport
   MLIRCIR
   MLIRCIRTransforms
 )

diff  --git a/clang/tools/cir-translate/cir-translate.cpp 
b/clang/tools/cir-translate/cir-translate.cpp
index 2b00d1bd62e4a..997d44dc5a62f 100644
--- a/clang/tools/cir-translate/cir-translate.cpp
+++ b/clang/tools/cir-translate/cir-translate.cpp
@@ -31,6 +31,7 @@
 #include "clang/Basic/DiagnosticOptions.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
 #include "clang/CIR/Dialect/Passes.h"
 #include "clang/CIR/LowerToLLVM.h"
 #include "clang/CIR/MissingFeatures.h"
@@ -169,6 +170,7 @@ void registerToLLVMTranslation() {
         registry.insert<mlir::DLTIDialect, mlir::func::FuncDialect>();
         mlir::registerAllToLLVMIRTranslations(registry);
         cir::direct::registerCIRDialectTranslation(registry);
+        cir::omp::registerOpenMPExtensions(registry);
       });
 }
 

diff  --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt 
b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
index a46924cd9878e..9b11d4b87e8df 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -6,8 +6,8 @@ add_mlir_dialect_library(MLIROpenMPTransforms
   MLIROpenMPPassIncGen
 
   LINK_LIBS PUBLIC
+  MLIRFunctionInterfaces
   MLIRIR
-  MLIRFuncDialect
   MLIRLLVMDialect
   MLIROpenMPDialect
   MLIRPass

diff  --git a/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp 
b/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp
index 18a36f73edaf2..e3357e03d9c16 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp
@@ -10,10 +10,10 @@
 //
 
//===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -42,28 +42,30 @@ class MarkDeclareTargetPass
 
   void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
                         llvm::SmallPtrSet<Operation *, 16> visited) {
-    if (auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) {
-      auto current =
-          llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation());
-
-      if (current.isDeclareTarget()) {
-        auto currentDt = current.getDeclareTargetDeviceType();
-
-        // Found the same function twice, with 
diff erent device_types,
-        // mark as Any as it belongs to both
-        if (currentDt != parentInfo.devTy &&
-            currentDt != omp::DeclareTargetDeviceType::any) {
-          current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
-                                   current.getDeclareTargetCaptureClause(),
-                                   current.getDeclareTargetAutomap());
-        }
-      } else {
-        current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
-                                 parentInfo.automap);
-      }
+    Operation *symOp = getOperation().lookupSymbol(symRef);
+    if (!symOp)
+      return;
+    auto current = llvm::dyn_cast<omp::DeclareTargetInterface>(symOp);
+    if (!current)
+      return;
+
+    if (current.isDeclareTarget()) {
+      auto currentDt = current.getDeclareTargetDeviceType();
 
-      markNestedFuncs(parentInfo, currFOp, visited);
+      // Found the same function twice, with 
diff erent device_types,
+      // mark as Any as it belongs to both
+      if (currentDt != parentInfo.devTy &&
+          currentDt != omp::DeclareTargetDeviceType::any) {
+        current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
+                                 current.getDeclareTargetCaptureClause(),
+                                 current.getDeclareTargetAutomap());
+      }
+    } else {
+      current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
+                               parentInfo.automap);
     }
+
+    markNestedFuncs(parentInfo, symOp, visited);
   }
 
   void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
@@ -138,16 +140,16 @@ class MarkDeclareTargetPass
   // as implicitly declare target if they are called from within an explicitly
   // marked declare target function or a target region (TargetOp)
   void runOnOperation() override {
-    for (auto functionOp : getOperation().getOps<func::FuncOp>()) {
-      auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>(
-          functionOp.getOperation());
-      if (declareTargetOp.isDeclareTarget()) {
-        llvm::SmallPtrSet<Operation *, 16> visited;
-        ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
-                              declareTargetOp.getDeclareTargetCaptureClause(),
-                              declareTargetOp.getDeclareTargetAutomap()};
-        markNestedFuncs(parentInfo, functionOp, visited);
-      }
+    for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
+      auto declareTargetOp =
+          llvm::dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
+      if (!declareTargetOp || !declareTargetOp.isDeclareTarget())
+        continue;
+      llvm::SmallPtrSet<Operation *, 16> visited;
+      ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
+                            declareTargetOp.getDeclareTargetCaptureClause(),
+                            declareTargetOp.getDeclareTargetAutomap()};
+      markNestedFuncs(parentInfo, funcOp, visited);
     }
 
     // TODO: Extend to work with reverse-offloading, this shouldn't


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to