yaxunl updated this revision to Diff 180888.
yaxunl added a comment.

Passing template decl by ExpressionEvaluationContextRecord.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D56411/new/

https://reviews.llvm.org/D56411

Files:
  include/clang/Sema/Sema.h
  lib/Sema/SemaCUDA.cpp
  lib/Sema/SemaTemplate.cpp
  test/SemaCUDA/kernel-template-with-func-arg.cu

Index: test/SemaCUDA/kernel-template-with-func-arg.cu
===================================================================
--- /dev/null
+++ test/SemaCUDA/kernel-template-with-func-arg.cu
@@ -0,0 +1,57 @@
+// RUN: %clang_cc1 -fsyntax-only -verify %s
+
+#include "Inputs/cuda.h"
+
+struct C {
+  __device__ void devfun() {}
+  void hostfun() {}
+  template<class T> __device__ void devtempfun() {}
+  __device__ __host__ void devhostfun() {}
+};
+
+__device__ void devfun() {}
+__host__ void hostfun() {}
+template<class T> __device__ void devtempfun() {}
+__device__ __host__ void devhostfun() {}
+
+template <void (*devF)()> __global__ void kernel() { devF();}
+template <typename T, void(T::*devF)()> __global__ void kernel2(T *p) { (p->*devF)(); }
+
+template<> __global__ void kernel<devfun>();
+template<> __global__ void kernel<hostfun>(); // expected-error {{no function template matches function template specialization 'kernel'}}
+                                              // expected-note@-5 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}}
+template<> __global__ void kernel<devtempfun<int> >();
+template<> __global__ void kernel<devhostfun>();
+
+template<> __global__ void kernel<&devfun>();
+template<> __global__ void kernel<&hostfun>(); // expected-error {{no function template matches function template specialization 'kernel'}}
+                                               // expected-note@-11 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}}
+template<> __global__ void kernel<&devtempfun<int> >();
+template<> __global__ void kernel<&devhostfun>();
+
+template<> __global__ void kernel2<C, &C::devfun>(C *p);
+template<> __global__ void kernel2<C, &C::hostfun>(C *p); // expected-error {{no function template matches function template specialization 'kernel2'}}
+                                                          // expected-note@-16 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}}
+template<> __global__ void kernel2<C, &C::devtempfun<int> >(C *p);
+template<> __global__ void kernel2<C, &C::devhostfun>(C *p);
+
+void fun() {
+  kernel<&devfun><<<1,1>>>();
+  kernel<&hostfun><<<1,1>>>(); // expected-error {{no matching function for call to 'kernel'}}
+                               // expected-note@-24 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}}
+  kernel<&devtempfun<int> ><<<1,1>>>();
+  kernel<&devhostfun><<<1,1>>>();
+
+  kernel<devfun><<<1,1>>>();
+  kernel<hostfun><<<1,1>>>(); // expected-error {{no matching function for call to 'kernel'}}
+                              // expected-note@-30 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}}
+  kernel<devtempfun<int> ><<<1,1>>>();
+  kernel<devhostfun><<<1,1>>>();
+
+  C a;
+  kernel2<C, &C::devfun><<<1,1>>>(&a);
+  kernel2<C, &C::hostfun><<<1,1>>>(&a); // expected-error {{no matching function for call to 'kernel2'}}
+                                        // expected-note@-36 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}}
+  kernel2<C, &C::devtempfun<int> ><<<1,1>>>(&a);
+  kernel2<C, &C::devhostfun><<<1,1>>>(&a);
+}
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -4534,6 +4534,7 @@
 
   EnterExpressionEvaluationContext ConstantEvaluated(
       SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+  SemaRef.ExprEvalContexts.back().Template = Template;
   return SemaRef.SubstExpr(Param->getDefaultArgument(), TemplateArgLists);
 }
 
@@ -4784,8 +4785,8 @@
       TemplateArgument Result;
       unsigned CurSFINAEErrors = NumSFINAEErrors;
       ExprResult Res =
-        CheckTemplateArgument(NTTP, NTTPType, Arg.getArgument().getAsExpr(),
-                              Result, CTAK);
+          CheckTemplateArgument(NTTP, NTTPType, Arg.getArgument().getAsExpr(),
+                                Result, CTAK, dyn_cast<TemplateDecl>(Template));
       if (Res.isInvalid())
         return true;
       // If the current template argument causes an error, give up now.
@@ -6154,6 +6155,22 @@
   return true;
 }
 
+namespace {
+FunctionDecl *GetFunctionDecl(Expr *Arg) {
+  Expr *E = Arg;
+  if (UnaryOperator *UO = dyn_cast<UnaryOperator>(E)) {
+    E = UO ? UO->getSubExpr() : nullptr;
+  }
+  if (DeclRefExpr *DRE = dyn_cast_or_null<DeclRefExpr>(E)) {
+    ValueDecl *Entity = DRE ? DRE->getDecl() : nullptr;
+    if (Entity) {
+      if (auto Callee = dyn_cast<FunctionDecl>(Entity))
+        return Callee;
+    }
+  }
+  return nullptr;
+}
+} // namespace
 /// Check a template argument against its corresponding
 /// non-type template parameter.
 ///
@@ -6164,7 +6181,8 @@
 ExprResult Sema::CheckTemplateArgument(NonTypeTemplateParmDecl *Param,
                                        QualType ParamType, Expr *Arg,
                                        TemplateArgument &Converted,
-                                       CheckTemplateArgumentKind CTAK) {
+                                       CheckTemplateArgumentKind CTAK,
+                                       TemplateDecl *Template) {
   SourceLocation StartLoc = Arg->getBeginLoc();
 
   // If the parameter type somehow involves auto, deduce the type now.
@@ -6251,6 +6269,7 @@
   // a constant-evaluated context.
   EnterExpressionEvaluationContext ConstantEvaluated(
       *this, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+  ExprEvalContexts.back().Template = Template;
 
   if (getLangOpts().CPlusPlus17) {
     // C++17 [temp.arg.nontype]p1:
@@ -6570,6 +6589,10 @@
         return ExprError();
     }
 
+    if (auto *FD = GetFunctionDecl(Arg))
+      if (!CheckCUDACall(Arg->getBeginLoc(), FD))
+        return ExprError();
+
     if (!ParamType->isMemberPointerType()) {
       if (CheckTemplateArgumentAddressOfObjectOrFunction(*this, Param,
                                                          ParamType,
Index: lib/Sema/SemaCUDA.cpp
===================================================================
--- lib/Sema/SemaCUDA.cpp
+++ lib/Sema/SemaCUDA.cpp
@@ -836,9 +836,19 @@
 bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
   assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
   assert(Callee && "Callee may not be null.");
+
+  if (ExprEvalContexts.back().ExprContext ==
+      ExpressionEvaluationContextRecord::ExpressionKind::EK_TemplateArgument)
+    return true;
+
+  FunctionDecl *Caller = nullptr;
+  if (auto *Template = ExprEvalContexts.back().Template)
+    if (auto *FD = dyn_cast<FunctionDecl>(Template->getTemplatedDecl()))
+      Caller = FD;
   // FIXME: Is bailing out early correct here?  Should we instead assume that
   // the caller is a global initializer?
-  FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
+  if (!Caller)
+    Caller = dyn_cast<FunctionDecl>(CurContext);
   if (!Caller)
     return true;
 
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -1009,6 +1009,10 @@
       EK_Decltype, EK_TemplateArgument, EK_Other
     } ExprContext;
 
+    /// If we are checking arguments of a template, this is the template
+    /// under check.
+    TemplateDecl *Template;
+
     ExpressionEvaluationContextRecord(ExpressionEvaluationContext Context,
                                       unsigned NumCleanupObjects,
                                       CleanupInfo ParentCleanup,
@@ -1017,7 +1021,7 @@
         : Context(Context), ParentCleanup(ParentCleanup),
           NumCleanupObjects(NumCleanupObjects), NumTypos(0),
           ManglingContextDecl(ManglingContextDecl), MangleNumbering(),
-          ExprContext(ExprContext) {}
+          ExprContext(ExprContext), Template(nullptr) {}
 
     /// Retrieve the mangling numbering context, used to consistently
     /// number constructs like lambdas for mangling.
@@ -6453,10 +6457,12 @@
 
   bool CheckTemplateArgument(TemplateTypeParmDecl *Param,
                              TypeSourceInfo *Arg);
-  ExprResult CheckTemplateArgument(NonTypeTemplateParmDecl *Param,
-                                   QualType InstantiatedParamType, Expr *Arg,
-                                   TemplateArgument &Converted,
-                               CheckTemplateArgumentKind CTAK = CTAK_Specified);
+  ExprResult
+  CheckTemplateArgument(NonTypeTemplateParmDecl *Param,
+                        QualType InstantiatedParamType, Expr *Arg,
+                        TemplateArgument &Converted,
+                        CheckTemplateArgumentKind CTAK = CTAK_Specified,
+                        TemplateDecl *Template = nullptr);
   bool CheckTemplateTemplateArgument(TemplateParameterList *Params,
                                      TemplateArgumentLoc &Arg);
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to