================
@@ -391,29 +414,236 @@ class MapInfoFinalizationPass
   /// of the base address index.
   void adjustMemberIndices(
       llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &memberIndices,
-      size_t memberIndex) {
-    llvm::SmallVector<int64_t> baseAddrIndex = memberIndices[memberIndex];
+      ParentAndPlacement parentAndPlacement) {
+    llvm::SmallVector<int64_t> baseAddrIndex =
+        memberIndices[parentAndPlacement.index];
+    auto &expansionIndices = expandedBaseAddr[parentAndPlacement.parent];
 
     // If we find another member that is "derived/a member of" the descriptor
     // that is not the descriptor itself, we must insert a 0 for the new base
     // address we have just added for the descriptor into the list at the
     // appropriate position to maintain correctness of the positional/index 
data
     // for that member.
-    for (llvm::SmallVector<int64_t> &member : memberIndices)
+    for (auto [i, member] : llvm::enumerate(memberIndices)) {
+      if (std::find(expansionIndices.begin(), expansionIndices.end(), i) !=
+          expansionIndices.end())
+        if (member.size() == baseAddrIndex.size() + 1 &&
+            member[baseAddrIndex.size()] == 0)
+          continue;
+
       if (member.size() > baseAddrIndex.size() &&
           std::equal(baseAddrIndex.begin(), baseAddrIndex.end(),
                      member.begin()))
         member.insert(std::next(member.begin(), baseAddrIndex.size()), 0);
+    }
 
     // Add the base address index to the main base address member data
     baseAddrIndex.push_back(0);
 
-    // Insert our newly created baseAddrIndex into the larger list of indices 
at
-    // the correct location.
-    memberIndices.insert(std::next(memberIndices.begin(), memberIndex + 1),
+    uint64_t newIdxInsert = parentAndPlacement.index + 1;
+    expansionIndices.push_back(newIdxInsert);
+
+    // Insert our newly created baseAddrIndex into the larger list of
+    // indices at the correct location.
+    memberIndices.insert(std::next(memberIndices.begin(), newIdxInsert),
                          baseAddrIndex);
   }
 
+  // This function takes a Map clause owning target operation (e.g. TargetOp or
+  // TargetDataOp) and a lambda function, the lambda function is invoked on the
+  // various map clause ranges of the target operation that was passed in (e.g.
+  // the use_device_ptr/addr and regular maps count as map clause ranges for 
the
+  // purpose of this function) with the intent of inserting new maps into the
+  // range in a manner that is consistent with the target that was passed in.
+  //
+  // The lambda function should take 3 parameters a range that represents the
+  // map range, an operation representing the target and an unsigned integer
+  // representing the start index for the map range in terms of the targets
+  // block argument list. The insertion behaviour of the function is left to 
the
+  // lambda.
+  void
+  insertIntoMapClauseInterface(mlir::Operation *target,
+                               std::function<void(mlir::MutableOperandRange &,
+                                                  mlir::Operation *, unsigned)>
+                                   addOperands) {
+    auto argIface =
+        llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
+
+    if (auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
+      mlir::MutableOperandRange mapVarsArr = 
mapClauseOwner.getMapVarsMutable();
+      unsigned blockArgInsertIndex =
+          argIface
+              ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
+              : 0;
+      addOperands(mapVarsArr,
+                  llvm::dyn_cast_if_present<mlir::omp::TargetOp>(target),
+                  blockArgInsertIndex);
+    }
+
+    if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
+      mlir::MutableOperandRange useDevAddrMutableOpRange =
+          targetDataOp.getUseDeviceAddrVarsMutable();
+      addOperands(useDevAddrMutableOpRange, target,
+                  argIface.getUseDeviceAddrBlockArgsStart() +
+                      argIface.numUseDeviceAddrBlockArgs());
+
+      mlir::MutableOperandRange useDevPtrMutableOpRange =
+          targetDataOp.getUseDevicePtrVarsMutable();
+      addOperands(useDevPtrMutableOpRange, target,
+                  argIface.getUseDevicePtrBlockArgsStart() +
+                      argIface.numUseDevicePtrBlockArgs());
+    } else if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target)) {
+      mlir::MutableOperandRange hasDevAddrMutableOpRange =
+          targetOp.getHasDeviceAddrVarsMutable();
+      addOperands(hasDevAddrMutableOpRange, target,
+                  argIface.getHasDeviceAddrBlockArgsStart() +
+                      argIface.numHasDeviceAddrBlockArgs());
+    }
+  }
+
+  // This function aims to insert new maps derived from existing maps into the
+  // corresponding clause list, interlinking it correctly with block arguments
+  // where required.
+  void addDerivedMemberToTarget(
+      mlir::omp::MapInfoOp owner, mlir::omp::MapInfoOp derived,
+      llvm::SmallVectorImpl<ParentAndPlacement> &mapMemberUsers,
+      fir::FirOpBuilder &builder, mlir::Operation *target) {
+    auto addOperands = [&](mlir::MutableOperandRange &mapVarsArr,
+                           mlir::Operation *directiveOp,
+                           unsigned blockArgInsertIndex = 0) {
+      // Check we're inserting into the correct MapInfoOp list.
+      if (!llvm::is_contained(mapVarsArr.getAsOperandRange(),
+                              mapMemberUsers.empty()
+                                  ? owner.getResult()
+                                  : mapMemberUsers[0].parent.getResult()))
+        return;
+
+      // Check we're not inserting a duplicate map.
+      if (llvm::is_contained(mapVarsArr.getAsOperandRange(),
+                             derived.getResult()))
+        return;
+
+      llvm::SmallVector<mlir::Value> newMapOps;
+      newMapOps.reserve(mapVarsArr.size());
+      llvm::copy(mapVarsArr.getAsOperandRange(), 
std::back_inserter(newMapOps));
+
+      newMapOps.push_back(derived);
+      if (directiveOp) {
+        directiveOp->getRegion(0).insertArgument(
+            blockArgInsertIndex, derived.getType(), derived.getLoc());
+        blockArgInsertIndex++;
+      }
+
+      mapVarsArr.assign(newMapOps);
+    };
+
+    insertIntoMapClauseInterface(target, addOperands);
+  }
+
+  // We add all mapped record members not directly used in the target region
+  // to the block arguments in front of their parent and we place them into
+  // the map operands list for consistency.
+  //
+  // These indirect uses (via accesses to their parent) will still be
+  // mapped individually in most cases, and a parent mapping doesn't
+  // guarantee the parent will be mapped in its totality, partial
+  // mapping is common.
+  //
+  // For example:
+  //    map(tofrom: x%y)
+  //
+  // Will generate a mapping for "x" (the parent) and "y" (the member).
+  // The parent "x" will not be mapped, but the member "y" will.
+  // However, we must have the parent as a BlockArg and MapOperand
+  // in these cases, to maintain the correct uses within the region and
+  // to help tracking that the member is part of a larger object.
+  //
+  // In the case of:
+  //    map(tofrom: x%y, x%z)
+  //
+  // The parent member becomes more critical, as we perform a partial
+  // structure mapping where we link the mapping of the members y
+  // and z together via the parent x. We do this at a kernel argument
+  // level in LLVM IR and not just MLIR, which is important to maintain
+  // similarity to Clang and for the runtime to do the correct thing.
+  // However, we still do not map the structure in its totality but
+  // rather we generate an un-sized "binding" map entry for it.
+  //
+  // In the case of:
+  //    map(tofrom: x, x%y, x%z)
+  //
+  // We do actually map the entirety of "x", so the explicit mapping of
+  // x%y, x%z becomes unnecessary, except in cases where y or z are
+  // pointers.
+  void addImplicitMembersToTarget(mlir::omp::MapInfoOp op,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Operation *target) {
+    // TargetDataOp is technically a MapClauseOwningOpInterface, so we
+    // do not need to explicitly check for the extra cases here for use_device
+    // addr/ptr.
+    if (!llvm::isa_and_present<mlir::omp::MapClauseOwningOpInterface>(target))
+      return;
+
+    auto addOperands = [&](mlir::MutableOperandRange &mapVarsArr,
+                           mlir::Operation *directiveOp,
+                           unsigned blockArgInsertIndex = 0) {
+      if (!llvm::is_contained(mapVarsArr.getAsOperandRange(), op.getResult()))
+        return;
+
+      // There doesn't appear to be a simple way to convert MutableOperandRange
+      // to a vector currently, so we instead use a for_each to populate our
+      // vector.
+      llvm::SmallVector<mlir::Value> newMapOps;
+      newMapOps.reserve(mapVarsArr.size());
+      llvm::for_each(mapVarsArr.getAsOperandRange(),
+                     [&newMapOps](mlir::Value oper) {
+                       if (oper)
----------------
bhandarkar-pranav wrote:

If we can get rid of the null check, we could do what 
`addDerivedMembersToTarget` (sibling function of `addImplicitMembersToTarget` 
does - a simple `llvm::copy` -> 
`lvm::copy(mapVarsArr.getAsOperandRange(), std::back_inserter(newMapOps));`


https://github.com/llvm/llvm-project/pull/177715
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to