================
@@ -2193,80 +2197,141 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const 
mlir::Type &type,
   return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
 }
 
-void collectMapDataFromMapVars(MapInfoData &mapData,
-                               llvm::SmallVectorImpl<Value> &mapVars,
-                               LLVM::ModuleTranslation &moduleTranslation,
-                               DataLayout &dl, llvm::IRBuilderBase &builder) {
-  for (mlir::Value mapValue : mapVars) {
-    if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
-            mapValue.getDefiningOp())) {
-      mlir::Value offloadPtr =
-          mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
-      mapData.OriginalValue.push_back(
-          moduleTranslation.lookupValue(offloadPtr));
-      mapData.Pointers.push_back(mapData.OriginalValue.back());
-
-      if (llvm::Value *refPtr =
-              getRefPtrIfDeclareTarget(offloadPtr,
-                                       moduleTranslation)) { // declare target
-        mapData.IsDeclareTarget.push_back(true);
-        mapData.BasePointers.push_back(refPtr);
-      } else { // regular mapped variable
-        mapData.IsDeclareTarget.push_back(false);
-        mapData.BasePointers.push_back(mapData.OriginalValue.back());
-      }
+static void collectMapDataFromMapOperands(
+    MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
+    LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
+    llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = 
{},
+    const ArrayRef<Value> &useDevAddrOperands = {}) {
+  // Process MapOperands
+  for (Value mapValue : mapVars) {
+    auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
+    Value offloadPtr =
+        mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
+    mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
+    mapData.Pointers.push_back(mapData.OriginalValue.back());
+
+    if (llvm::Value *refPtr =
+            getRefPtrIfDeclareTarget(offloadPtr,
+                                     moduleTranslation)) { // declare target
+      mapData.IsDeclareTarget.push_back(true);
+      mapData.BasePointers.push_back(refPtr);
+    } else { // regular mapped variable
+      mapData.IsDeclareTarget.push_back(false);
+      mapData.BasePointers.push_back(mapData.OriginalValue.back());
+    }
 
-      mapData.BaseType.push_back(
-          moduleTranslation.convertType(mapOp.getVarType()));
-      mapData.Sizes.push_back(
-          getSizeInBytes(dl, mapOp.getVarType(), mapOp, 
mapData.Pointers.back(),
-                         mapData.BaseType.back(), builder, moduleTranslation));
-      mapData.MapClause.push_back(mapOp.getOperation());
-      mapData.Types.push_back(
-          llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
-      mapData.Names.push_back(LLVM::createMappingInformation(
-          mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
-      mapData.DevicePointers.push_back(
-          llvm::OpenMPIRBuilder::DeviceInfoTy::None);
-
-      // Check if this is a member mapping and correctly assign that it is, if
-      // it is a member of a larger object.
-      // TODO: Need better handling of members, and distinguishing of members
-      // that are implicitly allocated on device vs explicitly passed in as
-      // arguments.
-      // TODO: May require some further additions to support nested record
-      // types, i.e. member maps that can have member maps.
-      mapData.IsAMember.push_back(false);
-      for (mlir::Value mapValue : mapVars) {
-        if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
-                mapValue.getDefiningOp())) {
-          for (auto member : map.getMembers()) {
-            if (member == mapOp) {
-              mapData.IsAMember.back() = true;
-            }
+    mapData.BaseType.push_back(
+        moduleTranslation.convertType(mapOp.getVarType()));
+    mapData.Sizes.push_back(
+        getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
+                       mapData.BaseType.back(), builder, moduleTranslation));
+    mapData.MapClause.push_back(mapOp.getOperation());
+    mapData.Types.push_back(
+        llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
+    mapData.Names.push_back(LLVM::createMappingInformation(
+        mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
+    
mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+    mapData.IsAMapping.push_back(true);
+
+    // Check if this is a member mapping and correctly assign that it is, if
+    // it is a member of a larger object.
+    // TODO: Need better handling of members, and distinguishing of members
+    // that are implicitly allocated on device vs explicitly passed in as
+    // arguments.
+    // TODO: May require some further additions to support nested record
+    // types, i.e. member maps that can have member maps.
+    mapData.IsAMember.push_back(false);
+    for (Value mapValue : mapVars) {
+      if (auto map =
+              dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp())) {
+        for (auto member : map.getMembers()) {
+          if (member == mapOp) {
+            mapData.IsAMember.back() = true;
           }
         }
       }
     }
   }
+
+  auto findMapInfo = [&mapData](llvm::Value *val,
+                                llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) 
{
+    unsigned index = 0;
+    bool found = false;
+    for (llvm::Value *basePtr : mapData.OriginalValue) {
+      if (basePtr == val && mapData.IsAMapping[index]) {
+        found = true;
+        mapData.Types[index] |=
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+        mapData.DevicePointers[index] = devInfoTy;
+      }
+      index++;
+    }
+    return found;
+  };
+
+  // Process useDevPtr(Addr)Operands
+  auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
----------------
skatrak wrote:

I guess we can leave it as is for now, since everything is contained in the 
same function. Hopefully when support for this matures a bit more we'll be able 
to see what we can share.

https://github.com/llvm/llvm-project/pull/101707
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to