https://github.com/NewSigma updated 
https://github.com/llvm/llvm-project/pull/200662

>From 65e8bdeb7207df30e0df1f182e4be2067004b516 Mon Sep 17 00:00:00 2001
From: NewSigma <[email protected]>
Date: Mon, 8 Jun 2026 11:00:49 +0800
Subject: [PATCH] [clang][Sema][CUDA] Restrict immediate template resolution to
 host-device functions

---
 clang/docs/ReleaseNotes.rst         |  1 +
 clang/include/clang/Sema/Overload.h | 18 +------
 clang/lib/Sema/SemaOverload.cpp     | 30 ++++++++++--
 clang/test/SemaCUDA/pr200545.cu     | 73 +++++++++++++++++++++++++++++
 4 files changed, 101 insertions(+), 21 deletions(-)
 create mode 100644 clang/test/SemaCUDA/pr200545.cu

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index f97e90634396a..4ac3f6db2c732 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -665,6 +665,7 @@ Bug Fixes in This Version
   an array via an element-at-a-time copy loop (#GH192026)
 - Fixed an issue where certain designated initializers would be rejected for 
constexpr variables. (#GH193373)
 - Fixed a crash when ``#embed`` is used with C++ modules (#GH195350)
+- Fixed a bug where ``-x cuda`` caused clang to immediately resolve templates 
that should not be. (#GH200545)
 - Fixed an issue where ``__typeof_unqual`` and ``__typeof_unqual__`` were 
rejected as a declaration specifier in block scope in C++.
 - Fixed crash when checking for overflow for unary operator that can't 
overflow (#GH170072)
 
diff --git a/clang/include/clang/Sema/Overload.h 
b/clang/include/clang/Sema/Overload.h
index d42963e325b58..1e412ff6fc9e2 100644
--- a/clang/include/clang/Sema/Overload.h
+++ b/clang/include/clang/Sema/Overload.h
@@ -1353,7 +1353,7 @@ class Sema;
     bool shouldDeferDiags(Sema &S, ArrayRef<Expr *> Args, SourceLocation 
OpLoc);
 
     // Whether the resolution of template candidates should be deferred
-    bool shouldDeferTemplateArgumentDeduction(const LangOptions &Opts) const;
+    bool shouldDeferTemplateArgumentDeduction(const Sema &S) const;
 
     /// Determine when this overload candidate will be new to the
     /// overload set.
@@ -1545,22 +1545,6 @@ class Sema;
   // good candidate as we can get, despite the fact that it takes one less
   // parameter.
   bool shouldEnforceArgLimit(bool PartialOverloading, FunctionDecl *Function);
-
-  inline bool OverloadCandidateSet::shouldDeferTemplateArgumentDeduction(
-      const LangOptions &Opts) const {
-    return
-        // For user defined conversion we need to check against different
-        // combination of CV qualifiers and look at any explicit specifier, so
-        // always deduce template candidates.
-        Kind != CSK_InitByUserDefinedConversion
-        // When doing code completion, we want to see all the
-        // viable candidates.
-        && Kind != CSK_CodeCompletion
-        // CUDA may prefer template candidates even when a non-candidate
-        // is a perfect match
-        && !Opts.CUDA;
-  }
-
 } // namespace clang
 
 #endif // LLVM_CLANG_SEMA_OVERLOAD_H
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index d53fd726e9f0b..0ad938568bce2 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -8192,7 +8192,7 @@ void Sema::AddMethodTemplateCandidate(
     return;
 
   if (ExplicitTemplateArgs ||
-      !CandidateSet.shouldDeferTemplateArgumentDeduction(getLangOpts())) {
+      !CandidateSet.shouldDeferTemplateArgumentDeduction(*this)) {
     AddMethodTemplateCandidateImmediately(
         *this, CandidateSet, MethodTmpl, FoundDecl, ActingContext,
         ExplicitTemplateArgs, ObjectType, ObjectClassification, Args,
@@ -8322,7 +8322,7 @@ void Sema::AddTemplateOverloadCandidate(
   bool DependentExplicitSpecifier = hasDependentExplicit(FunctionTemplate);
 
   if (ExplicitTemplateArgs ||
-      !CandidateSet.shouldDeferTemplateArgumentDeduction(getLangOpts()) ||
+      !CandidateSet.shouldDeferTemplateArgumentDeduction(*this) ||
       (isa<CXXConstructorDecl>(FunctionTemplate->getTemplatedDecl()) &&
        DependentExplicitSpecifier)) {
 
@@ -8760,7 +8760,7 @@ void Sema::AddTemplateConversionCandidate(
   if (!CandidateSet.isNewCandidate(FunctionTemplate))
     return;
 
-  if (!CandidateSet.shouldDeferTemplateArgumentDeduction(getLangOpts()) ||
+  if (!CandidateSet.shouldDeferTemplateArgumentDeduction(*this) ||
       CandidateSet.getKind() ==
           OverloadCandidateSet::CSK_InitByUserDefinedConversion ||
       CandidateSet.getKind() == OverloadCandidateSet::CSK_InitByConstructor) {
@@ -11581,7 +11581,7 @@ OverloadingResult 
OverloadCandidateSet::BestViableFunction(Sema &S,
                                                            SourceLocation Loc,
                                                            iterator &Best) {
 
-  assert((shouldDeferTemplateArgumentDeduction(S.getLangOpts()) ||
+  assert((shouldDeferTemplateArgumentDeduction(S) ||
           DeferredCandidatesCount == 0) &&
          "Unexpected deferred template candidates");
 
@@ -13533,6 +13533,28 @@ void OverloadCandidateSet::NoteCandidates(Sema &S, 
ArrayRef<Expr *> Args,
   }
 }
 
+bool OverloadCandidateSet::shouldDeferTemplateArgumentDeduction(
+    const Sema &S) const {
+  if (S.getLangOpts().CUDA) {
+    auto *Caller = S.getCurFunctionDecl(true);
+    // Overloading based on __host__ and __device__ attributes takes
+    // higher priority, HD functions may favor template candidates even when a
+    // non-template candidate would be a perfect match.
+    if (Caller && Caller->hasAttr<CUDAHostAttr>() &&
+        Caller->hasAttr<CUDADeviceAttr>())
+      return false;
+  }
+
+  return
+      // For user defined conversion we need to check against different
+      // combination of CV qualifiers and look at any explicit specifier, so
+      // always deduce template candidates.
+      Kind != CSK_InitByUserDefinedConversion
+      // When doing code completion, we want to see all the
+      // viable candidates.
+      && Kind != CSK_CodeCompletion;
+}
+
 static SourceLocation
 GetLocationForCandidate(const TemplateSpecCandidate *Cand) {
   return Cand->Specialization ? Cand->Specialization->getLocation()
diff --git a/clang/test/SemaCUDA/pr200545.cu b/clang/test/SemaCUDA/pr200545.cu
new file mode 100644
index 0000000000000..97401d1fefccf
--- /dev/null
+++ b/clang/test/SemaCUDA/pr200545.cu
@@ -0,0 +1,73 @@
+// Test that template argument deduction is deferred correctly.
+//
+// RUN: %clang_cc1 -std=c++20 -fsyntax-only -verify 
-verify-ignore-unexpected=note %s
+
+#include "Inputs/cuda.h"
+
+namespace h_free_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  void fn(int) {}
+  void fn(DoNotDeduct auto) {}
+
+  void call() {
+    fn(0);
+    fn(nullptr); // expected-error@-9 {{static assertion failed due to 
requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace h_member_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  struct A {
+    void operator=(int) {}
+    void operator=(DoNotDeduct auto) {}
+  };
+
+  void call(A a) {
+    a.operator=(0);
+    a.operator=(nullptr); // expected-error@-11 {{static assertion failed due 
to requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace hd_free_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  __host__ __device__ void fn(int) {}
+  __host__ __device__ void fn(DoNotDeduct auto) {}
+
+  __host__ __device__ void call() {
+    fn(0); // expected-error@-8 {{static assertion failed due to requirement 
'sizeof(int) == 0'}}
+    fn(nullptr); // expected-error@-9 {{static assertion failed due to 
requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace hd_member_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  struct A {
+    __host__ __device__ void operator=(int) {}
+    __host__ __device__ void operator=(DoNotDeduct auto) {}
+  };
+
+  __host__ __device__ void call(A a) {
+    a.operator=(0); // expected-error@-10 {{static assertion failed due to 
requirement 'sizeof(int) == 0'}}
+    a.operator=(nullptr); // expected-error@-11 {{static assertion failed due 
to requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to