tra updated this revision to Diff 75482.
tra added a comment.

Added a comment explaining expected constexpr function template matching 
behavior.


https://reviews.llvm.org/D25845

Files:
  include/clang/Sema/Sema.h
  lib/Sema/SemaCUDA.cpp
  lib/Sema/SemaDeclAttr.cpp
  lib/Sema/SemaTemplate.cpp
  test/SemaCUDA/function-template-overload.cu

Index: test/SemaCUDA/function-template-overload.cu
===================================================================
--- test/SemaCUDA/function-template-overload.cu
+++ test/SemaCUDA/function-template-overload.cu
@@ -56,24 +56,54 @@
 template <typename T> __host__ __device__ HDType overload_h_d2(T a) { return HDType(); }
 template <typename T1, typename T2 = int> __device__ DType overload_h_d2(T1 a) { T1 x; T2 y; return DType(); }
 
+// constexpr functions are implicitly HD, but explicit
+// instantiation/specialization must use target attributes as written.
+template <typename T> constexpr T overload_ce_implicit_hd(T a) { return a+1; }
+// expected-note@-1 3 {{candidate template ignored: target attributes do not match}}
+
+// These will not match the template.
+template __host__ __device__ int overload_ce_implicit_hd(int a);
+// expected-error@-1 {{explicit instantiation of 'overload_ce_implicit_hd' does not refer to a function template, variable template, member function, member class, or static data member}}
+template <> __host__ __device__ long overload_ce_implicit_hd(long a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}}
+template <> __host__ __device__ constexpr long overload_ce_implicit_hd(long a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}}
+
+// These should work, because template matching ignores implicit HD
+// attributes compiler gives to constexpr functions/templates so
+// 'overload_ce_implicit_hd' template will match __host__ functions
+// only.
+template __host__ int overload_ce_implicit_hd(int a);
+template <> __host__ long overload_ce_implicit_hd(long a);
+
+template float overload_ce_implicit_hd(float a);
+template <> float* overload_ce_implicit_hd(float *a);
+template <> constexpr double overload_ce_implicit_hd(double a) { return a + 3.0; };
+
 __host__ void hf() {
   overload_hd(13);
+  overload_ce_implicit_hd('h');        // Implicitly instantiated
+  overload_ce_implicit_hd(1.0f);       // Explicitly instantiated
+  overload_ce_implicit_hd(2.0);        // Explicitly specialized
 
   HType h = overload_h_d(10);
   HType h2i = overload_h_d2<int>(11);
   HType h2ii = overload_h_d2<int>(12);
 
   // These should be implicitly instantiated from __host__ template returning HType.
-  DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}}
-  DType d2i = overload_h_d2<int>(21); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d = overload_h_d(20);          // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2i = overload_h_d2<int>(21);  // expected-error {{no viable conversion from 'HType' to 'DType'}}
   DType d2ii = overload_h_d2<int>(22); // expected-error {{no viable conversion from 'HType' to 'DType'}}
 }
 __device__ void df() {
   overload_hd(23);
+  overload_ce_implicit_hd('d');        // Implicitly instantiated
+  overload_ce_implicit_hd(1.0f);       // Explicitly instantiated
+  overload_ce_implicit_hd(2.0);        // Explicitly specialized
 
   // These should be implicitly instantiated from __device__ template returning DType.
-  HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}}
-  HType h2i = overload_h_d2<int>(11); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h = overload_h_d(10);          // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2i = overload_h_d2<int>(11);  // expected-error {{no viable conversion from 'DType' to 'HType'}}
   HType h2ii = overload_h_d2<int>(12); // expected-error {{no viable conversion from 'DType' to 'HType'}}
 
   DType d = overload_h_d(20);
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -7043,13 +7043,13 @@
 
       // Filter out matches that have different target.
       if (LangOpts.CUDA &&
-          IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(FD)) {
+          IdentifyCUDATarget(Specialization, true) !=
+              IdentifyCUDATarget(FD, true)) {
         FailedCandidates.addCandidate().set(
             I.getPair(), FunTmpl->getTemplatedDecl(),
             MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
         continue;
       }
-
       // Record this candidate.
       if (ExplicitTemplateArgs)
         ConvertedTemplateArgs[Specialization] = std::move(Args);
@@ -7164,6 +7164,8 @@
   // the prior function template specialization.
   Previous.clear();
   Previous.addDecl(Specialization);
+  if (LangOpts.CUDA)
+    mergeCUDATargetAttributes(FD, Specialization);
   return false;
 }
 
@@ -8114,7 +8116,7 @@
 
     // Filter out matches that have different target.
     if (LangOpts.CUDA &&
-        IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(Attr)) {
+        IdentifyCUDATarget(Specialization, true) != IdentifyCUDATarget(Attr)) {
       FailedCandidates.addCandidate().set(
           P.getPair(), FunTmpl->getTemplatedDecl(),
           MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
Index: lib/Sema/SemaDeclAttr.cpp
===================================================================
--- lib/Sema/SemaDeclAttr.cpp
+++ lib/Sema/SemaDeclAttr.cpp
@@ -5616,15 +5616,18 @@
     handleFormatArgAttr(S, D, Attr);
     break;
   case AttributeList::AT_CUDAGlobal:
-    handleGlobalAttr(S, D, Attr);
+    if (!D->hasAttr<CUDAGlobalAttr>())
+      handleGlobalAttr(S, D, Attr);
     break;
   case AttributeList::AT_CUDADevice:
-    handleSimpleAttributeWithExclusions<CUDADeviceAttr, CUDAGlobalAttr>(S, D,
-                                                                        Attr);
+    if (!D->hasAttr<CUDADeviceAttr>())
+      handleSimpleAttributeWithExclusions<CUDADeviceAttr, CUDAGlobalAttr>(S, D,
+                                                                          Attr);
     break;
   case AttributeList::AT_CUDAHost:
-    handleSimpleAttributeWithExclusions<CUDAHostAttr, CUDAGlobalAttr>(S, D,
-                                                                      Attr);
+    if (!D->hasAttr<CUDAHostAttr>())
+      handleSimpleAttributeWithExclusions<CUDAHostAttr, CUDAGlobalAttr>(S, D,
+                                                                        Attr);
     break;
   case AttributeList::AT_GNUInline:
     handleGNUInlineAttr(S, D, Attr);
Index: lib/Sema/SemaCUDA.cpp
===================================================================
--- lib/Sema/SemaCUDA.cpp
+++ lib/Sema/SemaCUDA.cpp
@@ -84,17 +84,28 @@
   if (HasGlobalAttr)
     return CFT_Global;
 
-  if ((HasHostAttr && HasDeviceAttr) || ForceCUDAHostDeviceDepth > 0)
+  if (HasHostAttr && HasDeviceAttr)
     return CFT_HostDevice;
 
   if (HasDeviceAttr)
     return CFT_Device;
 
   return CFT_Host;
 }
 
+template <typename A>
+static bool getAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
+  if (Attr *Attribute = D->getAttr<A>()) {
+    if (IgnoreImplicitAttr && Attribute->isImplicit())
+      return false;
+    return true;
+  }
+  return false;
+}
+
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this function
-Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D) {
+Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
+                                                  bool IgnoreImplicitHDAttr) {
   // Code that lives outside a function is run on the host.
   if (D == nullptr)
     return CFT_Host;
@@ -105,13 +116,13 @@
   if (D->hasAttr<CUDAGlobalAttr>())
     return CFT_Global;
 
-  if (D->hasAttr<CUDADeviceAttr>()) {
-    if (D->hasAttr<CUDAHostAttr>())
+  if (getAttr<CUDADeviceAttr>(D, IgnoreImplicitHDAttr)) {
+    if (getAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr))
       return CFT_HostDevice;
     return CFT_Device;
-  } else if (D->hasAttr<CUDAHostAttr>()) {
+  } else if (getAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
     return CFT_Host;
-  } else if (D->isImplicit()) {
+  } else if (D->isImplicit() && !IgnoreImplicitHDAttr) {
     // Some implicit declarations (like intrinsic functions) are not marked.
     // Set the most lenient target on them for maximal flexibility.
     return CFT_HostDevice;
@@ -523,8 +534,10 @@
     return;
   }
 
-  NewD->addAttr(CUDAHostAttr::CreateImplicit(Context));
-  NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context));
+  if (!NewD->hasAttr<CUDAHostAttr>())
+    NewD->addAttr(CUDAHostAttr::CreateImplicit(Context));
+  if (!NewD->hasAttr<CUDADeviceAttr>())
+    NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context));
 }
 
 // In CUDA, there are some constructs which may appear in semantically-valid
@@ -867,3 +880,21 @@
     }
   }
 }
+
+void Sema::mergeCUDATargetAttributes(NamedDecl *New, Decl *Old) {
+  if (auto *OldAttr = Old->getMostRecentDecl()->getAttr<CUDADeviceAttr>()) {
+    auto *NewAttr = OldAttr->clone(Context);
+    NewAttr->setInherited(true);
+    New->addAttr(NewAttr);
+  }
+  if (auto *OldAttr = Old->getMostRecentDecl()->getAttr<CUDAHostAttr>()) {
+    auto *NewAttr = OldAttr->clone(Context);
+    NewAttr->setInherited(true);
+    New->addAttr(NewAttr);
+  }
+  if (auto *OldAttr = Old->getMostRecentDecl()->getAttr<CUDAGlobalAttr>()) {
+    auto *NewAttr = OldAttr->clone(Context);
+    NewAttr->setInherited(true);
+    New->addAttr(NewAttr);
+  }
+}
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -9392,7 +9392,8 @@
   ///
   /// Use this rather than examining the function's attributes yourself -- you
   /// will get it wrong.  Returns CFT_Host if D is null.
-  CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D);
+  CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D,
+                                        bool IgnoreImplicitHDAttr = false);
   CUDAFunctionTarget IdentifyCUDATarget(const AttributeList *Attr);
 
   /// Gets the CUDA target for the current context.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to