jhuber6 created this revision.
jhuber6 added reviewers: jdoerfert, JonChesterfield, tra, yaxunl.
Herald added a subscriber: carlosgalvezp.
Herald added a project: All.
jhuber6 requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

This patch adds the necessary code generation to create the wrapper code
that registers all the globals in CUDA. We create the necessary
functions and iterate through the list of
`__start_cuda_offloading_entries` to find which globals must be
registered. This is very similar to the code generation done currently
in Clang for non-rdc builds, but here we are registering a fully linked
fatbinary and finding the globals via the above sections.

With this we should be able to fully support basic RDC / LTO building of CUDA
code.

It's also worth noting that this does not include the necessary PTX to JIT the
image, so to use this support the offloading architecture must match the
system's architecture.

Depends on D123810 <https://reviews.llvm.org/D123810>


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D123812

Files:
  clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
  clang/tools/clang-linker-wrapper/OffloadWrapper.cpp

Index: clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
===================================================================
--- clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
+++ clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
@@ -20,6 +20,8 @@
 using namespace llvm;
 
 namespace {
+/// Magic number that begins the section containing the CUDA fatbinary.
+constexpr unsigned CudaFatMagic = 0x466243b1;
 
 IntegerType *getSizeTTy(Module &M) {
   LLVMContext &C = M.getContext();
@@ -255,6 +257,265 @@
   appendToGlobalDtors(M, Func, /*Priority*/ 1);
 }
 
+// struct fatbin_wrapper {
+//  int32_t magic;
+//  int32_t version;
+//  void *image;
+//  void *reserved;
+//};
+StructType *getFatbinWrapperTy(Module &M) {
+  LLVMContext &C = M.getContext();
+  StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
+  if (!FatbinTy)
+    FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
+                                  Type::getInt32Ty(C), Type::getInt8PtrTy(C),
+                                  Type::getInt8PtrTy(C));
+  return FatbinTy;
+}
+
+/// Embed the image \p Image into the module \p M so it can be found by the
+/// runtime.
+GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image) {
+  LLVMContext &C = M.getContext();
+  llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
+  llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
+
+  // Create the global string containing the fatbinary.
+  StringRef FatbinConstantSection =
+      Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin";
+  auto *Data = ConstantDataArray::get(C, Image);
+  auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
+                                    GlobalVariable::InternalLinkage, Data,
+                                    ".fatbin_image");
+  Fatbin->setSection(FatbinConstantSection);
+
+  // Create the fatbinary wrapper
+  StringRef FatbinWrapperSection =
+      Triple.isMacOSX() ? "__NV_CUDA,__fatbin" : ".nvFatBinSegment";
+  Constant *FatbinWrapper[] = {
+      ConstantInt::get(Type::getInt32Ty(C), CudaFatMagic),
+      ConstantInt::get(Type::getInt32Ty(C), 1),
+      ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
+      ConstantPointerNull::get(Type::getInt8PtrTy(C))};
+
+  Constant *FatbinInitializer =
+      ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
+
+  auto *FatbinDesc =
+      new GlobalVariable(M, getFatbinWrapperTy(M),
+                         /*isConstant*/ true, GlobalValue::InternalLinkage,
+                         FatbinInitializer, ".fatbin_wrapper");
+  FatbinDesc->setSection(FatbinWrapperSection);
+  FatbinDesc->setAlignment(Align(8));
+
+  // We create a dummy entry to ensure the linker will define the begin / end
+  // symbols. The CUDA runtime should ignore the null address if we attempt to
+  // register it.
+  auto *DummyInit =
+      ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
+  auto *DummyEntry = new GlobalVariable(
+      M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
+      "__dummy.cuda_offloading.entry");
+  DummyEntry->setSection("cuda_offloading_entries");
+  DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
+
+  return FatbinDesc;
+}
+
+/// Create the register globals function. We will iterate all of the offloading
+/// entries stored at the begin / end symbols and register them according to
+/// their type. This creates the following function in IR:
+///
+/// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
+/// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
+///
+/// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
+///                                    void *, void *, void *, void *, int *);
+/// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
+///                               int64_t, int32_t, int32_t);
+///
+/// void __cudaRegisterTest(void **fatbinHandle) {
+///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
+///        entry != &__stop_cuda_offloading_entries; ++entry) {
+///     if (!entry->size)
+///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
+///                              entry->name, -1, 0, 0, 0, 0, 0);
+///     else
+///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
+///                         0, entry->size, 0, 0);
+///   }
+/// }
+///
+/// TODO: This only registers functions are variables. Additional support is
+///       required for texture / surface / managed variables.
+Function *createRegisterGlobalsFunction(Module &M) {
+  LLVMContext &C = M.getContext();
+  // Get the __cudaRegisterFunction function declaration.
+  auto *RegFuncTy = FunctionType::get(
+      Type::getInt32Ty(C),
+      {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
+       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
+       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
+       Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
+      /*isVarArg*/ false);
+  FunctionCallee RegFunc =
+      M.getOrInsertFunction("__cudaRegisterFunction", RegFuncTy);
+
+  // Get the __cudaRegisterVar function declaration.
+  auto *RegVarTy = FunctionType::get(
+      Type::getInt32Ty(C),
+      {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
+       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
+       getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
+      /*isVarArg*/ false);
+  FunctionCallee RegVar = M.getOrInsertFunction("__cudaRegisterVar", RegVarTy);
+
+  // Create the references to the start / stop symbols defined by the linker.
+  auto *EntriesB = new GlobalVariable(
+      M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
+      /*Initializer*/ nullptr, "__start_cuda_offloading_entries");
+  EntriesB->setVisibility(GlobalValue::HiddenVisibility);
+  auto *EntriesE = new GlobalVariable(
+      M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
+      /*Initializer*/ nullptr, "__stop_cuda_offloading_entries");
+  EntriesE->setVisibility(GlobalValue::HiddenVisibility);
+
+  auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
+                                         Type::getInt8PtrTy(C)->getPointerTo(),
+                                         /*isVarArg*/ false);
+  auto *RegGlobalsFn = Function::Create(
+      RegGlobalsTy, GlobalValue::InternalLinkage, ".cuda.globals_reg", &M);
+  RegGlobalsFn->setSection(".text.startup");
+
+  // Create the loop to register all the entries.
+  IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
+  auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
+  auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
+  auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
+  auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
+  auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
+
+  Builder.CreateBr(EntryBB);
+  Builder.SetInsertPoint(EntryBB);
+  auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
+  auto *AddrPtr =
+      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
+                                {ConstantInt::get(getSizeTTy(M), 0),
+                                 ConstantInt::get(Type::getInt32Ty(C), 0)});
+  auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
+  auto *NamePtr =
+      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
+                                {ConstantInt::get(getSizeTTy(M), 0),
+                                 ConstantInt::get(Type::getInt32Ty(C), 1)});
+  auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
+  auto *SizePtr =
+      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
+                                {ConstantInt::get(getSizeTTy(M), 0),
+                                 ConstantInt::get(Type::getInt32Ty(C), 2)});
+  auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
+  auto *FnCond =
+      Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
+  Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
+  Builder.SetInsertPoint(IfThenBB);
+  Builder.CreateCall(RegFunc,
+                     {RegGlobalsFn->arg_begin(), Addr, Name, Name,
+                      ConstantInt::get(Type::getInt32Ty(C), -1),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt32PtrTy(C))});
+  Builder.CreateBr(IfEndBB);
+  Builder.SetInsertPoint(IfElseBB);
+  Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
+                              ConstantInt::get(Type::getInt32Ty(C), 0), Size,
+                              ConstantInt::get(Type::getInt32Ty(C), 0),
+                              ConstantInt::get(Type::getInt32Ty(C), 0)});
+  Builder.CreateBr(IfEndBB);
+  Builder.SetInsertPoint(IfEndBB);
+  auto *NewEntry = Builder.CreateInBoundsGEP(
+      getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
+  auto *Cmp = Builder.CreateICmpEQ(NewEntry, EntriesE);
+  Entry->addIncoming(EntriesB, &RegGlobalsFn->getEntryBlock());
+  Entry->addIncoming(NewEntry, IfEndBB);
+  Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
+  Builder.SetInsertPoint(ExitBB);
+  Builder.CreateRetVoid();
+
+  return RegGlobalsFn;
+}
+
+// Create the constructor and destructor to register the fatbinary with the CUDA
+// runtime.
+void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc) {
+  LLVMContext &C = M.getContext();
+  auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
+  auto *CtorFunc = Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
+                                    ".cuda.fatbin_reg", &M);
+  CtorFunc->setSection(".text.startup");
+
+  auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
+  auto *DtorFunc = Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
+                                    ".cuda.fatbin_unreg", &M);
+  DtorFunc->setSection(".text.startup");
+
+  // Get the __cudaRegisterFatBinary function declaration.
+  auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
+                                     Type::getInt8PtrTy(C),
+                                     /*isVarArg*/ false);
+  FunctionCallee RegFatbin =
+      M.getOrInsertFunction("__cudaRegisterFatBinary", RegFatTy);
+  // Get the __cudaRegisterFatBinaryEnd function declaration.
+  auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
+                                        Type::getInt8PtrTy(C)->getPointerTo(),
+                                        /*isVarArg*/ false);
+  FunctionCallee RegFatbinEnd =
+      M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
+  // Get the __cudaUnregisterFatBinary function declaration.
+  auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
+                                       Type::getInt8PtrTy(C)->getPointerTo(),
+                                       /*isVarArg*/ false);
+  FunctionCallee UnregFatbin =
+      M.getOrInsertFunction("__cudaUnregisterFatBinary", UnregFatTy);
+
+  auto *AtExitTy =
+      FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
+                        /*isVarArg*/ false);
+  FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
+
+  auto *BinaryHandleGlobal = new llvm::GlobalVariable(
+      M, Type::getInt8PtrTy(C)->getPointerTo(), false,
+      llvm::GlobalValue::InternalLinkage,
+      llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
+      ".cuda.binary_handle");
+
+  // Create the constructor to register this image with the runtime.
+  IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
+  CallInst *Handle = CtorBuilder.CreateCall(
+      RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+                     FatbinDesc, Type::getInt8PtrTy(C)));
+  CtorBuilder.CreateAlignedStore(
+      Handle, BinaryHandleGlobal,
+      Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
+  CtorBuilder.CreateCall(createRegisterGlobalsFunction(M), Handle);
+  CtorBuilder.CreateCall(RegFatbinEnd, Handle);
+  CtorBuilder.CreateCall(AtExit, DtorFunc);
+  CtorBuilder.CreateRetVoid();
+
+  // Create the destructor to unregister the image with the runtime. We cannot
+  // use a standard global destructor after CUDA 9.2 so this must be called by
+  // `atexit()` intead.
+  IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
+  LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
+      Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
+      Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
+  DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
+  DtorBuilder.CreateRetVoid();
+
+  // Add this function to constructors.
+  appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
+}
+
 } // namespace
 
 Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
@@ -267,7 +528,12 @@
   return Error::success();
 }
 
-llvm::Error wrapCudaBinary(llvm::Module &M, llvm::ArrayRef<char> Images) {
-  return createStringError(inconvertibleErrorCode(),
-                           "Cuda wrapping is not yet supported.");
+Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
+  GlobalVariable *Desc = createFatbinDesc(M, Image);
+  if (!Desc)
+    return createStringError(inconvertibleErrorCode(),
+                             "No fatinbary section created.");
+
+  createRegisterFatbinFunction(M, Desc);
+  return Error::success();
 }
Index: clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
===================================================================
--- clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
+++ clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
@@ -1479,7 +1479,14 @@
   auto FileOrErr = wrapDeviceImages(LinkedImages);
   if (!FileOrErr)
     return reportError(FileOrErr.takeError());
-  LinkerArgs.append(*FileOrErr);
+
+  // We need to insert the new files next to the old ones to make sure they're
+  // linked with the same libraries / arguments.
+  auto FirstInput = std::next(llvm::find_if(LinkerArgs, [](StringRef Str) {
+    return sys::fs::exists(Str) && !sys::fs::is_directory(Str) &&
+           Str != ExecutableName;
+  }));
+  LinkerArgs.insert(FirstInput, FileOrErr->begin(), FileOrErr->end());
 
   // Run the host linking job.
   if (Error Err = runLinker(LinkerUserPath, LinkerArgs))
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to