llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Amit Tiwari (amitamd7)

<details>
<summary>Changes</summary>

This patch handles the strided update in the `#pragma omp target update 
from(data[a:b:c])` directive where 'c' represents the strided access leading to 
non-contiguous update in the `data` array when the offloaded execution returns 
the control back to host from device using the `from` clause.

Issue: Clang CodeGen where info is generated for the particular `MapType` (to, 
from, etc), it was failing to detect the strided access. Because of this, the 
`MapType` bits were incorrect when passed to runtime. This led to incorrect 
execution (contiguous) in the libomptarget runtime code.

Added a minimal testcase that verifies the working of the patch.

---
Full diff: https://github.com/llvm/llvm-project/pull/144635.diff


2 Files Affected:

- (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+34-1) 
- (added) offload/test/offloading/strided_update.c (+51) 


``````````diff
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp 
b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 4173355491fd4..81a2dd0fae5c9 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -7384,7 +7384,40 @@ class MappableExprsHandler {
     // dimension.
     uint64_t DimSize = 1;
 
-    bool IsNonContiguous = CombinedInfo.NonContigInfo.IsNonContiguous;
+    // Detects non-contiguous updates due to strided accesses.
+    // Sets the 'IsNonContiguous' flag so that the 'MapType' bits are set
+    // correctly when generating information to be passed to the runtime. The
+    // flag is set to true if any array section has a stride not equal to 1, or
+    // if the stride is not a constant expression (conservatively assumed
+    // non-contiguous).
+    bool IsNonContiguous = false;
+    for (const auto &Component : Components) {
+      const auto *OASE =
+          dyn_cast<ArraySectionExpr>(Component.getAssociatedExpression());
+      if (OASE) {
+        const Expr *StrideExpr = OASE->getStride();
+        if (StrideExpr) {
+          // Check if the stride is a constant integer expression
+          if (StrideExpr->isIntegerConstantExpr(CGF.getContext())) {
+            if (auto Constant =
+                    StrideExpr->getIntegerConstantExpr(CGF.getContext())) {
+              int64_t StrideVal = Constant->getExtValue();
+              if (StrideVal != 1) {
+                // Set flag if stride is not 1 (i.e., non-contiguous update)
+                IsNonContiguous = true;
+                break;
+              }
+            }
+          } else {
+            // If stride is not a constant, conservatively treat as
+            // non-contiguous
+            IsNonContiguous = true;
+            break;
+          }
+        }
+      }
+    }
+
     bool IsPrevMemberReference = false;
 
     bool IsPartialMapped =
diff --git a/offload/test/offloading/strided_update.c 
b/offload/test/offloading/strided_update.c
new file mode 100644
index 0000000000000..fc47216fb5684
--- /dev/null
+++ b/offload/test/offloading/strided_update.c
@@ -0,0 +1,51 @@
+// Checks that "update from" clause in OpenMP is supported when the elements 
are updated in a non-contiguous manner.
+// RUN: %libomptarget-compile-run-and-check-generic
+#include <omp.h>  
+#include <stdio.h>  
+  
+int main() {  
+  int len = 8;  
+  double data[len];  
+  #pragma omp target map(tofrom: len, data[0:len])  
+  {  
+    for (int i = 0; i < len; i++) {  
+      data[i] = i;  
+    }  
+  }  
+  // initial values  
+  printf("original host array values:\n");  
+  for (int i = 0; i < len; i++)  
+    printf("%f\n", data[i]);  
+  printf("\n");  
+  
+  #pragma omp target data map(to: len, data[0:len])  
+  {  
+    #pragma omp target  
+    for (int i = 0; i < len; i++) {  
+      data[i] += i ;  
+    }  
+  
+    #pragma omp target update from(data[0:8:2])  
+  }  
+  // from results  
+  // CHECK: 0.000000
+  // CHECK: 1.000000
+  // CHECK: 4.000000
+  // CHECK: 3.000000
+  // CHECK: 8.000000
+  // CHECK: 5.000000
+  // CHECK: 12.000000
+  // CHECK: 7.000000
+  // CHECK-NOT: 2.000000
+  // CHECK-NOT: 6.000000
+  // CHECK-NOT: 10.000000
+  // CHECK-NOT: 14.000000
+
+  printf("from target array results:\n");  
+  for (int i = 0; i < len; i++)  
+    printf("%f\n", data[i]);  
+  printf("\n");  
+  
+  return 0;  
+}  
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/144635
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to