================
@@ -665,6 +666,99 @@ OutlinedFunctionDecl 
*BuildSYCLKernelEntryPointOutline(Sema &SemaRef,
   return OFD;
 }
 
+class KernelParamsChecker : public ConstSubobjectVisitor<KernelParamsChecker> {
+  SemaSYCL &SemaSYCLRef;
+  bool IsValid = true;
+  using ObjectAccess =
+      llvm::PointerUnion<const ParmVarDecl *, const CXXBaseSpecifier *,
+                         const FieldDecl *>;
+  SmallVector<ObjectAccess, 4> ObjectAccessPath;
+
+  void emitObjectAccessPathNotes() {
+    for (auto Parent : ObjectAccessPath) {
+      if (auto *FD = Parent.dyn_cast<const FieldDecl *>()) {
+        SemaSYCLRef.Diag(FD->getParent()->getLocation(),
+                         diag::note_within_field_of_type)
+            << FD->getParent();
+      } else if (auto *BS = Parent.dyn_cast<const CXXBaseSpecifier *>()) {
+        CXXRecordDecl *RD = BS->getType()->getAsCXXRecordDecl();
+        assert(RD);
+        SemaSYCLRef.Diag(BS->getBeginLoc(), diag::note_within_base_of_type)
+            << RD;
+      } else {
+        // Nothing to emit for ParmVarDecl since its location just points to
+        // skep-attributed function template.
+        assert(isa<const ParmVarDecl *>(Parent));
+      }
+    }
+  }
+
+public:
+  KernelParamsChecker(SemaSYCL &SR, SourceLocation Loc)
+      : ConstSubobjectVisitor<KernelParamsChecker>(SR.getASTContext()),
+        SemaSYCLRef(SR) {}
+
+  void checkParameter(const ParmVarDecl *PVD) {
+    ObjectAccessPath.push_back(PVD);
+    // Check the immediate type of the parameter.
+    if (checkType(PVD->getType())) {
+      // If type checking wasn't short circuited, visit subobjects to check
+      // them.
+      visit(PVD->getType());
+    }
+    ObjectAccessPath.pop_back();
+    assert(ObjectAccessPath.empty());
+  }
+
+  bool visitBaseSpecifierPre(const CXXBaseSpecifier *BS) {
+    ObjectAccessPath.push_back(BS);
+    return checkType(BS->getType());
+  }
+
+  bool visitFieldDeclPre(const FieldDecl *FD) {
+    ObjectAccessPath.push_back(FD);
+    return checkType(FD->getType());
+  }
+
+  // Returns true if subobjects should be visited and false otherwise.
+  bool checkType(QualType Ty) {
+    if (Ty->isReferenceType()) {
+      auto DirectParent = ObjectAccessPath.back();
+      // Reference cannot be a base, so just assume we came via a FieldDecl.
+      if (isa<const ParmVarDecl *>(DirectParent)) {
+        // If reference is a kernel parameter, there is nothing to do. We allow
+        // references in direct kernel parameters for better performance of the
+        // host code and we eliminate them when building actual kernel.
+        return true;
+      }
+
+      auto *DirectFieldParent = cast<const FieldDecl *>(DirectParent);
+      SemaSYCLRef.Diag(DirectFieldParent->getLocation(),
+                       diag::err_bad_kernel_param_type)
+          << DirectFieldParent->getType();
+      emitObjectAccessPathNotes();
+
+      IsValid = false;
+      return false;
----------------
tahonermann wrote:

I think skipping further visitation is the right choice here, but perhaps that 
is worth a comment.

```suggestion
      // Don't visit the type of the reference since any further invalid
      // kernel parameter types contained within the referenced type
      // might not be relevant once the programmer addresses the
      // invalid use of a reference.
      IsValid = false;
      return false;
```

https://github.com/llvm/llvm-project/pull/192957
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to