https://github.com/jtb20 created 
https://github.com/llvm/llvm-project/pull/101305

This patch fixes a couple of cases where Clang aborts with loop nests that are 
being collapsed (via the relevant OpenMP clause) into a new, combined loop.

The problematic cases happen when a variable declared within the loop nest is 
used in the (init, condition, iter) statement of a more deeply-nested loop.  I 
don't think these cases (generally?) fall under the non-rectangular loop nest 
rules as defined in OpenMP 5.0+, but I could be wrong (and anyway, emitting an 
error is better than crashing).

In terms of implementation: the crash happens because (to a first 
approximation) all the loop bounds calculations are pulled out to the start of 
the new, combined loop, but variables declared in the loop nest "haven't been 
seen yet".  I believe there is special handling for iteration variables 
declared in "for" init statements, but not for variables declared elsewhere in 
the "imperfect" parts of a loop nest.

So, this patch tries to diagnose the troublesome cases before they can cause a 
crash.  This is slightly awkward because at the point where we want to do the 
diagnosis (SemaOpenMP.cpp), we don't have scope information readily available.  
Instead we "manually" scan through the AST of the loop nest looking for var 
decls (ForVarDeclFinder), then we ensure we're not using any of those in loop 
control subexprs (ForSubExprChecker). All that is only done when we have a 
"collapse" clause.

Range-for loops can also cause crashes at present without this patch, so are 
handled too.

>From 56d5d7797929d8bc81bf394a46c97b9bf645744e Mon Sep 17 00:00:00 2001
From: Julian Brown <julian.br...@amd.com>
Date: Wed, 26 Jun 2024 11:21:01 -0500
Subject: [PATCH] [clang][OpenMP] Diagnose badly-formed collapsed imperfect
 loop nests (#60678)

This patch fixes a couple of cases where Clang aborts with loop nests
that are being collapsed (via the relevant OpenMP clause) into a new,
combined loop.

The problematic cases happen when a variable declared within the
loop nest is used in the (init, condition, iter) statement of a more
deeply-nested loop.  I don't think these cases (generally?) fall under
the non-rectangular loop nest rules as defined in OpenMP 5.0+, but I
could be wrong (and anyway, emitting an error is better than crashing).

In terms of implementation: the crash happens because (to a first
approximation) all the loop bounds calculations are pulled out to the
start of the new, combined loop, but variables declared in the loop nest
"haven't been seen yet".  I believe there is special handling for
iteration variables declared in "for" init statements, but not for
variables declared elsewhere in the "imperfect" parts of a loop nest.

So, this patch tries to diagnose the troublesome cases before they can
cause a crash.  This is slightly awkward because at the point where we
want to do the diagnosis (SemaOpenMP.cpp), we don't have scope information
readily available.  Instead we "manually" scan through the AST of the
loop nest looking for var decls (ForVarDeclFinder), then we ensure we're
not using any of those in loop control subexprs (ForSubExprChecker).
All that is only done when we have a "collapse" clause.

Range-for loops can also cause crashes at present without this patch,
so are handled too.
---
 .../clang/Basic/DiagnosticSemaKinds.td        |   2 +
 clang/lib/AST/StmtOpenMP.cpp                  |   1 +
 clang/lib/Sema/SemaOpenMP.cpp                 | 132 ++++++++++++++++--
 clang/test/OpenMP/loop_collapse_1.c           |  40 ++++++
 clang/test/OpenMP/loop_collapse_2.cpp         |  80 +++++++++++
 5 files changed, 247 insertions(+), 8 deletions(-)
 create mode 100644 clang/test/OpenMP/loop_collapse_1.c
 create mode 100644 clang/test/OpenMP/loop_collapse_2.cpp

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 581434d33c5c9..beb78eb0a4ef4 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11152,6 +11152,8 @@ def err_omp_loop_diff_cxx : Error<
   "upper and lower loop bounds">;
 def err_omp_loop_cannot_use_stmt : Error<
   "'%0' statement cannot be used in OpenMP for loop">;
+def err_omp_loop_bad_collapse_var : Error<
+  "cannot use variable %1 in collapsed imperfectly-nested loop 
%select{init|condition|increment}0 statement">;
 def err_omp_simd_region_cannot_use_stmt : Error<
   "'%0' statement cannot be used in OpenMP simd region">;
 def warn_omp_loop_64_bit_var : Warning<
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 451a9fe9fe3d2..75ea55c99dfc5 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -12,6 +12,7 @@
 
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 
 using namespace clang;
 using namespace llvm::omp;
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 4f50efda155fb..e78af5cc7ab0a 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -21,6 +21,7 @@
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclOpenMP.h"
 #include "clang/AST/OpenMPClause.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/StmtVisitor.h"
@@ -7668,6 +7669,47 @@ struct LoopIterationSpace final {
   Expr *FinalCondition = nullptr;
 };
 
+class ForSubExprChecker : public RecursiveASTVisitor<ForSubExprChecker> {
+  const llvm::SmallSet<Decl *, 4> *CollapsedLoopVarDecls;
+  VarDecl *ForbiddenVar;
+  SourceRange ErrLoc;
+
+public:
+  explicit ForSubExprChecker(
+      const llvm::SmallSet<Decl *, 4> *CollapsedLoopVarDecls)
+      : CollapsedLoopVarDecls(CollapsedLoopVarDecls), ForbiddenVar(nullptr) {}
+
+  bool shouldVisitImplicitCode() const { return true; }
+
+  bool VisitDeclRefExpr(DeclRefExpr *E) {
+    ValueDecl *VD = E->getDecl();
+    if (!isa<VarDecl, BindingDecl>(VD))
+      return true;
+    VarDecl *V = VD->getPotentiallyDecomposedVarDecl();
+    if (V->getType()->isReferenceType()) {
+      VarDecl *VD = V->getDefinition();
+      if (VD->hasInit()) {
+        Expr *I = VD->getInit();
+        DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(I);
+        if (!DRE)
+          return true;
+        V = DRE->getDecl()->getPotentiallyDecomposedVarDecl();
+      }
+    }
+    Decl *Canon = V->getCanonicalDecl();
+    if (CollapsedLoopVarDecls->contains(Canon)) {
+      ForbiddenVar = V;
+      ErrLoc = E->getSourceRange();
+      return false;
+    }
+
+    return true;
+  }
+
+  VarDecl *getForbiddenVar() { return ForbiddenVar; }
+  SourceRange &getErrRange() { return ErrLoc; }
+};
+
 /// Helper class for checking canonical form of the OpenMP loops and
 /// extracting iteration space of each loop in the loop nest, that will be used
 /// for IR generation.
@@ -7682,6 +7724,8 @@ class OpenMPIterationSpaceChecker {
   SourceLocation DefaultLoc;
   /// A location for diagnostics (when increment is not compatible).
   SourceLocation ConditionLoc;
+  /// The set of variables declared within the (to be collapsed) loop nest.
+  const llvm::SmallSet<Decl *, 4> *CollapsedLoopVarDecls;
   /// A source location for referring to loop init later.
   SourceRange InitSrcRange;
   /// A source location for referring to condition later.
@@ -7725,10 +7769,13 @@ class OpenMPIterationSpaceChecker {
   Expr *Condition = nullptr;
 
 public:
-  OpenMPIterationSpaceChecker(Sema &SemaRef, bool SupportsNonRectangular,
-                              DSAStackTy &Stack, SourceLocation DefaultLoc)
+  OpenMPIterationSpaceChecker(
+      Sema &SemaRef, bool SupportsNonRectangular, DSAStackTy &Stack,
+      SourceLocation DefaultLoc,
+      const llvm::SmallSet<Decl *, 4> *CollapsedLoopDecls)
       : SemaRef(SemaRef), SupportsNonRectangular(SupportsNonRectangular),
-        Stack(Stack), DefaultLoc(DefaultLoc), ConditionLoc(DefaultLoc) {}
+        Stack(Stack), DefaultLoc(DefaultLoc), ConditionLoc(DefaultLoc),
+        CollapsedLoopVarDecls(CollapsedLoopDecls) {}
   /// Check init-expr for canonical loop form and save loop counter
   /// variable - #Var and its initialization value - #LB.
   bool checkAndSetInit(Stmt *S, bool EmitDiags = true);
@@ -8049,6 +8096,16 @@ bool OpenMPIterationSpaceChecker::checkAndSetInit(Stmt 
*S, bool EmitDiags) {
     if (!ExprTemp->cleanupsHaveSideEffects())
       S = ExprTemp->getSubExpr();
 
+  if (CollapsedLoopVarDecls) {
+    ForSubExprChecker FSEC{CollapsedLoopVarDecls};
+    if (!FSEC.TraverseStmt(S)) {
+      SourceRange &Range = FSEC.getErrRange();
+      SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
+          << Range.getEnd() << 0 << FSEC.getForbiddenVar();
+      return true;
+    }
+  }
+
   InitSrcRange = S->getSourceRange();
   if (Expr *E = dyn_cast<Expr>(S))
     S = E->IgnoreParens();
@@ -8152,6 +8209,17 @@ bool OpenMPIterationSpaceChecker::checkAndSetCond(Expr 
*S) {
   }
   Condition = S;
   S = getExprAsWritten(S);
+
+  if (CollapsedLoopVarDecls) {
+    ForSubExprChecker FSEC{CollapsedLoopVarDecls};
+    if (!FSEC.TraverseStmt(S)) {
+      SourceRange &Range = FSEC.getErrRange();
+      SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
+          << Range.getEnd() << 1 << FSEC.getForbiddenVar();
+      return true;
+    }
+  }
+
   SourceLocation CondLoc = S->getBeginLoc();
   auto &&CheckAndSetCond =
       [this, IneqCondIsCanonical](BinaryOperatorKind Opcode, const Expr *LHS,
@@ -8250,6 +8318,16 @@ bool OpenMPIterationSpaceChecker::checkAndSetInc(Expr 
*S) {
     if (!ExprTemp->cleanupsHaveSideEffects())
       S = ExprTemp->getSubExpr();
 
+  if (CollapsedLoopVarDecls) {
+    ForSubExprChecker FSEC{CollapsedLoopVarDecls};
+    if (!FSEC.TraverseStmt(S)) {
+      SourceRange &Range = FSEC.getErrRange();
+      SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
+          << Range.getEnd() << 2 << FSEC.getForbiddenVar();
+      return true;
+    }
+  }
+
   IncrementSrcRange = S->getSourceRange();
   S = S->IgnoreParens();
   if (auto *UO = dyn_cast<UnaryOperator>(S)) {
@@ -8972,7 +9050,7 @@ void 
SemaOpenMP::ActOnOpenMPLoopInitialization(SourceLocation ForLoc,
 
   DSAStack->loopStart();
   OpenMPIterationSpaceChecker ISC(SemaRef, /*SupportsNonRectangular=*/true,
-                                  *DSAStack, ForLoc);
+                                  *DSAStack, ForLoc, nullptr);
   if (!ISC.checkAndSetInit(Init, /*EmitDiags=*/false)) {
     if (ValueDecl *D = ISC.getLoopDecl()) {
       auto *VD = dyn_cast<VarDecl>(D);
@@ -9069,7 +9147,8 @@ static bool checkOpenMPIterationSpace(
     Expr *OrderedLoopCountExpr,
     SemaOpenMP::VarsWithInheritedDSAType &VarsWithImplicitDSA,
     llvm::MutableArrayRef<LoopIterationSpace> ResultIterSpaces,
-    llvm::MapVector<const Expr *, DeclRefExpr *> &Captures) {
+    llvm::MapVector<const Expr *, DeclRefExpr *> &Captures,
+    const llvm::SmallSet<Decl *, 4> *CollapsedLoopVarDecls) {
   bool SupportsNonRectangular = !isOpenMPLoopTransformationDirective(DKind);
   // OpenMP [2.9.1, Canonical Loop Form]
   //   for (init-expr; test-expr; incr-expr) structured-block
@@ -9108,7 +9187,8 @@ static bool checkOpenMPIterationSpace(
     return false;
 
   OpenMPIterationSpaceChecker ISC(SemaRef, SupportsNonRectangular, DSA,
-                                  For ? For->getForLoc() : 
CXXFor->getForLoc());
+                                  For ? For->getForLoc() : CXXFor->getForLoc(),
+                                  CollapsedLoopVarDecls);
 
   // Check init.
   Stmt *Init = For ? For->getInit() : CXXFor->getBeginStmt();
@@ -9475,6 +9555,36 @@ static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> 
PostUpdates) {
   return PostUpdate;
 }
 
+class ForVarDeclFinder : public RecursiveASTVisitor<ForVarDeclFinder> {
+  int NestingDepth;
+  llvm::SmallSet<Decl *, 4> &VarDecls;
+
+public:
+  explicit ForVarDeclFinder(llvm::SmallSet<Decl *, 4> &VD)
+      : NestingDepth(0), VarDecls(VD) {}
+
+  bool VisitForStmt(ForStmt *F) {
+    ++NestingDepth;
+    TraverseStmt(F->getBody());
+    --NestingDepth;
+    return false;
+  }
+
+  bool VisitCXXForRangeStmt(CXXForRangeStmt *RF) {
+    ++NestingDepth;
+    TraverseStmt(RF->getBody());
+    --NestingDepth;
+    return false;
+  }
+
+  bool VisitVarDecl(VarDecl *D) {
+    Decl *C = D->getCanonicalDecl();
+    if (NestingDepth > 0)
+      VarDecls.insert(C);
+    return true;
+  }
+};
+
 /// Called on a for stmt to check itself and nested loops (if any).
 /// \return Returns 0 if one of the collapsed stmts is not canonical for loop,
 /// number of collapsed loops otherwise.
@@ -9487,6 +9597,7 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr 
*CollapseLoopCountExpr,
   unsigned NestedLoopCount = 1;
   bool SupportsNonPerfectlyNested = (SemaRef.LangOpts.OpenMP >= 50) &&
                                     
!isOpenMPLoopTransformationDirective(DKind);
+  llvm::SmallSet<Decl *, 4> CollapsedLoopVarDecls{};
 
   if (CollapseLoopCountExpr) {
     // Found 'collapse' clause - calculate collapse number.
@@ -9494,6 +9605,9 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr 
*CollapseLoopCountExpr,
     if (!CollapseLoopCountExpr->isValueDependent() &&
         CollapseLoopCountExpr->EvaluateAsInt(Result, SemaRef.getASTContext())) 
{
       NestedLoopCount = Result.Val.getInt().getLimitedValue();
+
+      ForVarDeclFinder FVDF{CollapsedLoopVarDecls};
+      FVDF.TraverseStmt(AStmt);
     } else {
       Built.clear(/*Size=*/1);
       return 1;
@@ -9531,11 +9645,13 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr 
*CollapseLoopCountExpr,
           SupportsNonPerfectlyNested, NumLoops,
           [DKind, &SemaRef, &DSA, NumLoops, NestedLoopCount,
            CollapseLoopCountExpr, OrderedLoopCountExpr, &VarsWithImplicitDSA,
-           &IterSpaces, &Captures](unsigned Cnt, Stmt *CurStmt) {
+           &IterSpaces, &Captures,
+           CollapsedLoopVarDecls](unsigned Cnt, Stmt *CurStmt) {
             if (checkOpenMPIterationSpace(
                     DKind, CurStmt, SemaRef, DSA, Cnt, NestedLoopCount,
                     NumLoops, CollapseLoopCountExpr, OrderedLoopCountExpr,
-                    VarsWithImplicitDSA, IterSpaces, Captures))
+                    VarsWithImplicitDSA, IterSpaces, Captures,
+                    &CollapsedLoopVarDecls))
               return true;
             if (Cnt > 0 && Cnt >= NestedLoopCount &&
                 IterSpaces[Cnt].CounterVar) {
diff --git a/clang/test/OpenMP/loop_collapse_1.c 
b/clang/test/OpenMP/loop_collapse_1.c
new file mode 100644
index 0000000000000..c9877419223dd
--- /dev/null
+++ b/clang/test/OpenMP/loop_collapse_1.c
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -fopenmp -fopenmp-version=50 -verify %s
+
+void func( double *A, int N, int M, int NB ) {
+#pragma omp parallel
+  {
+    int nblks = (N-1)/NB;
+    int lnb = ((N-1)/NB)*NB;
+
+#pragma omp for collapse(2)
+    for (int jblk = 0 ; jblk < nblks ; jblk++ ) {
+      int jb = (jblk == nblks - 1 ? lnb : NB);
+      for (int jk = 0; jk < N; jk+=jb) {  // expected-error{{cannot use 
variable 'jb' in collapsed imperfectly-nested loop increment statement}}
+      }
+    }
+
+#pragma omp for collapse(2)
+    for (int a = 0; a < N; a++) {
+      for (int b = 0; b < M; b++) {
+        int cx = a+b < NB ? a : b;
+        for (int c = 0; c < cx; c++) {
+        }
+      }
+    }
+
+#pragma omp for collapse(3)
+    for (int a = 0; a < N; a++) {
+      for (int b = 0; b < M; b++) {
+        int cx = a+b < NB ? a : b;
+        for (int c = 0; c < cx; c++) {  // expected-error{{cannot use variable 
'cx' in collapsed imperfectly-nested loop condition statement}}
+        }
+      }
+    }
+  }
+}
+
+int main(void) {
+  double arr[256];
+  func (arr, 16, 16, 16);
+  return 0;
+}
diff --git a/clang/test/OpenMP/loop_collapse_2.cpp 
b/clang/test/OpenMP/loop_collapse_2.cpp
new file mode 100644
index 0000000000000..59deddf65e37b
--- /dev/null
+++ b/clang/test/OpenMP/loop_collapse_2.cpp
@@ -0,0 +1,80 @@
+// RUN: %clang_cc1 -fopenmp -fopenmp-version=50 -verify %s
+
+// We just want to try out a range for statement... this seems a bit OTT.
+template<typename T>
+class fakevector {
+  T *contents;
+  long size;
+public:
+  fakevector(long sz) : size(sz) {
+    contents = new T[sz];
+  }
+  ~fakevector() {
+    delete[] contents;
+  }
+  T& operator[](long x) { return contents[x]; }
+  typedef T *iterator;
+  fakevector<T>::iterator begin() {
+    return &contents[0];
+  }
+  fakevector<T>::iterator end() {
+    return &contents[size];
+  }
+};
+
+void func( double *A, int N, int M, int NB ) {
+#pragma omp parallel
+  {
+    int nblks = (N-1)/NB;
+    int lnb = ((N-1)/NB)*NB;
+#pragma omp for collapse(2)
+    for (int jblk = 0 ; jblk < nblks ; jblk++ ) {
+      int jb = (jblk == nblks - 1 ? lnb : NB);
+      for (int jk = 0; jk < N; jk+=jb) {  // expected-error{{cannot use 
variable 'jb' in collapsed imperfectly-nested loop increment statement}}
+      }
+    }
+
+#pragma omp for collapse(2)
+    for (int a = 0; a < N; a++) {
+        for (int b = 0; b < M; b++) {
+          int cx = a+b < NB ? a : b;
+          for (int c = 0; c < cx; c++) {
+          }
+        }
+    }
+
+    fakevector<float> myvec{N};
+#pragma omp for collapse(2)
+    for (auto &a : myvec) {
+      fakevector<float> myvec3{M};
+      for (auto &b : myvec3) {  // expected-error{{cannot use variable 
'myvec3' in collapsed imperfectly-nested loop init statement}}
+      }
+    }
+
+    fakevector<float> myvec2{M};
+
+#pragma omp for collapse(3)
+    for (auto &a : myvec) {
+      for (auto &b : myvec2) {
+        int cx = a < b ? N : M;
+        for (int c = 0; c < cx; c++) {  // expected-error {{cannot use 
variable 'cx' in collapsed imperfectly-nested loop condition statement}}
+        }
+      }
+    }
+
+#pragma omp for collapse(3)
+    for (auto &a : myvec) {
+      int cx = a < 5 ? M : N;
+      for (auto &b : myvec2) {
+        for (int c = 0; c < cx; c++) {  // expected-error{{cannot use variable 
'cx' in collapsed imperfectly-nested loop condition statement}}
+        }
+      }
+    }
+  }
+}
+
+int main(void) {
+  double arr[256];
+  func (arr, 16, 16, 16);
+  return 0;
+}

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to