TIFitis created this revision.
TIFitis added reviewers: jsjodin, jdoerfert, kiranktp.
Herald added subscribers: sunshaoce, guansong, hiraditya, yaxunl.
Herald added a project: All.
TIFitis requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, jplehr, sstefan1.
Herald added projects: clang, LLVM.

This patch migrats the UseDevicePtr and UseDeviceAddr clause related code for 
handling privatization from Clang codegen to the OMPIRBuilder

Depends on D150860 <https://reviews.llvm.org/D150860>


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D152554

Files:
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CGOpenMPRuntime.h
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/CodeGen/CodeGenFunction.h
  clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4086,7 +4086,7 @@
     omp::RuntimeFunction *MapperFunc,
     function_ref<InsertPointTy(InsertPointTy CodeGenIP, int BodyGenType)>
         BodyGenCB,
-    function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
     function_ref<Value *(unsigned int)> CustomMapperCB) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -4129,6 +4129,14 @@
 
       Builder.CreateCall(beginMapperFunc, OffloadingArgs);
 
+      for (auto DeviceMap : Info.DevicePtrInfoMap) {
+        if (isa<AllocaInst>(DeviceMap.second.second)) {
+          auto *LI =
+              Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
+          Builder.CreateStore(LI, DeviceMap.second.second);
+        }
+      }
+
       // If device pointer privatization is required, emit the body of the
       // region here. It will have to be duplicated: with and without
       // privatization.
@@ -4541,7 +4549,7 @@
 void OpenMPIRBuilder::emitOffloadingArrays(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
     TargetDataInfo &Info, bool IsNonContiguous,
-    function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
     function_ref<Value *(unsigned int)> CustomMapperCB) {
 
   // Reset the array information.
@@ -4679,8 +4687,19 @@
           M.getDataLayout().getPrefTypeAlign(Builder.getInt8PtrTy()));
 
       if (Info.requiresDevicePointerInfo()) {
-        assert(DeviceAddrCB);
-        DeviceAddrCB(I, BP, BPVal);
+        if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
+          CodeGenIP = Builder.saveIP();
+          Builder.restoreIP(AllocaIP);
+          Info.DevicePtrInfoMap[BPVal] = {
+              BP, Builder.CreateAlloca(Builder.getPtrTy())};
+          Builder.restoreIP(CodeGenIP);
+          assert(DeviceAddrCB);
+          DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
+        } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
+          Info.DevicePtrInfoMap[BPVal] = {BP, BP};
+          assert(DeviceAddrCB);
+          DeviceAddrCB(I, BP);
+        }
       }
 
       Value *PVal = CombinedInfo.Pointers[I];
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1564,6 +1564,9 @@
   public:
     TargetDataRTArgs RTArgs;
 
+    SmallMapVector<const Value *, std::pair<Value *, Value *>, 4>
+        DevicePtrInfoMap;
+
     /// Indicate whether any user-defined mapper exists.
     bool HasMapper = false;
     /// The total number of pointers passed to the runtime library.
@@ -1590,7 +1593,9 @@
     bool separateBeginEndCalls() { return SeparateBeginEndCalls; }
   };
 
+  enum class DeviceInfoTy { None, Pointer, Address };
   using MapValuesArrayTy = SmallVector<Value *, 4>;
+  using MapDeviceInfoArrayTy = SmallVector<DeviceInfoTy, 4>;
   using MapFlagsArrayTy = SmallVector<omp::OpenMPOffloadMappingFlags, 4>;
   using MapNamesArrayTy = SmallVector<Constant *, 4>;
   using MapDimArrayTy = SmallVector<uint64_t, 4>;
@@ -1609,6 +1614,7 @@
     };
     MapValuesArrayTy BasePointers;
     MapValuesArrayTy Pointers;
+    MapDeviceInfoArrayTy DevicePointers;
     MapValuesArrayTy Sizes;
     MapFlagsArrayTy Types;
     MapNamesArrayTy Names;
@@ -1619,6 +1625,8 @@
       BasePointers.append(CurInfo.BasePointers.begin(),
                           CurInfo.BasePointers.end());
       Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
+      DevicePointers.append(CurInfo.DevicePointers.begin(),
+                            CurInfo.DevicePointers.end());
       Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
       Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
       Names.append(CurInfo.Names.begin(), CurInfo.Names.end());
@@ -1655,7 +1663,7 @@
   void emitOffloadingArrays(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
       TargetDataInfo &Info, bool IsNonContiguous = false,
-      function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB = nullptr,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
       function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
 
   /// Creates offloading entry for the provided entry ID \a ID, address \a
@@ -2029,7 +2037,7 @@
       omp::RuntimeFunction *MapperFunc = nullptr,
       function_ref<InsertPointTy(InsertPointTy CodeGenIP, int BodyGenType)>
           BodyGenCB = nullptr,
-      function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB = nullptr,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
       function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
 
   using TargetBodyGenCallbackTy = function_ref<InsertPointTy(
Index: clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
===================================================================
--- clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
+++ clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
@@ -411,11 +411,11 @@
     // CK2:     [[BP2:%.+]] = getelementptr inbounds [3 x ptr], ptr %{{.+}}, i32 0, i32 2
     // CK2:     store ptr [[RVAL2:%.+]], ptr [[BP2]],
     // CK2:     call void @__tgt_target_data_begin{{.+}}[[MTYPE03]]
+    // CK2:     [[VAL1:%.+]] = load ptr, ptr [[BP1]],
+    // CK2:     store ptr [[VAL1]], ptr [[PVT1:%.+]],
     // CK2:     [[VAL2:%.+]] = load ptr, ptr [[BP2]],
     // CK2:     store ptr [[VAL2]], ptr [[PVT2:%.+]],
     // CK2:     store ptr [[PVT2]], ptr [[_PVT2:%.+]],
-    // CK2:     [[VAL1:%.+]] = load ptr, ptr [[BP1]],
-    // CK2:     store ptr [[VAL1]], ptr [[PVT1:%.+]],
     // CK2:     store ptr [[PVT1]], ptr [[_PVT1:%.+]],
     // CK2:     [[TT2:%.+]] = load ptr, ptr [[_PVT2]],
     // CK2:     [[_TT2:%.+]] = load ptr, ptr [[TT2]],
Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -3404,10 +3404,13 @@
                             OMPPrivateScope &PrivateScope);
   void EmitOMPUseDevicePtrClause(
       const OMPUseDevicePtrClause &C, OMPPrivateScope &PrivateScope,
-      const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap);
+      const llvm::SmallMapVector<const llvm::Value *,
+                                 std::pair<llvm::Value *, llvm::Value *>, 4>
+          &DevicePtrInfoMap);
   void EmitOMPUseDeviceAddrClause(
       const OMPUseDeviceAddrClause &C, OMPPrivateScope &PrivateScope,
-      const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap);
+      const llvm::DenseMap<const ValueDecl *, llvm::Value *>
+          CaptureDeviceAddrMap);
   /// Emit code for copyin clause in \a D directive. The next code is
   /// generated at the start of outlined functions for directives:
   /// \code
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -7158,60 +7158,99 @@
 
 void CodeGenFunction::EmitOMPUseDevicePtrClause(
     const OMPUseDevicePtrClause &C, OMPPrivateScope &PrivateScope,
-    const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap) {
-  auto OrigVarIt = C.varlist_begin();
-  auto InitIt = C.inits().begin();
-  for (const Expr *PvtVarIt : C.private_copies()) {
-    const auto *OrigVD =
-        cast<VarDecl>(cast<DeclRefExpr>(*OrigVarIt)->getDecl());
-    const auto *InitVD = cast<VarDecl>(cast<DeclRefExpr>(*InitIt)->getDecl());
-    const auto *PvtVD = cast<VarDecl>(cast<DeclRefExpr>(PvtVarIt)->getDecl());
+    const llvm::SmallMapVector<const llvm::Value *,
+                               std::pair<llvm::Value *, llvm::Value *>, 4>
+        &DevicePtrInfoMap) {
+  llvm::SmallDenseSet<CanonicalDeclPtr<const Decl>, 4> Processed;
+  for (const Expr *Ref : C.varlists()) {
+    const auto *OrigVD = cast<VarDecl>(cast<DeclRefExpr>(Ref)->getDecl());
+    if (!Processed.insert(OrigVD).second)
+      continue;
+
+    const llvm::Value *DevPtr = nullptr;
 
     // In order to identify the right initializer we need to match the
     // declaration used by the mapping logic. In some cases we may get
     // OMPCapturedExprDecl that refers to the original declaration.
-    const ValueDecl *MatchingVD = OrigVD;
-    if (const auto *OED = dyn_cast<OMPCapturedExprDecl>(MatchingVD)) {
+    if (const auto *OED = dyn_cast<OMPCapturedExprDecl>(OrigVD)) {
       // OMPCapturedExprDecl are used to privative fields of the current
       // structure.
       const auto *ME = cast<MemberExpr>(OED->getInit());
       assert(isa<CXXThisExpr>(ME->getBase()->IgnoreImpCasts()) &&
              "Base should be the current struct!");
-      MatchingVD = ME->getMemberDecl();
+      auto *MI = cast<llvm::Instruction>(EmitLValue(ME).getPointer(*this));
+      while (auto *LI = dyn_cast<llvm::LoadInst>(MI)) {
+        auto *PO = LI->getPointerOperand();
+        if (!isa<llvm::AllocaInst>(PO) && isa<llvm::Instruction>(PO))
+          MI = cast<llvm::Instruction>(PO);
+        else
+          break;
+      }
+      for (const auto &DevMap : DevicePtrInfoMap) {
+        auto *DI = cast<llvm::Instruction>(DevMap.first);
+        while (auto *LI = dyn_cast<llvm::LoadInst>(DI)) {
+          auto *PO = LI->getPointerOperand();
+          if (!isa<llvm::AllocaInst>(PO) && isa<llvm::Instruction>(PO))
+            DI = cast<llvm::Instruction>(PO);
+          else
+            break;
+        }
+        if (DI->isIdenticalTo(MI)) {
+          DevPtr = DevMap.first;
+          break;
+        }
+      }
+      if (MI->users().empty())
+        MI->eraseFromParent();
+    } else {
+      if (Ref->isGLValue()) {
+        auto *MI = cast<llvm::Instruction>(EmitScalarExpr(Ref));
+        while (auto *LI = dyn_cast<llvm::LoadInst>(MI)) {
+          auto *PO = LI->getPointerOperand();
+          if (!isa<llvm::AllocaInst>(PO) && isa<llvm::Instruction>(PO))
+            MI = cast<llvm::Instruction>(PO);
+          else
+            break;
+        }
+        for (const auto &DevMap : DevicePtrInfoMap) {
+          auto *DI = cast<llvm::Instruction>(DevMap.first);
+          while (auto *LI = dyn_cast<llvm::LoadInst>(DI)) {
+            auto *PO = LI->getPointerOperand();
+            if (!isa<llvm::AllocaInst>(PO) && isa<llvm::Instruction>(PO))
+              DI = cast<llvm::Instruction>(PO);
+            else
+              break;
+          }
+          if (DI->isIdenticalTo(MI)) {
+            DevPtr = DevMap.first;
+            break;
+          }
+        }
+        if (MI->users().empty())
+          MI->eraseFromParent();
+      } else {
+        DevPtr = EmitLValue(Ref).getPointer(*this);
+      }
     }
 
     // If we don't have information about the current list item, move on to
     // the next one.
-    auto InitAddrIt = CaptureDeviceAddrMap.find(MatchingVD);
-    if (InitAddrIt == CaptureDeviceAddrMap.end())
+    if (!DevPtr)
+      continue;
+    auto InitAddrIt = DevicePtrInfoMap.find(DevPtr);
+    if (InitAddrIt == DevicePtrInfoMap.end())
       continue;
 
-    // Initialize the temporary initialization variable with the address
-    // we get from the runtime library. We have to cast the source address
-    // because it is always a void *. References are materialized in the
-    // privatization scope, so the initialization here disregards the fact
-    // the original variable is a reference.
     llvm::Type *Ty = ConvertTypeForMem(OrigVD->getType().getNonReferenceType());
-    Address InitAddr = Builder.CreateElementBitCast(InitAddrIt->second, Ty);
-    setAddrOfLocalVar(InitVD, InitAddr);
-
-    // Emit private declaration, it will be initialized by the value we
-    // declaration we just added to the local declarations map.
-    EmitDecl(*PvtVD);
-
-    // The initialization variables reached its purpose in the emission
-    // of the previous declaration, so we don't need it anymore.
-    LocalDeclMap.erase(InitVD);
 
     // Return the address of the private variable.
-    bool IsRegistered =
-        PrivateScope.addPrivate(OrigVD, GetAddrOfLocalVar(PvtVD));
+    bool IsRegistered = PrivateScope.addPrivate(
+        OrigVD,
+        Address(InitAddrIt->second.second, Ty,
+                getContext().getTypeAlignInChars(getContext().VoidPtrTy)));
     assert(IsRegistered && "firstprivate var already registered as private");
     // Silence the warning about unused variable.
     (void)IsRegistered;
-
-    ++OrigVarIt;
-    ++InitIt;
   }
 }
 
@@ -7226,7 +7265,8 @@
 
 void CodeGenFunction::EmitOMPUseDeviceAddrClause(
     const OMPUseDeviceAddrClause &C, OMPPrivateScope &PrivateScope,
-    const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap) {
+    const llvm::DenseMap<const ValueDecl *, llvm::Value *>
+        CaptureDeviceAddrMap) {
   llvm::SmallDenseSet<CanonicalDeclPtr<const Decl>, 4> Processed;
   for (const Expr *Ref : C.varlists()) {
     const VarDecl *OrigVD = getBaseDecl(Ref);
@@ -7251,7 +7291,11 @@
     if (InitAddrIt == CaptureDeviceAddrMap.end())
       continue;
 
-    Address PrivAddr = InitAddrIt->getSecond();
+    llvm::Type *Ty = ConvertTypeForMem(OrigVD->getType().getNonReferenceType());
+
+    Address PrivAddr =
+        Address(InitAddrIt->second, Ty,
+                getContext().getTypeAlignInChars(getContext().VoidPtrTy));
     // For declrefs and variable length array need to load the pointer for
     // correct mapping, since the pointer to the data was passed to the runtime.
     if (isa<DeclRefExpr>(Ref->IgnoreParenImpCasts()) ||
@@ -7308,7 +7352,7 @@
         // Emit all instances of the use_device_ptr clause.
         for (const auto *C : S.getClausesOfKind<OMPUseDevicePtrClause>())
           CGF.EmitOMPUseDevicePtrClause(*C, PrivateScope,
-                                        Info.CaptureDeviceAddrMap);
+                                        Info.DevicePtrInfoMap);
         for (const auto *C : S.getClausesOfKind<OMPUseDeviceAddrClause>())
           CGF.EmitOMPUseDeviceAddrClause(*C, PrivateScope,
                                          Info.CaptureDeviceAddrMap);
Index: clang/lib/CodeGen/CGOpenMPRuntime.h
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.h
+++ clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -1439,9 +1439,9 @@
                             bool SeparateBeginEndCalls)
         : llvm::OpenMPIRBuilder::TargetDataInfo(RequiresDevicePointerInfo,
                                                 SeparateBeginEndCalls) {}
-    /// Map between the a declaration of a capture and the corresponding base
-    /// pointer address where the runtime returns the device pointers.
-    llvm::DenseMap<const ValueDecl *, Address> CaptureDeviceAddrMap;
+    /// Map between the a declaration of a capture and the corresponding new
+    /// alloca address where the runtime returns the device pointers.
+    llvm::DenseMap<const ValueDecl *, llvm::Value *> CaptureDeviceAddrMap;
   };
 
   /// Emit the target data mapping code associated with \a D.
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6821,6 +6821,7 @@
     const Expr *getMapExpr() const { return MapExpr; }
   };
 
+  using DeviceInfoTy = llvm::OpenMPIRBuilder::DeviceInfoTy;
   using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
   using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
   using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy;
@@ -7592,6 +7593,7 @@
             CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
             CombinedInfo.BasePointers.push_back(BP.getPointer());
             CombinedInfo.DevicePtrDecls.push_back(nullptr);
+            CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
             CombinedInfo.Pointers.push_back(LB.getPointer());
             CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
                 Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -7604,6 +7606,7 @@
           CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
           CombinedInfo.BasePointers.push_back(BP.getPointer());
           CombinedInfo.DevicePtrDecls.push_back(nullptr);
+          CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
           CombinedInfo.Pointers.push_back(LB.getPointer());
           Size = CGF.Builder.CreatePtrDiff(
               CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(),
@@ -7622,6 +7625,7 @@
           CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
           CombinedInfo.BasePointers.push_back(BP.getPointer());
           CombinedInfo.DevicePtrDecls.push_back(nullptr);
+          CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
           CombinedInfo.Pointers.push_back(LB.getPointer());
           CombinedInfo.Sizes.push_back(
               CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -8122,10 +8126,12 @@
 
     auto &&UseDeviceDataCombinedInfoGen =
         [&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr,
-                                     CodeGenFunction &CGF) {
+                                     CodeGenFunction &CGF, bool IsDevAddr) {
           UseDeviceDataCombinedInfo.Exprs.push_back(VD);
           UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr);
           UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD);
+          UseDeviceDataCombinedInfo.DevicePointers.emplace_back(
+              IsDevAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer);
           UseDeviceDataCombinedInfo.Pointers.push_back(Ptr);
           UseDeviceDataCombinedInfo.Sizes.push_back(
               llvm::Constant::getNullValue(CGF.Int64Ty));
@@ -8165,7 +8171,7 @@
             } else {
               Ptr = CGF.EmitLoadOfScalar(CGF.EmitLValue(IE), IE->getExprLoc());
             }
-            UseDeviceDataCombinedInfoGen(VD, Ptr, CGF);
+            UseDeviceDataCombinedInfoGen(VD, Ptr, CGF, IsDevAddr);
           }
         };
 
@@ -8192,6 +8198,7 @@
           // item.
           if (CI != Data.end()) {
             if (IsDevAddr) {
+              CI->ForDeviceAddr = IsDevAddr;
               CI->ReturnDevicePointer = true;
               Found = true;
               break;
@@ -8204,6 +8211,7 @@
                   PrevCI == CI->Components.rend() ||
                   isa<MemberExpr>(PrevCI->getAssociatedExpression()) || !VarD ||
                   VarD->hasLocalStorage()) {
+                CI->ForDeviceAddr = IsDevAddr;
                 CI->ReturnDevicePointer = true;
                 Found = true;
                 break;
@@ -8295,6 +8303,8 @@
                    "No relevant declaration related with device pointer??");
 
             CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
+            CurInfo.DevicePointers[CurrentBasePointersIdx] =
+                L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer;
             CurInfo.Types[CurrentBasePointersIdx] |=
                 OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
           }
@@ -8335,6 +8345,8 @@
           CurInfo.Exprs.push_back(L.VD);
           CurInfo.BasePointers.emplace_back(BasePtr);
           CurInfo.DevicePtrDecls.emplace_back(L.VD);
+          CurInfo.DevicePointers.emplace_back(
+              L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer);
           CurInfo.Pointers.push_back(Ptr);
           CurInfo.Sizes.push_back(
               llvm::Constant::getNullValue(this->CGF.Int64Ty));
@@ -8430,6 +8442,7 @@
     // Base is the base of the struct
     CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer());
     CombinedInfo.DevicePtrDecls.push_back(nullptr);
+    CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
     // Pointer is the address of the lowest element
     llvm::Value *LB = LBAddr.getPointer();
     const CXXMethodDecl *MD =
@@ -8552,6 +8565,7 @@
       CombinedInfo.Exprs.push_back(VD);
       CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF));
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF));
       CombinedInfo.Sizes.push_back(
           CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy),
@@ -8579,6 +8593,7 @@
         CombinedInfo.Exprs.push_back(VD);
         CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
         CombinedInfo.DevicePtrDecls.push_back(nullptr);
+        CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
         CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF));
         CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
             CGF.getTypeSize(
@@ -8591,6 +8606,7 @@
         CombinedInfo.Exprs.push_back(VD);
         CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
         CombinedInfo.DevicePtrDecls.push_back(nullptr);
+        CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
         CombinedInfo.Pointers.push_back(VarRVal.getScalarVal());
         CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0));
       }
@@ -8659,6 +8675,7 @@
       CombinedInfo.Exprs.push_back(VD);
       CombinedInfo.BasePointers.emplace_back(Arg);
       CombinedInfo.DevicePtrDecls.emplace_back(VD);
+      CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
       CombinedInfo.Pointers.push_back(Arg);
       CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
           CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
@@ -8899,6 +8916,7 @@
       CombinedInfo.Exprs.push_back(nullptr);
       CombinedInfo.BasePointers.push_back(CV);
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       CombinedInfo.Pointers.push_back(CV);
       const auto *PtrTy = cast<PointerType>(RI.getType().getTypePtr());
       CombinedInfo.Sizes.push_back(
@@ -8912,6 +8930,7 @@
       CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
       CombinedInfo.BasePointers.push_back(CV);
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       CombinedInfo.Pointers.push_back(CV);
       if (!RI.getType()->isAnyPointerType()) {
         // We have to signal to the runtime captures passed by value that are
@@ -8944,6 +8963,7 @@
       CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
       CombinedInfo.BasePointers.push_back(CV);
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) {
         Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue(
             CV, ElementType, CGF.getContext().getDeclAlign(VD),
@@ -9025,7 +9045,6 @@
     CGOpenMPRuntime::TargetDataInfo &Info, llvm::OpenMPIRBuilder &OMPBuilder,
     bool IsNonContiguous = false) {
   CodeGenModule &CGM = CGF.CGM;
-  ASTContext &Ctx = CGF.getContext();
 
   // Reset the array information.
   Info.clearArrayInfo();
@@ -9047,11 +9066,9 @@
                     fillInfoMap);
   }
 
-  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *BP, llvm::Value *BPVal) {
+  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *NewDecl) {
     if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) {
-      Address BPAddr(BP, BPVal->getType(),
-                     Ctx.getTypeAlignInChars(Ctx.VoidPtrTy));
-      Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);
+      Info.CaptureDeviceAddrMap.try_emplace(DevVD, NewDecl);
     }
   };
 
@@ -9775,6 +9792,8 @@
         CurInfo.Exprs.push_back(nullptr);
         CurInfo.BasePointers.push_back(*CV);
         CurInfo.DevicePtrDecls.push_back(nullptr);
+        CurInfo.DevicePointers.push_back(
+            MappableExprsHandler::DeviceInfoTy::None);
         CurInfo.Pointers.push_back(*CV);
         CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
             CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
@@ -10458,12 +10477,9 @@
                          CGF.Builder.GetInsertPoint());
   };
 
-  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *BP, llvm::Value *BPVal) {
+  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *NewDecl) {
     if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) {
-      ASTContext &Ctx = CGF.getContext();
-      Address BPAddr(BP, BPVal->getType(),
-                     Ctx.getTypeAlignInChars(Ctx.VoidPtrTy));
-      Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);
+      Info.CaptureDeviceAddrMap.try_emplace(DevVD, NewDecl);
     }
   };
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to