yaxunl updated this revision to Diff 248927.
yaxunl added a comment.

update patch


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D68578/new/

https://reviews.llvm.org/D68578

Files:
  clang/include/clang/AST/GlobalDecl.h
  clang/lib/AST/Expr.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/AST/Mangle.cpp
  clang/lib/CodeGen/CGCUDANV.cpp
  clang/lib/CodeGen/CGCUDARuntime.h
  clang/lib/CodeGen/CGDecl.cpp
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CodeGenModule.cpp
  clang/lib/CodeGen/CodeGenModule.h
  clang/test/CodeGenCUDA/amdgpu-kernel-arg-pointer-type.cu
  clang/test/CodeGenCUDA/kernel-stub-name.cu
  clang/test/CodeGenCUDA/unnamed-types.cu

Index: clang/test/CodeGenCUDA/unnamed-types.cu
===================================================================
--- clang/test/CodeGenCUDA/unnamed-types.cu
+++ clang/test/CodeGenCUDA/unnamed-types.cu
@@ -36,4 +36,4 @@
   }(p);
 }
 // HOST: @__hip_register_globals
-// HOST: __hipRegisterFunction{{.*}}@_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
+// HOST: __hipRegisterFunction{{.*}}@_Z17__device_stub__k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
Index: clang/test/CodeGenCUDA/kernel-stub-name.cu
===================================================================
--- clang/test/CodeGenCUDA/kernel-stub-name.cu
+++ clang/test/CodeGenCUDA/kernel-stub-name.cu
@@ -6,15 +6,50 @@
 
 #include "Inputs/cuda.h"
 
+extern "C" __global__ void ckernel() {}
+
+namespace ns {
+__global__ void nskernel() {}
+} // namespace ns
+
 template<class T>
 __global__ void kernelfunc() {}
 
+__global__ void kernel_decl();
+
+// Device side kernel names
+
+// CHECK: @[[CKERN:[0-9]*]] = {{.*}} c"ckernel\00"
+// CHECK: @[[NSKERN:[0-9]*]] = {{.*}} c"_ZN2ns8nskernelEv\00"
+// CHECK: @[[TKERN:[0-9]*]] = {{.*}} c"_Z10kernelfuncIiEvv\00"
+
+// Non-template kernel stub functions
+
+// CHECK: define{{.*}}@[[CSTUB:__device_stub__ckernel]]
+// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[CSTUB]]
+// CHECK: define{{.*}}@[[NSSTUB:_ZN2ns23__device_stub__nskernelEv]]
+// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[NSSTUB]]
+
 // CHECK-LABEL: define{{.*}}@_Z8hostfuncv()
-// CHECK: call void @[[STUB:_Z10kernelfuncIiEvv.stub]]()
-void hostfunc(void) { kernelfunc<int><<<1, 1>>>(); }
+// CHECK: call void @[[CSTUB]]()
+// CHECK: call void @[[NSSTUB]]()
+// CHECK: call void @[[TSTUB:_Z25__device_stub__kernelfuncIiEvv]]()
+// CHECK: call void @[[DSTUB:_Z26__device_stub__kernel_declv]]()
+void hostfunc(void) {
+  ckernel<<<1, 1>>>();
+  ns::nskernel<<<1, 1>>>();
+  kernelfunc<int><<<1, 1>>>();
+  kernel_decl<<<1, 1>>>();
+}
+
+// Template kernel stub functions
+
+// CHECK: define{{.*}}@[[TSTUB]]
+// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[TSTUB]]
 
-// CHECK: define{{.*}}@[[STUB]]
-// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[STUB]]
+// CHECK: declare{{.*}}@[[DSTUB]]
 
 // CHECK-LABEL: define{{.*}}@__hip_register_globals
-// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[STUB]]
+// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[CSTUB]]{{.*}}@[[CKERN]]
+// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[NSSTUB]]{{.*}}@[[NSKERN]]
+// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[TSTUB]]{{.*}}@[[TKERN]]
Index: clang/test/CodeGenCUDA/amdgpu-kernel-arg-pointer-type.cu
===================================================================
--- clang/test/CodeGenCUDA/amdgpu-kernel-arg-pointer-type.cu
+++ clang/test/CodeGenCUDA/amdgpu-kernel-arg-pointer-type.cu
@@ -13,19 +13,19 @@
 // HOST-NOT: %struct.T.coerce
 
 // CHECK: define amdgpu_kernel void  @_Z7kernel1Pi(i32 addrspace(1)* %x.coerce)
-// HOST: define void @_Z7kernel1Pi.stub(i32* %x)
+// HOST: define void @_Z22__device_stub__kernel1Pi(i32* %x)
 __global__ void kernel1(int *x) {
   x[0]++;
 }
 
 // CHECK: define amdgpu_kernel void  @_Z7kernel2Ri(i32 addrspace(1)* dereferenceable(4) %x.coerce)
-// HOST: define void @_Z7kernel2Ri.stub(i32* dereferenceable(4) %x)
+// HOST: define void @_Z22__device_stub__kernel2Ri(i32* dereferenceable(4) %x)
 __global__ void kernel2(int &x) {
   x++;
 }
 
 // CHECK: define amdgpu_kernel void  @_Z7kernel3PU3AS2iPU3AS1i(i32 addrspace(2)* %x, i32 addrspace(1)* %y)
-// HOST: define void @_Z7kernel3PU3AS2iPU3AS1i.stub(i32 addrspace(2)* %x, i32 addrspace(1)* %y)
+// HOST: define void @_Z22__device_stub__kernel3PU3AS2iPU3AS1i(i32 addrspace(2)* %x, i32 addrspace(1)* %y)
 __global__ void kernel3(__attribute__((address_space(2))) int *x,
                         __attribute__((address_space(1))) int *y) {
   y[0] = x[0];
@@ -43,7 +43,7 @@
 // `by-val` struct will be coerced into a similar struct with all generic
 // pointers lowerd into global ones.
 // CHECK: define amdgpu_kernel void @_Z7kernel41S(%struct.S.coerce %s.coerce)
-// HOST: define void @_Z7kernel41S.stub(i32* %s.coerce0, float* %s.coerce1)
+// HOST: define void @_Z22__device_stub__kernel41S(i32* %s.coerce0, float* %s.coerce1)
 __global__ void kernel4(struct S s) {
   s.x[0]++;
   s.y[0] += 1.f;
@@ -51,7 +51,7 @@
 
 // If a pointer to struct is passed, only the pointer itself is coerced into the global one.
 // CHECK: define amdgpu_kernel void @_Z7kernel5P1S(%struct.S addrspace(1)* %s.coerce)
-// HOST: define void @_Z7kernel5P1S.stub(%struct.S* %s)
+// HOST: define void @_Z22__device_stub__kernel5P1S(%struct.S* %s)
 __global__ void kernel5(struct S *s) {
   s->x[0]++;
   s->y[0] += 1.f;
@@ -62,7 +62,7 @@
 };
 // `by-val` array is also coerced.
 // CHECK: define amdgpu_kernel void @_Z7kernel61T(%struct.T.coerce %t.coerce)
-// HOST: define void @_Z7kernel61T.stub(float* %t.coerce0, float* %t.coerce1)
+// HOST: define void @_Z22__device_stub__kernel61T(float* %t.coerce0, float* %t.coerce1)
 __global__ void kernel6(struct T t) {
   t.x[0][0] += 1.f;
   t.x[1][0] += 2.f;
Index: clang/lib/CodeGen/CodeGenModule.h
===================================================================
--- clang/lib/CodeGen/CodeGenModule.h
+++ clang/lib/CodeGen/CodeGenModule.h
@@ -710,6 +710,9 @@
   CtorList &getGlobalCtors() { return GlobalCtors; }
   CtorList &getGlobalDtors() { return GlobalDtors; }
 
+  /// get GlobalDecl for non-ctor/dtor functions.
+  GlobalDecl getGlobalDecl(const FunctionDecl *FD);
+
   /// getTBAATypeInfo - Get metadata used to describe accesses to objects of
   /// the given type.
   llvm::MDNode *getTBAATypeInfo(QualType QTy);
Index: clang/lib/CodeGen/CodeGenModule.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenModule.cpp
+++ clang/lib/CodeGen/CodeGenModule.cpp
@@ -1029,6 +1029,9 @@
     if (FD &&
         FD->getType()->castAs<FunctionType>()->getCallConv() == CC_X86RegCall) {
       Out << "__regcall3__" << II->getName();
+    } else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
+               GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
+      Out << "__device_stub__" << II->getName();
     } else {
       Out << II->getName();
     }
@@ -1116,11 +1119,25 @@
   const auto *ND = cast<NamedDecl>(GD.getDecl());
   std::string MangledName = getMangledNameImpl(*this, GD, ND);
 
-  // Adjust kernel stub mangling as we may need to be able to differentiate
-  // them from the kernel itself (e.g., for HIP).
-  if (auto *FD = dyn_cast<FunctionDecl>(GD.getDecl()))
-    if (!getLangOpts().CUDAIsDevice && FD->hasAttr<CUDAGlobalAttr>())
-      MangledName = getCUDARuntime().getDeviceStubName(MangledName);
+  // Ensure either we have different ABIs between host and device compilations,
+  // says host compilation following MSVC ABI but device compilation follows
+  // Itanium C++ ABI or, if they follow the same ABI, kernel names after
+  // mangling should be the same after name stubbing. The later checking is
+  // very important as the device kernel name being mangled in host-compilation
+  // is used to resolve the device binaries to be executed. Inconsistent naming
+  // result in undefined behavior. Even though we cannot check that naming
+  // directly between host- and device-compilations, the host- and
+  // device-mangling in host compilation could help catching certain ones.
+  assert(!isa<FunctionDecl>(ND) || !ND->hasAttr<CUDAGlobalAttr>() ||
+         getLangOpts().CUDAIsDevice ||
+         (getContext().getAuxTargetInfo() &&
+          (getContext().getAuxTargetInfo()->getCXXABI() !=
+           getContext().getTargetInfo().getCXXABI())) ||
+         getCUDARuntime().getDeviceSideName(cast<FunctionDecl>(ND)) ==
+             getMangledNameImpl(
+                 *this,
+                 GD.getWithKernelReferenceKind(KernelReferenceKind::Kernel),
+                 ND));
 
   auto Result = Manglings.insert(std::make_pair(MangledName, GD));
   return MangledDeclNames[CanonicalGD] = Result.first->first();
@@ -5271,7 +5288,7 @@
   case Decl::CXXConversion:
   case Decl::CXXMethod:
   case Decl::Function:
-    EmitGlobal(cast<FunctionDecl>(D));
+    EmitGlobal(getGlobalDecl(cast<FunctionDecl>(D)));
     // Always provide some coverage mapping
     // even for the functions that aren't emitted.
     AddDeferredUnusedCoverageMapping(D);
@@ -5922,3 +5939,10 @@
                                 "__translate_sampler_initializer"),
                                 {C});
 }
+
+GlobalDecl CodeGenModule::getGlobalDecl(const FunctionDecl *FD) {
+  if (FD->hasAttr<CUDAGlobalAttr>())
+    return GlobalDecl::getDefaultKernelReference(FD);
+  else
+    return GlobalDecl(FD);
+}
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -4669,12 +4669,12 @@
   // Resolve direct calls.
   } else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
     if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
-      return EmitDirectCallee(*this, FD);
+      return EmitDirectCallee(*this, CGM.getGlobalDecl(FD));
     }
   } else if (auto ME = dyn_cast<MemberExpr>(E)) {
     if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {
       EmitIgnoredExpr(ME->getBase());
-      return EmitDirectCallee(*this, FD);
+      return EmitDirectCallee(*this, CGM.getGlobalDecl(FD));
     }
 
   // Look through template substitutions.
Index: clang/lib/CodeGen/CGDecl.cpp
===================================================================
--- clang/lib/CodeGen/CGDecl.cpp
+++ clang/lib/CodeGen/CGDecl.cpp
@@ -297,7 +297,7 @@
   else if (const auto *DD = dyn_cast<CXXDestructorDecl>(DC))
     GD = GlobalDecl(DD, Dtor_Base);
   else if (const auto *FD = dyn_cast<FunctionDecl>(DC))
-    GD = GlobalDecl(FD);
+    GD = getGlobalDecl(FD);
   else {
     // Don't do anything for Obj-C method decls or global closures. We should
     // never defer them.
Index: clang/lib/CodeGen/CGCUDARuntime.h
===================================================================
--- clang/lib/CodeGen/CGCUDARuntime.h
+++ clang/lib/CodeGen/CGCUDARuntime.h
@@ -25,6 +25,7 @@
 namespace clang {
 
 class CUDAKernelCallExpr;
+class Decl;
 class VarDecl;
 
 namespace CodeGen {
@@ -65,9 +66,7 @@
   /// Returns a module cleanup function or nullptr if it's not needed.
   /// Must be called after ModuleCtorFunction
   virtual llvm::Function *makeModuleDtorFunction() = 0;
-
-  /// Construct and return the stub name of a kernel.
-  virtual std::string getDeviceStubName(llvm::StringRef Name) const = 0;
+  virtual std::string getDeviceSideName(const Decl *ND) = 0;
 };
 
 /// Creates an instance of a CUDA runtime class.
Index: clang/lib/CodeGen/CGCUDANV.cpp
===================================================================
--- clang/lib/CodeGen/CGCUDANV.cpp
+++ clang/lib/CodeGen/CGCUDANV.cpp
@@ -117,7 +117,7 @@
 
   void emitDeviceStubBodyLegacy(CodeGenFunction &CGF, FunctionArgList &Args);
   void emitDeviceStubBodyNew(CodeGenFunction &CGF, FunctionArgList &Args);
-  std::string getDeviceSideName(const Decl *ND);
+  std::string getDeviceSideName(const Decl *ND) override;
 
 public:
   CGNVCUDARuntime(CodeGenModule &CGM);
@@ -132,8 +132,6 @@
   llvm::Function *makeModuleCtorFunction() override;
   /// Creates module destructor function
   llvm::Function *makeModuleDtorFunction() override;
-  /// Construct and return the stub name of a kernel.
-  std::string getDeviceStubName(llvm::StringRef Name) const override;
 };
 
 }
@@ -206,11 +204,17 @@
 
 std::string CGNVCUDARuntime::getDeviceSideName(const Decl *D) {
   auto *ND = cast<const NamedDecl>(D);
+  GlobalDecl GD;
+  // D could be either a kernel or a variable.
+  if (auto *FD = dyn_cast<FunctionDecl>(D))
+    GD = GlobalDecl(FD, KernelReferenceKind::Kernel);
+  else
+    GD = GlobalDecl(ND);
   std::string DeviceSideName;
   if (DeviceMC->shouldMangleDeclName(ND)) {
     SmallString<256> Buffer;
     llvm::raw_svector_ostream Out(Buffer);
-    DeviceMC->mangleName(ND, Out);
+    DeviceMC->mangleName(GD, Out);
     DeviceSideName = std::string(Out.str());
   } else
     DeviceSideName = std::string(ND->getIdentifier()->getName());
@@ -219,21 +223,6 @@
 
 void CGNVCUDARuntime::emitDeviceStub(CodeGenFunction &CGF,
                                      FunctionArgList &Args) {
-  // Ensure either we have different ABIs between host and device compilations,
-  // says host compilation following MSVC ABI but device compilation follows
-  // Itanium C++ ABI or, if they follow the same ABI, kernel names after
-  // mangling should be the same after name stubbing. The later checking is
-  // very important as the device kernel name being mangled in host-compilation
-  // is used to resolve the device binaries to be executed. Inconsistent naming
-  // result in undefined behavior. Even though we cannot check that naming
-  // directly between host- and device-compilations, the host- and
-  // device-mangling in host compilation could help catching certain ones.
-  assert((CGF.CGM.getContext().getAuxTargetInfo() &&
-          (CGF.CGM.getContext().getAuxTargetInfo()->getCXXABI() !=
-           CGF.CGM.getContext().getTargetInfo().getCXXABI())) ||
-         getDeviceStubName(getDeviceSideName(CGF.CurFuncDecl)) ==
-             CGF.CurFn->getName());
-
   EmittedKernels.push_back({CGF.CurFn, CGF.CurFuncDecl});
   if (CudaFeatureEnabled(CGM.getTarget().getSDKVersion(),
                          CudaFeature::CUDA_USES_NEW_LAUNCH) ||
@@ -797,12 +786,6 @@
   return ModuleDtorFunc;
 }
 
-std::string CGNVCUDARuntime::getDeviceStubName(llvm::StringRef Name) const {
-  if (!CGM.getLangOpts().HIP)
-    return std::string(Name);
-  return (Name + ".stub").str();
-}
-
 CGCUDARuntime *CodeGen::CreateNVCUDARuntime(CodeGenModule &CGM) {
   return new CGNVCUDARuntime(CGM);
 }
Index: clang/lib/AST/Mangle.cpp
===================================================================
--- clang/lib/AST/Mangle.cpp
+++ clang/lib/AST/Mangle.cpp
@@ -426,6 +426,8 @@
         GD = GlobalDecl(CtorD, Ctor_Complete);
       else if (const auto *DtorD = dyn_cast<CXXDestructorDecl>(D))
         GD = GlobalDecl(DtorD, Dtor_Complete);
+      else if (D->hasAttr<CUDAGlobalAttr>())
+        GD = GlobalDecl::getDefaultKernelReference(cast<FunctionDecl>(D));
       else
         GD = GlobalDecl(D);
       MC->mangleName(GD, OS);
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -480,6 +480,7 @@
                                   const AbiTagList *AdditionalAbiTags);
   void mangleSourceName(const IdentifierInfo *II);
   void mangleRegCallName(const IdentifierInfo *II);
+  void mangleDeviceStubName(const IdentifierInfo *II);
   void mangleSourceNameWithAbiTags(
       const NamedDecl *ND, const AbiTagList *AdditionalAbiTags = nullptr);
   void mangleLocalName(GlobalDecl GD,
@@ -1307,7 +1308,12 @@
       bool IsRegCall = FD &&
                        FD->getType()->castAs<FunctionType>()->getCallConv() ==
                            clang::CC_X86RegCall;
-      if (IsRegCall)
+      bool IsDeviceStub =
+          FD && FD->hasAttr<CUDAGlobalAttr>() &&
+          GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
+      if (IsDeviceStub)
+        mangleDeviceStubName(II);
+      else if (IsRegCall)
         mangleRegCallName(II);
       else
         mangleSourceName(II);
@@ -1496,6 +1502,14 @@
       << II->getName();
 }
 
+void CXXNameMangler::mangleDeviceStubName(const IdentifierInfo *II) {
+  // <source-name> ::= <positive length number> __device_stub__ <identifier>
+  // <number> ::= [n] <non-negative decimal integer>
+  // <identifier> ::= <unqualified source code identifier>
+  Out << II->getLength() + sizeof("__device_stub__") - 1 << "__device_stub__"
+      << II->getName();
+}
+
 void CXXNameMangler::mangleSourceName(const IdentifierInfo *II) {
   // <source-name> ::= <positive length number> <identifier>
   // <number> ::= [n] <non-negative decimal integer>
@@ -1559,8 +1573,14 @@
     GD = GlobalDecl(CD, Ctor_Complete);
   else if (auto *DD = dyn_cast<CXXDestructorDecl>(DC))
     GD = GlobalDecl(DD, Dtor_Complete);
-  else
-    GD = GlobalDecl(cast<FunctionDecl>(DC));
+  else {
+    auto *FD = cast<FunctionDecl>(DC);
+    // Local variables can only exist in real kernels.
+    if (FD->hasAttr<CUDAGlobalAttr>())
+      GD = GlobalDecl(FD, KernelReferenceKind::Kernel);
+    else
+      GD = GlobalDecl(FD);
+  }
   return GD;
 }
 
Index: clang/lib/AST/Expr.cpp
===================================================================
--- clang/lib/AST/Expr.cpp
+++ clang/lib/AST/Expr.cpp
@@ -689,6 +689,8 @@
           GD = GlobalDecl(CD, Ctor_Base);
         else if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(ND))
           GD = GlobalDecl(DD, Dtor_Base);
+        else if (ND->hasAttr<CUDAGlobalAttr>())
+          GD = GlobalDecl::getDefaultKernelReference(cast<FunctionDecl>(ND));
         else
           GD = GlobalDecl(ND);
         MC->mangleName(GD, Out);
Index: clang/include/clang/AST/GlobalDecl.h
===================================================================
--- clang/include/clang/AST/GlobalDecl.h
+++ clang/include/clang/AST/GlobalDecl.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_CLANG_AST_GLOBALDECL_H
 #define LLVM_CLANG_AST_GLOBALDECL_H
 
+#include "clang/AST/Attr.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/DeclOpenMP.h"
@@ -33,6 +34,11 @@
   AtExit,
 };
 
+enum class KernelReferenceKind : unsigned {
+  Kernel = 0,
+  Stub = 1,
+};
+
 /// GlobalDecl - represents a global declaration. This can either be a
 /// CXXConstructorDecl and the constructor type (Base, Complete).
 /// a CXXDestructorDecl and the destructor type (Base, Complete),
@@ -52,6 +58,7 @@
   void Init(const Decl *D) {
     assert(!isa<CXXConstructorDecl>(D) && "Use other ctor with ctor decls!");
     assert(!isa<CXXDestructorDecl>(D) && "Use other ctor with dtor decls!");
+    assert(!D->hasAttr<CUDAGlobalAttr>() && "Use other ctor with HIP kernels!");
 
     Value.setPointer(D);
   }
@@ -73,6 +80,10 @@
   GlobalDecl(const CXXDestructorDecl *D, CXXDtorType Type) : Value(D, Type) {}
   GlobalDecl(const VarDecl *D, DynamicInitKind StubKind)
       : Value(D, unsigned(StubKind)) {}
+  GlobalDecl(const FunctionDecl *D, KernelReferenceKind Kind)
+      : Value(D, unsigned(Kind)) {
+    assert(D->hasAttr<CUDAGlobalAttr>() && "Decl is not a HIP kernel!");
+  }
 
   GlobalDecl getCanonicalDecl() const {
     GlobalDecl CanonGD;
@@ -103,13 +114,22 @@
   }
 
   unsigned getMultiVersionIndex() const {
-    assert(isa<FunctionDecl>(getDecl()) &&
+    assert(isa<FunctionDecl>(
+               getDecl()) &&
+               !cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
            !isa<CXXConstructorDecl>(getDecl()) &&
            !isa<CXXDestructorDecl>(getDecl()) &&
            "Decl is not a plain FunctionDecl!");
     return MultiVersionIndex;
   }
 
+  KernelReferenceKind getKernelReferenceKind() const {
+    assert(isa<FunctionDecl>(getDecl()) &&
+           cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
+           "Decl is not a HIP kernel!");
+    return static_cast<KernelReferenceKind>(Value.getInt());
+  }
+
   friend bool operator==(const GlobalDecl &LHS, const GlobalDecl &RHS) {
     return LHS.Value == RHS.Value &&
            LHS.MultiVersionIndex == RHS.MultiVersionIndex;
@@ -125,6 +145,12 @@
     return GD;
   }
 
+  static GlobalDecl getDefaultKernelReference(const FunctionDecl *D) {
+    return GlobalDecl(D, D->getASTContext().getLangOpts().CUDAIsDevice
+                             ? KernelReferenceKind::Kernel
+                             : KernelReferenceKind::Stub);
+  }
+
   GlobalDecl getWithDecl(const Decl *D) {
     GlobalDecl Result(*this);
     Result.Value.setPointer(D);
@@ -147,6 +173,7 @@
 
   GlobalDecl getWithMultiVersionIndex(unsigned Index) {
     assert(isa<FunctionDecl>(getDecl()) &&
+           !cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
            !isa<CXXConstructorDecl>(getDecl()) &&
            !isa<CXXDestructorDecl>(getDecl()) &&
            "Decl is not a plain FunctionDecl!");
@@ -154,6 +181,15 @@
     Result.MultiVersionIndex = Index;
     return Result;
   }
+
+  GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind) {
+    assert(isa<FunctionDecl>(getDecl()) &&
+           cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
+           "Decl is not a HIP kernel!");
+    GlobalDecl Result(*this);
+    Result.Value.setInt(unsigned(Kind));
+    return Result;
+  }
 };
 
 } // namespace clang
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to