tra updated this revision to Diff 75812.
tra added a comment.

- Instead of relying on the first attribute we find, check all matching ones.
- Specializations inherit their target attributes from their base template 
only. Their effective target always matches that of the template and is no 
longer affected by whether specialization differens from template in its 
constexpr-ness.


https://reviews.llvm.org/D25845

Files:
  include/clang/Sema/Sema.h
  lib/Sema/SemaCUDA.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/SemaTemplate.cpp
  test/SemaCUDA/function-template-overload.cu

Index: test/SemaCUDA/function-template-overload.cu
===================================================================
--- test/SemaCUDA/function-template-overload.cu
+++ test/SemaCUDA/function-template-overload.cu
@@ -31,7 +31,8 @@
 template  <> __host__ HType overload_h_d(long a); // OK. instantiates H
 
 
-// Can't overload HD template with H or D template, though functions are OK.
+// Can't overload HD template with H or D template, though
+// non-template functions are OK.
 template <typename T> __host__ __device__ HDType overload_hd(T a) { return HDType(); }
 // expected-note@-1 {{previous declaration is here}}
 // expected-note@-2 2 {{candidate template ignored: could not match 'HDType' against 'HType'}}
@@ -56,24 +57,54 @@
 template <typename T> __host__ __device__ HDType overload_h_d2(T a) { return HDType(); }
 template <typename T1, typename T2 = int> __device__ DType overload_h_d2(T1 a) { T1 x; T2 y; return DType(); }
 
+// constexpr functions are implicitly HD, but explicit
+// instantiation/specialization must use target attributes as written.
+template <typename T> constexpr T overload_ce_implicit_hd(T a) { return a+1; }
+// expected-note@-1 3 {{candidate template ignored: target attributes do not match}}
+
+// These will not match the template.
+template __host__ __device__ int overload_ce_implicit_hd(int a);
+// expected-error@-1 {{explicit instantiation of 'overload_ce_implicit_hd' does not refer to a function template, variable template, member function, member class, or static data member}}
+template <> __host__ __device__ long overload_ce_implicit_hd(long a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}}
+template <> __host__ __device__ constexpr long overload_ce_implicit_hd(long a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}}
+
+// These should work, because template matching ignores implicit HD
+// attributes compiler gives to constexpr functions/templates so
+// 'overload_ce_implicit_hd' template will match __host__ functions
+// only.
+template __host__ int overload_ce_implicit_hd(int a);
+template <> __host__ long overload_ce_implicit_hd(long a);
+
+template float overload_ce_implicit_hd(float a);
+template <> float* overload_ce_implicit_hd(float *a);
+template <> constexpr double overload_ce_implicit_hd(double a) { return a + 3.0; };
+
 __host__ void hf() {
   overload_hd(13);
+  overload_ce_implicit_hd('h');        // Implicitly instantiated
+  overload_ce_implicit_hd(1.0f);       // Explicitly instantiated
+  overload_ce_implicit_hd(2.0);        // Explicitly specialized
 
   HType h = overload_h_d(10);
   HType h2i = overload_h_d2<int>(11);
   HType h2ii = overload_h_d2<int>(12);
 
   // These should be implicitly instantiated from __host__ template returning HType.
-  DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}}
-  DType d2i = overload_h_d2<int>(21); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d = overload_h_d(20);          // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2i = overload_h_d2<int>(21);  // expected-error {{no viable conversion from 'HType' to 'DType'}}
   DType d2ii = overload_h_d2<int>(22); // expected-error {{no viable conversion from 'HType' to 'DType'}}
 }
 __device__ void df() {
   overload_hd(23);
+  overload_ce_implicit_hd('d');        // Implicitly instantiated
+  overload_ce_implicit_hd(1.0f);       // Explicitly instantiated
+  overload_ce_implicit_hd(2.0);        // Explicitly specialized
 
   // These should be implicitly instantiated from __device__ template returning DType.
-  HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}}
-  HType h2i = overload_h_d2<int>(11); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h = overload_h_d(10);          // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2i = overload_h_d2<int>(11);  // expected-error {{no viable conversion from 'DType' to 'HType'}}
   HType h2ii = overload_h_d2<int>(12); // expected-error {{no viable conversion from 'DType' to 'HType'}}
 
   DType d = overload_h_d(20);
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -7047,7 +7047,9 @@
       // target attributes into account, we perform target match check
       // here and reject candidates that have different target.
       if (LangOpts.CUDA &&
-          IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(FD)) {
+          IdentifyCUDATarget(Specialization,
+                             /* IgnoreImplicitHDAttributes = */ true) !=
+              IdentifyCUDATarget(FD, /* IgnoreImplicitHDAttributes = */ true)) {
         FailedCandidates.addCandidate().set(
             I.getPair(), FunTmpl->getTemplatedDecl(),
             MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
@@ -7164,6 +7166,12 @@
       SpecInfo->getTemplateSpecializationKind(),
       ExplicitTemplateArgs ? &ConvertedTemplateArgs[Specialization] : nullptr);
 
+  // Template may have implicit target attributes that specialization
+  // must inherit in order to have the same effective target as its
+  // template.
+  if (LangOpts.CUDA)
+    inheritCUDATargetAttrs(FD, Specialization);
+
   // The "previous declaration" for this function template specialization is
   // the prior function template specialization.
   Previous.clear();
@@ -8121,19 +8129,14 @@
     // target. Given that regular template deduction does not take it
     // into account, we perform target match check here and reject
     // candidates that have different target.
-    if (LangOpts.CUDA) {
-      CUDAFunctionTarget DeclaratorTarget = IdentifyCUDATarget(Attr);
-      // We need to adjust target when HD is forced by
-      // #pragma clang force_cuda_host_device
-      if (ForceCUDAHostDeviceDepth > 0 &&
-          (DeclaratorTarget == CFT_Device || DeclaratorTarget == CFT_Host))
-        DeclaratorTarget = CFT_HostDevice;
-      if (IdentifyCUDATarget(Specialization) != DeclaratorTarget) {
-        FailedCandidates.addCandidate().set(
-            P.getPair(), FunTmpl->getTemplatedDecl(),
-            MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
-        continue;
-      }
+    if (LangOpts.CUDA &&
+        IdentifyCUDATarget(Specialization,
+                           /* IgnoreImplicitHDAttributes = */ true) !=
+            IdentifyCUDATarget(Attr)) {
+      FailedCandidates.addCandidate().set(
+          P.getPair(), FunTmpl->getTemplatedDecl(),
+          MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
+      continue;
     }
 
     Matches.addDecl(Specialization, P.getAccess());
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -8263,9 +8263,6 @@
   // Handle attributes.
   ProcessDeclAttributes(S, NewFD, D);
 
-  if (getLangOpts().CUDA)
-    maybeAddCUDAHostDeviceAttrs(NewFD, Previous);
-
   if (getLangOpts().OpenCL) {
     // OpenCL v1.1 s6.5: Using an address space qualifier in a function return
     // type declaration will generate a compilation error.
@@ -8368,6 +8365,15 @@
       TemplateArgs.setRAngleLoc(D.getIdentifierLoc());
     }
 
+    // We do not add HD attributes to specializations here because
+    // they may have different constexpr-ness compared to their
+    // templates and, after maybeAddCUDAHostDeviceAttrs() is applied,
+    // may end up with different effective targets. Instead,
+    // specializations inherit target attributes from template in
+    // CheckFunctionTemplateSpecialization() call below.
+    if (getLangOpts().CUDA & !isFunctionTemplateSpecialization)
+      maybeAddCUDAHostDeviceAttrs(NewFD, Previous);
+
     // If it's a friend (and only if it's a friend), it's possible
     // that either the specialized function type or the specialized
     // template is dependent, and therefore matching will fail.  In
Index: lib/Sema/SemaCUDA.cpp
===================================================================
--- lib/Sema/SemaCUDA.cpp
+++ lib/Sema/SemaCUDA.cpp
@@ -93,8 +93,23 @@
   return CFT_Host;
 }
 
+template <typename A>
+static bool hasAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
+  if (!D->hasAttrs())
+    return false;
+  for (Attr *Attribute : D->getAttrs()) {
+    if (!isa<A>(Attribute))
+      continue;
+    if (IgnoreImplicitAttr && Attribute->isImplicit())
+      continue;
+    return true;
+  }
+  return false;
+}
+
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this function
-Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D) {
+Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
+                                                  bool IgnoreImplicitHDAttr) {
   // Code that lives outside a function is run on the host.
   if (D == nullptr)
     return CFT_Host;
@@ -105,13 +120,13 @@
   if (D->hasAttr<CUDAGlobalAttr>())
     return CFT_Global;
 
-  if (D->hasAttr<CUDADeviceAttr>()) {
-    if (D->hasAttr<CUDAHostAttr>())
+  if (hasAttr<CUDADeviceAttr>(D, IgnoreImplicitHDAttr)) {
+    if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr))
       return CFT_HostDevice;
     return CFT_Device;
-  } else if (D->hasAttr<CUDAHostAttr>()) {
+  } else if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
     return CFT_Host;
-  } else if (D->isImplicit()) {
+  } else if (D->isImplicit() && !IgnoreImplicitHDAttr) {
     // Some implicit declarations (like intrinsic functions) are not marked.
     // Set the most lenient target on them for maximal flexibility.
     return CFT_HostDevice;
@@ -874,3 +889,25 @@
     }
   }
 }
+
+void Sema::inheritCUDATargetAttrs(FunctionDecl *NewFD, FunctionDecl *OldFD) {
+  // Propagate CUDA target attributes from template to FD. This is
+  // needed to ensure that FD and its template have the same
+  // effective target.
+  if (CUDAGlobalAttr *Attr = OldFD->getAttr<CUDAGlobalAttr>()) {
+    CUDAGlobalAttr *Clone = Attr->clone(Context);
+    Clone->setInherited(true);
+    NewFD->addAttr(Clone);
+  } else {
+    if (CUDAHostAttr *Attr = OldFD->getAttr<CUDAHostAttr>()) {
+      CUDAHostAttr *Clone = Attr->clone(Context);
+      Clone->setInherited(true);
+      NewFD->addAttr(Clone);
+    }
+    if (CUDADeviceAttr *Attr = OldFD->getAttr<CUDADeviceAttr>()) {
+      CUDADeviceAttr *Clone = Attr->clone(Context);
+      Clone->setInherited(true);
+      NewFD->addAttr(Clone);
+    }
+  }
+}
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -9390,7 +9390,8 @@
   ///
   /// Use this rather than examining the function's attributes yourself -- you
   /// will get it wrong.  Returns CFT_Host if D is null.
-  CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D);
+  CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D,
+                                        bool IgnoreImplicitHDAttr = false);
   CUDAFunctionTarget IdentifyCUDATarget(const AttributeList *Attr);
 
   /// Gets the CUDA target for the current context.
@@ -9493,6 +9494,8 @@
   /// Check whether NewFD is a valid overload for CUDA. Emits
   /// diagnostics and invalidates NewFD if not.
   void checkCUDATargetOverload(FunctionDecl *NewFD, LookupResult &Previous);
+  /// Copies target attributes from OldFD to NewFD.
+  void inheritCUDATargetAttrs(FunctionDecl *NewFD, FunctionDecl *OldFD);
 
   /// \name Code completion
   //@{
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to