jdoerfert updated this revision to Diff 232902.
jdoerfert added a comment.

Add one more test sin(long double), and fix some rebase issues


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D71179/new/

https://reviews.llvm.org/D71179

Files:
  clang/include/clang/AST/Decl.h
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Basic/OpenMPKinds.def
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/Decl.cpp
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CodeGenModule.cpp
  clang/lib/Headers/__clang_cuda_cmath.h
  clang/lib/Headers/__clang_cuda_device_functions.h
  clang/lib/Headers/__clang_cuda_math_forward_declares.h
  clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
  clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
  clang/lib/Headers/openmp_wrappers/cmath
  clang/lib/Headers/openmp_wrappers/math.h
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaDecl.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/SemaOverload.cpp
  clang/lib/Sema/SemaTemplate.cpp
  clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
  clang/test/AST/ast-dump-openmp-begin-declare-variant.c
  clang/test/OpenMP/begin_declare_variant_codegen.cpp
  clang/test/OpenMP/math_codegen.cpp
  clang/test/OpenMP/math_fp_macro.cpp

Index: clang/test/OpenMP/math_fp_macro.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/math_fp_macro.cpp
@@ -0,0 +1,9 @@
+// RUN: %clang_cc1 -verify -fopenmp -fopenmp-targets=nvptx64-nvidia-cuda -x c++ -emit-llvm %s -triple %itanium_abi_triple -fexceptions -fcxx-exceptions -o - | FileCheck %s
+// expected-no-diagnostics
+
+#include <cmath>
+
+int main() {
+  double a(0);
+  return (std::fpclassify(a) != FP_ZERO);
+}
Index: clang/test/OpenMP/math_codegen.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/math_codegen.cpp
@@ -0,0 +1,15 @@
+#include <cmath>
+
+void math(short s, int i, float f, double d) {
+  sin(s);
+  sin(i);
+  sin(f);
+  sin(d);
+}
+
+void foo(short s, int i, float f, double d, long double ld) {
+  sin(ld);
+  math(s, i, f, d);
+#pragma omp target
+  { math(s, i, f, d); }
+}
Index: clang/test/OpenMP/begin_declare_variant_codegen.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/begin_declare_variant_codegen.cpp
@@ -0,0 +1,134 @@
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -emit-llvm %s -triple %itanium_abi_triple -fexceptions -fcxx-exceptions -o - | FileCheck %s
+// expected-no-diagnostics
+
+int bar(void) {
+  return 0;
+}
+
+template <typename T>
+T baz(void) { return 0; }
+
+#pragma omp begin declare variant match(device={kind(cpu)})
+int foo(void) {
+  return 1;
+}
+int bar(void) {
+  return 1;
+}
+template <typename T>
+T baz(void) { return 1; }
+
+template <typename T>
+T biz(void) { return 1; }
+
+template <typename T>
+T buz(void) { return 3; }
+
+template <>
+char buz(void) { return 1; }
+
+template <typename T>
+T bez(void) { return 3; }
+#pragma omp end declare variant
+
+#pragma omp begin declare variant match(device={kind(gpu)})
+int foo(void) {
+  return 2;
+}
+int bar(void) {
+  return 2;
+}
+#pragma omp end declare variant
+
+
+#pragma omp begin declare variant match(device={kind(fpga)})
+
+This text is never parsed!
+
+#pragma omp end declare variant
+
+int foo(void) {
+  return 0;
+}
+
+template <typename T>
+T biz(void) { return 0; }
+
+template <>
+char buz(void) { return 0; }
+
+template <>
+long bez(void) { return 0; }
+
+#pragma omp begin declare variant match(device = {kind(cpu)})
+template <>
+long bez(void) { return 1; }
+#pragma omp end declare variant
+
+int test() {
+  return foo() + bar() + baz<int>() + biz<short>() + buz<char>() + bez<long>();
+}
+
+// Make sure all ompvariant functions return 1 and all others return 0.
+
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3barv()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3foov.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3barv.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define signext i8 @_Z3buzIcET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i8 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3foov()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define signext i8 @_Z3buzIcET_v()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i8 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i64 @_Z3bezIlET_v()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i64 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i64 @_Z3bezIlET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i64 1
+// CHECK-NEXT:  }
+
+// Make sure we call only ompvariant functions
+
+// CHECK:  define i32 @_Z4testv()
+// CHECK:    %call = call i32 @_Z3foov.ompvariant()
+// CHECK:    %call1 = call i32 @_Z3barv.ompvariant()
+// CHECK:    %call2 = call i32 @_Z3bazIiET_v.ompvariant()
+// CHECK:    %call4 = call signext i16 @_Z3bizIsET_v.ompvariant()
+// CHECK:    %call6 = call signext i8 @_Z3buzIcET_v.ompvariant()
+// CHECK:    %call10 = call i64 @_Z3bezIlET_v.ompvariant()
+
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define linkonce_odr i32 @_Z3bazIiET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define linkonce_odr signext i16 @_Z3bizIsET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i16 1
+// CHECK-NEXT:  }
Index: clang/test/AST/ast-dump-openmp-begin-declare-variant.c
===================================================================
--- /dev/null
+++ clang/test/AST/ast-dump-openmp-begin-declare-variant.c
@@ -0,0 +1,83 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fopenmp -ast-dump %s | FileCheck %s
+
+int also_before(void) {
+  return 0;
+}
+
+#pragma omp begin declare variant match(device={kind(cpu)})
+int also_after(void) {
+  return 1;
+}
+int also_before(void) {
+  return 1;
+}
+#pragma omp end declare variant
+
+#pragma omp begin declare variant match(device={kind(gpu)})
+int also_after(void) {
+  return 2;
+}
+int also_before(void) {
+  return 2;
+}
+#pragma omp end declare variant
+
+
+#pragma omp begin declare variant match(device={kind(fpga)})
+
+This text is never parsed!
+
+#pragma omp end declare variant
+
+int also_after(void) {
+  return 0;
+}
+
+int test() {
+  return also_after() + also_before();
+}
+
+// Make sure:
+// 1) we pick the right versions, that is test should reference the kind(cpu) versions.
+// 2) we do not see the ast nodes for the gpu kind
+// 3) we do not choke on the text in the kind(fpga) guarded scope.
+
+// CHECK:       -FunctionDecl {{.*}} <{{.*}}3:1, line:{{.*}}:1> line:{{.*}}:5 also_before 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:23, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 0
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Inherited Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  |-FunctionDecl [[GOOD_ALSO_AFTER:0x[a-z0-9]*]] <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 used also_after 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:22, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 1
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  |-FunctionDecl [[GOOD_ALSO_BEFORE:0x[a-z0-9]*]] <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 used also_before 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:23, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 1
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  |-FunctionDecl {{.*}} <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 also_after 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:22, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 0
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Inherited Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  `-FunctionDecl {{.*}} <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 test 'int ()'
+// CHECK-NEXT:    `-CompoundStmt {{.*}} <col:12, line:{{.*}}:1>
+// CHECK-NEXT:      `-ReturnStmt {{.*}} <line:{{.*}}:3, col:37>
+// CHECK-NEXT:        `-BinaryOperator {{.*}} <col:10, col:37> 'int' '+'
+// CHECK-NEXT:          |-CallExpr {{.*}} <col:10, col:21> 'int'
+// CHECK-NEXT:          | `-ImplicitCastExpr {{.*}} <col:10> 'int (*)(void)' <FunctionToPointerDecay>
+// CHECK-NEXT:          |   `-DeclRefExpr {{.*}} <col:10> 'int (void)' lvalue Function [[GOOD_ALSO_AFTER]] 'also_after' 'int (void)'
+// CHECK-NEXT:          `-CallExpr {{.*}} <col:25, col:37> 'int'
+// CHECK-NEXT:            `-ImplicitCastExpr {{.*}} <col:25> 'int (*)(void)' <FunctionToPointerDecay>
+// CHECK-NEXT:              `-DeclRefExpr {{.*}} <col:25> 'int (void)' lvalue Function [[GOOD_ALSO_BEFORE]] 'also_before' 'int (void)'
+
Index: clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -429,9 +429,17 @@
       llvm_unreachable("Unexpected context selector set kind.");
     }
   }
-  S.ActOnOpenMPDeclareVariantDirective(DeclVarData.getValue().first,
-                                       DeclVarData.getValue().second,
-                                       Attr.getRange(), Data);
+  FunctionDecl *VariantFD =
+      DeclVarData.hasValue() ? DeclVarData.getValue().first : FD;
+  Expr *VariantExpr =
+      DeclVarData.hasValue() ? DeclVarData.getValue().second : nullptr;
+  S.ActOnOpenMPDeclareVariantDirective(VariantFD, VariantExpr, Attr.getRange(),
+                                       Data);
+
+  // The new attribute on the instantiation is inherited if the template
+  // attribute was.
+  if (Attr.isInherited())
+    VariantFD->getAttr<OMPDeclareVariantAttr>()->setInherited(true);
 }
 
 static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -8596,6 +8596,9 @@
   if (LangOpts.CUDA)
     inheritCUDATargetAttrs(FD, *Specialization->getPrimaryTemplate());
 
+  if (LangOpts.OpenMP)
+    inheritOpenMPVariantAttrs(FD, *Specialization->getPrimaryTemplate());
+
   // The "previous declaration" for this function template specialization is
   // the prior function template specialization.
   Previous.clear();
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -17,6 +17,7 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
+#include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/TypeOrdering.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/DiagnosticOptions.h"
@@ -9264,7 +9265,8 @@
   return Comparison::Equal;
 }
 
-static bool isBetterMultiversionCandidate(const OverloadCandidate &Cand1,
+static bool isBetterMultiversionCandidate(ASTContext &ASTCtx,
+                                          const OverloadCandidate &Cand1,
                                           const OverloadCandidate &Cand2) {
   if (!Cand1.Function || !Cand1.Function->isMultiVersion() || !Cand2.Function ||
       !Cand2.Function->isMultiVersion())
@@ -9275,6 +9277,18 @@
   if (Cand1.Function->isInvalidDecl()) return false;
   if (Cand2.Function->isInvalidDecl()) return true;
 
+  // If we have an OpenMP declare variant attribute on either candidate we use
+  // it to order the candidates. The first is only better if it has a attribute
+  // that is considered better or if it has no attribute and the one on the
+  // second candidate is not a match.
+  auto *OMPVariantAttr1 = Cand1.Function->getAttr<OMPDeclareVariantAttr>();
+  auto *OMPVariantAttr2 = Cand2.Function->getAttr<OMPDeclareVariantAttr>();
+  if (OMPVariantAttr1 || OMPVariantAttr2) {
+    auto *OMPVariantAttrBest =
+        getBetterOpenMPContextMatch(ASTCtx, OMPVariantAttr1, OMPVariantAttr2);
+    return OMPVariantAttrBest == OMPVariantAttr1;
+  }
+
   // If this is a cpu_dispatch/cpu_specific multiversion situation, prefer
   // cpu_dispatch, else arbitrarily based on the identifiers.
   bool Cand1CPUDisp = Cand1.Function->hasAttr<CPUDispatchAttr>();
@@ -9543,7 +9557,7 @@
   if (HasPS1 != HasPS2 && HasPS1)
     return true;
 
-  return isBetterMultiversionCandidate(Cand1, Cand2);
+  return isBetterMultiversionCandidate(S.getASTContext(), Cand1, Cand2);
 }
 
 /// Determine whether two declarations are "equivalent" for the purposes of
@@ -9657,6 +9671,20 @@
     }
   }
 
+  // [OpenMP] Similar to the CUDA code above, OpenMP declare variants might not
+  // be eligible at all so we need to filter them out early.
+  if (S.getLangOpts().OpenMP) {
+    // TODO use context information
+    auto IsNonMatchVariant = [&](OverloadCandidate *Cand) {
+      if (!Cand->Viable || !Cand->Function)
+        return false;
+      auto *OMPVariantAttr = Cand->Function->getAttr<OMPDeclareVariantAttr>();
+      return OMPVariantAttr &&
+             !isOpenMPContextMatch(S.getASTContext(), OMPVariantAttr);
+    };
+    llvm::erase_if(Candidates, IsNonMatchVariant);
+  }
+
   // Find the best viable function.
   Best = end();
   for (auto *Cand : Candidates) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -14,6 +14,7 @@
 #include "TreeTransform.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTMutationListener.h"
+#include "clang/AST/Attr.h"
 #include "clang/AST/CXXInheritance.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
@@ -5219,8 +5220,8 @@
 
   // The VariantRef must point to function.
   if (!VariantRef) {
-    Diag(SR.getBegin(), diag::err_omp_function_expected) << VariantId;
-    return None;
+    // Diag(SR.getBegin(), diag::err_omp_function_expected) << VariantId;
+    return std::make_pair(FD, VariantRef);
   }
 
   // Do not check templates, wait until instantiation.
@@ -5317,6 +5318,7 @@
     return None;
   }
 
+  // TODO check these for missing VariantRef as well
   enum DoesntSupport {
     VirtFuncs = 1,
     Constructors = 3,
@@ -5377,16 +5379,19 @@
                               PDiag(diag::err_omp_declare_variant_diff)
                                   << FD->getLocation()),
           /*TemplatesSupported=*/true, /*ConstexprSupported=*/false,
-          /*CLinkageMayDiffer=*/true))
+          /*CLinkageMayDiffer=*/true,
+          /*StorageClassMayDiffer=*/true,
+          /*ConstexprSpecMayDiffer=*/true,
+          /*InlineSpecificationMayDiffer=*/true))
     return None;
   return std::make_pair(FD, cast<Expr>(DRE));
 }
 
-void Sema::ActOnOpenMPDeclareVariantDirective(
+bool Sema::ActOnOpenMPDeclareVariantDirective(
     FunctionDecl *FD, Expr *VariantRef, SourceRange SR,
     ArrayRef<OMPCtxSelectorData> Data) {
   if (Data.empty())
-    return;
+    return false;
   SmallVector<Expr *, 4> CtxScores;
   SmallVector<unsigned, 4> CtxSets;
   SmallVector<unsigned, 4> Ctxs;
@@ -5396,7 +5401,7 @@
     OpenMPContextSelectorSetKind CtxSet = D.CtxSet;
     OpenMPContextSelectorKind Ctx = D.Ctx;
     if (CtxSet == OMP_CTX_SET_unknown || Ctx == OMP_CTX_unknown)
-      return;
+      return false;
     Expr *Score = nullptr;
     if (D.Score.isUsable()) {
       Score = D.Score.get();
@@ -5449,8 +5454,16 @@
         CtxSets.begin(), CtxSets.size(), Ctxs.begin(), Ctxs.size(),
         ImplVendors.begin(), ImplVendors.size(), DeviceKinds.begin(),
         DeviceKinds.size(), SR);
-    FD->addAttr(NewAttr);
+    if (FD) {
+      FD->addAttr(NewAttr);
+    } else {
+      assert(!DeclareVariantScopeAttr &&
+             "TODO nested begin/end declare varinat");
+      DeclareVariantScopeAttr = NewAttr;
+      return !isOpenMPContextMatch(getASTContext(), DeclareVariantScopeAttr);
+    }
   }
+  return false;
 }
 
 void Sema::markOpenMPDeclareVariantFuncsReferenced(SourceLocation Loc,
@@ -5463,11 +5476,15 @@
          Func->specific_attrs<OMPDeclareVariantAttr>()) {
       // TODO: add checks for active OpenMP context where possible.
       Expr *VariantRef = A->getVariantFuncRef();
-      auto *DRE = cast<DeclRefExpr>(VariantRef->IgnoreParenImpCasts());
-      auto *F = cast<FunctionDecl>(DRE->getDecl());
+      FunctionDecl *F = Func;
+      if (VariantRef) {
+        auto *DRE = cast<DeclRefExpr>(VariantRef->IgnoreParenImpCasts());
+        F = cast<FunctionDecl>(DRE->getDecl());
+      }
       if (!F->isDefined() && F->isTemplateInstantiation())
         InstantiateFunctionDefinition(Loc, F->getFirstDecl());
-      MarkFunctionReferenced(Loc, F, MightBeOdrUse);
+      if (F != Func)
+        MarkFunctionReferenced(Loc, F, MightBeOdrUse);
     }
   }
 }
@@ -17034,3 +17051,20 @@
   return OMPAllocateClause::Create(Context, StartLoc, LParenLoc, Allocator,
                                    ColonLoc, EndLoc, Vars);
 }
+
+template <typename AttrTy>
+static void copyAttrIfPresent(Sema &S, FunctionDecl *FD,
+                              const FunctionDecl &TemplateFD) {
+  if (!FD->hasAttr<AttrTy>())
+    if (AttrTy *Attribute = TemplateFD.getAttr<AttrTy>()) {
+      AttrTy *Clone = Attribute->clone(S.Context);
+      Clone->setInherited(true);
+      FD->addAttr(Clone);
+    }
+}
+
+void Sema::inheritOpenMPVariantAttrs(FunctionDecl *FD,
+                                     const FunctionTemplateDecl &TD) {
+  const FunctionDecl &TemplateFD = *TD.getTemplatedDecl();
+  copyAttrIfPresent<OMPDeclareVariantAttr>(*this, FD, TemplateFD);
+}
Index: clang/lib/Sema/SemaDecl.cpp
===================================================================
--- clang/lib/Sema/SemaDecl.cpp
+++ clang/lib/Sema/SemaDecl.cpp
@@ -2349,8 +2349,7 @@
     if (!isa<TypedefNameDecl>(Old))
       return;
 
-    Diag(New->getLocation(), diag::err_redefinition)
-      << New->getDeclName();
+    Diag(New->getLocation(), diag::err_redefinition) << New->getDeclName();
     notePreviousDefinition(Old, New->getLocation());
     return New->setInvalidDecl();
   }
@@ -8654,6 +8653,14 @@
                                               isVirtualOkay);
   if (!NewFD) return nullptr;
 
+  if (getLangOpts().OpenMP && DeclareVariantScopeAttr) {
+    OMPDeclareVariantAttr *DeclVarAttr =
+        DeclareVariantScopeAttr->clone(getASTContext());
+    DeclVarAttr->setInherited(false);
+    NewFD->addAttr(DeclVarAttr);
+    NewFD->setIsMultiVersion();
+  }
+
   if (OriginalLexicalContext && OriginalLexicalContext->isObjCContainer())
     NewFD->setTopLevelDeclInObjCContainer();
 
@@ -9784,6 +9791,10 @@
       if (MVType != MultiVersionKind::Target)
         return true;
       break;
+    case attr::OMPDeclareVariant:
+      if (MVType != MultiVersionKind::OMPVariant)
+        return true;
+      break;
     default:
       return true;
     }
@@ -9797,7 +9808,8 @@
     const PartialDiagnosticAt &NoteCausedDiagIDAt,
     const PartialDiagnosticAt &NoSupportDiagIDAt,
     const PartialDiagnosticAt &DiffDiagIDAt, bool TemplatesSupported,
-    bool ConstexprSupported, bool CLinkageMayDiffer) {
+    bool ConstexprSupported, bool CLinkageMayDiffer, bool StorageClassMayDiffer,
+    bool ConstexprSpecMayDiffer, bool InlineSpecificationMayDiffer) {
   enum DoesntSupport {
     FuncTemplates = 0,
     VirtFuncs = 1,
@@ -9860,7 +9872,7 @@
 
   QualType NewQType = Context.getCanonicalType(NewFD->getType());
   const auto *NewType = cast<FunctionType>(NewQType);
-  QualType NewReturnType = NewType->getReturnType();
+  QualType NewReturnType = NewType->getReturnType().getUnqualifiedType();
 
   if (NewReturnType->isUndeducedType())
     return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
@@ -9876,18 +9888,21 @@
     if (OldTypeInfo.getCC() != NewTypeInfo.getCC())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << CallingConv;
 
-    QualType OldReturnType = OldType->getReturnType();
+    QualType OldReturnType = OldType->getReturnType().getUnqualifiedType();
 
     if (OldReturnType != NewReturnType)
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << ReturnType;
 
-    if (OldFD->getConstexprKind() != NewFD->getConstexprKind())
+    if (!ConstexprSpecMayDiffer &&
+        OldFD->getConstexprKind() != NewFD->getConstexprKind())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << ConstexprSpec;
 
-    if (OldFD->isInlineSpecified() != NewFD->isInlineSpecified())
+    if (!InlineSpecificationMayDiffer &&
+        OldFD->isInlineSpecified() != NewFD->isInlineSpecified())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << InlineSpec;
 
-    if (OldFD->getStorageClass() != NewFD->getStorageClass())
+    if (!StorageClassMayDiffer &&
+        OldFD->getStorageClass() != NewFD->getStorageClass())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << StorageClass;
 
     if (!CLinkageMayDiffer && OldFD->isExternC() != NewFD->isExternC())
@@ -9905,7 +9920,9 @@
                                              const FunctionDecl *NewFD,
                                              bool CausesMV,
                                              MultiVersionKind MVType) {
-  if (!S.getASTContext().getTargetInfo().supportsMultiVersioning()) {
+  bool IsOpenMPVariant = MVType == MultiVersionKind::OMPVariant;
+  if (!IsOpenMPVariant &&
+      !S.getASTContext().getTargetInfo().supportsMultiVersioning()) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_not_supported);
     if (OldFD)
       S.Diag(OldFD->getLocation(), diag::note_previous_declaration);
@@ -9918,19 +9935,20 @@
 
   // For now, disallow all other attributes.  These should be opt-in, but
   // an analysis of all of them is a future FIXME.
-  if (CausesMV && OldFD && HasNonMultiVersionAttributes(OldFD, MVType)) {
+  if (CausesMV && OldFD && !IsOpenMPVariant &&
+      HasNonMultiVersionAttributes(OldFD, MVType)) {
     S.Diag(OldFD->getLocation(), diag::err_multiversion_no_other_attrs)
         << IsCPUSpecificCPUDispatchMVType;
     S.Diag(NewFD->getLocation(), diag::note_multiversioning_caused_here);
     return true;
   }
 
-  if (HasNonMultiVersionAttributes(NewFD, MVType))
+  if (!IsOpenMPVariant && HasNonMultiVersionAttributes(NewFD, MVType))
     return S.Diag(NewFD->getLocation(), diag::err_multiversion_no_other_attrs)
            << IsCPUSpecificCPUDispatchMVType;
 
   // Only allow transition to MultiVersion if it hasn't been used.
-  if (OldFD && CausesMV && OldFD->isUsed(false))
+  if (OldFD && CausesMV && !IsOpenMPVariant && OldFD->isUsed(false))
     return S.Diag(NewFD->getLocation(), diag::err_multiversion_after_used);
 
   return S.areMultiversionVariantFunctionsCompatible(
@@ -9942,9 +9960,12 @@
                               << IsCPUSpecificCPUDispatchMVType),
       PartialDiagnosticAt(NewFD->getLocation(),
                           S.PDiag(diag::err_multiversion_diff)),
-      /*TemplatesSupported=*/false,
+      /*TemplatesSupported=*/IsOpenMPVariant,
       /*ConstexprSupported=*/!IsCPUSpecificCPUDispatchMVType,
-      /*CLinkageMayDiffer=*/false);
+      /*CLinkageMayDiffer=*/IsOpenMPVariant,
+      /*StorageClassMayDiffer=*/IsOpenMPVariant,
+      /*ConstexprSpecMayDiffer=*/IsOpenMPVariant,
+      /*InlineSpecificationMayDiffer=*/IsOpenMPVariant);
 }
 
 /// Check the validity of a multiversion function declaration that is the
@@ -9955,7 +9976,8 @@
 /// Returns true if there was an error, false otherwise.
 static bool CheckMultiVersionFirstFunction(Sema &S, FunctionDecl *FD,
                                            MultiVersionKind MVType,
-                                           const TargetAttr *TA) {
+                                           const TargetAttr *TA,
+                                           NamedDecl *OldDecl) {
   assert(MVType != MultiVersionKind::None &&
          "Function lacks multiversion attribute");
 
@@ -10075,8 +10097,8 @@
     Sema &S, FunctionDecl *OldFD, FunctionDecl *NewFD,
     MultiVersionKind NewMVType, const TargetAttr *NewTA,
     const CPUDispatchAttr *NewCPUDisp, const CPUSpecificAttr *NewCPUSpec,
-    bool &Redeclaration, NamedDecl *&OldDecl, bool &MergeTypeWithPrevious,
-    LookupResult &Previous) {
+    const OMPDeclareVariantAttr *NewOpenMPVariant, bool &Redeclaration,
+    NamedDecl *&OldDecl, bool &MergeTypeWithPrevious, LookupResult &Previous) {
 
   MultiVersionKind OldMVType = OldFD->getMultiVersionKind();
   // Disallow mixing of multiversioning types.
@@ -10090,6 +10112,18 @@
     return true;
   }
 
+  if (OldMVType == MultiVersionKind::OMPVariant &&
+      NewMVType == MultiVersionKind::None) {
+    assert(!NewOpenMPVariant && "Didn't expect variant attr!");
+    auto *OldOMPVariant = OldFD->getAttr<OMPDeclareVariantAttr>();
+    auto *NewOMPVariant = OldOMPVariant->clone(S.getASTContext());
+    NewOMPVariant->setInherited(true);
+    NewFD->addAttr(NewOMPVariant);
+    NewFD->setIsMultiVersion();
+    NewOpenMPVariant = NewOMPVariant;
+    NewMVType = MultiVersionKind::OMPVariant;
+  }
+
   TargetAttr::ParsedTargetAttr NewParsed;
   if (NewTA) {
     NewParsed = NewTA->parse();
@@ -10125,6 +10159,14 @@
         NewFD->setInvalidDecl();
         return true;
       }
+    } else if (NewMVType == MultiVersionKind::OMPVariant) {
+      auto *CurOMPVariant = CurFD->getAttr<OMPDeclareVariantAttr>();
+      if (!CurOMPVariant) {
+        CurOMPVariant = NewOpenMPVariant->clone(S.getASTContext());
+        CurOMPVariant->setInherited(true);
+        CurFD->addAttr(CurOMPVariant);
+        CurFD->setIsMultiVersion();
+      }
     } else {
       const auto *CurCPUSpec = CurFD->getAttr<CPUSpecificAttr>();
       const auto *CurCPUDisp = CurFD->getAttr<CPUDispatchAttr>();
@@ -10217,7 +10259,6 @@
   return false;
 }
 
-
 /// Check the validity of a mulitversion function declaration.
 /// Also sets the multiversion'ness' of the function itself.
 ///
@@ -10231,10 +10272,12 @@
   const auto *NewTA = NewFD->getAttr<TargetAttr>();
   const auto *NewCPUDisp = NewFD->getAttr<CPUDispatchAttr>();
   const auto *NewCPUSpec = NewFD->getAttr<CPUSpecificAttr>();
+  const auto *NewOpenMPVariant = NewFD->getAttr<OMPDeclareVariantAttr>();
+  unsigned NumMV = bool(NewTA) + bool(NewCPUDisp) + bool(NewCPUSpec) +
+                   bool(NewOpenMPVariant);
 
   // Mixing Multiversioning types is prohibited.
-  if ((NewTA && NewCPUDisp) || (NewTA && NewCPUSpec) ||
-      (NewCPUDisp && NewCPUSpec)) {
+  if (NumMV > 1) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_types_mixed);
     NewFD->setInvalidDecl();
     return true;
@@ -10255,14 +10298,18 @@
     return false;
   }
 
+  if (auto *USD = dyn_cast_or_null<UsingShadowDecl>(OldDecl))
+    OldDecl = USD->getTargetDecl();
+
   if (!OldDecl || !OldDecl->getAsFunction() ||
-      OldDecl->getDeclContext()->getRedeclContext() !=
-          NewFD->getDeclContext()->getRedeclContext()) {
+      (OldDecl->getDeclContext()->getRedeclContext() !=
+           NewFD->getDeclContext()->getRedeclContext() &&
+       !OldDecl->getAsFunction()->isOpenMPMultiVersion())) {
     // If there's no previous declaration, AND this isn't attempting to cause
     // multiversioning, this isn't an error condition.
     if (MVType == MultiVersionKind::None)
       return false;
-    return CheckMultiVersionFirstFunction(S, NewFD, MVType, NewTA);
+    return CheckMultiVersionFirstFunction(S, NewFD, MVType, NewTA, OldDecl);
   }
 
   FunctionDecl *OldFD = OldDecl->getAsFunction();
@@ -10270,7 +10317,8 @@
   if (!OldFD->isMultiVersion() && MVType == MultiVersionKind::None)
     return false;
 
-  if (OldFD->isMultiVersion() && MVType == MultiVersionKind::None) {
+  if (OldFD->isMultiVersion() && MVType == MultiVersionKind::None &&
+      !OldFD->isOpenMPMultiVersion()) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_required_in_redecl)
         << (OldFD->getMultiVersionKind() != MultiVersionKind::Target);
     NewFD->setInvalidDecl();
@@ -10287,8 +10335,8 @@
   // appropriate attribute in the current function decl.  Resolve that these are
   // still compatible with previous declarations.
   return CheckMultiVersionAdditionalDecl(
-      S, OldFD, NewFD, MVType, NewTA, NewCPUDisp, NewCPUSpec, Redeclaration,
-      OldDecl, MergeTypeWithPrevious, Previous);
+      S, OldFD, NewFD, MVType, NewTA, NewCPUDisp, NewCPUSpec, NewOpenMPVariant,
+      Redeclaration, OldDecl, MergeTypeWithPrevious, Previous);
 }
 
 /// Perform semantic checking of a new function declaration.
@@ -10317,8 +10365,8 @@
   // Determine whether the type of this function should be merged with
   // a previous visible declaration. This never happens for functions in C++,
   // and always happens in C if the previous declaration was visible.
-  bool MergeTypeWithPrevious = !getLangOpts().CPlusPlus &&
-                               !Previous.isShadowed();
+  bool MergeTypeWithPrevious =
+      !getLangOpts().CPlusPlus && !Previous.isShadowed();
 
   bool Redeclaration = false;
   NamedDecl *OldDecl = nullptr;
@@ -13582,6 +13630,16 @@
   else
     FD = cast<FunctionDecl>(D);
 
+  if (getLangOpts().OpenMP && DeclareVariantScopeAttr) {
+    OMPDeclareVariantAttr *DeclVarAttr = FD->getAttr<OMPDeclareVariantAttr>();
+    if (!DeclVarAttr) {
+      DeclVarAttr = DeclareVariantScopeAttr->clone(getASTContext());
+      FD->addAttr(DeclVarAttr);
+    }
+    DeclVarAttr->setInherited(false);
+    FD->setIsMultiVersion();
+  }
+
   // Do not push if it is a lambda because one is already pushed when building
   // the lambda in ActOnStartOfLambdaDefinition().
   if (!isLambdaCallOperator(FD))
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -44,6 +44,8 @@
   OMPD_target_teams_distribute_parallel,
   OMPD_mapper,
   OMPD_variant,
+  OMPD_begin,
+  OMPD_begin_declare,
 };
 
 class DeclDirectiveListParserHelper final {
@@ -83,6 +85,7 @@
       .Case("update", OMPD_update)
       .Case("mapper", OMPD_mapper)
       .Case("variant", OMPD_variant)
+      .Case("begin", OMPD_begin)
       .Default(OMPD_unknown);
 }
 
@@ -91,18 +94,21 @@
   // E.g.: OMPD_for OMPD_simd ===> OMPD_for_simd
   // TODO: add other combined directives in topological order.
   static const unsigned F[][3] = {
+      {OMPD_begin, OMPD_declare, OMPD_begin_declare},
+      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_cancellation, OMPD_point, OMPD_cancellation_point},
       {OMPD_declare, OMPD_reduction, OMPD_declare_reduction},
       {OMPD_declare, OMPD_mapper, OMPD_declare_mapper},
       {OMPD_declare, OMPD_simd, OMPD_declare_simd},
       {OMPD_declare, OMPD_target, OMPD_declare_target},
       {OMPD_declare, OMPD_variant, OMPD_declare_variant},
+      {OMPD_begin_declare, OMPD_variant, OMPD_begin_declare_variant},
+      {OMPD_end_declare, OMPD_variant, OMPD_end_declare_variant},
       {OMPD_distribute, OMPD_parallel, OMPD_distribute_parallel},
       {OMPD_distribute_parallel, OMPD_for, OMPD_distribute_parallel_for},
       {OMPD_distribute_parallel_for, OMPD_simd,
        OMPD_distribute_parallel_for_simd},
       {OMPD_distribute, OMPD_simd, OMPD_distribute_simd},
-      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_end_declare, OMPD_target, OMPD_end_declare_target},
       {OMPD_target, OMPD_data, OMPD_target_data},
       {OMPD_target, OMPD_enter, OMPD_target_enter},
@@ -1046,37 +1052,8 @@
   return false;
 }
 
-/// Parse clauses for '#pragma omp declare variant ( variant-func-id ) clause'.
-void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
-                                           CachedTokens &Toks,
-                                           SourceLocation Loc) {
-  PP.EnterToken(Tok, /*IsReinject*/ true);
-  PP.EnterTokenStream(Toks, /*DisableMacroExpansion=*/true,
-                      /*IsReinject*/ true);
-  // Consume the previously pushed token.
-  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
-  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
-
-  FNContextRAII FnContext(*this, Ptr);
-  // Parse function declaration id.
-  SourceLocation RLoc;
-  // Parse with IsAddressOfOperand set to true to parse methods as DeclRefExprs
-  // instead of MemberExprs.
-  ExprResult AssociatedFunction =
-      ParseOpenMPParensExpr(getOpenMPDirectiveName(OMPD_declare_variant), RLoc,
-                            /*IsAddressOfOperand=*/true);
-  if (!AssociatedFunction.isUsable()) {
-    if (!Tok.is(tok::annot_pragma_openmp_end))
-      while (!SkipUntil(tok::annot_pragma_openmp_end, StopBeforeMatch))
-        ;
-    // Skip the last annot_pragma_openmp_end.
-    (void)ConsumeAnnotationToken();
-    return;
-  }
-  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
-      Actions.checkOpenMPDeclareVariantFunction(
-          Ptr, AssociatedFunction.get(), SourceRange(Loc, Tok.getLocation()));
-
+void Parser::ParseOMPDeclareVariantMatchClause(
+    SourceLocation Loc, SmallVectorImpl<Sema::OMPCtxSelectorData> &Data) {
   // Parse 'match'.
   OpenMPClauseKind CKind = Tok.isAnnotation()
                                ? OMPC_unknown
@@ -1103,7 +1080,6 @@
   }
 
   // Parse inner context selectors.
-  SmallVector<Sema::OMPCtxSelectorData, 4> Data;
   if (!parseOpenMPContextSelectors(Loc, Data)) {
     // Parse ')'.
     (void)T.consumeClose();
@@ -1113,6 +1089,41 @@
           << getOpenMPDirectiveName(OMPD_declare_variant);
     }
   }
+}
+
+/// Parse clauses for '#pragma omp declare variant ( variant-func-id ) clause'.
+void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
+                                           CachedTokens &Toks,
+                                           SourceLocation Loc) {
+  PP.EnterToken(Tok, /*IsReinject*/ true);
+  PP.EnterTokenStream(Toks, /*DisableMacroExpansion=*/true,
+                      /*IsReinject*/ true);
+  // Consume the previously pushed token.
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+
+  FNContextRAII FnContext(*this, Ptr);
+  // Parse function declaration id.
+  SourceLocation RLoc;
+  // Parse with IsAddressOfOperand set to true to parse methods as DeclRefExprs
+  // instead of MemberExprs.
+  ExprResult AssociatedFunction =
+      ParseOpenMPParensExpr(getOpenMPDirectiveName(OMPD_declare_variant), RLoc,
+                            /*IsAddressOfOperand=*/true);
+  if (!AssociatedFunction.isUsable()) {
+    if (!Tok.is(tok::annot_pragma_openmp_end))
+      while (!SkipUntil(tok::annot_pragma_openmp_end, StopBeforeMatch))
+        ;
+    // Skip the last annot_pragma_openmp_end.
+    (void)ConsumeAnnotationToken();
+    return;
+  }
+  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
+      Actions.checkOpenMPDeclareVariantFunction(
+          Ptr, AssociatedFunction.get(), SourceRange(Loc, Tok.getLocation()));
+
+  SmallVector<Sema::OMPCtxSelectorData, 4> Data;
+  ParseOMPDeclareVariantMatchClause(Loc, Data);
 
   // Skip last tokens.
   while (Tok.isNot(tok::annot_pragma_openmp_end))
@@ -1446,6 +1457,46 @@
     }
     break;
   }
+  case OMPD_begin_declare_variant: {
+    // The syntax is:
+    // { #pragma omp begin declare variant clause }
+    // <function-declaration-or-definition-sequence>
+    // { #pragma omp end declare variant }
+    //
+    ConsumeToken();
+
+    SmallVector<Sema::OMPCtxSelectorData, 4> Data;
+    ParseOMPDeclareVariantMatchClause(Loc, Data);
+
+    // Skip last tokens.
+    while (Tok.isNot(tok::annot_pragma_openmp_end))
+      ConsumeAnyToken();
+
+    bool Elide = Actions.ActOnOpenMPDeclareVariantDirective(
+        nullptr, nullptr, SourceRange(Loc, Tok.getLocation()), Data);
+    if (!Elide)
+      break;
+
+    // Elide all the code till the matching end declare variant was found.
+    unsigned Nesting = 1;
+    do {
+      ConsumeAnyToken();
+      OpenMPDirectiveKind DK = parseOpenMPDirectiveKind(*this);
+      if (DK == OMPD_end_declare_variant)
+        --Nesting;
+      if (DK == OMPD_begin_declare_variant)
+        ++Nesting;
+    } while (Nesting);
+
+    LLVM_FALLTHROUGH;
+  }
+  case OMPD_end_declare_variant:
+    assert(getActions().DeclareVariantScopeAttr &&
+           "TODO error for unmatched end declare variant");
+    // TODO: verify DeclareVariantScopeAttr is null after parsing
+    // TODO: Make this a call in the SEMA
+    getActions().DeclareVariantScopeAttr = nullptr;
+    break;
   case OMPD_declare_variant:
   case OMPD_declare_simd: {
     // The syntax is:
@@ -1932,6 +1983,8 @@
   case OMPD_end_declare_target:
   case OMPD_requires:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
     Diag(Tok, diag::err_omp_unexpected_directive)
         << 1 << getOpenMPDirectiveName(DKind);
     SkipUntil(tok::annot_pragma_openmp_end);
Index: clang/lib/Headers/openmp_wrappers/math.h
===================================================================
--- clang/lib/Headers/openmp_wrappers/math.h
+++ clang/lib/Headers/openmp_wrappers/math.h
@@ -9,9 +9,4 @@
 
 #include <__clang_openmp_math.h>
 
-#ifndef __CLANG_NO_HOST_MATH__
 #include_next <math.h>
-#else
-#undef __CLANG_NO_HOST_MATH__
-#endif
-
Index: clang/lib/Headers/openmp_wrappers/cmath
===================================================================
--- clang/lib/Headers/openmp_wrappers/cmath
+++ clang/lib/Headers/openmp_wrappers/cmath
@@ -9,8 +9,4 @@
 
 #include <__clang_openmp_math.h>
 
-#ifndef __CLANG_NO_HOST_MATH__
 #include_next <cmath>
-#else
-#undef __CLANG_NO_HOST_MATH__
-#endif
Index: clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
===================================================================
--- clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
+++ clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
@@ -18,6 +18,8 @@
 
 #define __CUDA__
 
+#pragma omp begin declare variant match(device = {kind(gpu)})
+
 #if defined(__cplusplus)
   #include <__clang_cuda_math_forward_declares.h>
 #endif
@@ -27,6 +29,8 @@
 /// Provide definitions for these functions.
 #include <__clang_cuda_device_functions.h>
 
+#pragma omp end declare variant
+
 #undef __CUDA__
 
 #endif
Index: clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
===================================================================
--- clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
+++ clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
@@ -8,17 +8,6 @@
  */
 
 #if defined(__NVPTX__) && defined(_OPENMP)
-/// TODO:
-/// We are currently reusing the functionality of the Clang-CUDA code path
-/// as an alternative to the host declarations provided by math.h and cmath.
-/// This is suboptimal.
-///
-/// We should instead declare the device functions in a similar way, e.g.,
-/// through OpenMP 5.0 variants, and afterwards populate the module with the
-/// host declarations by unconditionally including the host math.h or cmath,
-/// respectively. This is actually what the Clang-CUDA code path does, using
-/// __device__ instead of variants to avoid redeclarations and get the desired
-/// overload resolution.
 
 #define __CUDA__
 
@@ -28,8 +17,5 @@
 
 #undef __CUDA__
 
-/// Magic macro for stopping the math.h/cmath host header from being included.
-#define __CLANG_NO_HOST_MATH__
-
 #endif
 
Index: clang/lib/Headers/__clang_cuda_math_forward_declares.h
===================================================================
--- clang/lib/Headers/__clang_cuda_math_forward_declares.h
+++ clang/lib/Headers/__clang_cuda_math_forward_declares.h
@@ -27,30 +27,8 @@
   static __inline__ __attribute__((always_inline)) __attribute__((device))
 #endif
 
-// For C++ 17 we need to include noexcept attribute to be compatible
-// with the header-defined version. This may be removed once
-// variant is supported.
-#if defined(_OPENMP) && defined(__cplusplus) && __cplusplus >= 201703L
-#define __NOEXCEPT noexcept
-#else
-#define __NOEXCEPT
-#endif
-
-#if !(defined(_OPENMP) && defined(__cplusplus))
 __DEVICE__ long abs(long);
 __DEVICE__ long long abs(long long);
-__DEVICE__ double abs(double);
-__DEVICE__ float abs(float);
-#endif
-// While providing the CUDA declarations and definitions for math functions,
-// we may manually define additional functions.
-// TODO: Once variant is supported the additional functions will have
-// to be removed.
-#if defined(_OPENMP) && defined(__cplusplus)
-__DEVICE__ const double abs(const double);
-__DEVICE__ const float abs(const float);
-#endif
-__DEVICE__ int abs(int) __NOEXCEPT;
 __DEVICE__ double acos(double);
 __DEVICE__ float acos(float);
 __DEVICE__ double acosh(double);
@@ -85,8 +63,8 @@
 __DEVICE__ float exp(float);
 __DEVICE__ double expm1(double);
 __DEVICE__ float expm1(float);
-__DEVICE__ double fabs(double) __NOEXCEPT;
-__DEVICE__ float fabs(float) __NOEXCEPT;
+__DEVICE__ double fabs(double);
+__DEVICE__ float fabs(float);
 __DEVICE__ double fdim(double, double);
 __DEVICE__ float fdim(float, float);
 __DEVICE__ double floor(double);
@@ -99,8 +77,6 @@
 __DEVICE__ float fmin(float, float);
 __DEVICE__ double fmod(double, double);
 __DEVICE__ float fmod(float, float);
-__DEVICE__ int fpclassify(double);
-__DEVICE__ int fpclassify(float);
 __DEVICE__ double frexp(double, int *);
 __DEVICE__ float frexp(float, int *);
 __DEVICE__ double hypot(double, double);
@@ -136,12 +112,12 @@
 __DEVICE__ bool isnormal(float);
 __DEVICE__ bool isunordered(double, double);
 __DEVICE__ bool isunordered(float, float);
-__DEVICE__ long labs(long) __NOEXCEPT;
+__DEVICE__ long labs(long);
 __DEVICE__ double ldexp(double, int);
 __DEVICE__ float ldexp(float, int);
 __DEVICE__ double lgamma(double);
 __DEVICE__ float lgamma(float);
-__DEVICE__ long long llabs(long long) __NOEXCEPT;
+__DEVICE__ long long llabs(long long);
 __DEVICE__ long long llrint(double);
 __DEVICE__ long long llrint(float);
 __DEVICE__ double log10(double);
@@ -152,9 +128,7 @@
 __DEVICE__ float log2(float);
 __DEVICE__ double logb(double);
 __DEVICE__ float logb(float);
-#if defined(_OPENMP) && defined(__cplusplus)
 __DEVICE__ long double log(long double);
-#endif
 __DEVICE__ double log(double);
 __DEVICE__ float log(float);
 __DEVICE__ long lrint(double);
@@ -245,7 +219,6 @@
 using ::fmax;
 using ::fmin;
 using ::fmod;
-using ::fpclassify;
 using ::frexp;
 using ::hypot;
 using ::ilogb;
@@ -302,7 +275,6 @@
 } // namespace std
 #endif
 
-#undef __NOEXCEPT
 #pragma pop_macro("__DEVICE__")
 
 #endif
Index: clang/lib/Headers/__clang_cuda_device_functions.h
===================================================================
--- clang/lib/Headers/__clang_cuda_device_functions.h
+++ clang/lib/Headers/__clang_cuda_device_functions.h
@@ -37,15 +37,6 @@
 #define __FAST_OR_SLOW(fast, slow) slow
 #endif
 
-// For C++ 17 we need to include noexcept attribute to be compatible
-// with the header-defined version. This may be removed once
-// variant is supported.
-#if defined(_OPENMP) && defined(__cplusplus) && __cplusplus >= 201703L
-#define __NOEXCEPT noexcept
-#else
-#define __NOEXCEPT
-#endif
-
 __DEVICE__ int __all(int __a) { return __nvvm_vote_all(__a); }
 __DEVICE__ int __any(int __a) { return __nvvm_vote_any(__a); }
 __DEVICE__ unsigned int __ballot(int __a) { return __nvvm_vote_ballot(__a); }
@@ -53,13 +44,8 @@
 __DEVICE__ unsigned long long __brevll(unsigned long long __a) {
   return __nv_brevll(__a);
 }
-#if defined(__cplusplus)
 __DEVICE__ void __brkpt() { asm volatile("brkpt;"); }
 __DEVICE__ void __brkpt(int __a) { __brkpt(); }
-#else
-__DEVICE__ void __attribute__((overloadable)) __brkpt(void) { asm volatile("brkpt;"); }
-__DEVICE__ void __attribute__((overloadable)) __brkpt(int __a) { __brkpt(); }
-#endif
 __DEVICE__ unsigned int __byte_perm(unsigned int __a, unsigned int __b,
                                     unsigned int __c) {
   return __nv_byte_perm(__a, __b, __c);
@@ -1483,8 +1469,8 @@
   return r;
 }
 #endif // CUDA_VERSION >= 9020
-__DEVICE__ int abs(int __a) __NOEXCEPT { return __nv_abs(__a); }
-__DEVICE__ double fabs(double __a) __NOEXCEPT { return __nv_fabs(__a); }
+__DEVICE__ int abs(int __a) { return __nv_abs(__a); }
+__DEVICE__ double fabs(double __a) { return __nv_fabs(__a); }
 __DEVICE__ double acos(double __a) { return __nv_acos(__a); }
 __DEVICE__ float acosf(float __a) { return __nv_acosf(__a); }
 __DEVICE__ double acosh(double __a) { return __nv_acosh(__a); }
@@ -1581,15 +1567,15 @@
 __DEVICE__ double jn(int __n, double __a) { return __nv_jn(__n, __a); }
 __DEVICE__ float jnf(int __n, float __a) { return __nv_jnf(__n, __a); }
 #if defined(__LP64__) || defined(_WIN64)
-__DEVICE__ long labs(long __a) __NOEXCEPT { return __nv_llabs(__a); };
+__DEVICE__ long labs(long __a) { return __nv_llabs(__a); };
 #else
-__DEVICE__ long labs(long __a) __NOEXCEPT { return __nv_abs(__a); };
+__DEVICE__ long labs(long __a) { return __nv_abs(__a); };
 #endif
 __DEVICE__ double ldexp(double __a, int __b) { return __nv_ldexp(__a, __b); }
 __DEVICE__ float ldexpf(float __a, int __b) { return __nv_ldexpf(__a, __b); }
 __DEVICE__ double lgamma(double __a) { return __nv_lgamma(__a); }
 __DEVICE__ float lgammaf(float __a) { return __nv_lgammaf(__a); }
-__DEVICE__ long long llabs(long long __a) __NOEXCEPT { return __nv_llabs(__a); }
+__DEVICE__ long long llabs(long long __a) { return __nv_llabs(__a); }
 __DEVICE__ long long llmax(long long __a, long long __b) {
   return __nv_llmax(__a, __b);
 }
@@ -1719,23 +1705,6 @@
 __DEVICE__ float rsqrtf(float __a) { return __nv_rsqrtf(__a); }
 __DEVICE__ double scalbn(double __a, int __b) { return __nv_scalbn(__a, __b); }
 __DEVICE__ float scalbnf(float __a, int __b) { return __nv_scalbnf(__a, __b); }
-// TODO: remove once variant is supported
-#ifndef _OPENMP
-__DEVICE__ double scalbln(double __a, long __b) {
-  if (__b > INT_MAX)
-    return __a > 0 ? HUGE_VAL : -HUGE_VAL;
-  if (__b < INT_MIN)
-    return __a > 0 ? 0.0 : -0.0;
-  return scalbn(__a, (int)__b);
-}
-__DEVICE__ float scalblnf(float __a, long __b) {
-  if (__b > INT_MAX)
-    return __a > 0 ? HUGE_VALF : -HUGE_VALF;
-  if (__b < INT_MIN)
-    return __a > 0 ? 0.f : -0.f;
-  return scalbnf(__a, (int)__b);
-}
-#endif
 __DEVICE__ double sin(double __a) { return __nv_sin(__a); }
 __DEVICE__ void sincos(double __a, double *__s, double *__c) {
   return __nv_sincos(__a, __s, __c);
@@ -1787,7 +1756,7 @@
 __DEVICE__ double yn(int __a, double __b) { return __nv_yn(__a, __b); }
 __DEVICE__ float ynf(int __a, float __b) { return __nv_ynf(__a, __b); }
 
-#undef __NOEXCEPT
 #pragma pop_macro("__DEVICE__")
 #pragma pop_macro("__FAST_OR_SLOW")
+
 #endif // __CLANG_CUDA_DEVICE_FUNCTIONS_H__
Index: clang/lib/Headers/__clang_cuda_cmath.h
===================================================================
--- clang/lib/Headers/__clang_cuda_cmath.h
+++ clang/lib/Headers/__clang_cuda_cmath.h
@@ -32,30 +32,15 @@
 
 #ifdef _OPENMP
 #define __DEVICE__ static __attribute__((always_inline))
+#pragma omp begin declare variant match(device = {kind(gpu)})
 #else
 #define __DEVICE__ static __device__ __inline__ __attribute__((always_inline))
 #endif
 
-// For C++ 17 we need to include noexcept attribute to be compatible
-// with the header-defined version. This may be removed once
-// variant is supported.
-#if defined(_OPENMP) && defined(__cplusplus) && __cplusplus >= 201703L
-#define __NOEXCEPT noexcept
-#else
-#define __NOEXCEPT
-#endif
-
-#if !(defined(_OPENMP) && defined(__cplusplus))
 __DEVICE__ long long abs(long long __n) { return ::llabs(__n); }
 __DEVICE__ long abs(long __n) { return ::labs(__n); }
 __DEVICE__ float abs(float __x) { return ::fabsf(__x); }
 __DEVICE__ double abs(double __x) { return ::fabs(__x); }
-#endif
-// TODO: remove once variat is supported.
-#if defined(_OPENMP) && defined(__cplusplus)
-__DEVICE__ const float abs(const float __x) { return ::fabsf((float)__x); }
-__DEVICE__ const double abs(const double __x) { return ::fabs((double)__x); }
-#endif
 __DEVICE__ float acos(float __x) { return ::acosf(__x); }
 __DEVICE__ float asin(float __x) { return ::asinf(__x); }
 __DEVICE__ float atan(float __x) { return ::atanf(__x); }
@@ -64,20 +49,9 @@
 __DEVICE__ float cos(float __x) { return ::cosf(__x); }
 __DEVICE__ float cosh(float __x) { return ::coshf(__x); }
 __DEVICE__ float exp(float __x) { return ::expf(__x); }
-__DEVICE__ float fabs(float __x) __NOEXCEPT { return ::fabsf(__x); }
+__DEVICE__ float fabs(float __x) { return ::fabsf(__x); }
 __DEVICE__ float floor(float __x) { return ::floorf(__x); }
 __DEVICE__ float fmod(float __x, float __y) { return ::fmodf(__x, __y); }
-// TODO: remove when variant is supported
-#ifndef _OPENMP
-__DEVICE__ int fpclassify(float __x) {
-  return __builtin_fpclassify(FP_NAN, FP_INFINITE, FP_NORMAL, FP_SUBNORMAL,
-                              FP_ZERO, __x);
-}
-__DEVICE__ int fpclassify(double __x) {
-  return __builtin_fpclassify(FP_NAN, FP_INFINITE, FP_NORMAL, FP_SUBNORMAL,
-                              FP_ZERO, __x);
-}
-#endif
 __DEVICE__ float frexp(float __arg, int *__exp) {
   return ::frexpf(__arg, __exp);
 }
@@ -232,7 +206,6 @@
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, fmax);
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, fmin);
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, fmod);
-__CUDA_CLANG_FN_INTEGER_OVERLOAD_1(int, fpclassify)
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, hypot);
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_1(int, ilogb)
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_1(bool, isfinite)
@@ -360,7 +333,6 @@
 using ::fmax;
 using ::fmin;
 using ::fmod;
-using ::fpclassify;
 using ::frexp;
 using ::hypot;
 using ::ilogb;
@@ -457,10 +429,6 @@
 using ::remquof;
 using ::rintf;
 using ::roundf;
-// TODO: remove once variant is supported
-#ifndef _OPENMP
-using ::scalblnf;
-#endif
 using ::scalbnf;
 using ::sinf;
 using ::sinhf;
@@ -479,6 +447,10 @@
 } // namespace std
 #endif
 
+#ifdef _OPENMP
+#pragma omp end declare variant
+#endif
+
 #undef __NOEXCEPT
 #undef __DEVICE__
 
Index: clang/lib/CodeGen/CodeGenModule.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenModule.cpp
+++ clang/lib/CodeGen/CodeGenModule.cpp
@@ -982,6 +982,18 @@
   }
 }
 
+static void AppendOpenMPVariantMangling(const CodeGenModule &CGM,
+                                        const FunctionDecl *FD,
+                                        raw_ostream &Out) {
+  for (const OMPDeclareVariantAttr *Attr :
+       FD->specific_attrs<OMPDeclareVariantAttr>()) {
+    if (Attr->isInherited())
+      continue;
+    // TODO: Mangle the name based on the context
+    Out << ".ompvariant";
+  }
+}
+
 static std::string getMangledNameImpl(const CodeGenModule &CGM, GlobalDecl GD,
                                       const NamedDecl *ND,
                                       bool OmitMultiVersionMangling = false) {
@@ -1022,6 +1034,9 @@
       case MultiVersionKind::Target:
         AppendTargetMangling(CGM, FD->getAttr<TargetAttr>(), Out);
         break;
+      case MultiVersionKind::OMPVariant:
+        AppendOpenMPVariantMangling(CGM, FD, Out);
+        break;
       case MultiVersionKind::None:
         llvm_unreachable("None multiversion type isn't valid here");
       }
@@ -2534,7 +2549,8 @@
     return;
   }
 
-    // Check if this must be emitted as declare variant.
+  // Check if this must be emitted as declare variant.
+  // TODO: See the TODO at the other emitDeclareVariant call.
   if (LangOpts.OpenMP && isa<FunctionDecl>(Global) && OpenMPRuntime &&
       OpenMPRuntime->emitDeclareVariant(GD, /*IsForDefinition=*/false))
     return;
@@ -2856,6 +2872,10 @@
   for (GlobalDecl GD : MultiVersionFuncs) {
     SmallVector<CodeGenFunction::MultiVersionResolverOption, 10> Options;
     const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
+    // OpenMP multi versioning is (for now) resolved at compile time, no
+    // resolver function necessary (yet).
+    if (FD->isOpenMPMultiVersion())
+      continue;
     getContext().forEachMultiversionedFunctionVersion(
         FD, [this, &GD, &Options](const FunctionDecl *CurFD) {
           GlobalDecl CurGD{
@@ -3104,10 +3124,17 @@
     }
     // Check if this must be emitted as declare variant and emit reference to
     // the the declare variant function.
+    // TODO: We should introduce function aliases for `omp declare variant`
+    //       directives such that we can treat them through the same overload
+    //       resolution scheme (via multi versioning) as `omp begin declare
+    //       variant` functions. For an `omp declare variant(VARIANT) ...`
+    //       that is attached to a BASE function we would create a global alias
+    //       VARIANT = BASE which will participate in the multi version overload
+    //       resolution. If picked, here is no need to emit them explicitly.
     if (LangOpts.OpenMP && OpenMPRuntime)
       (void)OpenMPRuntime->emitDeclareVariant(GD, /*IsForDefinition=*/true);
 
-    if (FD->isMultiVersion()) {
+    if (FD->isMultiVersion() && !FD->isOpenMPMultiVersion()) {
       const auto *TA = FD->getAttr<TargetAttr>();
       if (TA && TA->isDefaultVersion())
         UpdateMultiVersionNames(GD, FD);
@@ -4385,6 +4412,7 @@
 void CodeGenModule::EmitGlobalFunctionDefinition(GlobalDecl GD,
                                                  llvm::GlobalValue *GV) {
   // Check if this must be emitted as declare variant.
+  // TODO: See the TODO at the other emitDeclareVariant call.
   if (LangOpts.OpenMP && OpenMPRuntime &&
       OpenMPRuntime->emitDeclareVariant(GD, /*IsForDefinition=*/true))
     return;
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -11042,231 +11042,6 @@
   return Address(Addr, Align);
 }
 
-namespace {
-using OMPContextSelectorData =
-    OpenMPCtxSelectorData<ArrayRef<StringRef>, llvm::APSInt>;
-using CompleteOMPContextSelectorData = SmallVector<OMPContextSelectorData, 4>;
-} // anonymous namespace
-
-/// Checks current context and returns true if it matches the context selector.
-template <OpenMPContextSelectorSetKind CtxSet, OpenMPContextSelectorKind Ctx,
-          typename... Arguments>
-static bool checkContext(const OMPContextSelectorData &Data,
-                         Arguments... Params) {
-  assert(Data.CtxSet != OMP_CTX_SET_unknown && Data.Ctx != OMP_CTX_unknown &&
-         "Unknown context selector or context selector set.");
-  return false;
-}
-
-/// Checks for implementation={vendor(<vendor>)} context selector.
-/// \returns true iff <vendor>="llvm", false otherwise.
-template <>
-bool checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(
-    const OMPContextSelectorData &Data) {
-  return llvm::all_of(Data.Names,
-                      [](StringRef S) { return !S.compare_lower("llvm"); });
-}
-
-/// Checks for device={kind(<kind>)} context selector.
-/// \returns true if <kind>="host" and compilation is for host.
-/// true if <kind>="nohost" and compilation is for device.
-/// true if <kind>="cpu" and compilation is for Arm, X86 or PPC CPU.
-/// true if <kind>="gpu" and compilation is for NVPTX or AMDGCN.
-/// false otherwise.
-template <>
-bool checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(
-    const OMPContextSelectorData &Data, CodeGenModule &CGM) {
-  for (StringRef Name : Data.Names) {
-    if (!Name.compare_lower("host")) {
-      if (CGM.getLangOpts().OpenMPIsDevice)
-        return false;
-      continue;
-    }
-    if (!Name.compare_lower("nohost")) {
-      if (!CGM.getLangOpts().OpenMPIsDevice)
-        return false;
-      continue;
-    }
-    switch (CGM.getTriple().getArch()) {
-    case llvm::Triple::arm:
-    case llvm::Triple::armeb:
-    case llvm::Triple::aarch64:
-    case llvm::Triple::aarch64_be:
-    case llvm::Triple::aarch64_32:
-    case llvm::Triple::ppc:
-    case llvm::Triple::ppc64:
-    case llvm::Triple::ppc64le:
-    case llvm::Triple::x86:
-    case llvm::Triple::x86_64:
-      if (Name.compare_lower("cpu"))
-        return false;
-      break;
-    case llvm::Triple::amdgcn:
-    case llvm::Triple::nvptx:
-    case llvm::Triple::nvptx64:
-      if (Name.compare_lower("gpu"))
-        return false;
-      break;
-    case llvm::Triple::UnknownArch:
-    case llvm::Triple::arc:
-    case llvm::Triple::avr:
-    case llvm::Triple::bpfel:
-    case llvm::Triple::bpfeb:
-    case llvm::Triple::hexagon:
-    case llvm::Triple::mips:
-    case llvm::Triple::mipsel:
-    case llvm::Triple::mips64:
-    case llvm::Triple::mips64el:
-    case llvm::Triple::msp430:
-    case llvm::Triple::r600:
-    case llvm::Triple::riscv32:
-    case llvm::Triple::riscv64:
-    case llvm::Triple::sparc:
-    case llvm::Triple::sparcv9:
-    case llvm::Triple::sparcel:
-    case llvm::Triple::systemz:
-    case llvm::Triple::tce:
-    case llvm::Triple::tcele:
-    case llvm::Triple::thumb:
-    case llvm::Triple::thumbeb:
-    case llvm::Triple::xcore:
-    case llvm::Triple::le32:
-    case llvm::Triple::le64:
-    case llvm::Triple::amdil:
-    case llvm::Triple::amdil64:
-    case llvm::Triple::hsail:
-    case llvm::Triple::hsail64:
-    case llvm::Triple::spir:
-    case llvm::Triple::spir64:
-    case llvm::Triple::kalimba:
-    case llvm::Triple::shave:
-    case llvm::Triple::lanai:
-    case llvm::Triple::wasm32:
-    case llvm::Triple::wasm64:
-    case llvm::Triple::renderscript32:
-    case llvm::Triple::renderscript64:
-      return false;
-    }
-  }
-  return true;
-}
-
-bool matchesContext(CodeGenModule &CGM,
-                    const CompleteOMPContextSelectorData &ContextData) {
-  for (const OMPContextSelectorData &Data : ContextData) {
-    switch (Data.Ctx) {
-    case OMP_CTX_vendor:
-      assert(Data.CtxSet == OMP_CTX_SET_implementation &&
-             "Expected implementation context selector set.");
-      if (!checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(Data))
-        return false;
-      break;
-    case OMP_CTX_kind:
-      assert(Data.CtxSet == OMP_CTX_SET_device &&
-             "Expected device context selector set.");
-      if (!checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(Data,
-                                                                           CGM))
-        return false;
-      break;
-    case OMP_CTX_unknown:
-      llvm_unreachable("Unknown context selector kind.");
-    }
-  }
-  return true;
-}
-
-static CompleteOMPContextSelectorData
-translateAttrToContextSelectorData(ASTContext &C,
-                                   const OMPDeclareVariantAttr *A) {
-  CompleteOMPContextSelectorData Data;
-  for (unsigned I = 0, E = A->scores_size(); I < E; ++I) {
-    Data.emplace_back();
-    auto CtxSet = static_cast<OpenMPContextSelectorSetKind>(
-        *std::next(A->ctxSelectorSets_begin(), I));
-    auto Ctx = static_cast<OpenMPContextSelectorKind>(
-        *std::next(A->ctxSelectors_begin(), I));
-    Data.back().CtxSet = CtxSet;
-    Data.back().Ctx = Ctx;
-    const Expr *Score = *std::next(A->scores_begin(), I);
-    Data.back().Score = Score->EvaluateKnownConstInt(C);
-    switch (Ctx) {
-    case OMP_CTX_vendor:
-      assert(CtxSet == OMP_CTX_SET_implementation &&
-             "Expected implementation context selector set.");
-      Data.back().Names =
-          llvm::makeArrayRef(A->implVendors_begin(), A->implVendors_end());
-      break;
-    case OMP_CTX_kind:
-      assert(CtxSet == OMP_CTX_SET_device &&
-             "Expected device context selector set.");
-      Data.back().Names =
-          llvm::makeArrayRef(A->deviceKinds_begin(), A->deviceKinds_end());
-      break;
-    case OMP_CTX_unknown:
-      llvm_unreachable("Unknown context selector kind.");
-    }
-  }
-  return Data;
-}
-
-static bool isStrictSubset(const CompleteOMPContextSelectorData &LHS,
-                           const CompleteOMPContextSelectorData &RHS) {
-  llvm::SmallDenseMap<std::pair<int, int>, llvm::StringSet<>, 4> RHSData;
-  for (const OMPContextSelectorData &D : RHS) {
-    auto &Pair = RHSData.FindAndConstruct(std::make_pair(D.CtxSet, D.Ctx));
-    Pair.getSecond().insert(D.Names.begin(), D.Names.end());
-  }
-  bool AllSetsAreEqual = true;
-  for (const OMPContextSelectorData &D : LHS) {
-    auto It = RHSData.find(std::make_pair(D.CtxSet, D.Ctx));
-    if (It == RHSData.end())
-      return false;
-    if (D.Names.size() > It->getSecond().size())
-      return false;
-    if (llvm::set_union(It->getSecond(), D.Names))
-      return false;
-    AllSetsAreEqual =
-        AllSetsAreEqual && (D.Names.size() == It->getSecond().size());
-  }
-
-  return LHS.size() != RHS.size() || !AllSetsAreEqual;
-}
-
-static bool greaterCtxScore(const CompleteOMPContextSelectorData &LHS,
-                            const CompleteOMPContextSelectorData &RHS) {
-  // Score is calculated as sum of all scores + 1.
-  llvm::APSInt LHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
-  bool RHSIsSubsetOfLHS = isStrictSubset(RHS, LHS);
-  if (RHSIsSubsetOfLHS) {
-    LHSScore = llvm::APSInt::get(0);
-  } else {
-    for (const OMPContextSelectorData &Data : LHS) {
-      if (Data.Score.getBitWidth() > LHSScore.getBitWidth()) {
-        LHSScore = LHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
-      } else if (Data.Score.getBitWidth() < LHSScore.getBitWidth()) {
-        LHSScore += Data.Score.extend(LHSScore.getBitWidth());
-      } else {
-        LHSScore += Data.Score;
-      }
-    }
-  }
-  llvm::APSInt RHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
-  if (!RHSIsSubsetOfLHS && isStrictSubset(LHS, RHS)) {
-    RHSScore = llvm::APSInt::get(0);
-  } else {
-    for (const OMPContextSelectorData &Data : RHS) {
-      if (Data.Score.getBitWidth() > RHSScore.getBitWidth()) {
-        RHSScore = RHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
-      } else if (Data.Score.getBitWidth() < RHSScore.getBitWidth()) {
-        RHSScore += Data.Score.extend(RHSScore.getBitWidth());
-      } else {
-        RHSScore += Data.Score;
-      }
-    }
-  }
-  return llvm::APSInt::compareValues(LHSScore, RHSScore) >= 0;
-}
-
 /// Finds the variant function that matches current context with its context
 /// selector.
 static const FunctionDecl *getDeclareVariantFunction(CodeGenModule &CGM,
@@ -11275,21 +11050,12 @@
     return FD;
   // Iterate through all DeclareVariant attributes and check context selectors.
   const OMPDeclareVariantAttr *TopMostAttr = nullptr;
-  CompleteOMPContextSelectorData TopMostData;
-  for (const auto *A : FD->specific_attrs<OMPDeclareVariantAttr>()) {
-    CompleteOMPContextSelectorData Data =
-        translateAttrToContextSelectorData(CGM.getContext(), A);
-    if (!matchesContext(CGM, Data))
-      continue;
-    // If the attribute matches the context, find the attribute with the highest
-    // score.
-    if (!TopMostAttr || !greaterCtxScore(TopMostData, Data)) {
-      TopMostAttr = A;
-      TopMostData.swap(Data);
-    }
-  }
+  for (const auto *A : FD->specific_attrs<OMPDeclareVariantAttr>())
+    TopMostAttr = getBetterOpenMPContextMatch(CGM.getContext(), TopMostAttr, A);
   if (!TopMostAttr)
     return FD;
+  if (!TopMostAttr->getVariantFuncRef())
+    return FD;
   return cast<FunctionDecl>(
       cast<DeclRefExpr>(TopMostAttr->getVariantFuncRef()->IgnoreParenImpCasts())
           ->getDecl());
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -13,6 +13,8 @@
 #include "clang/AST/StmtOpenMP.h"
 
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/Attr.h"
+#include "llvm/ADT/SetOperations.h"
 
 using namespace clang;
 
@@ -2239,3 +2241,263 @@
   return new (Mem)
       OMPTargetTeamsDistributeSimdDirective(CollapsedNum, NumClauses);
 }
+
+// TODO: We have various representations for the same data, it might help to
+//       reuse some instead of converting them.
+// TODO: It is unclear where this checking code should live. It is used all over
+//       the place and would probably fit bet in OMPDeclareVariantAttr.
+using OMPContextSelectorData =
+    OpenMPCtxSelectorData<ArrayRef<StringRef>, llvm::APSInt>;
+using CompleteOMPContextSelectorData = SmallVector<OMPContextSelectorData, 4>;
+
+/// Checks current context and returns true if it matches the context selector.
+template <OpenMPContextSelectorSetKind CtxSet, OpenMPContextSelectorKind Ctx,
+          typename... Arguments>
+static bool checkContext(const OMPContextSelectorData &Data,
+                         Arguments... Params) {
+  assert(Data.CtxSet != OMP_CTX_SET_unknown && Data.Ctx != OMP_CTX_unknown &&
+         "Unknown context selector or context selector set.");
+  return false;
+}
+
+/// Checks for implementation={vendor(<vendor>)} context selector.
+/// \returns true iff <vendor>="llvm", false otherwise.
+template <>
+bool checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(
+    const OMPContextSelectorData &Data) {
+  return llvm::all_of(Data.Names,
+                      [](StringRef S) { return !S.compare_lower("llvm"); });
+}
+
+/// Checks for device={kind(<kind>)} context selector.
+/// \returns true if <kind>="host" and compilation is for host.
+/// true if <kind>="nohost" and compilation is for device.
+/// true if <kind>="cpu" and compilation is for Arm, X86 or PPC CPU.
+/// true if <kind>="gpu" and compilation is for NVPTX or AMDGCN.
+/// false otherwise.
+template <>
+bool checkContext<OMP_CTX_SET_device, OMP_CTX_kind, const LangOptions &,
+                  const TargetInfo &>(const OMPContextSelectorData &Data,
+                                      const LangOptions &LO,
+                                      const TargetInfo &TI) {
+  for (StringRef Name : Data.Names) {
+    if (!Name.compare_lower("host")) {
+      if (LO.OpenMPIsDevice)
+        return false;
+      continue;
+    }
+    if (!Name.compare_lower("nohost")) {
+      if (!LO.OpenMPIsDevice)
+        return false;
+      continue;
+    }
+    switch (TI.getTriple().getArch()) {
+    case llvm::Triple::arm:
+    case llvm::Triple::armeb:
+    case llvm::Triple::aarch64:
+    case llvm::Triple::aarch64_be:
+    case llvm::Triple::aarch64_32:
+    case llvm::Triple::ppc:
+    case llvm::Triple::ppc64:
+    case llvm::Triple::ppc64le:
+    case llvm::Triple::x86:
+    case llvm::Triple::x86_64:
+      if (Name.compare_lower("cpu"))
+        return false;
+      break;
+    case llvm::Triple::amdgcn:
+    case llvm::Triple::nvptx:
+    case llvm::Triple::nvptx64:
+      if (Name.compare_lower("gpu"))
+        return false;
+      break;
+    case llvm::Triple::UnknownArch:
+    case llvm::Triple::arc:
+    case llvm::Triple::avr:
+    case llvm::Triple::bpfel:
+    case llvm::Triple::bpfeb:
+    case llvm::Triple::hexagon:
+    case llvm::Triple::mips:
+    case llvm::Triple::mipsel:
+    case llvm::Triple::mips64:
+    case llvm::Triple::mips64el:
+    case llvm::Triple::msp430:
+    case llvm::Triple::r600:
+    case llvm::Triple::riscv32:
+    case llvm::Triple::riscv64:
+    case llvm::Triple::sparc:
+    case llvm::Triple::sparcv9:
+    case llvm::Triple::sparcel:
+    case llvm::Triple::systemz:
+    case llvm::Triple::tce:
+    case llvm::Triple::tcele:
+    case llvm::Triple::thumb:
+    case llvm::Triple::thumbeb:
+    case llvm::Triple::xcore:
+    case llvm::Triple::le32:
+    case llvm::Triple::le64:
+    case llvm::Triple::amdil:
+    case llvm::Triple::amdil64:
+    case llvm::Triple::hsail:
+    case llvm::Triple::hsail64:
+    case llvm::Triple::spir:
+    case llvm::Triple::spir64:
+    case llvm::Triple::kalimba:
+    case llvm::Triple::shave:
+    case llvm::Triple::lanai:
+    case llvm::Triple::wasm32:
+    case llvm::Triple::wasm64:
+    case llvm::Triple::renderscript32:
+    case llvm::Triple::renderscript64:
+      return false;
+    }
+  }
+  return true;
+}
+
+static CompleteOMPContextSelectorData
+translateAttrToContextSelectorData(ASTContext &C,
+                                   const OMPDeclareVariantAttr *A) {
+  CompleteOMPContextSelectorData Data;
+  if (!A)
+    return Data;
+  for (unsigned I = 0, E = A->scores_size(); I < E; ++I) {
+    Data.emplace_back();
+    auto CtxSet = static_cast<OpenMPContextSelectorSetKind>(
+        *std::next(A->ctxSelectorSets_begin(), I));
+    auto Ctx = static_cast<OpenMPContextSelectorKind>(
+        *std::next(A->ctxSelectors_begin(), I));
+    Data.back().CtxSet = CtxSet;
+    Data.back().Ctx = Ctx;
+    const Expr *Score = *std::next(A->scores_begin(), I);
+    Data.back().Score = Score->EvaluateKnownConstInt(C);
+    switch (Ctx) {
+    case OMP_CTX_vendor:
+      assert(CtxSet == OMP_CTX_SET_implementation &&
+             "Expected implementation context selector set.");
+      Data.back().Names =
+          llvm::makeArrayRef(A->implVendors_begin(), A->implVendors_end());
+      break;
+    case OMP_CTX_kind:
+      assert(CtxSet == OMP_CTX_SET_device &&
+             "Expected device context selector set.");
+      Data.back().Names =
+          llvm::makeArrayRef(A->deviceKinds_begin(), A->deviceKinds_end());
+      break;
+    case OMP_CTX_unknown:
+      llvm_unreachable("Unknown context selector kind.");
+    }
+  }
+  return Data;
+}
+
+static bool
+matchesOpenMPContextImpl(const CompleteOMPContextSelectorData &ContextData,
+                         const LangOptions &LO, const TargetInfo &TI) {
+  for (const OMPContextSelectorData &Data : ContextData) {
+    switch (Data.Ctx) {
+    case OMP_CTX_vendor:
+      assert(Data.CtxSet == OMP_CTX_SET_implementation &&
+             "Expected implementation context selector set.");
+      if (!checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(Data))
+        return false;
+      break;
+    case OMP_CTX_kind:
+      assert(Data.CtxSet == OMP_CTX_SET_device &&
+             "Expected device context selector set.");
+      if (!checkContext<OMP_CTX_SET_device, OMP_CTX_kind, const LangOptions &,
+                        const TargetInfo &>(Data, LO, TI))
+        return false;
+      break;
+    case OMP_CTX_unknown:
+      llvm_unreachable("Unknown context selector kind.");
+    }
+  }
+  return true;
+}
+
+static bool isStrictSubset(const CompleteOMPContextSelectorData &LHS,
+                           const CompleteOMPContextSelectorData &RHS) {
+  llvm::SmallDenseMap<std::pair<int, int>, llvm::StringSet<>, 4> RHSData;
+  for (const OMPContextSelectorData &D : RHS) {
+    auto &Pair = RHSData.FindAndConstruct(std::make_pair(D.CtxSet, D.Ctx));
+    Pair.getSecond().insert(D.Names.begin(), D.Names.end());
+  }
+  bool AllSetsAreEqual = true;
+  for (const OMPContextSelectorData &D : LHS) {
+    auto It = RHSData.find(std::make_pair(D.CtxSet, D.Ctx));
+    if (It == RHSData.end())
+      return false;
+    if (D.Names.size() > It->getSecond().size())
+      return false;
+    if (llvm::set_union(It->getSecond(), D.Names))
+      return false;
+    AllSetsAreEqual =
+        AllSetsAreEqual && (D.Names.size() == It->getSecond().size());
+  }
+
+  return LHS.size() != RHS.size() || !AllSetsAreEqual;
+}
+
+const OMPDeclareVariantAttr *
+clang::getBetterOpenMPContextMatch(ASTContext &C,
+                                   const OMPDeclareVariantAttr *LHSAttr,
+                                   const OMPDeclareVariantAttr *RHSAttr) {
+  const CompleteOMPContextSelectorData LHS =
+      translateAttrToContextSelectorData(C, LHSAttr);
+  const CompleteOMPContextSelectorData RHS =
+      translateAttrToContextSelectorData(C, RHSAttr);
+  bool LHSMatch = LHSAttr && matchesOpenMPContextImpl(LHS, C.getLangOpts(),
+                                                      C.getTargetInfo());
+  bool RHSMatch = RHSAttr && matchesOpenMPContextImpl(RHS, C.getLangOpts(),
+                                                      C.getTargetInfo());
+  bool LHSisOK = LHSMatch && !LHSAttr->isInherited();
+  bool RHSisOK = RHSMatch && !RHSAttr->isInherited();
+  if (!LHSisOK && !RHSisOK)
+    return nullptr;
+  if (LHSisOK && !RHSisOK)
+    return LHSAttr;
+  if (!LHSisOK && RHSisOK)
+    return RHSAttr;
+  assert(LHSisOK && RHSisOK && "broken invariant");
+
+  // Score is calculated as sum of all scores + 1.
+  llvm::APSInt LHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
+  bool RHSIsSubsetOfLHS = isStrictSubset(RHS, LHS);
+  if (RHSIsSubsetOfLHS) {
+    LHSScore = llvm::APSInt::get(0);
+  } else {
+    for (const OMPContextSelectorData &Data : LHS) {
+      if (Data.Score.getBitWidth() > LHSScore.getBitWidth()) {
+        LHSScore = LHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
+      } else if (Data.Score.getBitWidth() < LHSScore.getBitWidth()) {
+        LHSScore += Data.Score.extend(LHSScore.getBitWidth());
+      } else {
+        LHSScore += Data.Score;
+      }
+    }
+  }
+  llvm::APSInt RHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
+  if (!RHSIsSubsetOfLHS && isStrictSubset(LHS, RHS)) {
+    RHSScore = llvm::APSInt::get(0);
+  } else {
+    for (const OMPContextSelectorData &Data : RHS) {
+      if (Data.Score.getBitWidth() > RHSScore.getBitWidth()) {
+        RHSScore = RHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
+      } else if (Data.Score.getBitWidth() < RHSScore.getBitWidth()) {
+        RHSScore += Data.Score.extend(RHSScore.getBitWidth());
+      } else {
+        RHSScore += Data.Score;
+      }
+    }
+  }
+  return llvm::APSInt::compareValues(LHSScore, RHSScore) >= 0 ? LHSAttr
+                                                              : RHSAttr;
+}
+
+bool clang::isOpenMPContextMatch(ASTContext &C,
+                                 const OMPDeclareVariantAttr *A) {
+  const CompleteOMPContextSelectorData Data =
+      translateAttrToContextSelectorData(C, A);
+  return matchesOpenMPContextImpl(Data, C.getLangOpts(), C.getTargetInfo());
+}
Index: clang/lib/AST/Decl.cpp
===================================================================
--- clang/lib/AST/Decl.cpp
+++ clang/lib/AST/Decl.cpp
@@ -3075,6 +3075,9 @@
     return MultiVersionKind::CPUDispatch;
   if (hasAttr<CPUSpecificAttr>())
     return MultiVersionKind::CPUSpecific;
+  if (hasAttr<OMPDeclareVariantAttr>() &&
+      !getAttr<OMPDeclareVariantAttr>()->getVariantFuncRef())
+    return MultiVersionKind::OMPVariant;
   return MultiVersionKind::None;
 }
 
@@ -3090,6 +3093,11 @@
   return isMultiVersion() && hasAttr<TargetAttr>();
 }
 
+bool FunctionDecl::isOpenMPMultiVersion() const {
+  return isMultiVersion() && hasAttr<OMPDeclareVariantAttr>() &&
+         !getAttr<OMPDeclareVariantAttr>()->getVariantFuncRef();
+}
+
 void
 FunctionDecl::setPreviousDeclaration(FunctionDecl *PrevDecl) {
   redeclarable_base::setPreviousDecl(PrevDecl);
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -9346,6 +9346,9 @@
   // OpenMP directives and clauses.
   //
 private:
+  /// Copies declare variant attributes from the template TD to the function FD.
+  void inheritOpenMPVariantAttrs(FunctionDecl *FD,
+                                 const FunctionTemplateDecl &TD);
   void *VarDataSharingAttributesStack;
   /// Number of nested '#pragma omp declare target' directives.
   unsigned DeclareTargetNestingLevel = 0;
@@ -9415,6 +9418,9 @@
   using OMPCtxSelectorData =
       OpenMPCtxSelectorData<SmallVector<OMPCtxStringType, 4>, ExprResult>;
 
+  /// A declare variant attribute if we are inside a begin/end declare variant
+  OMPDeclareVariantAttr *DeclareVariantScopeAttr = nullptr;
+
   /// Checks if the variant/multiversion functions are compatible.
   bool areMultiversionVariantFunctionsCompatible(
       const FunctionDecl *OldFD, const FunctionDecl *NewFD,
@@ -9422,7 +9428,9 @@
       const PartialDiagnosticAt &NoteCausedDiagIDAt,
       const PartialDiagnosticAt &NoSupportDiagIDAt,
       const PartialDiagnosticAt &DiffDiagIDAt, bool TemplatesSupported,
-      bool ConstexprSupported, bool CLinkageMayDiffer);
+      bool ConstexprSupported, bool CLinkageMayDiffer,
+      bool StorageClassMayDiffer, bool ConstexprSpecMayDiffer,
+      bool InlineSpecificationMayDiffer);
 
   /// Function tries to capture lambda's captured variables in the OpenMP region
   /// before the original lambda is captured.
@@ -9897,7 +9905,7 @@
   /// must be used instead of the original one, specified in \p DG.
   /// \param Data Set of context-specific data for the specified context
   /// selector.
-  void ActOnOpenMPDeclareVariantDirective(FunctionDecl *FD, Expr *VariantRef,
+  bool ActOnOpenMPDeclareVariantDirective(FunctionDecl *FD, Expr *VariantRef,
                                           SourceRange SR,
                                           ArrayRef<OMPCtxSelectorData> Data);
 
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -2860,6 +2860,10 @@
   parseOpenMPContextSelectors(SourceLocation Loc,
                               SmallVectorImpl<Sema::OMPCtxSelectorData> &Data);
 
+  /// Parse match clause of '#pragma omp [begin] declare variant'.
+  void ParseOMPDeclareVariantMatchClause(
+      SourceLocation Loc, SmallVectorImpl<Sema::OMPCtxSelectorData> &Data);
+
   /// Parse clauses for '#pragma omp declare variant'.
   void ParseOMPDeclareVariantClauses(DeclGroupPtrTy Ptr, CachedTokens &Toks,
                                      SourceLocation Loc);
Index: clang/include/clang/Basic/OpenMPKinds.def
===================================================================
--- clang/include/clang/Basic/OpenMPKinds.def
+++ clang/include/clang/Basic/OpenMPKinds.def
@@ -292,6 +292,8 @@
 OPENMP_DIRECTIVE_EXT(parallel_master_taskloop, "parallel master taskloop")
 OPENMP_DIRECTIVE_EXT(master_taskloop_simd, "master taskloop simd")
 OPENMP_DIRECTIVE_EXT(parallel_master_taskloop_simd, "parallel master taskloop simd")
+OPENMP_DIRECTIVE_EXT(begin_declare_variant, "begin declare variant")
+OPENMP_DIRECTIVE_EXT(end_declare_variant, "end declare variant")
 
 // OpenMP clauses.
 OPENMP_CLAUSE(allocator, OMPAllocatorClause)
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -4556,6 +4556,18 @@
   }
 };
 
+class OMPDeclareVariantAttr;
+
+/// Helper to determine the best of two potential context matches. Note that
+/// nullptr are valid inputs but also valid outputs, e.g., if neither attribute
+/// describes a matching context.
+const OMPDeclareVariantAttr *
+getBetterOpenMPContextMatch(ASTContext &C, const OMPDeclareVariantAttr *LHSAttr,
+                            const OMPDeclareVariantAttr *RHSAttr);
+
+/// Return true if the context described by \p A matches.
+bool isOpenMPContextMatch(ASTContext &C, const OMPDeclareVariantAttr *A);
+
 } // end namespace clang
 
 #endif
Index: clang/include/clang/AST/Decl.h
===================================================================
--- clang/include/clang/AST/Decl.h
+++ clang/include/clang/AST/Decl.h
@@ -1775,7 +1775,8 @@
   None,
   Target,
   CPUSpecific,
-  CPUDispatch
+  CPUDispatch,
+  OMPVariant,
 };
 
 /// Represents a function declaration or definition.
@@ -2345,6 +2346,10 @@
   /// the target functionality.
   bool isTargetMultiVersion() const;
 
+  /// True if this function is a multiversioned function as a part of
+  /// the OpenMP begin/end declare variant functionality.
+  bool isOpenMPMultiVersion() const;
+
   void setPreviousDeclaration(FunctionDecl * PrevDecl);
 
   FunctionDecl *getCanonicalDecl() override;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to