Author: erichkeane
Date: 2024-05-10T07:03:48-07:00
New Revision: 63177422a834f4b81d59b827b5f2a1c5a9083749

URL: 
https://github.com/llvm/llvm-project/commit/63177422a834f4b81d59b827b5f2a1c5a9083749
DIFF: 
https://github.com/llvm/llvm-project/commit/63177422a834f4b81d59b827b5f2a1c5a9083749.diff

LOG: [OpenACC][NFC] Fix isa behavior for OpenACC types

I discovered while working on a different patch that I'd not implemented
the 'classof' for any of the Clauses, which resulted in 'isa' always
returning 'true' for all of the types.  This patch goes through all the
existing clauses and adds 'classof' such that it will work correctly.

Additionally, in doing this, I found a bug where I was doing a cast to
the wrong type in the ASTWriter, so this fixes that problem as well.

Added: 
    

Modified: 
    clang/include/clang/AST/OpenACCClause.h
    clang/include/clang/AST/StmtOpenACC.h
    clang/lib/AST/OpenACCClause.cpp
    clang/lib/Serialization/ASTWriter.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/OpenACCClause.h 
b/clang/include/clang/AST/OpenACCClause.h
index d8332816a4999..3d0b1ab9d31e0 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -36,7 +36,7 @@ class OpenACCClause {
   SourceLocation getBeginLoc() const { return Location.getBegin(); }
   SourceLocation getEndLoc() const { return Location.getEnd(); }
 
-  static bool classof(const OpenACCClause *) { return true; }
+  static bool classof(const OpenACCClause *) { return false; }
 
   using child_iterator = StmtIterator;
   using const_child_iterator = ConstStmtIterator;
@@ -63,6 +63,8 @@ class OpenACCClauseWithParams : public OpenACCClause {
       : OpenACCClause(K, BeginLoc, EndLoc), LParenLoc(LParenLoc) {}
 
 public:
+  static bool classof(const OpenACCClause *C);
+
   SourceLocation getLParenLoc() const { return LParenLoc; }
 
   child_range children() {
@@ -92,6 +94,9 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams {
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Default;
+  }
   OpenACCDefaultClauseKind getDefaultClauseKind() const {
     return DefaultClauseKind;
   }
@@ -116,6 +121,8 @@ class OpenACCClauseWithCondition : public 
OpenACCClauseWithParams {
         ConditionExpr(ConditionExpr) {}
 
 public:
+  static bool classof(const OpenACCClause *C);
+
   bool hasConditionExpr() const { return ConditionExpr; }
   const Expr *getConditionExpr() const { return ConditionExpr; }
   Expr *getConditionExpr() { return ConditionExpr; }
@@ -143,6 +150,9 @@ class OpenACCIfClause : public OpenACCClauseWithCondition {
                   Expr *ConditionExpr, SourceLocation EndLoc);
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::If;
+  }
   static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
                                  SourceLocation LParenLoc, Expr *ConditionExpr,
                                  SourceLocation EndLoc);
@@ -154,6 +164,9 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition 
{
                     Expr *ConditionExpr, SourceLocation EndLoc);
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Self;
+  }
   static OpenACCSelfClause *Create(const ASTContext &C, SourceLocation 
BeginLoc,
                                    SourceLocation LParenLoc,
                                    Expr *ConditionExpr, SourceLocation EndLoc);
@@ -180,6 +193,7 @@ class OpenACCClauseWithExprs : public 
OpenACCClauseWithParams {
   llvm::ArrayRef<Expr *> getExprs() const { return Exprs; }
 
 public:
+  static bool classof(const OpenACCClause *C);
   child_range children() {
     return child_range(reinterpret_cast<Stmt **>(Exprs.begin()),
                        reinterpret_cast<Stmt **>(Exprs.end()));
@@ -214,6 +228,9 @@ class OpenACCWaitClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Wait;
+  }
   static OpenACCWaitClause *Create(const ASTContext &C, SourceLocation 
BeginLoc,
                                    SourceLocation LParenLoc, Expr *DevNumExpr,
                                    SourceLocation QueuesLoc,
@@ -246,6 +263,9 @@ class OpenACCNumGangsClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::NumGangs;
+  }
   static OpenACCNumGangsClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
@@ -275,6 +295,7 @@ class OpenACCClauseWithSingleIntExpr : public 
OpenACCClauseWithExprs {
   }
 
 public:
+  static bool classof(const OpenACCClause *C);
   bool hasIntExpr() const { return !getExprs().empty(); }
   const Expr *getIntExpr() const {
     return hasIntExpr() ? getExprs()[0] : nullptr;
@@ -288,6 +309,9 @@ class OpenACCNumWorkersClause : public 
OpenACCClauseWithSingleIntExpr {
                           Expr *IntExpr, SourceLocation EndLoc);
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::NumWorkers;
+  }
   static OpenACCNumWorkersClause *Create(const ASTContext &C,
                                          SourceLocation BeginLoc,
                                          SourceLocation LParenLoc,
@@ -299,6 +323,9 @@ class OpenACCVectorLengthClause : public 
OpenACCClauseWithSingleIntExpr {
                             Expr *IntExpr, SourceLocation EndLoc);
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::VectorLength;
+  }
   static OpenACCVectorLengthClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          Expr *IntExpr, SourceLocation EndLoc);
@@ -309,6 +336,9 @@ class OpenACCAsyncClause : public 
OpenACCClauseWithSingleIntExpr {
                      Expr *IntExpr, SourceLocation EndLoc);
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Async;
+  }
   static OpenACCAsyncClause *Create(const ASTContext &C,
                                     SourceLocation BeginLoc,
                                     SourceLocation LParenLoc, Expr *IntExpr,
@@ -326,6 +356,7 @@ class OpenACCClauseWithVarList : public 
OpenACCClauseWithExprs {
       : OpenACCClauseWithExprs(K, BeginLoc, LParenLoc, EndLoc) {}
 
 public:
+  static bool classof(const OpenACCClause *C);
   ArrayRef<Expr *> getVarList() { return getExprs(); }
   ArrayRef<Expr *> getVarList() const { return getExprs(); }
 };
@@ -344,6 +375,9 @@ class OpenACCPrivateClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Private;
+  }
   static OpenACCPrivateClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
@@ -363,6 +397,9 @@ class OpenACCFirstPrivateClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::FirstPrivate;
+  }
   static OpenACCFirstPrivateClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
@@ -382,6 +419,9 @@ class OpenACCDevicePtrClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::DevicePtr;
+  }
   static OpenACCDevicePtrClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
@@ -401,6 +441,9 @@ class OpenACCAttachClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Attach;
+  }
   static OpenACCAttachClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
@@ -420,6 +463,9 @@ class OpenACCNoCreateClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::NoCreate;
+  }
   static OpenACCNoCreateClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
@@ -439,6 +485,9 @@ class OpenACCPresentClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Present;
+  }
   static OpenACCPresentClause *
   Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation 
LParenLoc,
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
@@ -462,6 +511,11 @@ class OpenACCCopyClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Copy ||
+           C->getClauseKind() == OpenACCClauseKind::PCopy ||
+           C->getClauseKind() == OpenACCClauseKind::PresentOrCopy;
+  }
   static OpenACCCopyClause *
   Create(const ASTContext &C, OpenACCClauseKind Spelling,
          SourceLocation BeginLoc, SourceLocation LParenLoc,
@@ -488,6 +542,11 @@ class OpenACCCopyInClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::CopyIn ||
+           C->getClauseKind() == OpenACCClauseKind::PCopyIn ||
+           C->getClauseKind() == OpenACCClauseKind::PresentOrCopyIn;
+  }
   bool isReadOnly() const { return IsReadOnly; }
   static OpenACCCopyInClause *
   Create(const ASTContext &C, OpenACCClauseKind Spelling,
@@ -515,6 +574,11 @@ class OpenACCCopyOutClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::CopyOut ||
+           C->getClauseKind() == OpenACCClauseKind::PCopyOut ||
+           C->getClauseKind() == OpenACCClauseKind::PresentOrCopyOut;
+  }
   bool isZero() const { return IsZero; }
   static OpenACCCopyOutClause *
   Create(const ASTContext &C, OpenACCClauseKind Spelling,
@@ -542,6 +606,11 @@ class OpenACCCreateClause final
   }
 
 public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Create ||
+           C->getClauseKind() == OpenACCClauseKind::PCreate ||
+           C->getClauseKind() == OpenACCClauseKind::PresentOrCreate;
+  }
   bool isZero() const { return IsZero; }
   static OpenACCCreateClause *
   Create(const ASTContext &C, OpenACCClauseKind Spelling,

diff  --git a/clang/include/clang/AST/StmtOpenACC.h 
b/clang/include/clang/AST/StmtOpenACC.h
index 66f8f844e0b29..b706864798baa 100644
--- a/clang/include/clang/AST/StmtOpenACC.h
+++ b/clang/include/clang/AST/StmtOpenACC.h
@@ -93,6 +93,10 @@ class OpenACCAssociatedStmtConstruct : public 
OpenACCConstructStmt {
   }
 
 public:
+  static bool classof(const Stmt *T) {
+    return false;
+  }
+
   child_range children() {
     if (getAssociatedStmt())
       return child_range(&AssociatedStmt, &AssociatedStmt + 1);

diff  --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index be079556a87a8..ee13437b97b48 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -17,6 +17,33 @@
 
 using namespace clang;
 
+bool OpenACCClauseWithParams::classof(const OpenACCClause *C) {
+  return OpenACCClauseWithCondition::classof(C) ||
+         OpenACCClauseWithExprs::classof(C);
+}
+bool OpenACCClauseWithExprs::classof(const OpenACCClause *C) {
+  return OpenACCWaitClause::classof(C) || OpenACCNumGangsClause::classof(C) ||
+         OpenACCClauseWithSingleIntExpr::classof(C) ||
+         OpenACCClauseWithVarList::classof(C);
+}
+bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
+  return OpenACCPrivateClause::classof(C) ||
+         OpenACCFirstPrivateClause::classof(C) ||
+         OpenACCDevicePtrClause::classof(C) ||
+         OpenACCDevicePtrClause::classof(C) ||
+         OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) 
||
+         OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
+         OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
+         OpenACCCreateClause::classof(C);
+}
+bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
+  return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
+}
+bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
+  return OpenACCNumWorkersClause::classof(C) ||
+         OpenACCVectorLengthClause::classof(C) ||
+         OpenACCAsyncClause::classof(C);
+}
 OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
                                                    OpenACCDefaultClauseKind K,
                                                    SourceLocation BeginLoc,

diff  --git a/clang/lib/Serialization/ASTWriter.cpp 
b/clang/lib/Serialization/ASTWriter.cpp
index e7b9050165bb7..ab07cf3efa45d 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7816,7 +7816,7 @@ void ASTRecordWriter::writeOpenACCClause(const 
OpenACCClause *C) {
     return;
   }
   case OpenACCClauseKind::Self: {
-    const auto *SC = cast<OpenACCIfClause>(C);
+    const auto *SC = cast<OpenACCSelfClause>(C);
     writeSourceLocation(SC->getLParenLoc());
     writeBool(SC->hasConditionExpr());
     if (SC->hasConditionExpr())


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to