ajohnson-uoregon created this revision.
ajohnson-uoregon added reviewers: klimek, jdoerfert.
Herald added a subscriber: yaxunl.
Herald added a project: All.
ajohnson-uoregon requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

adding more AST matchers for all possible launch params to a CUDA kernel, e.g. 
cudaGridDim and cudaSharedMemPerBlock


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D120952

Files:
  clang/include/clang/ASTMatchers/ASTMatchers.h
  clang/lib/ASTMatchers/Dynamic/Registry.cpp

Index: clang/lib/ASTMatchers/Dynamic/Registry.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Registry.cpp
+++ clang/lib/ASTMatchers/Dynamic/Registry.cpp
@@ -177,7 +177,11 @@
   REGISTER_MATCHER(continueStmt);
   REGISTER_MATCHER(coreturnStmt);
   REGISTER_MATCHER(coyieldExpr);
+  REGISTER_MATCHER(cudaBlockDim);
+  REGISTER_MATCHER(cudaGridDim);
   REGISTER_MATCHER(cudaKernelCallExpr);
+  REGISTER_MATCHER(cudaSharedMemPerBlock);
+  REGISTER_MATCHER(cudaStream);
   REGISTER_MATCHER(cxxBaseSpecifier);
   REGISTER_MATCHER(cxxBindTemporaryExpr);
   REGISTER_MATCHER(cxxBoolLiteral);
@@ -320,6 +324,7 @@
   REGISTER_MATCHER(hasInit);
   REGISTER_MATCHER(hasInitializer);
   REGISTER_MATCHER(hasInitStatement);
+  REGISTER_MATCHER(hasKernelConfig);
   REGISTER_MATCHER(hasKeywordSelector);
   REGISTER_MATCHER(hasLHS);
   REGISTER_MATCHER(hasLocalQualifiers);
Index: clang/include/clang/ASTMatchers/ASTMatchers.h
===================================================================
--- clang/include/clang/ASTMatchers/ASTMatchers.h
+++ clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -7848,6 +7848,80 @@
 extern const internal::VariadicDynCastAllOfMatcher<Stmt, CUDAKernelCallExpr>
     cudaKernelCallExpr;
 
+/// Matches the config in <<<>>> on CUDA kernel calls.
+///
+/// Example: will match <<<i,j>>> in
+/// \code
+///   kernel<<<i,j>>>();
+/// \endcode
+AST_MATCHER_P(CUDAKernelCallExpr, hasKernelConfig, internal::Matcher<CallExpr>,
+              InnerMatcher) {
+  if (const CallExpr *Config = Node.getConfig()) {
+    return InnerMatcher.matches(*Config, Finder, Builder);
+  }
+  return false;
+}
+
+/// Matches the first argument (grid dim) in <<<>>> on CUDA kernel calls.
+///
+/// Example: will match i in
+/// \code
+///   kernel<<<i,j>>>();
+/// \endcode
+AST_MATCHER_P(CUDAKernelCallExpr, cudaGridDim, internal::Matcher<Expr>,
+              InnerMatcher) {
+  const CallExpr *Config = Node.getConfig();
+  if (Config && Config->getNumArgs() > 0) {
+    return InnerMatcher.matches(*(Config->getArg(0)), Finder, Builder);
+  }
+  return false;
+}
+
+/// Matches the second argument (block dim) in <<<>>> on CUDA kernel calls.
+///
+/// Example: will match j in
+/// \code
+///   kernel<<<i,j>>>();
+/// \endcode
+AST_MATCHER_P(CUDAKernelCallExpr, cudaBlockDim, internal::Matcher<Expr>,
+              InnerMatcher) {
+  const CallExpr *Config = Node.getConfig();
+  if (Config && Config->getNumArgs() > 1) {
+    return InnerMatcher.matches(*(Config->getArg(1)), Finder, Builder);
+  }
+  return false;
+}
+
+/// Matches the third argument (shared mem size) in <<<>>> on CUDA kernel calls.
+///
+/// Example: will match mem in
+/// \code
+///   kernel<<<i, j, mem, 0>>>();
+/// \endcode
+AST_MATCHER_P(CUDAKernelCallExpr, cudaSharedMemPerBlock, internal::Matcher<Expr>,
+              InnerMatcher) {
+  const CallExpr *Config = Node.getConfig();
+  if (Config && Config->getNumArgs() > 2) {
+    return InnerMatcher.matches(*(Config->getArg(2)), Finder, Builder);
+  }
+  return false;
+}
+
+/// Matches the fourth argument (CUDA stream) in <<<>>> on CUDA kernel calls.
+///
+/// Example: will match 0 in
+/// \code
+///   kernel<<<i, j, mem, 0>>>();
+/// \endcode
+AST_MATCHER_P(CUDAKernelCallExpr, cudaStream, internal::Matcher<Expr>,
+              InnerMatcher) {
+  const CallExpr *Config = Node.getConfig();
+  if (Config && Config->getNumArgs() > 3) {
+    return InnerMatcher.matches(*(Config->getArg(3)), Finder, Builder);
+  }
+  return false;
+}
+
 /// Matches expressions that resolve to a null pointer constant, such as
 /// GNU's __null, C++11's nullptr, or C's NULL macro.
 ///
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to