llvmorg-github-actions[bot] wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Mariya Podchishchaeva (Fznamznon)

<details>
<summary>Changes</summary>

WIP

---

Patch is 60.82 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/196128.diff


29 Files Affected:

- (modified) clang/include/clang/AST/ASTNodeTraverser.h (+3-1) 
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1) 
- (modified) clang/include/clang/AST/StmtSYCL.h (+19-7) 
- (modified) clang/include/clang/Basic/Attr.td (+7) 
- (modified) clang/include/clang/Sema/ScopeInfo.h (+4) 
- (modified) clang/include/clang/Sema/SemaSYCL.h (+6-3) 
- (modified) clang/lib/Sema/SemaDecl.cpp (+14-4) 
- (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3) 
- (modified) clang/lib/Sema/SemaSYCL.cpp (+451-34) 
- (modified) clang/lib/Sema/TreeTransform.h (+6-1) 
- (modified) clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp (+3) 
- (added) clang/test/ASTSYCL/ast-dump-sycl-kernel-decomposition.cpp (+141) 
- (modified) clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp (+3) 
- (modified) clang/test/ASTSYCL/ast-print-sycl-kernel-call.cpp (+3) 
- (modified) clang/test/CodeGenSYCL/function-attrs.cpp (+3) 
- (added) clang/test/CodeGenSYCL/kernel-arg-decomposition.cpp (+96) 
- (modified) clang/test/CodeGenSYCL/kernel-caller-entry-point.cpp (+3) 
- (modified) clang/test/CodeGenSYCL/sycl-kernel-entry-point-exceptions.cpp (+3) 
- (modified) clang/test/CodeGenSYCL/unique_stable_name_windows_diff.cpp (+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-entry-point-attr-appertainment.cpp 
(+3) 
- (modified) 
clang/test/SemaSYCL/sycl-kernel-entry-point-attr-device-odr-use.cpp (+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-entry-point-attr-grammar.cpp (+3) 
- (modified) 
clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-module.cpp (+3) 
- (modified) 
clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-pch.cpp (+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name.cpp 
(+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-entry-point-attr-sfinae.cpp (+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-entry-point-attr-this.cpp (+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-launch-ms-compat.cpp (+3) 
- (modified) clang/test/SemaSYCL/sycl-kernel-launch.cpp (+3) 


``````````diff
diff --git a/clang/include/clang/AST/ASTNodeTraverser.h 
b/clang/include/clang/AST/ASTNodeTraverser.h
index 5e9463d54747d..c4bafc2017609 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -858,8 +858,10 @@ class ASTNodeTraverser
   void
   VisitUnresolvedSYCLKernelCallStmt(const UnresolvedSYCLKernelCallStmt *Node) {
     Visit(Node->getOriginalStmt());
-    if (Traversal != TK_IgnoreUnlessSpelledInSource)
+    if (Traversal != TK_IgnoreUnlessSpelledInSource) {
       Visit(Node->getKernelLaunchIdExpr());
+      Visit(Node->getSpecArgsIdExpr());
+    }
   }
 
   void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h 
b/clang/include/clang/AST/RecursiveASTVisitor.h
index b5be0910194bd..fdb828a7cb680 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3034,6 +3034,7 @@ DEF_TRAVERSE_STMT(UnresolvedSYCLKernelCallStmt, {
   if (getDerived().shouldVisitImplicitCode()) {
     TRY_TO(TraverseStmt(S->getOriginalStmt()));
     TRY_TO(TraverseStmt(S->getKernelLaunchIdExpr()));
+    TRY_TO(TraverseStmt(S->getSpecArgsIdExpr()));
     ShouldVisitChildren = false;
   }
 })
diff --git a/clang/include/clang/AST/StmtSYCL.h 
b/clang/include/clang/AST/StmtSYCL.h
index 79ac88532e143..cd682f4cea594 100644
--- a/clang/include/clang/AST/StmtSYCL.h
+++ b/clang/include/clang/AST/StmtSYCL.h
@@ -105,12 +105,19 @@ class UnresolvedSYCLKernelCallStmt : public Stmt {
   Stmt *OriginalStmt = nullptr;
   // KernelLaunchIdExpr stores an UnresolvedLookupExpr or UnresolvedMemberExpr
   // corresponding to the SYCL kernel launch function for which a call
-  // will be synthesized during template instantiation.
+  // will be synthesized during template instantiation of the host code.
   Expr *KernelLaunchIdExpr = nullptr;
-
-  UnresolvedSYCLKernelCallStmt(CompoundStmt *CS, Expr *IdExpr)
+  // Similar to KernelLaunchIdExpr HandleSYCLSpecialParamsIdExpr stores an
+  // UnresolvedLookupExpr or UnresolvedMemberExpr corresponding to the fuction
+  // handling of special SYCL kernel parameters for which a call will be
+  // synthesized during template instantiation of the device code.
+  Expr *HandleSYCLSpecialParamsIdExpr = nullptr;
+
+  UnresolvedSYCLKernelCallStmt(CompoundStmt *CS, Expr *IdExpr,
+                               Expr *HandleSYCLSpecialParamsIdExpr)
       : Stmt(UnresolvedSYCLKernelCallStmtClass), OriginalStmt(CS),
-        KernelLaunchIdExpr(IdExpr) {}
+        KernelLaunchIdExpr(IdExpr),
+        HandleSYCLSpecialParamsIdExpr(HandleSYCLSpecialParamsIdExpr) {}
 
   void setOriginalStmt(CompoundStmt *CS) { OriginalStmt = CS; }
 
@@ -118,12 +125,13 @@ class UnresolvedSYCLKernelCallStmt : public Stmt {
 
 public:
   static UnresolvedSYCLKernelCallStmt *Create(const ASTContext &C,
-                                              CompoundStmt *CS, Expr *IdExpr) {
-    return new (C) UnresolvedSYCLKernelCallStmt(CS, IdExpr);
+                                              CompoundStmt *CS, Expr *IdExpr,
+                                              Expr *SpecArgsExpr) {
+    return new (C) UnresolvedSYCLKernelCallStmt(CS, IdExpr, SpecArgsExpr);
   }
 
   static UnresolvedSYCLKernelCallStmt *CreateEmpty(const ASTContext &C) {
-    return new (C) UnresolvedSYCLKernelCallStmt(nullptr, nullptr);
+    return new (C) UnresolvedSYCLKernelCallStmt(nullptr, nullptr, nullptr);
   }
 
   CompoundStmt *getOriginalStmt() { return cast<CompoundStmt>(OriginalStmt); }
@@ -133,6 +141,10 @@ class UnresolvedSYCLKernelCallStmt : public Stmt {
 
   Expr *getKernelLaunchIdExpr() { return KernelLaunchIdExpr; }
   const Expr *getKernelLaunchIdExpr() const { return KernelLaunchIdExpr; }
+  Expr *getSpecArgsIdExpr() { return HandleSYCLSpecialParamsIdExpr; }
+  const Expr *getSpecArgsIdExpr() const {
+    return HandleSYCLSpecialParamsIdExpr;
+  }
 
   SourceLocation getBeginLoc() const LLVM_READONLY {
     return getOriginalStmt()->getBeginLoc();
diff --git a/clang/include/clang/Basic/Attr.td 
b/clang/include/clang/Basic/Attr.td
index 70b5773f95b08..5f991590638f1 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -1744,6 +1744,13 @@ def SYCLSpecialClass: InheritableAttr {
   let Documentation = [SYCLSpecialClassDocs];
 }
 
+def SYCLSpecialKernelParameter : InheritableAttr {
+  let Spellings = [CXX11<"clang", "sycl_special_kernel_parameter">];
+  let Subjects = SubjectList<[CXXRecord]>;
+  let LangOpts = [SYCLHost, SYCLDevice];
+  let Documentation = [Undocumented];
+}
+
 def C11NoReturn : InheritableAttr {
   let Spellings = [CustomKeyword<"_Noreturn">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
diff --git a/clang/include/clang/Sema/ScopeInfo.h 
b/clang/include/clang/Sema/ScopeInfo.h
index f334f58ebd0a7..1e76e3c676385 100644
--- a/clang/include/clang/Sema/ScopeInfo.h
+++ b/clang/include/clang/Sema/ScopeInfo.h
@@ -249,6 +249,10 @@ class FunctionScopeInfo {
   /// to a SYCL kernel launch function in a dependent context.
   Expr *SYCLKernelLaunchIdExpr = nullptr;
 
+  /// An unresolved identifier lookup expression for an implicit call
+  /// to a handling function for SYCL kernel special parameters.
+  Expr *HandleSYCLSpecialParamsIdExpr = nullptr;
+
 public:
   /// Represents a simple identification of a weak object.
   ///
diff --git a/clang/include/clang/Sema/SemaSYCL.h 
b/clang/include/clang/Sema/SemaSYCL.h
index 4980aa44c3012..268f31d8947cb 100644
--- a/clang/include/clang/Sema/SemaSYCL.h
+++ b/clang/include/clang/Sema/SemaSYCL.h
@@ -83,19 +83,22 @@ class SemaSYCL : public SemaBase {
   /// passed as the 'LaunchIdExpr' argument in a call to either
   /// BuildSYCLKernelCallStmt() or BuildUnresolvedSYCLKernelCallStmt() after
   /// the function body has been parsed.
-  ExprResult BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, QualType 
KernelName);
+  ExprResult BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, QualType KernelName,
+                                         StringRef FuncName);
 
   /// Builds a SYCLKernelCallStmt to wrap 'Body' and to be used as the body of
   /// 'FD'. 'LaunchIdExpr' specifies the lookup result returned by a previous
   /// call to BuildSYCLKernelLaunchIdExpr().
   StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body,
-                                     Expr *LaunchIdExpr);
+                                     Expr *LaunchIdExpr,
+                                     Expr *HandleSpecParamsExpr);
 
   /// Builds an UnresolvedSYCLKernelCallStmt to wrap 'Body'. 'LaunchIdExpr'
   /// specifies the lookup result returned by a previous call to
   /// BuildSYCLKernelLaunchIdExpr().
   StmtResult BuildUnresolvedSYCLKernelCallStmt(CompoundStmt *Body,
-                                               Expr *LaunchIdExpr);
+                                               Expr *LaunchIdExpr,
+                                               Expr *HandleSpecParamsExpr);
 };
 
 } // namespace clang
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index eb5b6d65b4d58..744b120a775a0 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -16466,7 +16466,7 @@ Decl *Sema::ActOnStartOfFunctionDef(Scope *FnBodyScope, 
Decl *D,
     const auto *SKEPAttr = FD->getAttr<SYCLKernelEntryPointAttr>();
     if (!SKEPAttr->isInvalidAttr()) {
       ExprResult LaunchIdExpr =
-          SYCL().BuildSYCLKernelLaunchIdExpr(FD, SKEPAttr->getKernelName());
+          SYCL().BuildSYCLKernelLaunchIdExpr(FD, SKEPAttr->getKernelName(), 
"sycl_kernel_launch");
       // Do not mark 'FD' as invalid if construction of `LaunchIDExpr` produces
       // an invalid result. Name lookup failure for 'sycl_kernel_launch' is
       // treated as an error in the definition of 'FD'; treating it as an error
@@ -16475,6 +16475,13 @@ Decl *Sema::ActOnStartOfFunctionDef(Scope 
*FnBodyScope, Decl *D,
       // 'LaunchIDExpr' failed, then 'SYCLKernelLaunchIdExpr' will be assigned
       // a null pointer value below; that is expected.
       getCurFunction()->SYCLKernelLaunchIdExpr = LaunchIdExpr.get();
+      if (!LaunchIdExpr.isInvalid() &&
+          !LaunchIdExpr.get()->getType()->isVoidType()) {
+        ExprResult HSPSPIdExpr = SYCL().BuildSYCLKernelLaunchIdExpr(
+            FD, SKEPAttr->getKernelName(),
+            "sycl_handle_special_kernel_parameters");
+        getCurFunction()->HandleSYCLSpecialParamsIdExpr = HSPSPIdExpr.get();
+      }
     }
   }
 
@@ -16690,7 +16697,8 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt 
*Body, bool IsInstantiation,
         // The function body should already be a SYCLKernelCallStmt in this
         // case, but might not be if there were previous errors.
         SR = Body;
-      } else if (!getCurFunction()->SYCLKernelLaunchIdExpr) {
+      } else if (!getCurFunction()->SYCLKernelLaunchIdExpr ||
+                 !getCurFunction()->HandleSYCLSpecialParamsIdExpr) {
         // If name lookup for a template named sycl_kernel_launch failed
         // earlier, don't try to build a SYCL kernel call statement as that
         // would cause additional errors to be issued; just proceed with the
@@ -16698,11 +16706,13 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt 
*Body, bool IsInstantiation,
         SR = Body;
       } else if (FD->isTemplated()) {
         SR = SYCL().BuildUnresolvedSYCLKernelCallStmt(
-            cast<CompoundStmt>(Body), 
getCurFunction()->SYCLKernelLaunchIdExpr);
+            cast<CompoundStmt>(Body), getCurFunction()->SYCLKernelLaunchIdExpr,
+            getCurFunction()->HandleSYCLSpecialParamsIdExpr);
       } else {
         SR = SYCL().BuildSYCLKernelCallStmt(
             FD, cast<CompoundStmt>(Body),
-            getCurFunction()->SYCLKernelLaunchIdExpr);
+            getCurFunction()->SYCLKernelLaunchIdExpr,
+            getCurFunction()->HandleSYCLSpecialParamsIdExpr);
       }
       // If construction of the replacement body fails, just continue with the
       // original function body. An early error return here is not valid; the
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 386651fa691e0..03f86ec1bf480 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7753,6 +7753,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, 
const ParsedAttr &AL,
   case ParsedAttr::AT_SYCLSpecialClass:
     handleSimpleAttribute<SYCLSpecialClassAttr>(S, D, AL);
     break;
+  case ParsedAttr::AT_SYCLSpecialKernelParameter:
+    handleSimpleAttribute<SYCLSpecialKernelParameterAttr>(S, D, AL);
+    break;
   case ParsedAttr::AT_Format:
     handleFormatAttr(S, D, AL);
     break;
diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp
index 112a6e4416df2..b3de40cfdd68c 100644
--- a/clang/lib/Sema/SemaSYCL.cpp
+++ b/clang/lib/Sema/SemaSYCL.cpp
@@ -425,7 +425,8 @@ void SemaSYCL::CheckSYCLEntryPointFunctionDecl(FunctionDecl 
*FD) {
 }
 
 ExprResult SemaSYCL::BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD,
-                                                 QualType KNT) {
+                                                 QualType KNT,
+                                                 StringRef FuncName) {
   // The current context must be the function definition context to ensure
   // that name lookup is performed within the correct scope.
   assert(SemaRef.CurContext == FD && "The current declaration context does not 
"
@@ -440,12 +441,13 @@ ExprResult 
SemaSYCL::BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD,
 
   ASTContext &Ctx = SemaRef.getASTContext();
   IdentifierInfo &SYCLKernelLaunchID =
-      Ctx.Idents.get("sycl_kernel_launch", tok::TokenKind::identifier);
+      Ctx.Idents.get(FuncName, tok::TokenKind::identifier);
 
   // Establish a code synthesis context for the implicit name lookup of
   // a template named 'sycl_kernel_launch'. In the event of an error, this
   // ensures an appropriate diagnostic note is issued to explain why the
   // lookup was performed.
+  // FIXME: Extend diagnostics for handle special parameters function
   Sema::CodeSynthesisContext CSC;
   CSC.Kind = Sema::CodeSynthesisContext::SYCLKernelLaunchLookup;
   CSC.Entity = FD;
@@ -492,16 +494,316 @@ ExprResult 
SemaSYCL::BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD,
   return IdExpr;
 }
 
+static bool isSyclSpecialType(QualType Ty) {
+  if (const auto *RT = Ty->getAsRecordDecl())
+    return RT->getMostRecentDecl()->hasAttr<SYCLSpecialKernelParameterAttr>();
+  return false;
+}
+
 namespace {
+/// A special visitor to visit subobjects within a type, i.e. fields of a
+/// class or elements of an array. Useful for SYCl because in SYCL kernels are
+/// defined via lambda expressions or named callable objects and kernel
+/// parameters are fields of these. These visitors will be used for diagnosing
+/// invalid kernel arugments as well as for functional transformations.
+class SubobjectVisitor {
+  ASTContext &Ctx;
+
+  // These enable handler execution only when previous Handlers succeed.
+  template <typename... Tn>
+  bool handleField(FieldDecl *FD, QualType FDTy, Tn &&...tn) {
+    bool result = true;
+    (void)std::initializer_list<int>{(result = result && tn(FD, FDTy), 0)...};
+    return result;
+  }
+  template <typename... Tn>
+  bool handleField(const CXXBaseSpecifier &BD, QualType BDTy, Tn &&...tn) {
+    bool result = true;
+    std::initializer_list<int>{(result = result && tn(BD, BDTy), 0)...};
+    return result;
+  }
+
+#define KF_FOR_EACH(FUNC, Item, Qt)                                            
\
+  handleField(Item, Qt, ([&](FieldDecl *FD, QualType FDTy) {                   
\
+                return Handlers.FUNC(FD, FDTy);                                
\
+              })...)
+
+  // Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter
+  // the Wrapper structure that we're currently visiting. Owner is the parent
+  // type (which doesn't exist in cases where it is a FieldDecl in the
+  // 'root'), and Wrapper is the current struct being unwrapped.
+  template <typename ParentTy, typename... HandlerTys>
+  void visitComplexRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
+                          const CXXRecordDecl *Wrapper, QualType RecordTy,
+                          HandlerTys &...Handlers) {
+    (void)std::initializer_list<int>{
+        (Handlers.enterStruct(Owner, Parent, RecordTy), 0)...};
+    visitRecordHelper(Wrapper, Wrapper->bases(), Handlers...);
+    visitRecordHelper(Wrapper, Wrapper->fields(), Handlers...);
+    (void)std::initializer_list<int>{
+        (Handlers.leaveStruct(Owner, Parent, RecordTy), 0)...};
+  }
+
+  template <typename... HandlerTys>
+  void visitArray(const CXXRecordDecl *Owner, FieldDecl *Field,
+                  QualType ArrayTy, HandlerTys &...Handlers) {
+    // TODO add support for simple array visiting, i.e. without entering array
+    // elements.
+    visitComplexArray(Owner, Field, ArrayTy, Handlers...);
+  }
+
+  template <typename ParentTy, typename... HandlerTys>
+  void visitRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
+                   const CXXRecordDecl *Wrapper, QualType RecordTy,
+                   HandlerTys &...Handlers) {
+    // TODO add support for simple record visiting, i.e. without entering 
record
+    // fields.
+    visitComplexRecord(Owner, Parent, Wrapper, RecordTy, Handlers...);
+  }
+
+  template <typename... HandlerTys>
+  void visitRecordHelper(const CXXRecordDecl *Owner,
+                         clang::CXXRecordDecl::base_class_const_range Range,
+                         HandlerTys &...Handlers) {
+    for (const auto &Base : Range) {
+      QualType BaseTy = Base.getType();
+      visitRecord(Owner, Base, BaseTy->getAsCXXRecordDecl(), BaseTy,
+                  Handlers...);
+    }
+  }
+
+  template <typename... HandlerTys>
+  void visitRecordHelper(const CXXRecordDecl *Owner, RecordDecl::field_range,
+                         HandlerTys &...Handlers) {
+    visitRecordFields(Owner, Handlers...);
+  }
+
+  template <typename... HandlerTys>
+  void visitArrayElementImpl(const CXXRecordDecl *Owner, FieldDecl *ArrayField,
+                             QualType ElementTy, uint64_t Index,
+                             HandlerTys &...Handlers) {
+    visitField(Owner, ArrayField, ElementTy, Handlers...);
+  }
+
+  template <typename... HandlerTys>
+  void visitNthArrayElement(const CXXRecordDecl *Owner, FieldDecl *ArrayField,
+                            QualType ElementTy, uint64_t Index,
+                            HandlerTys &...Handlers) {
+    visitArrayElementImpl(Owner, ArrayField, ElementTy, Index, Handlers...);
+  }
+
+  template <typename... HandlerTys>
+  void visitComplexArray(const CXXRecordDecl *Owner, FieldDecl *Field,
+                         QualType ArrayTy, HandlerTys &...Handlers) {
+    // Array workflow is:
+    // handleArrayType
+    // enterArray
+    // visitField (same as before, note that The FieldDecl is the of array
+    // itself, not the element)
+    // ... repeat per element, opt-out for duplicates.
+    // leaveArray
+
+    if (!KF_FOR_EACH(handleArrayType, Field, ArrayTy))
+      return;
+
+    const ConstantArrayType *CAT = Ctx.getAsConstantArrayType(ArrayTy);
+    assert(CAT && "Should only be called on constant-size array.");
+    QualType ET = CAT->getElementType();
+    uint64_t ElemCount = CAT->getSize().getZExtValue();
+
+    (void)std::initializer_list<int>{
+        (Handlers.enterArray(Field, ArrayTy, ET), 0)...};
+
+    for (uint64_t Index = 0; Index < ElemCount; ++Index)
+      visitNthArrayElement(Owner, Field, ET, Index, Handlers...);
+
+    (void)std::initializer_list<int>{
+        (Handlers.leaveArray(Field, ArrayTy, ET), 0)...};
+  }
+
+  template <typename... HandlerTys>
+  void visitField(const CXXRecordDecl *Owner, FieldDecl *Field,
+                  QualType FieldTy, HandlerTys &...Handlers) {
+    if (FieldTy->isStructureOrClassType()) {
+      if (KF_FOR_EACH(handleStructType, Field, FieldTy)) {
+        CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
+        visitRecord(Owner, Field, RD, FieldTy, Handlers...);
+      }
+    } else if (FieldTy->isUnionType())
+      KF_FOR_EACH(handleUnionType, Field, FieldTy);
+    else if (FieldTy->isReferenceType())
+      KF_FOR_EACH(handleReferenceType, Field, FieldTy);
+    else if (FieldTy->isPointerType())
+      KF_FOR_EACH(handlePointerType, Field, FieldTy);
+    else if (FieldTy->isArrayType())
+      visitArray(Owner, Field, FieldTy, Handlers...);
+    else if (FieldTy->isScalarType() || FieldTy->isVectorType())
+      KF_FOR_EACH(handleScalarType, Field, FieldTy);
+    else
+      KF_FOR_EACH(handleOtherType, Field, FieldTy);
+  }
+
+public:
+  SubobjectVisitor(ASTContext &C) : Ctx(C) {}
+
+  template <typename... HandlerTys>
+  void visitRecordBases(const CXXRecordDecl *KernelFunctor,
+                        HandlerTys &...Handlers) {
+    visitRecordHelper(KernelFunctor, KernelFunctor->bases(), Handlers...);
+  }
+
+  template <typename... HandlerTys>
+  void visitRecordFields(const CXXRecordDecl *Owner, HandlerTys &...Handlers) {
+    for (const auto Field : Owner->fields())
+      visitField(Owner, Field, Field->getType(), Handlers...);
+  }
+
+#undef KF_FOR_EACH
+};
+
+class SyclKernelFieldHandlerBase {
+public:
+  virtual bool handleStructType(FieldDecl *, QualType) { return true; }
+  virtual bool handleUnionType(FieldDecl *, QualType) { return true; }
+  virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
+  virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
+  virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
+  virtual bool handleScalarType(FieldDecl *, QualType) { return true; }
+  // Most handlers shouldn't be handling this, just the field checker.
+  virtual bool handleOtherType(FieldDecl *, QualType) { return true; }
+
+  virtual bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) {
+    return true;
+  }
+  virtual bool leaveStruct(const CXXRecordDecl *, FieldDecl *, QualType) {
+    return true;
+  }
+  virtual bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &,
+                           QualType) {
+    return true;
+  }
+  virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &,
+                           QualType) {
+    return true;
+  }
+  // The following are used for stepping th...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/196128
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to