hliao created this revision.
hliao added reviewers: rsmith, rjmccall, tra, yaxunl.
Herald added a reviewer: martong.
Herald added a reviewer: shafik.
Herald added subscribers: cfe-commits, erik.pilkington.
Herald added a project: clang.

- On Windows, extended lambda has extra issues due to the numbering schemes are 
different between the host compilation (Microsoft C++ ABI) and the device 
compilation (Itanium C++ ABI. Additional device side lambda number is required 
per lambda for the host compilation to correctly mangle the device-side lambda 
name.
- A hybrid numbering context `MSHIPNumberingContext` is introduced to number a 
lambda for both host- and device-compilations.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D69322

Files:
  clang/include/clang/AST/DeclCXX.h
  clang/include/clang/AST/Mangle.h
  clang/include/clang/AST/MangleNumberingContext.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/ASTImporter.cpp
  clang/lib/AST/CXXABI.h
  clang/lib/AST/ItaniumCXXABI.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/AST/MicrosoftCXXABI.cpp
  clang/lib/CodeGen/CGCUDANV.cpp
  clang/lib/Sema/SemaLambda.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReaderDecl.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/CodeGenCUDA/unnamed-types.cu

Index: clang/test/CodeGenCUDA/unnamed-types.cu
===================================================================
--- clang/test/CodeGenCUDA/unnamed-types.cu
+++ clang/test/CodeGenCUDA/unnamed-types.cu
@@ -1,12 +1,17 @@
 // RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -aux-triple amdgcn-amd-amdhsa -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST
+// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-pc-windows-msvc -aux-triple amdgcn-amd-amdhsa -emit-llvm %s -o - | FileCheck %s --check-prefix=MSVC
 // RUN: %clang_cc1 -std=c++11 -x hip -triple amdgcn-amd-amdhsa -fcuda-is-device -emit-llvm %s -o - | FileCheck %s --check-prefix=DEVICE
 
 #include "Inputs/cuda.h"
 
 // HOST: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
+// HOST: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1
+// Check that, on MSVC, the same device kernel mangling name is generated.
+// MSVC: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
+// MSVC: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1
 
 __device__ float d0(float x) {
-  return [](float x) { return x + 2.f; }(x);
+  return [](float x) { return x + 1.f; }(x);
 }
 
 __device__ float d1(float x) {
@@ -14,11 +19,21 @@
 }
 
 // DEVICE: amdgpu_kernel void @_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_(
+// DEVICE: define internal float @_ZZZ2f1PfENKUlS_E_clES_ENKUlfE_clEf(
 template <typename F>
 __global__ void k0(float *p, F f) {
   p[0] = f(p[0]) + d0(p[1]) + d1(p[2]);
 }
 
+// DEVICE: amdgpu_kernel void @_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_(
+// DEVICE: define internal float @_ZZ2f1PfENKUlfE_clEf(
+// DEVICE: define internal float @_ZZ2f1PfENKUlffE_clEff(
+// DEVICE: define internal float @_ZZ2f1PfENKUlfE0_clEf(
+template <typename F0, typename F1, typename F2>
+__global__ void k1(float *p, F0 f0, F1 f1, F2 f2) {
+  p[0] = f0(p[0]) + f1(p[1], p[2]) + f2(p[3]);
+}
+
 void f0(float *p) {
   [](float *p) {
     *p = 1.f;
@@ -29,11 +44,17 @@
 // linkages are still required to keep the original `internal` linkage.
 
 // HOST: define internal void @_ZZ2f1PfENKUlS_E_clES_(
-// DEVICE: define internal float @_ZZZ2f1PfENKUlS_E_clES_ENKUlfE_clEf(
 void f1(float *p) {
   [](float *p) {
-    k0<<<1,1>>>(p, [] __device__ (float x) { return x + 1.f; });
+    k0<<<1,1>>>(p, [] __device__ (float x) { return x + 3.f; });
   }(p);
+  k1<<<1,1>>>(p,
+              [] __device__ (float x) { return x + 4.f; },
+              [] __device__ (float x, float y) { return x * y; },
+              [] __device__ (float x) { return x + 5.f; });
 }
 // HOST: @__hip_register_globals
 // HOST: __hipRegisterFunction{{.*}}@_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
+// HOST: __hipRegisterFunction{{.*}}@_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_{{.*}}@1
+// MSVC: __hipRegisterFunction{{.*}}@"??$k0@V<lambda_1>@?0???R1?0??f1@@YAXPEAM@Z@QEBA@0@Z@@@YAXPEAMV<lambda_1>@?0???R0?0??f1@@YAX0@Z@QEBA@0@Z@@Z{{.*}}@0
+// MSVC: __hipRegisterFunction{{.*}}@"??$k1@V<lambda_2>@?0??f1@@YAXPEAM@Z@V<lambda_3>@?0??2@YAX0@Z@V<lambda_4>@?0??2@YAX0@Z@@@YAXPEAMV<lambda_2>@?0??f1@@YAX0@Z@V<lambda_3>@?0??1@YAX0@Z@V<lambda_4>@?0??1@YAX0@Z@@Z{{.*}}@1
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6226,6 +6226,7 @@
     Record->push_back(Lambda.NumExplicitCaptures);
     Record->push_back(Lambda.HasKnownInternalLinkage);
     Record->push_back(Lambda.ManglingNumber);
+    Record->push_back(Lambda.DeviceManglingNumber);
     AddDeclRef(D->getLambdaContextDecl());
     AddTypeSourceInfo(Lambda.MethodTyInfo);
     for (unsigned I = 0, N = Lambda.NumCaptures; I != N; ++I) {
Index: clang/lib/Serialization/ASTReaderDecl.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderDecl.cpp
+++ clang/lib/Serialization/ASTReaderDecl.cpp
@@ -1692,6 +1692,7 @@
     Lambda.NumExplicitCaptures = Record.readInt();
     Lambda.HasKnownInternalLinkage = Record.readInt();
     Lambda.ManglingNumber = Record.readInt();
+    Lambda.DeviceManglingNumber = Record.readInt();
     Lambda.ContextDecl = ReadDeclID();
     Lambda.Captures = (Capture *)Reader.getContext().Allocate(
         sizeof(Capture) * Lambda.NumCaptures);
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -11497,10 +11497,11 @@
                                         E->getCaptureDefault());
   getDerived().transformedLocalDecl(OldClass, {Class});
 
-  Optional<std::tuple<unsigned, bool, Decl *>> Mangling;
+  Optional<std::tuple<bool, unsigned, unsigned, Decl *>> Mangling;
   if (getDerived().ReplacingOriginal())
-    Mangling = std::make_tuple(OldClass->getLambdaManglingNumber(),
-                               OldClass->hasKnownLambdaInternalLinkage(),
+    Mangling = std::make_tuple(OldClass->hasKnownLambdaInternalLinkage(),
+                               OldClass->getLambdaManglingNumber(),
+                               OldClass->getDeviceLambdaManglingNumber(),
                                OldClass->getLambdaContextDecl());
 
   // Build the call operator.
Index: clang/lib/Sema/SemaLambda.cpp
===================================================================
--- clang/lib/Sema/SemaLambda.cpp
+++ clang/lib/Sema/SemaLambda.cpp
@@ -431,15 +431,16 @@
 
 void Sema::handleLambdaNumbering(
     CXXRecordDecl *Class, CXXMethodDecl *Method,
-    Optional<std::tuple<unsigned, bool, Decl *>> Mangling) {
+    Optional<std::tuple<bool, unsigned, unsigned, Decl *>> Mangling) {
   if (Mangling) {
-    unsigned ManglingNumber;
     bool HasKnownInternalLinkage;
+    unsigned ManglingNumber, DeviceManglingNumber;
     Decl *ManglingContextDecl;
-    std::tie(ManglingNumber, HasKnownInternalLinkage, ManglingContextDecl) =
-        Mangling.getValue();
+    std::tie(HasKnownInternalLinkage, ManglingNumber, DeviceManglingNumber,
+             ManglingContextDecl) = Mangling.getValue();
     Class->setLambdaMangling(ManglingNumber, ManglingContextDecl,
                              HasKnownInternalLinkage);
+    Class->setDeviceLambdaManglingNumber(DeviceManglingNumber);
     return;
   }
 
@@ -475,6 +476,10 @@
     unsigned ManglingNumber = MCtx->getManglingNumber(Method);
     Class->setLambdaMangling(ManglingNumber, ManglingContextDecl,
                              HasKnownInternalLinkage);
+    if (MCtx->hasDeviceMangleNumberingContext()) {
+      Class->setDeviceLambdaManglingNumber(
+          MCtx->getDeviceManglingNumber(Method));
+    }
   }
 }
 
Index: clang/lib/CodeGen/CGCUDANV.cpp
===================================================================
--- clang/lib/CodeGen/CGCUDANV.cpp
+++ clang/lib/CodeGen/CGCUDANV.cpp
@@ -166,6 +166,10 @@
   CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
   VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
   VoidPtrPtrTy = VoidPtrTy->getPointerTo();
+
+  DeviceMC->setDeviceMangleContext(
+      CGM.getContext().getTargetInfo().getCXXABI().isMicrosoft() &&
+      CGM.getContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily());
 }
 
 llvm::FunctionCallee CGNVCUDARuntime::getSetupArgumentFn() const {
Index: clang/lib/AST/MicrosoftCXXABI.cpp
===================================================================
--- clang/lib/AST/MicrosoftCXXABI.cpp
+++ clang/lib/AST/MicrosoftCXXABI.cpp
@@ -15,6 +15,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/DeclCXX.h"
+#include "clang/AST/Mangle.h"
 #include "clang/AST/MangleNumberingContext.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/AST/Type.h"
@@ -63,6 +64,44 @@
   }
 };
 
+class MSHIPNumberingContext : public MangleNumberingContext {
+  MicrosoftNumberingContext HostCtx;
+  std::unique_ptr<MangleNumberingContext> DeviceCtx;
+
+public:
+  MSHIPNumberingContext(MangleContext *Mangler) {
+    DeviceCtx = createItaniumNumberingContext(Mangler);
+  }
+
+  unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
+    return HostCtx.getManglingNumber(CallOperator);
+  }
+
+  unsigned getManglingNumber(const BlockDecl *BD) override {
+    return HostCtx.getManglingNumber(BD);
+  }
+
+  unsigned getStaticLocalNumber(const VarDecl *VD) override {
+    return HostCtx.getStaticLocalNumber(VD);
+  }
+
+  unsigned getManglingNumber(const VarDecl *VD,
+                             unsigned MSLocalManglingNumber) override {
+    return HostCtx.getManglingNumber(VD, MSLocalManglingNumber);
+  }
+
+  unsigned getManglingNumber(const TagDecl *TD,
+                             unsigned MSLocalManglingNumber) override {
+    return HostCtx.getManglingNumber(TD, MSLocalManglingNumber);
+  }
+
+  bool hasDeviceMangleNumberingContext() const override { return true; }
+
+  unsigned getDeviceManglingNumber(const CXXMethodDecl *CallOperator) override {
+    return DeviceCtx->getManglingNumber(CallOperator);
+  }
+};
+
 class MicrosoftCXXABI : public CXXABI {
   ASTContext &Context;
   llvm::SmallDenseMap<CXXRecordDecl *, CXXConstructorDecl *> RecordToCopyCtor;
@@ -72,8 +111,17 @@
   llvm::SmallDenseMap<TagDecl *, TypedefNameDecl *>
       UnnamedTagDeclToTypedefNameDecl;
 
+  std::unique_ptr<MangleContext> Mangler;
+
 public:
-  MicrosoftCXXABI(ASTContext &Ctx) : Context(Ctx) { }
+  MicrosoftCXXABI(ASTContext &Ctx) : Context(Ctx) {
+    if (Context.getLangOpts().CUDA) {
+      assert(Context.getTargetInfo().getCXXABI().isMicrosoft() &&
+             Context.getAuxTargetInfo()->getCXXABI().isItaniumFamily() &&
+             "Unexpected C++ ABI combinations.");
+      Mangler.reset(Context.createMangleContext(Context.getAuxTargetInfo()));
+    }
+  }
 
   MemberPointerInfo
   getMemberPointerInfo(const MemberPointerType *MPT) const override;
@@ -132,6 +180,8 @@
 
   std::unique_ptr<MangleNumberingContext>
   createMangleNumberingContext() const override {
+    if (Context.getLangOpts().CUDA)
+      return std::make_unique<MSHIPNumberingContext>(Mangler.get());
     return std::make_unique<MicrosoftNumberingContext>();
   }
 };
@@ -260,4 +310,3 @@
 CXXABI *clang::CreateMicrosoftCXXABI(ASTContext &Ctx) {
   return new MicrosoftCXXABI(Ctx);
 }
-
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -122,6 +122,8 @@
   llvm::DenseMap<DiscriminatorKeyTy, unsigned> Discriminator;
   llvm::DenseMap<const NamedDecl*, unsigned> Uniquifier;
 
+  bool IsDevCtx = false;
+
 public:
   explicit ItaniumMangleContextImpl(ASTContext &Context,
                                     DiagnosticsEngine &Diags)
@@ -134,6 +136,10 @@
   bool shouldMangleStringLiteral(const StringLiteral *) override {
     return false;
   }
+
+  bool isDeviceMangleContext() const override { return IsDevCtx; }
+  void setDeviceMangleContext(bool IsDev) override { IsDevCtx = IsDev;}
+
   void mangleCXXName(const NamedDecl *D, raw_ostream &) override;
   void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk,
                    raw_ostream &) override;
@@ -1765,7 +1771,9 @@
   // (in lexical order) with that same <lambda-sig> and context.
   //
   // The AST keeps track of the number for us.
-  unsigned Number = Lambda->getLambdaManglingNumber();
+  unsigned Number = Context.isDeviceMangleContext()
+                        ? Lambda->getDeviceLambdaManglingNumber()
+                        : Lambda->getLambdaManglingNumber();
   assert(Number > 0 && "Lambda should be mangled as an unnamed class");
   if (Number > 1)
     mangleNumber(Number - 2);
Index: clang/lib/AST/ItaniumCXXABI.cpp
===================================================================
--- clang/lib/AST/ItaniumCXXABI.cpp
+++ clang/lib/AST/ItaniumCXXABI.cpp
@@ -258,3 +258,9 @@
 CXXABI *clang::CreateItaniumCXXABI(ASTContext &Ctx) {
   return new ItaniumCXXABI(Ctx);
 }
+
+std::unique_ptr<MangleNumberingContext>
+clang::createItaniumNumberingContext(MangleContext *Mangler) {
+  return std::make_unique<ItaniumNumberingContext>(
+      cast<ItaniumMangleContext>(Mangler));
+}
Index: clang/lib/AST/CXXABI.h
===================================================================
--- clang/lib/AST/CXXABI.h
+++ clang/lib/AST/CXXABI.h
@@ -22,8 +22,9 @@
 class CXXConstructorDecl;
 class DeclaratorDecl;
 class Expr;
-class MemberPointerType;
+class MangleContext;
 class MangleNumberingContext;
+class MemberPointerType;
 
 /// Implements C++ ABI-specific semantic analysis functions.
 class CXXABI {
@@ -75,6 +76,7 @@
 /// Creates an instance of a C++ ABI class.
 CXXABI *CreateItaniumCXXABI(ASTContext &Ctx);
 CXXABI *CreateMicrosoftCXXABI(ASTContext &Ctx);
+std::unique_ptr<MangleNumberingContext> createItaniumNumberingContext(MangleContext *);
 }
 
 #endif
Index: clang/lib/AST/ASTImporter.cpp
===================================================================
--- clang/lib/AST/ASTImporter.cpp
+++ clang/lib/AST/ASTImporter.cpp
@@ -2696,6 +2696,8 @@
         return CDeclOrErr.takeError();
       D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), *CDeclOrErr,
                                DCXX->hasKnownLambdaInternalLinkage());
+      D2CXX->setDeviceLambdaManglingNumber(
+          DCXX->getDeviceLambdaManglingNumber());
     } else if (DCXX->isInjectedClassName()) {
       // We have to be careful to do a similar dance to the one in
       // Sema::ActOnStartCXXMemberDeclarations
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -5936,7 +5936,7 @@
   /// Number lambda for linkage purposes if necessary.
   void handleLambdaNumbering(
       CXXRecordDecl *Class, CXXMethodDecl *Method,
-      Optional<std::tuple<unsigned, bool, Decl *>> Mangling = None);
+      Optional<std::tuple<bool, unsigned, unsigned, Decl *>> Mangling = None);
 
   /// Endow the lambda scope info with the relevant properties.
   void buildLambdaScope(sema::LambdaScopeInfo *LSI,
Index: clang/include/clang/AST/MangleNumberingContext.h
===================================================================
--- clang/include/clang/AST/MangleNumberingContext.h
+++ clang/include/clang/AST/MangleNumberingContext.h
@@ -16,6 +16,7 @@
 
 #include "clang/Basic/LLVM.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/Support/ErrorHandling.h"
 
 namespace clang {
 
@@ -52,6 +53,15 @@
   /// this context.
   virtual unsigned getManglingNumber(const TagDecl *TD,
                                      unsigned MSLocalManglingNumber) = 0;
+
+  /// Has device mangle numbering context.
+  virtual bool hasDeviceMangleNumberingContext() const { return false; }
+
+  /// Retrieve the mangling number of a new lambda expression with the
+  /// given call operator within the device context.
+  virtual unsigned getDeviceManglingNumber(const CXXMethodDecl *) {
+    llvm_unreachable("There's no device context associated!");
+  }
 };
 
 } // end namespace clang
Index: clang/include/clang/AST/Mangle.h
===================================================================
--- clang/include/clang/AST/Mangle.h
+++ clang/include/clang/AST/Mangle.h
@@ -95,6 +95,9 @@
   virtual bool shouldMangleCXXName(const NamedDecl *D) = 0;
   virtual bool shouldMangleStringLiteral(const StringLiteral *SL) = 0;
 
+  virtual bool isDeviceMangleContext() const { return false; }
+  virtual void setDeviceMangleContext(bool) {}
+
   // FIXME: consider replacing raw_ostream & with something like SmallString &.
   void mangleName(const NamedDecl *D, raw_ostream &);
   virtual void mangleCXXName(const NamedDecl *D, raw_ostream &) = 0;
Index: clang/include/clang/AST/DeclCXX.h
===================================================================
--- clang/include/clang/AST/DeclCXX.h
+++ clang/include/clang/AST/DeclCXX.h
@@ -396,6 +396,9 @@
     /// mangling in the Itanium C++ ABI.
     unsigned ManglingNumber : 31;
 
+    /// The device side mangling number.
+    unsigned DeviceManglingNumber = 0;
+
     /// The declaration that provides context for this lambda, if the
     /// actual DeclContext does not suffice. This is used for lambdas that
     /// occur within default arguments of function parameters within the class
@@ -1736,6 +1739,16 @@
     getLambdaData().HasKnownInternalLinkage = HasKnownInternalLinkage;
   }
 
+  /// Set the device side mangling number.
+  void setDeviceLambdaManglingNumber(unsigned Num) {
+    getLambdaData().DeviceManglingNumber = Num;
+  }
+
+  unsigned getDeviceLambdaManglingNumber() const {
+    assert(isLambda() && "Not a lambda closure type!");
+    return getLambdaData().DeviceManglingNumber;
+  }
+
   /// Returns the inheritance model used for this record.
   MSInheritanceAttr::Spelling getMSInheritanceModel() const;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to