ajohnson-uoregon updated this revision to Diff 412875.
ajohnson-uoregon added a comment.

Still trying to get right commits, apologies XD


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D120949/new/

https://reviews.llvm.org/D120949

Files:
  clang-tools-extra/clang-rewrite/ClangRewrite.cpp
  clang-tools-extra/clang-rewrite/MatcherGenCallback.h
  clang-tools-extra/clang-rewrite/NewCodeCallback.h
  clang-tools-extra/clang-rewrite/RewriteCallback.h
  clang-tools-extra/clang-rewrite/tests/new.cpp
  clang/include/clang/ASTMatchers/ASTMatchers.h
  clang/lib/ASTMatchers/Dynamic/Registry.cpp

Index: clang/lib/ASTMatchers/Dynamic/Registry.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Registry.cpp
+++ clang/lib/ASTMatchers/Dynamic/Registry.cpp
@@ -356,6 +356,7 @@
   REGISTER_MATCHER(hasSpecializedTemplate);
   REGISTER_MATCHER(hasStaticStorageDuration);
   REGISTER_MATCHER(hasStructuredBlock);
+  REGISTER_MATCHER(hasSubStmt);
   REGISTER_MATCHER(hasSyntacticForm);
   REGISTER_MATCHER(hasTargetDecl);
   REGISTER_MATCHER(hasTemplateArgument);
Index: clang/include/clang/ASTMatchers/ASTMatchers.h
===================================================================
--- clang/include/clang/ASTMatchers/ASTMatchers.h
+++ clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -2737,6 +2737,22 @@
   return false;
 }
 
+/// Matches the statement an attribute is attached to.
+///
+/// Example:
+/// \code
+///   attributedStmt(hasSubStmt(returnStmt()))
+/// \endcode
+/// would match return 1; here:
+/// \code
+///   else [[unlikely]]
+///     return 1;
+/// \endcode
+AST_MATCHER_P(AttributedStmt, hasSubStmt, internal::Matcher<Stmt>, InnerMatcher) {
+  const Stmt *const Statement = Node.getSubStmt();
+  return (Statement != nullptr && InnerMatcher.matches(*Statement, Finder, Builder));
+}
+
 /// Matches \c QualTypes in the clang AST.
 extern const internal::VariadicAllOfMatcher<QualType> qualType;
 
Index: clang-tools-extra/clang-rewrite/tests/new.cpp
===================================================================
--- clang-tools-extra/clang-rewrite/tests/new.cpp
+++ clang-tools-extra/clang-rewrite/tests/new.cpp
@@ -6,14 +6,26 @@
   for (int i = 0; i < 3; i++) {
     x = i;
   }
+  [[likely]]
+  {
+    printf("not a matcher\n");
+  }
+  [[clang::matcher_block]]
   {
     return x;
   }
+}
 
+constexpr double pow(double x, long long n) noexcept {
+    if (n > 0) [[likely]]
+        return x * pow(x, n - 1);
+    else [[unlikely]]
+        return 1;
 }
 
 // [[clang::matcher("cuda_kernel")]]
 // auto kern() {
+//   [[clang::matcher_block]]
 //   {
 //     kernel<<<numblocks, numthreads>>>(arg1, arg2, ...);
 //   }
@@ -21,25 +33,37 @@
 //
 // [[clang::replace("cuda_kernel")]]
 // auto hip() {
-//   {
-//     hip_launch(kernkel, numblocks, numthreads, 0, 0, arg1, arg2, ...);
+//   if (kernel == "gaussian") {
+//     [[clang::matcher_block]]
+//     {
+//       hip_launch(kernel, numblocks, numthreads, 0, 0, arg1, arg2, ...);
+//     }
 //   }
 // }
 
 
 [[clang::replace("returns")]]
 auto return42() {
-  return 42;
+  [[clang::matcher_block]]
+  {
+    return 42;
+  }
 }
 
 [[clang::insert_before("returns", "thencode")]]
 auto foobar() {
-  printf("returning\n");;
+  [[clang::matcher_block]]
+  {
+    printf("returning\n");
+  }
 }
 
 [[clang::insert_after("thencode")]]
 auto helloworld() {
-  printf("hello world\n");
+  [[clang::matcher_block]]
+  {
+    printf("hello world\n");
+  }
 }
 
 int main() {
Index: clang-tools-extra/clang-rewrite/RewriteCallback.h
===================================================================
--- clang-tools-extra/clang-rewrite/RewriteCallback.h
+++ clang-tools-extra/clang-rewrite/RewriteCallback.h
@@ -137,11 +137,11 @@
     unsigned int end_line = end.getSpellingLineNumber();
     unsigned int end_col = end.getSpellingColumnNumber();
 
-    if (verbose) {
+    // if (verbose) {
       printf("FOUND match for %s at %d:%d - %d:%d\n",
              matcher->getName().c_str(), begin_line, begin_col, end_line,
              end_col);
-    }
+    // }
     if (rw.isRewritable(match->getBeginLoc()) &&
         rw.isRewritable(match->getEndLoc())) {
 
Index: clang-tools-extra/clang-rewrite/NewCodeCallback.h
===================================================================
--- clang-tools-extra/clang-rewrite/NewCodeCallback.h
+++ clang-tools-extra/clang-rewrite/NewCodeCallback.h
@@ -57,13 +57,39 @@
 //         .bind("replace");
 
 DeclarationMatcher insert_before_match =
-  functionDecl(hasAttr(attr::InsertCodeBefore)).bind("insert_before_match");
+  functionDecl(allOf(
+    hasAttr(attr::InsertCodeBefore),
+    hasBody(compoundStmt(
+      hasAnySubstatement(attributedStmt(allOf(
+        isAttr(attr::MatcherBlock),
+        hasSubStmt(compoundStmt(anything()).bind("body"))
+      )))
+    ))
+  )).bind("insert_before_match");
 
 DeclarationMatcher insert_after_match =
-  functionDecl(hasAttr(attr::InsertCodeAfter)).bind("insert_after_match");
+  functionDecl(allOf(
+    hasAttr(attr::InsertCodeAfter),
+    hasBody(compoundStmt(
+      hasAnySubstatement(attributedStmt(allOf(
+        isAttr(attr::MatcherBlock),
+        hasSubStmt(compoundStmt(anything()).bind("body"))
+      )))
+    ))
+  )).bind("insert_after_match");
 
 DeclarationMatcher replace_match =
-  functionDecl(hasAttr(attr::ReplaceCode)).bind("replace");
+  functionDecl(allOf(
+    hasAttr(attr::ReplaceCode),
+    hasBody(compoundStmt(
+      hasAnySubstatement(attributedStmt(allOf(
+        isAttr(attr::MatcherBlock),
+        hasSubStmt(compoundStmt(anything()).bind("body"))
+      )))
+    ))
+  )).bind("replace");
+
+
 
 std::vector<CodeAction *> all_actions;
 
@@ -129,70 +155,69 @@
     }
 
     // grab function body as new code
-    Stmt* new_code = nullptr;
-    if (func->hasBody()) {
-      new_code = func->getBody();
-      printf("function body!!!\n");
-      new_code->dump();
+    const CompoundStmt* body = result.Nodes.getNodeAs<CompoundStmt>("body");
+    if (!body || !context->getSourceManager().isWrittenInMainFile(body->getBeginLoc())) {
+      printf("ERROR: invalid body\n");
+      return;
+    }
+    printf("function body\n");
+    body->dump();
+
+    FullSourceLoc body_begin;
+    FullSourceLoc body_end;
+    if (!body->body_empty()) {
+      body_begin = context->getFullLoc(body->body_front()->getBeginLoc());
+
+      // go to end of line; stmts don't work, gotta lex to the end of the line
+      SourceLocation eol = Lexer::getLocForEndOfToken(
+          body->body_back()->getBeginLoc(), 0, context->getSourceManager(),
+          context->getLangOpts());
+      Optional<Token> tok = Lexer::findNextToken(
+          eol, context->getSourceManager(), context->getLangOpts());
+      while (tok.hasValue() && tok->isNot(clang::tok::semi)) {
+        tok = Lexer::findNextToken(eol, context->getSourceManager(),
+                                   context->getLangOpts());
+        eol = tok->getLocation();
+      }
+      // TODO: this is a hack and we should be smarter about semicolons
+      if (kind != Replace) {
+        eol = tok->getEndLoc(); // grab semicolon
+      }
+      body_end = context->getFullLoc(eol);
     }
-    else {
-      printf("WARNING: code modification empty\n");
+    else { // empty body just use brackets
+      body_begin = context->getFullLoc(body->getLBracLoc());
+      body_end = context->getFullLoc(body->getRBracLoc());
     }
 
-    // FullSourceLoc body_begin;
-    // FullSourceLoc body_end;
-    // if (!new_code->body_empty()) {
-    //   body_begin = context->getFullLoc(new_code->body_front()->getBeginLoc());
-    //
-    //   // go to end of line; stmts don't work, gotta lex to the end of the line
-    //   SourceLocation eol = Lexer::getLocForEndOfToken(
-    //       new_code->body_back()->getBeginLoc(), 0, context->getSourceManager(),
-    //       context->getLangOpts());
-    //   Optional<Token> tok = Lexer::findNextToken(
-    //       eol, context->getSourceManager(), context->getLangOpts());
-    //   while (tok.hasValue() && tok->isNot(clang::tok::semi)) {
-    //     tok = Lexer::findNextToken(eol, context->getSourceManager(),
-    //                                context->getLangOpts());
-    //     eol = tok->getLocation();
-    //   }
-    //   // TODO: this is a hack and we should be smarter about semicolons
-    //   if (kind != Replace) {
-    //     eol = tok->getEndLoc(); // grab semicolon
-    //   }
-    //   body_end = context->getFullLoc(eol);
-    // } else { // empty body just use brackets
-    //   body_begin = context->getFullLoc(new_code->getLBracLoc());
-    //   body_end = context->getFullLoc(new_code->getRBracLoc());
-    // }
-    //
-    // FileID fid = body_begin.getFileID();
-    // unsigned int begin_offset = body_begin.getFileOffset();
-    // unsigned int end_offset = body_end.getFileOffset();
-    //
-    // printf("begin offset %u\n", begin_offset);
-    // printf("end offset   %u\n", end_offset);
-    // printf("array length %u\n", end_offset - begin_offset);
-    //
-    // llvm::Optional<llvm::MemoryBufferRef> buff =
-    //     context->getSourceManager().getBufferOrNone(fid);
-    //
-    // char *code = new char[end_offset - begin_offset + 1];
-    // if (buff.hasValue()) {
-    //   memcpy(code, &(buff->getBufferStart()[begin_offset]),
-    //          (end_offset - begin_offset + 1) * sizeof(char));
-    //   code[end_offset - begin_offset] =
-    //       '\0'; // force null terminated for Reasons
-    //   printf("code??? %s\n", code);
-    // } else {
-    //   printf("no buffer :<\n");
-    // }
-    //
-    // // make action, put in vector of actions
-    // CodeAction *act =
-    //     new CodeAction(kind, matcher_names, std::string(code), action_name);
-    // all_actions.push_back(act);
-    //
-    // delete[] code;
+    FileID fid = body_begin.getFileID();
+    unsigned int begin_offset = body_begin.getFileOffset();
+    unsigned int end_offset = body_end.getFileOffset();
+
+    printf("begin offset %u\n", begin_offset);
+    printf("end offset   %u\n", end_offset);
+    printf("array length %u\n", end_offset - begin_offset);
+
+    llvm::Optional<llvm::MemoryBufferRef> buff =
+        context->getSourceManager().getBufferOrNone(fid);
+
+    char *code = new char[end_offset - begin_offset + 1];
+    if (buff.hasValue()) {
+      memcpy(code, &(buff->getBufferStart()[begin_offset]),
+             (end_offset - begin_offset + 1) * sizeof(char));
+      code[end_offset - begin_offset] =
+          '\0'; // force null terminated for Reasons
+      printf("code??? %s\n", code);
+    } else {
+      printf("no buffer :<\n");
+    }
+
+    // make action, put in vector of actions
+    CodeAction *act =
+        new CodeAction(kind, matcher_names, std::string(code), action_name);
+    all_actions.push_back(act);
+
+    delete[] code;
   }
 
 private:
@@ -233,12 +258,5 @@
   }
 };
 
-class ReplaceCallback2 : public NewCodeCallback {
-public:
-  ReplaceCallback2() {
-    kind = Replace;
-    kind_name = "replace";
-  }
-};
 
 #endif
Index: clang-tools-extra/clang-rewrite/MatcherGenCallback.h
===================================================================
--- clang-tools-extra/clang-rewrite/MatcherGenCallback.h
+++ clang-tools-extra/clang-rewrite/MatcherGenCallback.h
@@ -99,7 +99,7 @@
     if (root == nullptr) {
       root = temp;
       current = root;
-      // bind_to("match");
+      bind_to("match");
     }
     else {
       current->add_child(current, temp);
@@ -127,7 +127,10 @@
   functionDecl(allOf(
     hasAttr(attr::Matcher),
     hasBody(compoundStmt(
-      hasAnySubstatement(compoundStmt(anything()).bind("body"))
+      hasAnySubstatement(attributedStmt(allOf(
+        isAttr(attr::MatcherBlock),
+        hasSubStmt(compoundStmt(anything()).bind("body"))
+      )))
     ))
   )).bind("matcher");
 
Index: clang-tools-extra/clang-rewrite/ClangRewrite.cpp
===================================================================
--- clang-tools-extra/clang-rewrite/ClangRewrite.cpp
+++ clang-tools-extra/clang-rewrite/ClangRewrite.cpp
@@ -107,13 +107,11 @@
     InsertPostmatchCallback postmatch_callback;
     ReplaceCallback replace_callback;
     MatcherGenCallback matcher_callback;
-    // ReplaceCallback2 r2d2;
 
     inst_finder.addMatcher(insert_before_match, &prematch_callback);
     inst_finder.addMatcher(insert_after_match, &postmatch_callback);
     inst_finder.addMatcher(replace_match, &replace_callback);
     inst_finder.addMatcher(matcher, &matcher_callback);
-    // inst_finder.addMatcher(replace2, &r2d2);
 
     // MatcherWrapper<DynTypedMatcher>* m = new MatcherWrapper<DynTypedMatcher>(rettest, "returns_test",
     //   "test",
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to