llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clangir

@llvm/pr-subscribers-clang

Author: David Rivera (RiverDave)

<details>
<summary>Changes</summary>

Related: https://github.com/llvm/llvm-project/issues/179278, 
https://github.com/llvm/llvm-project/issues/175871,

NVPTX lowering info to map Lang -&gt; Target AS. Also handle poison attr 
lowering present in global vars (`__shared__` in this case). I've added the 
lowering pattern for it in the addr-space test suite we had. I've also adjusted 
the CUDA AS tests for further clarity.

---
Full diff: https://github.com/llvm/llvm-project/pull/186562.diff


6 Files Affected:

- (modified) clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt 
(+1) 
- (modified) clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp 
(+3) 
- (modified) 
clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetLoweringInfo.h (+2) 
- (added) clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/NVPTX.cpp 
(+39) 
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+15-8) 
- (modified) clang/test/CIR/CodeGenCUDA/address-spaces.cu (+36-19) 


``````````diff
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
index 07e3a67f97859..86502b7f5dd4e 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
@@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTargetLowering
   LowerItaniumCXXABI.cpp
   TargetLoweringInfo.cpp
   Targets/AMDGPU.cpp
+  Targets/NVPTX.cpp
 
   DEPENDS
   clangBasic
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp
index 26e63b3b676ae..6b6eec473ec89 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp
@@ -50,6 +50,9 @@ createTargetLoweringInfo(LowerModule &lm) {
   switch (triple.getArch()) {
   case llvm::Triple::amdgcn:
     return createAMDGPUTargetLoweringInfo();
+  case llvm::Triple::nvptx:
+  case llvm::Triple::nvptx64:
+    return createNVPTXTargetLoweringInfo();
   default:
     assert(!cir::MissingFeatures::targetLoweringInfo());
     return std::make_unique<TargetLoweringInfo>();
diff --git 
a/clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetLoweringInfo.h 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetLoweringInfo.h
index a307bcb373dec..2f778d8302f02 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetLoweringInfo.h
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetLoweringInfo.h
@@ -36,6 +36,8 @@ class TargetLoweringInfo {
 // Target-specific factory functions.
 std::unique_ptr<TargetLoweringInfo> createAMDGPUTargetLoweringInfo();
 
+std::unique_ptr<TargetLoweringInfo> createNVPTXTargetLoweringInfo();
+
 } // namespace cir
 
 #endif
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/NVPTX.cpp 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/NVPTX.cpp
new file mode 100644
index 0000000000000..f38d2b8bfa32d
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/NVPTX.cpp
@@ -0,0 +1,39 @@
+//===- NVPTX.cpp 
----------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "../TargetLoweringInfo.h"
+#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
+
+namespace cir {
+
+namespace {
+
+constexpr unsigned NVPTXAddrSpaceMap[] = {
+    llvm::NVPTXAS::ADDRESS_SPACE_GENERIC, llvm::NVPTXAS::ADDRESS_SPACE_GENERIC,
+    llvm::NVPTXAS::ADDRESS_SPACE_SHARED,  llvm::NVPTXAS::ADDRESS_SPACE_GLOBAL,
+    llvm::NVPTXAS::ADDRESS_SPACE_CONST,   llvm::NVPTXAS::ADDRESS_SPACE_GENERIC,
+};
+
+class NVPTXTargetLoweringInfo : public TargetLoweringInfo {
+public:
+  unsigned getTargetAddrSpaceFromCIRAddrSpace(
+      cir::LangAddressSpace addrSpace) const override {
+
+    auto idx = static_cast<unsigned>(addrSpace);
+    assert(idx < std::size(NVPTXAddrSpaceMap) &&
+           "Unknown CIR address space for NVPTX target");
+    return NVPTXAddrSpaceMap[idx];
+  }
+};
+
+} // namespace
+
+std::unique_ptr<TargetLoweringInfo> createNVPTXTargetLoweringInfo() {
+  return std::make_unique<NVPTXTargetLoweringInfo>();
+}
+} // namespace cir
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index b622fa1ef3205..c60f1276cf5f0 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -391,7 +391,7 @@ class CIRAttrToValue {
         .Case<cir::BoolAttr, cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
               cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
               cir::ConstPtrAttr, cir::GlobalViewAttr, cir::TypeInfoAttr,
-              cir::UndefAttr, cir::VTableAttr, cir::ZeroAttr>(
+              cir::UndefAttr, cir::PoisonAttr, cir::VTableAttr, cir::ZeroAttr>(
             [&](auto attrT) { return visitCirAttr(attrT); })
         .Default([&](auto attrT) { return mlir::Value(); });
   }
@@ -407,6 +407,7 @@ class CIRAttrToValue {
   mlir::Value visitCirAttr(cir::GlobalViewAttr attr);
   mlir::Value visitCirAttr(cir::TypeInfoAttr attr);
   mlir::Value visitCirAttr(cir::UndefAttr attr);
+  mlir::Value visitCirAttr(cir::PoisonAttr attr);
   mlir::Value visitCirAttr(cir::VTableAttr attr);
   mlir::Value visitCirAttr(cir::ZeroAttr attr);
 
@@ -768,6 +769,13 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::UndefAttr 
undefAttr) {
       rewriter, loc, converter->convertType(undefAttr.getType()));
 }
 
+/// PoisonAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::PoisonAttr poisonAttr) {
+  mlir::Location loc = parentOp->getLoc();
+  return mlir::LLVM::PoisonOp::create(
+      rewriter, loc, converter->convertType(poisonAttr.getType()));
+}
+
 // VTableAttr visitor.
 mlir::Value CIRAttrToValue::visitCirAttr(cir::VTableAttr vtableArr) {
   mlir::Type llvmTy = converter->convertType(vtableArr.getType());
@@ -2626,11 +2634,10 @@ 
CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
     cir::GlobalOp op, mlir::Attribute init,
     mlir::ConversionPatternRewriter &rewriter) const {
   // TODO: Generalize this handling when more types are needed here.
-  assert(
-      (isa<cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
-           cir::ConstPtrAttr, cir::ConstComplexAttr, cir::GlobalViewAttr,
-           cir::TypeInfoAttr, cir::UndefAttr, cir::VTableAttr, cir::ZeroAttr>(
-          init)));
+  assert((isa<cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
+              cir::ConstPtrAttr, cir::ConstComplexAttr, cir::GlobalViewAttr,
+              cir::TypeInfoAttr, cir::UndefAttr, cir::PoisonAttr,
+              cir::VTableAttr, cir::ZeroAttr>(init)));
 
   // TODO(cir): once LLVM's dialect has proper equivalent attributes this
   // should be updated. For now, we use a custom op to initialize globals
@@ -2691,8 +2698,8 @@ mlir::LogicalResult 
CIRToLLVMGlobalOpLowering::matchAndRewrite(
     } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
                          cir::ConstRecordAttr, cir::ConstPtrAttr,
                          cir::ConstComplexAttr, cir::GlobalViewAttr,
-                         cir::TypeInfoAttr, cir::UndefAttr, cir::VTableAttr,
-                         cir::ZeroAttr>(init.value())) {
+                         cir::TypeInfoAttr, cir::UndefAttr, cir::PoisonAttr,
+                         cir::VTableAttr, cir::ZeroAttr>(init.value())) {
       // TODO(cir): once LLVM's dialect has proper equivalent attributes this
       // should be updated. For now, we use a custom op to initialize globals
       // to the appropriate value.
diff --git a/clang/test/CIR/CodeGenCUDA/address-spaces.cu 
b/clang/test/CIR/CodeGenCUDA/address-spaces.cu
index a47a2867e7111..cc1791a8f2244 100644
--- a/clang/test/CIR/CodeGenCUDA/address-spaces.cu
+++ b/clang/test/CIR/CodeGenCUDA/address-spaces.cu
@@ -10,8 +10,9 @@
 // RUN:   -mmlir -mlir-print-ir-before=cir-target-lowering %s -o %t.cir 2> 
%t-pre.cir
 // RUN: FileCheck --check-prefix=CIR-PRE --input-file=%t-pre.cir %s
 
-// TODO: Add CIR (post target lowering) and LLVM checks once NVPTX 
TargetLoweringInfo
-// is implemented.
+// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -x cuda \
+// RUN:   -fcuda-is-device -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --check-prefix=CIR-POST --input-file=%t.cir %s
 
 // RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
 // RUN:            -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
@@ -32,26 +33,42 @@
 // CIR-DEVICE: cir.global "private" internal dso_local @_ZZ2fnvE1j = 
#cir.undef : !s32i {alignment = 4 : i64}
 // LLVM-DEVICE: @_ZZ2fnvE1j = internal global i32 undef, align 4
 
-__device__ int a;
-// CIR-PRE: cir.global external lang_address_space(offload_global) @a = 
#cir.int<0> : !s32i {alignment = 4 : i64, cu.externally_initialized = 
#cir.cu.externally_initialized}
-// LLVM-DEVICE: @[[DEV_LD:.*]] = externally_initialized global i32 0, align 4
-// OGCG-DAG: @a = addrspace(1) externally_initialized global i32 0, align 4
-// OGCG-DEVICE: @[[DEV_OD:.*]] = addrspace(1) externally_initialized global 
i32 0, align 4
-
-__constant__ int c;
-// CIR-PRE: cir.global constant external lang_address_space(offload_constant) 
@c = #cir.int<0> : !s32i {alignment = 4 : i64, cu.externally_initialized = 
#cir.cu.externally_initialized}
-// LLVM-DEVICE: @[[CONST_LL:.*]] = externally_initialized constant i32 0, 
align 4
-// OGCG-DAG: @c = addrspace(4) externally_initialized constant i32 0, align 4
-// OGCG-DEVICE: @[[CONST_OD:.*]] = addrspace(4) externally_initialized 
constant i32 0, align 4
-
-// OGCG-DEVICE: @_ZZ2fnvE1j = internal addrspace(3) global i32 undef, align 4
+// CIR-PRE: cir.global external  lang_address_space(offload_global) @i = 
#cir.int<0> : !s32i
+// CIR-POST: cir.global external  target_address_space(1) @i = #cir.int<0> : 
!s32i
+// LLVM-DEVICE-DAG: @i = addrspace(1) {{.*}}global i32 0, align 4
+// OGCG-DAG: @i = addrspace(1) externally_initialized global i32 0, align 4
+__device__ int i;
+
+// CIR-PRE: cir.global constant external  lang_address_space(offload_constant) 
@j = #cir.int<0> : !s32i
+// CIR-POST: cir.global constant external  target_address_space(4) @j = 
#cir.int<0> : !s32i
+// LLVM-DEVICE-DAG: @j = addrspace(4) {{.*}}constant i32 0, align 4
+// OGCG-DAG: @j = addrspace(4) externally_initialized constant i32 0, align 4
+__constant__ int j;
+
+// CIR-PRE: cir.global external  lang_address_space(offload_local) @k = 
#cir.poison : !s32i
+// CIR-POST: cir.global external  target_address_space(3) @k = #cir.poison : 
!s32i
+// LLVM-DEVICE-DAG: @k = addrspace(3) global i32 {{undef|poison}}, align 4
+// OGCG-DAG: @k = addrspace(3) global i32 undef, align 4
+__shared__ int k;
+
+// CIR-PRE: cir.global external  lang_address_space(offload_local) @b = 
#cir.poison : !cir.float
+// CIR-POST: cir.global external  target_address_space(3) @b = #cir.poison : 
!cir.float
+// LLVM-DEVICE-DAG: @b = addrspace(3) global float {{undef|poison}}, align 4
+// OGCG-DAG: @b = addrspace(3) global float undef, align 4
+__shared__ float b;
 
 __device__ void foo() {
-  // CIR-PRE: cir.get_global @a : !cir.ptr<!s32i, 
lang_address_space(offload_global)>
-  a++;
+  // CIR-PRE: cir.get_global @i : !cir.ptr<!s32i, 
lang_address_space(offload_global)>
+  // CIR-POST: cir.get_global @i : !cir.ptr<!s32i, target_address_space(1)>
+  i++;
+
+  // CIR-PRE: cir.get_global @j : !cir.ptr<!s32i, 
lang_address_space(offload_constant)>
+  // CIR-POST: cir.get_global @j : !cir.ptr<!s32i, target_address_space(4)>
+  j++;
 
-  // CIR-PRE: cir.get_global @c : !cir.ptr<!s32i, 
lang_address_space(offload_constant)>
-  c++;
+  // CIR-PRE: cir.get_global @k : !cir.ptr<!s32i, 
lang_address_space(offload_local)>
+  // CIR-POST: cir.get_global @k : !cir.ptr<!s32i, target_address_space(3)>
+  k++;
 }
 
 __global__ void fn() {

``````````

</details>


https://github.com/llvm/llvm-project/pull/186562
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to