tra created this revision.
tra added reviewers: jlebar, rsmith.
tra added a subscriber: cfe-commits.

Some functions and templates are treated as `__host__` `__device__` even when 
they don't have explicitly specified target attributes.
What's worse, this treatment may change depending on command line options 
(-fno-cuda-host-device-constexpr) or `#pragma clang force_cuda_host_device`.

Combined with strict checking for matching function target that comes with 
https://reviews.llvm.org/D25809, it makes it hard to write code which would 
explicitly instantiate or specialize some functions regardless of pragmas or 
command line options in effect.

This patch changes the way we match target attributes of base template vs 
attributes used in explicit instantiation or specialization so that only 
explicitly specified attributes are considered.


https://reviews.llvm.org/D25845

Files:
  include/clang/Sema/Sema.h
  lib/Sema/SemaCUDA.cpp
  lib/Sema/SemaDeclAttr.cpp
  lib/Sema/SemaTemplate.cpp
  test/SemaCUDA/function-template-overload.cu

Index: test/SemaCUDA/function-template-overload.cu
===================================================================
--- test/SemaCUDA/function-template-overload.cu
+++ test/SemaCUDA/function-template-overload.cu
@@ -56,24 +56,51 @@
 template <typename T> __host__ __device__ HDType overload_h_d2(T a) { return HDType(); }
 template <typename T1, typename T2 = int> __device__ DType overload_h_d2(T1 a) { T1 x; T2 y; return DType(); }
 
+// constexpr functions are implicitly HD, but explicit
+// instantiation/specialization must use target attributes as written.
+template <typename T> constexpr T overload_ce_implicit_hd(T a) { return a+1; }
+// expected-note@-1 3 {{candidate template ignored: target attributes do not match}}
+
+// These will not match the template.
+template __host__ __device__ int overload_ce_implicit_hd(int a);
+// expected-error@-1 {{explicit instantiation of 'overload_ce_implicit_hd' does not refer to a function template, variable template, member function, member class, or static data member}}
+template <> __host__ __device__ long overload_ce_implicit_hd(long a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}}
+template <> __host__ __device__ constexpr long overload_ce_implicit_hd(long a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}}
+
+// These should work.
+template __host__ int overload_ce_implicit_hd(int a);
+template <> __host__ long overload_ce_implicit_hd(long a);
+
+template float overload_ce_implicit_hd(float a);
+template <> float* overload_ce_implicit_hd(float *a);
+template <> constexpr double overload_ce_implicit_hd(double a) { return a + 3.0; };
+
 __host__ void hf() {
   overload_hd(13);
+  overload_ce_implicit_hd('h');        // Implicitly instantiated
+  overload_ce_implicit_hd(1.0f);       // Explicitly instantiated
+  overload_ce_implicit_hd(2.0);        // Explicitly specialized
 
   HType h = overload_h_d(10);
   HType h2i = overload_h_d2<int>(11);
   HType h2ii = overload_h_d2<int>(12);
 
   // These should be implicitly instantiated from __host__ template returning HType.
-  DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}}
-  DType d2i = overload_h_d2<int>(21); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d = overload_h_d(20);          // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2i = overload_h_d2<int>(21);  // expected-error {{no viable conversion from 'HType' to 'DType'}}
   DType d2ii = overload_h_d2<int>(22); // expected-error {{no viable conversion from 'HType' to 'DType'}}
 }
 __device__ void df() {
   overload_hd(23);
+  overload_ce_implicit_hd('d');        // Implicitly instantiated
+  overload_ce_implicit_hd(1.0f);       // Explicitly instantiated
+  overload_ce_implicit_hd(2.0);        // Explicitly specialized
 
   // These should be implicitly instantiated from __device__ template returning DType.
-  HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}}
-  HType h2i = overload_h_d2<int>(11); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h = overload_h_d(10);          // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2i = overload_h_d2<int>(11);  // expected-error {{no viable conversion from 'DType' to 'HType'}}
   HType h2ii = overload_h_d2<int>(12); // expected-error {{no viable conversion from 'DType' to 'HType'}}
 
   DType d = overload_h_d(20);
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -7043,13 +7043,13 @@
 
       // Filter out matches that have different target.
       if (LangOpts.CUDA &&
-          IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(FD)) {
+          IdentifyCUDATarget(Specialization, true) !=
+              IdentifyCUDATarget(FD, true)) {
         FailedCandidates.addCandidate().set(
             I.getPair(), FunTmpl->getTemplatedDecl(),
             MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
         continue;
       }
-
       // Record this candidate.
       if (ExplicitTemplateArgs)
         ConvertedTemplateArgs[Specialization] = std::move(Args);
@@ -7164,6 +7164,8 @@
   // the prior function template specialization.
   Previous.clear();
   Previous.addDecl(Specialization);
+  if (LangOpts.CUDA)
+    mergeCUDATargetAttributes(FD, Specialization);
   return false;
 }
 
@@ -8114,7 +8116,7 @@
 
     // Filter out matches that have different target.
     if (LangOpts.CUDA &&
-        IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(Attr)) {
+        IdentifyCUDATarget(Specialization, true) != IdentifyCUDATarget(Attr)) {
       FailedCandidates.addCandidate().set(
           P.getPair(), FunTmpl->getTemplatedDecl(),
           MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
Index: lib/Sema/SemaDeclAttr.cpp
===================================================================
--- lib/Sema/SemaDeclAttr.cpp
+++ lib/Sema/SemaDeclAttr.cpp
@@ -5616,15 +5616,18 @@
     handleFormatArgAttr(S, D, Attr);
     break;
   case AttributeList::AT_CUDAGlobal:
-    handleGlobalAttr(S, D, Attr);
+    if (!D->hasAttr<CUDAGlobalAttr>())
+      handleGlobalAttr(S, D, Attr);
     break;
   case AttributeList::AT_CUDADevice:
-    handleSimpleAttributeWithExclusions<CUDADeviceAttr, CUDAGlobalAttr>(S, D,
-                                                                        Attr);
+    if (!D->hasAttr<CUDADeviceAttr>())
+      handleSimpleAttributeWithExclusions<CUDADeviceAttr, CUDAGlobalAttr>(S, D,
+                                                                          Attr);
     break;
   case AttributeList::AT_CUDAHost:
-    handleSimpleAttributeWithExclusions<CUDAHostAttr, CUDAGlobalAttr>(S, D,
-                                                                      Attr);
+    if (!D->hasAttr<CUDAHostAttr>())
+      handleSimpleAttributeWithExclusions<CUDAHostAttr, CUDAGlobalAttr>(S, D,
+                                                                        Attr);
     break;
   case AttributeList::AT_GNUInline:
     handleGNUInlineAttr(S, D, Attr);
Index: lib/Sema/SemaCUDA.cpp
===================================================================
--- lib/Sema/SemaCUDA.cpp
+++ lib/Sema/SemaCUDA.cpp
@@ -84,17 +84,28 @@
   if (HasGlobalAttr)
     return CFT_Global;
 
-  if ((HasHostAttr && HasDeviceAttr) || ForceCUDAHostDeviceDepth > 0)
+  if (HasHostAttr && HasDeviceAttr)
     return CFT_HostDevice;
 
   if (HasDeviceAttr)
     return CFT_Device;
 
   return CFT_Host;
 }
 
+template <typename A>
+static bool getAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
+  if (Attr *Attribute = D->getAttr<A>()) {
+    if (IgnoreImplicitAttr && Attribute->isImplicit())
+      return false;
+    return true;
+  }
+  return false;
+}
+
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this function
-Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D) {
+Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
+                                                  bool IgnoreImplicitHDAttr) {
   // Code that lives outside a function is run on the host.
   if (D == nullptr)
     return CFT_Host;
@@ -105,13 +116,13 @@
   if (D->hasAttr<CUDAGlobalAttr>())
     return CFT_Global;
 
-  if (D->hasAttr<CUDADeviceAttr>()) {
-    if (D->hasAttr<CUDAHostAttr>())
+  if (getAttr<CUDADeviceAttr>(D, IgnoreImplicitHDAttr)) {
+    if (getAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr))
       return CFT_HostDevice;
     return CFT_Device;
-  } else if (D->hasAttr<CUDAHostAttr>()) {
+  } else if (getAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
     return CFT_Host;
-  } else if (D->isImplicit()) {
+  } else if (D->isImplicit() && !IgnoreImplicitHDAttr) {
     // Some implicit declarations (like intrinsic functions) are not marked.
     // Set the most lenient target on them for maximal flexibility.
     return CFT_HostDevice;
@@ -523,8 +534,10 @@
     return;
   }
 
-  NewD->addAttr(CUDAHostAttr::CreateImplicit(Context));
-  NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context));
+  if (!NewD->hasAttr<CUDAHostAttr>())
+    NewD->addAttr(CUDAHostAttr::CreateImplicit(Context));
+  if (!NewD->hasAttr<CUDADeviceAttr>())
+    NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context));
 }
 
 // In CUDA, there are some constructs which may appear in semantically-valid
@@ -867,3 +880,21 @@
     }
   }
 }
+
+void Sema::mergeCUDATargetAttributes(NamedDecl *New, Decl *Old) {
+  if (auto *OldAttr = Old->getMostRecentDecl()->getAttr<CUDADeviceAttr>()) {
+    auto *NewAttr = OldAttr->clone(Context);
+    NewAttr->setInherited(true);
+    New->addAttr(NewAttr);
+  }
+  if (auto *OldAttr = Old->getMostRecentDecl()->getAttr<CUDAHostAttr>()) {
+    auto *NewAttr = OldAttr->clone(Context);
+    NewAttr->setInherited(true);
+    New->addAttr(NewAttr);
+  }
+  if (auto *OldAttr = Old->getMostRecentDecl()->getAttr<CUDAGlobalAttr>()) {
+    auto *NewAttr = OldAttr->clone(Context);
+    NewAttr->setInherited(true);
+    New->addAttr(NewAttr);
+  }
+}
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -9403,7 +9403,8 @@
   ///
   /// Use this rather than examining the function's attributes yourself -- you
   /// will get it wrong.  Returns CFT_Host if D is null.
-  CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D);
+  CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D,
+                                        bool IgnoreImplicitHDAttr = false);
   CUDAFunctionTarget IdentifyCUDATarget(const AttributeList *Attr);
 
   /// Gets the CUDA target for the current context.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to