tra created this revision.
tra added reviewers: jlebar, rsmith.
tra added a subscriber: cfe-commits.

Current behavior:

- __host__ __device__ (HD) functions are considered to be redeclarations of 
`__host__` (H) of `__device__` (D) functions with same signature.
- Target attributes are not taken into account during selection of function 
template during explicit instantiation and specialization.

Issues:
a) It's possible for a H or D function to inherit HD attributes from a HD 
declaration which results in those functions being silently treated as HD in 
the rest of the code which leads to clang accepting the code instead of 
diagnosing it as an error.
b) If we have definitions of HD and a H or D function, compiler complains about 
redefinition of the same function, which is misleading.
c) It is impossible to disambiguate across target-overloaded function templates 
during explicit instantiation/specialization.

Changes in this patch:
a) treat HD functions as overloads and add Sema checks to explicitly diagnose 
attempts to overload HD functions with H or D ones.
b) Require matching target attributes for explicit function template 
instantiation/specialization and narrow list of candidates to templates with 
the same target. Diagnose rejected candidates.

The changes (a) and (b) can be split into separate patches, but both must be 
committed simultaneously as
(a) alone further breaks function template instantiation/specialization when 
target attributes are involved and (b) is half-broken until (a) is in place and 
prevents HD attributes merging into functions with H or D attributes.

Open issues:

- It's not clear how to handle explicit specialization of constexpr function 
templates:
  - Implicit target attributes of constexpr functions and templates change 
depending on whether -fno-cuda-host-device-constexpr is in effect.
  - C++11 [dcl.constexpr]p1: An explicit specialization of a constexpr function 
can differ from the template declaration with respect to the constexpr 
specifier.

One idea is to only match explicitly written target attributes when we choose 
candidate templates. This makes it easier to tell which template we instantiate 
based only on what's in the source we compile.


https://reviews.llvm.org/D25809

Files:
  include/clang/Basic/DiagnosticSemaKinds.td
  include/clang/Sema/Sema.h
  lib/Sema/SemaCUDA.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/SemaOverload.cpp
  lib/Sema/SemaTemplate.cpp
  test/CodeGenCUDA/launch-bounds.cu
  test/SemaCUDA/function-overload.cu
  test/SemaCUDA/function-template-overload.cu
  test/SemaCUDA/target_attr_inheritance.cu

Index: test/SemaCUDA/target_attr_inheritance.cu
===================================================================
--- test/SemaCUDA/target_attr_inheritance.cu
+++ /dev/null
@@ -1,29 +0,0 @@
-// Verifies correct inheritance of target attributes during template
-// instantiation and specialization.
-
-// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fsyntax-only -verify %s
-// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fsyntax-only -fcuda-is-device -verify %s
-
-#include "Inputs/cuda.h"
-
-// Function must inherit target attributes during instantiation, but not during
-// specialization.
-template <typename T> __host__ __device__ T function_template(const T &a);
-
-// Specialized functions have their own attributes.
-// expected-note@+1 {{candidate function not viable: call to __host__ function from __device__ function}}
-template <> __host__ float function_template<float>(const float &from);
-
-// expected-note@+1 {{candidate function not viable: call to __device__ function from __host__ function}}
-template <> __device__ double function_template<double>(const double &from);
-
-__host__ void hf() {
-  function_template<float>(1.0f); // OK. Specialization is __host__.
-  function_template<double>(2.0); // expected-error {{no matching function for call to 'function_template'}}
-  function_template(1);           // OK. Instantiated function template is HD.
-}
-__device__ void df() {
-  function_template<float>(3.0f); // expected-error {{no matching function for call to 'function_template'}}
-  function_template<double>(4.0); // OK. Specialization is __device__.
-  function_template(1);           // OK. Instantiated function template is HD.
-}
Index: test/SemaCUDA/function-template-overload.cu
===================================================================
--- /dev/null
+++ test/SemaCUDA/function-template-overload.cu
@@ -0,0 +1,82 @@
+// RUN: %clang_cc1 -std=c++11 -triple x86_64-unknown-linux-gnu -fsyntax-only -verify %s
+// RUN: %clang_cc1 -std=c++11 -triple nvptx64-nvidia-cuda -fsyntax-only -fcuda-is-device -verify %s
+
+#include "Inputs/cuda.h"
+
+struct HType {}; // expected-note-re 6 {{candidate constructor {{.*}} not viable: no known conversion from 'DType'}}
+struct DType {}; // expected-note-re 6 {{candidate constructor {{.*}} not viable: no known conversion from 'HType'}}
+struct HDType {};
+
+template <typename T> __host__ HType overload_h_d(T a) { return HType(); }
+// expected-note@-1 2 {{candidate template ignored: could not match 'HType' against 'DType'}}
+// expected-note@-2 2 {{candidate template ignored: target attributes do not match}}
+template <typename T> __device__ DType overload_h_d(T a) { return DType(); }
+// expected-note@-1 2 {{candidate template ignored: could not match 'DType' against 'HType'}}
+// expected-note@-2 2 {{candidate template ignored: target attributes do not match}}
+
+// Check explicit instantiation.
+template  __device__ __host__ DType overload_h_d(int a); // There's no HD template...
+// expected-error@-1 {{explicit instantiation of 'overload_h_d' does not refer to a function template, variable template, member function, member class, or static data member}}
+template  __device__ __host__ HType overload_h_d(int a); // There's no HD template...
+// expected-error@-1 {{explicit instantiation of 'overload_h_d' does not refer to a function template, variable template, member function, member class, or static data member}}
+template  __device__ DType overload_h_d(int a); // OK. instantiates D
+template  __host__ HType overload_h_d(int a); // OK. instantiates H
+
+// Check explicit specialization.
+template  <> __device__ __host__ DType overload_h_d(long a); // There's no HD template...
+// expected-error@-1 {{no function template matches function template specialization 'overload_h_d'}}
+template  <> __device__ __host__ HType overload_h_d(long a); // There's no HD template...
+// expected-error@-1 {{no function template matches function template specialization 'overload_h_d'}}
+template  <> __device__ DType overload_h_d(long a); // OK. instantiates D
+template  <> __host__ HType overload_h_d(long a); // OK. instantiates H
+
+
+// Can't overload HD template with H or D template, though functions are OK.
+template <typename T> __host__ __device__ HDType overload_hd(T a) { return HDType(); }
+// expected-note@-1 {{previous declaration is here}}
+// expected-note@-2 2 {{candidate template ignored: could not match 'HDType' against 'HType'}}
+template <typename T> __device__ HDType overload_hd(T a);
+// expected-error@-1 {{__device__ function 'overload_hd' cannot overload __host__ __device__ function 'overload_hd'}}
+__device__ HDType overload_hd(int a); // OK.
+
+// Verify that target attributes are taken into account when we
+// explicitly specialize or instantiate function tempaltes.
+template <> __host__ HType overload_hd(int a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_hd'}}
+template __host__ HType overload_hd(long a);
+// expected-error@-1 {{explicit instantiation of 'overload_hd' does not refer to a function template, variable template, member function, member class, or static data member}}
+__host__ HType overload_hd(int a); // OK
+
+template <typename T> __host__ T overload_h(T a); // expected-note {{previous declaration is here}}
+template <typename T> __host__ __device__ T overload_h(T a);
+// expected-error@-1 {{__host__ __device__ function 'overload_h' cannot overload __host__ function 'overload_h'}}
+template <typename T> __device__ T overload_h(T a); // OK. D can overload H.
+
+template <typename T> __host__ HType overload_h_d2(T a) { return HType(); }
+template <typename T> __host__ __device__ HDType overload_h_d2(T a) { return HDType(); }
+template <typename T1, typename T2 = int> __device__ DType overload_h_d2(T1 a) { T1 x; T2 y; return DType(); }
+
+__host__ void hf() {
+  overload_hd(13);
+
+  HType h = overload_h_d(10);
+  HType h2i = overload_h_d2<int>(11);
+  HType h2ii = overload_h_d2<int>(12);
+
+  // These should be implicitly instantiated from __host__ template returning HType.
+  DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2i = overload_h_d2<int>(21); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2ii = overload_h_d2<int>(22); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+}
+__device__ void df() {
+  overload_hd(23);
+
+  // These should be implicitly instantiated from __device__ template returning DType.
+  HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2i = overload_h_d2<int>(11); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2ii = overload_h_d2<int>(12); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+
+  DType d = overload_h_d(20);
+  DType d2i = overload_h_d2<int>(21);
+  DType d2ii = overload_h_d2<int>(22);
+}
Index: test/SemaCUDA/function-overload.cu
===================================================================
--- test/SemaCUDA/function-overload.cu
+++ test/SemaCUDA/function-overload.cu
@@ -40,21 +40,21 @@
 __device__ DeviceReturnTy dh() { return DeviceReturnTy(); }
 
 // H/HD and D/HD are not allowed.
-__host__ __device__ int hdh() { return 0; } // expected-note {{previous definition is here}}
-__host__ int hdh() { return 0; }            // expected-error {{redefinition of 'hdh'}}
+__host__ __device__ int hdh() { return 0; } // expected-note {{previous declaration is here}}
+__host__ int hdh() { return 0; }
+// expected-error@-1 {{__host__ function 'hdh' cannot overload __host__ __device__ function 'hdh'}}
 
-__host__ int hhd() { return 0; }            // expected-note {{previous definition is here}}
-__host__ __device__ int hhd() { return 0; } // expected-error {{redefinition of 'hhd'}}
-// expected-warning@-1 {{attribute declaration must precede definition}}
-// expected-note@-3 {{previous definition is here}}
+__host__ int hhd() { return 0; }            // expected-note {{previous declaration is here}}
+__host__ __device__ int hhd() { return 0; }
+// expected-error@-1 {{__host__ __device__ function 'hhd' cannot overload __host__ function 'hhd'}}
 
-__host__ __device__ int hdd() { return 0; } // expected-note {{previous definition is here}}
-__device__ int hdd() { return 0; }          // expected-error {{redefinition of 'hdd'}}
+__host__ __device__ int hdd() { return 0; } // expected-note {{previous declaration is here}}
+__device__ int hdd() { return 0; }
+// expected-error@-1 {{__device__ function 'hdd' cannot overload __host__ __device__ function 'hdd'}}
 
-__device__ int dhd() { return 0; }          // expected-note {{previous definition is here}}
-__host__ __device__ int dhd() { return 0; } // expected-error {{redefinition of 'dhd'}}
-// expected-warning@-1 {{attribute declaration must precede definition}}
-// expected-note@-3 {{previous definition is here}}
+__device__ int dhd() { return 0; }          // expected-note {{previous declaration is here}}
+__host__ __device__ int dhd() { return 0; }
+// expected-error@-1 {{__host__ __device__ function 'dhd' cannot overload __device__ function 'dhd'}}
 
 // Same tests for extern "C" functions.
 extern "C" __host__ int chh() { return 0; } // expected-note {{previous definition is here}}
@@ -65,13 +65,13 @@
 extern "C" __host__ HostReturnTy cdh() { return HostReturnTy(); }
 
 // H/HD and D/HD overloading is not allowed.
-extern "C" __host__ __device__ int chhd1() { return 0; } // expected-note {{previous definition is here}}
-extern "C" __host__ int chhd1() { return 0; }            // expected-error {{redefinition of 'chhd1'}}
+extern "C" __host__ __device__ int chhd1() { return 0; } // expected-note {{previous declaration is here}}
+extern "C" __host__ int chhd1() { return 0; }
+// expected-error@-1 {{__host__ function 'chhd1' cannot overload __host__ __device__ function 'chhd1'}}
 
-extern "C" __host__ int chhd2() { return 0; }            // expected-note {{previous definition is here}}
-extern "C" __host__ __device__ int chhd2() { return 0; } // expected-error {{redefinition of 'chhd2'}}
-// expected-warning@-1 {{attribute declaration must precede definition}}
-// expected-note@-3 {{previous definition is here}}
+extern "C" __host__ int chhd2() { return 0; } // expected-note {{previous declaration is here}}
+extern "C" __host__ __device__ int chhd2() { return 0; }
+// expected-error@-1 {{__host__ __device__ function 'chhd2' cannot overload __host__ function 'chhd2'}}
 
 // Helper functions to verify calling restrictions.
 __device__ DeviceReturnTy d() { return DeviceReturnTy(); }
@@ -250,33 +250,39 @@
 
 struct m_hhd {
   __host__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __host__ __device__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __host__ __device__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__host__ __device__ function 'operator delete' cannot overload __host__ function 'operator delete'}}
 };
 
 struct m_hdh {
   __host__ __device__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __host__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __host__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__host__ function 'operator delete' cannot overload __host__ __device__ function 'operator delete'}}
 };
 
 struct m_dhd {
   __device__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __host__ __device__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __host__ __device__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__host__ __device__ function 'operator delete' cannot overload __device__ function 'operator delete'}}
 };
 
 struct m_hdd {
   __host__ __device__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __device__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __device__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__device__ function 'operator delete' cannot overload __host__ __device__ function 'operator delete'}}
 };
 
 // __global__ functions can't be overloaded based on attribute
 // difference.
 struct G {
-  friend void friend_of_g(G &arg);
+  friend void friend_of_g(G &arg); // expected-note {{previous declaration is here}}
 private:
-  int x;
+  int x; // expected-note {{declared private here}}
 };
-__global__ void friend_of_g(G &arg) { int x = arg.x; } // expected-note {{previous definition is here}}
-void friend_of_g(G &arg) { int x = arg.x; } // expected-error {{redefinition of 'friend_of_g'}}
+__global__ void friend_of_g(G &arg) { int x = arg.x; }
+// expected-error@-1 {{__global__ function 'friend_of_g' cannot overload __host__ function 'friend_of_g'}}
+// expected-error@-2 {{'x' is a private member of 'G'}}
+void friend_of_g(G &arg) { int x = arg.x; }
 
 // HD functions are sometimes allowed to call H or D functions -- this
 // is an artifact of the source-to-source splitting performed by nvcc
Index: test/CodeGenCUDA/launch-bounds.cu
===================================================================
--- test/CodeGenCUDA/launch-bounds.cu
+++ test/CodeGenCUDA/launch-bounds.cu
@@ -36,16 +36,16 @@
 {
 }
 
-template void Kernel3<MAX_THREADS_PER_BLOCK>();
+template __global__ void Kernel3<MAX_THREADS_PER_BLOCK>();
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel3{{.*}}, !"maxntidx", i32 256}
 
 template <int max_threads_per_block, int min_blocks_per_mp>
 __global__ void
 __launch_bounds__(max_threads_per_block, min_blocks_per_mp)
 Kernel4()
 {
 }
-template void Kernel4<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
+template __global__ void Kernel4<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
 
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel4{{.*}}, !"maxntidx", i32 256}
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel4{{.*}}, !"minctasm", i32 2}
@@ -58,7 +58,7 @@
 Kernel5()
 {
 }
-template void Kernel5<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
+template __global__ void Kernel5<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
 
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel5{{.*}}, !"maxntidx", i32 356}
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel5{{.*}}, !"minctasm", i32 258}
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -7041,6 +7041,15 @@
         continue;
       }
 
+      // Filter out matches that have different target.
+      if (LangOpts.CUDA &&
+          IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(FD)) {
+        FailedCandidates.addCandidate().set(
+            I.getPair(), FunTmpl->getTemplatedDecl(),
+            MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
+        continue;
+      }
+
       // Record this candidate.
       if (ExplicitTemplateArgs)
         ConvertedTemplateArgs[Specialization] = std::move(Args);
@@ -8066,6 +8075,7 @@
   //  instantiated from the member definition associated with its class
   //  template.
   UnresolvedSet<8> Matches;
+  AttributeList *Attr = D.getDeclSpec().getAttributes().getList();
   TemplateSpecCandidateSet FailedCandidates(D.getIdentifierLoc());
   for (LookupResult::iterator P = Previous.begin(), PEnd = Previous.end();
        P != PEnd; ++P) {
@@ -8102,6 +8112,15 @@
       continue;
     }
 
+    // Filter out matches that have different target.
+    if (LangOpts.CUDA &&
+        IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(Attr)) {
+      FailedCandidates.addCandidate().set(
+          P.getPair(), FunTmpl->getTemplatedDecl(),
+          MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
+      continue;
+    }
+
     Matches.addDecl(Specialization, P.getAccess());
   }
 
@@ -8172,7 +8191,6 @@
   }
 
   Specialization->setTemplateSpecializationKind(TSK, D.getIdentifierLoc());
-  AttributeList *Attr = D.getDeclSpec().getAttributes().getList();
   if (Attr)
     ProcessDeclAttributeList(S, Specialization, Attr);
 
Index: lib/Sema/SemaOverload.cpp
===================================================================
--- lib/Sema/SemaOverload.cpp
+++ lib/Sema/SemaOverload.cpp
@@ -580,6 +580,7 @@
   case Sema::TDK_TooManyArguments:
   case Sema::TDK_TooFewArguments:
   case Sema::TDK_MiscellaneousDeductionFailure:
+  case Sema::TDK_CUDATargetMismatch:
     Result.Data = nullptr;
     break;
 
@@ -647,6 +648,7 @@
   case Sema::TDK_TooFewArguments:
   case Sema::TDK_InvalidExplicitArguments:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     break;
 
   case Sema::TDK_Inconsistent:
@@ -689,6 +691,7 @@
   case Sema::TDK_DeducedMismatch:
   case Sema::TDK_NonDeducedMismatch:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return TemplateParameter();
 
   case Sema::TDK_Incomplete:
@@ -720,6 +723,7 @@
   case Sema::TDK_Underqualified:
   case Sema::TDK_NonDeducedMismatch:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return nullptr;
 
   case Sema::TDK_DeducedMismatch:
@@ -747,6 +751,7 @@
   case Sema::TDK_InvalidExplicitArguments:
   case Sema::TDK_SubstitutionFailure:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return nullptr;
 
   case Sema::TDK_Inconsistent:
@@ -774,6 +779,7 @@
   case Sema::TDK_InvalidExplicitArguments:
   case Sema::TDK_SubstitutionFailure:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return nullptr;
 
   case Sema::TDK_Inconsistent:
@@ -1138,20 +1144,11 @@
 
     CUDAFunctionTarget NewTarget = IdentifyCUDATarget(New),
                        OldTarget = IdentifyCUDATarget(Old);
-    if (NewTarget == CFT_InvalidTarget || NewTarget == CFT_Global)
+    if (NewTarget == CFT_InvalidTarget)
       return false;
 
     assert((OldTarget != CFT_InvalidTarget) && "Unexpected invalid target.");
 
-    // Don't allow HD and global functions to overload other functions with the
-    // same signature.  We allow overloading based on CUDA attributes so that
-    // functions can have different implementations on the host and device, but
-    // HD/global functions "exist" in some sense on both the host and device, so
-    // should have the same implementation on both sides.
-    if ((NewTarget == CFT_HostDevice) || (OldTarget == CFT_HostDevice) ||
-        (NewTarget == CFT_Global) || (OldTarget == CFT_Global))
-      return false;
-
     // Allow overloading of functions with same signature and different CUDA
     // target attributes.
     return NewTarget != OldTarget;
@@ -9702,6 +9699,10 @@
     S.Diag(Templated->getLocation(), diag::note_ovl_candidate_bad_deduction);
     MaybeEmitInheritedConstructorNote(S, Found);
     return;
+  case Sema::TDK_CUDATargetMismatch:
+    S.Diag(Templated->getLocation(),
+           diag::note_cuda_ovl_candidate_target_mismatch);
+    return;
   }
 }
 
@@ -9958,6 +9959,7 @@
   case Sema::TDK_DeducedMismatch:
   case Sema::TDK_NonDeducedMismatch:
   case Sema::TDK_MiscellaneousDeductionFailure:
+  case Sema::TDK_CUDATargetMismatch:
     return 3;
 
   case Sema::TDK_InstantiationDepth:
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -8998,6 +8998,9 @@
                !R->isObjCObjectPointerType())
         Diag(NewFD->getLocation(), diag::warn_return_value_udt) << NewFD << R;
     }
+
+    if (!Redeclaration && LangOpts.CUDA)
+      CheckCUDATargetOverload(NewFD, Previous);
   }
   return Redeclaration;
 }
Index: lib/Sema/SemaCUDA.cpp
===================================================================
--- lib/Sema/SemaCUDA.cpp
+++ lib/Sema/SemaCUDA.cpp
@@ -54,6 +54,45 @@
                        /*IsExecConfig=*/true);
 }
 
+Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const AttributeList *Attr) {
+  bool HasHostAttr = false;
+  bool HasDeviceAttr = false;
+  bool HasGlobalAttr = false;
+  bool HasInvalidTargetAttr = false;
+  while (Attr) {
+    switch(Attr->getKind()){
+    case AttributeList::AT_CUDAGlobal:
+      HasGlobalAttr = true;
+      break;
+    case AttributeList::AT_CUDAHost:
+      HasHostAttr = true;
+      break;
+    case AttributeList::AT_CUDADevice:
+      HasDeviceAttr = true;
+      break;
+    case AttributeList::AT_CUDAInvalidTarget:
+      HasInvalidTargetAttr = true;
+      break;
+    default:
+      break;
+    }
+    Attr = Attr->getNext();
+  }
+  if (HasInvalidTargetAttr)
+    return CFT_InvalidTarget;
+
+  if (HasGlobalAttr)
+    return CFT_Global;
+
+  if ((HasHostAttr && HasDeviceAttr) || ForceCUDAHostDeviceDepth > 0)
+    return CFT_HostDevice;
+
+  if (HasDeviceAttr)
+    return CFT_Device;
+
+  return CFT_Host;
+}
+
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this function
 Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D) {
   // Code that lives outside a function is run on the host.
@@ -745,3 +784,30 @@
     Method->addAttr(CUDAHostAttr::CreateImplicit(Context));
   }
 }
+
+void Sema::CheckCUDATargetOverload(FunctionDecl *NewFD,
+                                   LookupResult &Previous) {
+  CUDAFunctionTarget NewTarget = IdentifyCUDATarget(NewFD);
+  for (auto OldND : Previous) {
+    FunctionDecl *OldFD = OldND->getAsFunction();
+    if (!OldFD || OldFD->isFunctionTemplateSpecialization())
+      continue;
+    CUDAFunctionTarget OldTarget = IdentifyCUDATarget(OldFD);
+    // Don't allow HD and global functions to overload other functions with the
+    // same signature.  We allow overloading based on CUDA attributes so that
+    // functions can have different implementations on the host and device, but
+    // HD/global functions "exist" in some sense on both the host and device, so
+    // should have the same implementation on both sides.
+    if (NewTarget != OldTarget &&
+        ((NewTarget == CFT_HostDevice) || (OldTarget == CFT_HostDevice) ||
+         (NewTarget == CFT_Global) || (OldTarget == CFT_Global)) &&
+        !IsOverload(NewFD, OldFD, /* UseMemberUsingDeclRules = */ false,
+                    /* ConsiderCudaAttrs = */ false)) {
+      Diag(NewFD->getLocation(), diag::err_cuda_ovl_target)
+          << NewTarget << NewFD->getDeclName() << OldTarget << OldFD;
+      Diag(OldFD->getLocation(), diag::note_previous_declaration);
+      NewFD->setInvalidDecl();
+      break;
+    }
+  }
+}
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -6555,7 +6555,9 @@
     /// not be resolved to a suitable function.
     TDK_FailedOverloadResolution,
     /// \brief Deduction failed; that's all we know.
-    TDK_MiscellaneousDeductionFailure
+    TDK_MiscellaneousDeductionFailure,
+    /// \brief CUDA Target attributes do not match.
+    TDK_CUDATargetMismatch
   };
 
   TemplateDeductionResult
@@ -9381,6 +9383,7 @@
   /// Use this rather than examining the function's attributes yourself -- you
   /// will get it wrong.  Returns CFT_Host if D is null.
   CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D);
+  CUDAFunctionTarget IdentifyCUDATarget(const AttributeList *Attr);
 
   /// Gets the CUDA target for the current context.
   CUDAFunctionTarget CurrentCUDATarget() {
@@ -9479,6 +9482,10 @@
   bool isEmptyCudaConstructor(SourceLocation Loc, CXXConstructorDecl *CD);
   bool isEmptyCudaDestructor(SourceLocation Loc, CXXDestructorDecl *CD);
 
+  /// Check whether NewFD is a valid overload for CUDA. Emits
+  /// diagnostics and invalidates NewFD if not.
+  void CheckCUDATargetOverload(FunctionDecl *NewFD, LookupResult &Previous);
+
   /// \name Code completion
   //@{
   /// \brief Describes the context in which code completion occurs.
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -6756,6 +6756,11 @@
     "__shared__ local variables not allowed in "
     "%select{__device__|__global__|__host__|__host__ __device__}0 functions">;
 def err_cuda_nonglobal_constant : Error<"__constant__ variables must be global">;
+def err_cuda_ovl_target : Error<
+  "%select{__device__|__global__|__host__|__host__ __device__}0 function %1 "
+  "cannot overload %select{__device__|__global__|__host__|__host__ __device__}2 function %3">;
+def note_cuda_ovl_candidate_target_mismatch : Note<
+    "candidate template ignored: target attributes do not match">;
 
 def warn_non_pod_vararg_with_format_string : Warning<
   "cannot pass %select{non-POD|non-trivial}0 object of type %1 to variadic "
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to