jdoerfert created this revision.
jdoerfert added reviewers: kiranchandramohan, ABataev, RaviNarayanaswamy, 
gtbercea, grokos, sdmitriev, JonChesterfield, hfinkel, fghanim.
Herald added subscribers: s.egerton, guansong, bollu, simoncook, fedor.sergeev, 
aheejin, rampitec.
Herald added a project: clang.

NOTE: This is a WIP patch to foster a discussion. Please do consider
      that when browsing the code. Details will be discussed in
      individual commits once we agreed on the overall model. This is
      also the reason why test coverage, documentation, TODOs, etc. is
      lacking.

This patch provides initial support for `omp begin/end declare variant`,
as defined in OpenMP technical report 8 (TR8).

A major purpose of this patch is to provide proper math.h/cmath support
for OpenMP target offloading. See PR42061, PR42798, PR42799.
The three tests included in this patch show that these bugs (should be)
fixed in this scheme.

In contrast to the declare variant handling we already have, this patch
makes use of the multi-version handling in clang. This is especially
useful as the variants have the same name as the base functions. We
should try to port all OpenMP variant handling to this scheme, see the
TODO in CodeGenModule for a proposed way towards this goal. Other than
that, we tried to reuse the existing multi-version and OpenMP variant
handling as much as possible.

NOTE: There are various TODOs that need fixing and switches that need
      additional cases.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D71179

Files:
  clang/include/clang/AST/Decl.h
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Basic/OpenMPKinds.def
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/Decl.cpp
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CodeGenModule.cpp
  clang/lib/Headers/__clang_cuda_cmath.h
  clang/lib/Headers/__clang_cuda_device_functions.h
  clang/lib/Headers/__clang_cuda_math_forward_declares.h
  clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
  clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
  clang/lib/Headers/openmp_wrappers/cmath
  clang/lib/Headers/openmp_wrappers/math.h
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaDecl.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/SemaOverload.cpp
  clang/lib/Sema/SemaTemplate.cpp
  clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
  clang/test/AST/ast-dump-openmp-begin-declare-variant.c
  clang/test/OpenMP/begin_declare_variant_codegen.cpp
  clang/test/OpenMP/math_fp_macro.cpp

Index: clang/test/OpenMP/math_fp_macro.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/math_fp_macro.cpp
@@ -0,0 +1,9 @@
+// RUN: %clang_cc1 -verify -fopenmp -fopenmp-targets=nvptx64-nvidia-cuda -x c++ -emit-llvm %s -triple %itanium_abi_triple -fexceptions -fcxx-exceptions -o - | FileCheck %s
+// expected-no-diagnostics
+
+#include <cmath>
+
+int main() {
+  double a(0);
+  return (std::fpclassify(a) != FP_ZERO);
+}
Index: clang/test/OpenMP/begin_declare_variant_codegen.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/begin_declare_variant_codegen.cpp
@@ -0,0 +1,134 @@
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -emit-llvm %s -triple %itanium_abi_triple -fexceptions -fcxx-exceptions -o - | FileCheck %s
+// expected-no-diagnostics
+
+int bar(void) {
+  return 0;
+}
+
+template <typename T>
+T baz(void) { return 0; }
+
+#pragma omp begin declare variant match(device={kind(cpu)})
+int foo(void) {
+  return 1;
+}
+int bar(void) {
+  return 1;
+}
+template <typename T>
+T baz(void) { return 1; }
+
+template <typename T>
+T biz(void) { return 1; }
+
+template <typename T>
+T buz(void) { return 3; }
+
+template <>
+char buz(void) { return 1; }
+
+template <typename T>
+T bez(void) { return 3; }
+#pragma omp end declare variant
+
+#pragma omp begin declare variant match(device={kind(gpu)})
+int foo(void) {
+  return 2;
+}
+int bar(void) {
+  return 2;
+}
+#pragma omp end declare variant
+
+
+#pragma omp begin declare variant match(device={kind(fpga)})
+
+This text is never parsed!
+
+#pragma omp end declare variant
+
+int foo(void) {
+  return 0;
+}
+
+template <typename T>
+T biz(void) { return 0; }
+
+template <>
+char buz(void) { return 0; }
+
+template <>
+long bez(void) { return 0; }
+
+#pragma omp begin declare variant match(device = {kind(cpu)})
+template <>
+long bez(void) { return 1; }
+#pragma omp end declare variant
+
+int test() {
+  return foo() + bar() + baz<int>() + biz<short>() + buz<char>() + bez<long>();
+}
+
+// Make sure all ompvariant functions return 1 and all others return 0.
+
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3barv()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3foov.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3barv.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define signext i8 @_Z3buzIcET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i8 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i32 @_Z3foov()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define signext i8 @_Z3buzIcET_v()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i8 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i64 @_Z3bezIlET_v()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i64 0
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define i64 @_Z3bezIlET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i64 1
+// CHECK-NEXT:  }
+
+// Make sure we call only ompvariant functions
+
+// CHECK:  define i32 @_Z4testv()
+// CHECK:    %call = call i32 @_Z3foov.ompvariant()
+// CHECK:    %call1 = call i32 @_Z3barv.ompvariant()
+// CHECK:    %call2 = call i32 @_Z3bazIiET_v.ompvariant()
+// CHECK:    %call4 = call signext i16 @_Z3bizIsET_v.ompvariant()
+// CHECK:    %call6 = call signext i8 @_Z3buzIcET_v.ompvariant()
+// CHECK:    %call10 = call i64 @_Z3bezIlET_v.ompvariant()
+
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define linkonce_odr i32 @_Z3bazIiET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i32 1
+// CHECK-NEXT:  }
+// CHECK:       ; Function Attrs:
+// CHECK-NEXT:  define linkonce_odr signext i16 @_Z3bizIsET_v.ompvariant()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    ret i16 1
+// CHECK-NEXT:  }
Index: clang/test/AST/ast-dump-openmp-begin-declare-variant.c
===================================================================
--- /dev/null
+++ clang/test/AST/ast-dump-openmp-begin-declare-variant.c
@@ -0,0 +1,83 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fopenmp -ast-dump %s | FileCheck %s
+
+int also_before(void) {
+  return 0;
+}
+
+#pragma omp begin declare variant match(device={kind(cpu)})
+int also_after(void) {
+  return 1;
+}
+int also_before(void) {
+  return 1;
+}
+#pragma omp end declare variant
+
+#pragma omp begin declare variant match(device={kind(gpu)})
+int also_after(void) {
+  return 2;
+}
+int also_before(void) {
+  return 2;
+}
+#pragma omp end declare variant
+
+
+#pragma omp begin declare variant match(device={kind(fpga)})
+
+This text is never parsed!
+
+#pragma omp end declare variant
+
+int also_after(void) {
+  return 0;
+}
+
+int test() {
+  return also_after() + also_before();
+}
+
+// Make sure:
+// 1) we pick the right versions, that is test should reference the kind(cpu) versions.
+// 2) we do not see the ast nodes for the gpu kind
+// 3) we do not choke on the text in the kind(fpga) guarded scope.
+
+// CHECK:       -FunctionDecl {{.*}} <{{.*}}3:1, line:{{.*}}:1> line:{{.*}}:5 also_before 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:23, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 0
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Inherited Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  |-FunctionDecl [[GOOD_ALSO_AFTER:0x[a-z0-9]*]] <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 used also_after 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:22, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 1
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  |-FunctionDecl [[GOOD_ALSO_BEFORE:0x[a-z0-9]*]] <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 used also_before 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:23, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 1
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  |-FunctionDecl {{.*}} <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 also_after 'int (void)'
+// CHECK-NEXT:  | |-CompoundStmt {{.*}} <col:22, line:{{.*}}:1>
+// CHECK-NEXT:  | | `-ReturnStmt {{.*}} <line:{{.*}}:3, col:10>
+// CHECK-NEXT:  | |   `-IntegerLiteral {{.*}} <col:10> 'int' 0
+// CHECK-NEXT:  | `-OMPDeclareVariantAttr {{.*}} <line:{{.*}}:1, col:60> Inherited Implicit 1 1 cpu
+// CHECK-NEXT:  |   |-<<<NULL>>>
+// CHECK-NEXT:  |   `-IntegerLiteral {{.*}} <<invalid sloc>> 'int' 1
+// CHECK-NEXT:  `-FunctionDecl {{.*}} <line:{{.*}}:1, line:{{.*}}:1> line:{{.*}}:5 test 'int ()'
+// CHECK-NEXT:    `-CompoundStmt {{.*}} <col:12, line:{{.*}}:1>
+// CHECK-NEXT:      `-ReturnStmt {{.*}} <line:{{.*}}:3, col:37>
+// CHECK-NEXT:        `-BinaryOperator {{.*}} <col:10, col:37> 'int' '+'
+// CHECK-NEXT:          |-CallExpr {{.*}} <col:10, col:21> 'int'
+// CHECK-NEXT:          | `-ImplicitCastExpr {{.*}} <col:10> 'int (*)(void)' <FunctionToPointerDecay>
+// CHECK-NEXT:          |   `-DeclRefExpr {{.*}} <col:10> 'int (void)' lvalue Function [[GOOD_ALSO_AFTER]] 'also_after' 'int (void)'
+// CHECK-NEXT:          `-CallExpr {{.*}} <col:25, col:37> 'int'
+// CHECK-NEXT:            `-ImplicitCastExpr {{.*}} <col:25> 'int (*)(void)' <FunctionToPointerDecay>
+// CHECK-NEXT:              `-DeclRefExpr {{.*}} <col:25> 'int (void)' lvalue Function [[GOOD_ALSO_BEFORE]] 'also_before' 'int (void)'
+
Index: clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -429,9 +429,17 @@
       llvm_unreachable("Unexpected context selector set kind.");
     }
   }
-  S.ActOnOpenMPDeclareVariantDirective(DeclVarData.getValue().first,
-                                       DeclVarData.getValue().second,
-                                       Attr.getRange(), Data);
+  FunctionDecl *VariantFD =
+      DeclVarData.hasValue() ? DeclVarData.getValue().first : FD;
+  Expr *VariantExpr =
+      DeclVarData.hasValue() ? DeclVarData.getValue().second : nullptr;
+  S.ActOnOpenMPDeclareVariantDirective(VariantFD, VariantExpr, Attr.getRange(),
+                                       Data);
+
+  // The new attribute on the instantiation is inherited if the template
+  // attribute was.
+  if (Attr.isInherited())
+    VariantFD->getAttr<OMPDeclareVariantAttr>()->setInherited(true);
 }
 
 static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -8596,6 +8596,9 @@
   if (LangOpts.CUDA)
     inheritCUDATargetAttrs(FD, *Specialization->getPrimaryTemplate());
 
+  if (LangOpts.OpenMP)
+    inheritOpenMPVariantAttrs(FD, *Specialization->getPrimaryTemplate());
+
   // The "previous declaration" for this function template specialization is
   // the prior function template specialization.
   Previous.clear();
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -17,6 +17,7 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
+#include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/TypeOrdering.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/DiagnosticOptions.h"
@@ -9264,7 +9265,8 @@
   return Comparison::Equal;
 }
 
-static bool isBetterMultiversionCandidate(const OverloadCandidate &Cand1,
+static bool isBetterMultiversionCandidate(ASTContext &ASTCtx,
+                                          const OverloadCandidate &Cand1,
                                           const OverloadCandidate &Cand2) {
   if (!Cand1.Function || !Cand1.Function->isMultiVersion() || !Cand2.Function ||
       !Cand2.Function->isMultiVersion())
@@ -9275,6 +9277,18 @@
   if (Cand1.Function->isInvalidDecl()) return false;
   if (Cand2.Function->isInvalidDecl()) return true;
 
+  // If we have an OpenMP declare variant attribute on either candidate we use
+  // it to order the candidates. The first is only better if it has a attribute
+  // that is considered better or if it has no attribute and the one on the
+  // second candidate is not a match.
+  auto *OMPVariantAttr1 = Cand1.Function->getAttr<OMPDeclareVariantAttr>();
+  auto *OMPVariantAttr2 = Cand2.Function->getAttr<OMPDeclareVariantAttr>();
+  if (OMPVariantAttr1 || OMPVariantAttr2) {
+    auto *OMPVariantAttrBest =
+        getBetterOpenMPContextMatch(ASTCtx, OMPVariantAttr1, OMPVariantAttr2);
+    return OMPVariantAttrBest == OMPVariantAttr1;
+  }
+
   // If this is a cpu_dispatch/cpu_specific multiversion situation, prefer
   // cpu_dispatch, else arbitrarily based on the identifiers.
   bool Cand1CPUDisp = Cand1.Function->hasAttr<CPUDispatchAttr>();
@@ -9543,7 +9557,7 @@
   if (HasPS1 != HasPS2 && HasPS1)
     return true;
 
-  return isBetterMultiversionCandidate(Cand1, Cand2);
+  return isBetterMultiversionCandidate(S.getASTContext(), Cand1, Cand2);
 }
 
 /// Determine whether two declarations are "equivalent" for the purposes of
@@ -9657,6 +9671,20 @@
     }
   }
 
+  // [OpenMP] Similar to the CUDA code above, OpenMP declare variants might not
+  // be eligible at all so we need to filter them out early.
+  if (S.getLangOpts().OpenMP) {
+    // TODO use context information
+    auto IsNonMatchVariant = [&](OverloadCandidate *Cand) {
+      if (!Cand->Viable || !Cand->Function)
+        return false;
+      auto *OMPVariantAttr = Cand->Function->getAttr<OMPDeclareVariantAttr>();
+      return OMPVariantAttr &&
+             !isOpenMPContextMatch(S.getASTContext(), OMPVariantAttr);
+    };
+    llvm::erase_if(Candidates, IsNonMatchVariant);
+  }
+
   // Find the best viable function.
   Best = end();
   for (auto *Cand : Candidates) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -5217,8 +5217,8 @@
 
   // The VariantRef must point to function.
   if (!VariantRef) {
-    Diag(SR.getBegin(), diag::err_omp_function_expected) << VariantId;
-    return None;
+    // Diag(SR.getBegin(), diag::err_omp_function_expected) << VariantId;
+    return std::make_pair(FD, VariantRef);
   }
 
   // Do not check templates, wait until instantiation.
@@ -5315,6 +5315,7 @@
     return None;
   }
 
+  // TODO check these for missing VariantRef as well
   enum DoesntSupport {
     VirtFuncs = 1,
     Constructors = 3,
@@ -5375,16 +5376,19 @@
                               PDiag(diag::err_omp_declare_variant_diff)
                                   << FD->getLocation()),
           /*TemplatesSupported=*/true, /*ConstexprSupported=*/false,
-          /*CLinkageMayDiffer=*/true))
+          /*CLinkageMayDiffer=*/true,
+          /*StorageClassMayDiffer=*/true,
+          /*ConstexprSpecMayDiffer=*/true,
+          /*InlineSpecificationMayDiffer=*/true))
     return None;
   return std::make_pair(FD, cast<Expr>(DRE));
 }
 
-void Sema::ActOnOpenMPDeclareVariantDirective(
+bool Sema::ActOnOpenMPDeclareVariantDirective(
     FunctionDecl *FD, Expr *VariantRef, SourceRange SR,
     ArrayRef<OMPCtxSelectorData> Data) {
   if (Data.empty())
-    return;
+    return false;
   SmallVector<Expr *, 4> CtxScores;
   SmallVector<unsigned, 4> CtxSets;
   SmallVector<unsigned, 4> Ctxs;
@@ -5394,7 +5398,7 @@
     OpenMPContextSelectorSetKind CtxSet = D.CtxSet;
     OpenMPContextSelectorKind Ctx = D.Ctx;
     if (CtxSet == OMP_CTX_SET_unknown || Ctx == OMP_CTX_unknown)
-      return;
+      return false;
     Expr *Score = nullptr;
     if (D.Score.isUsable()) {
       Score = D.Score.get();
@@ -5447,8 +5451,16 @@
         CtxSets.begin(), CtxSets.size(), Ctxs.begin(), Ctxs.size(),
         ImplVendors.begin(), ImplVendors.size(), DeviceKinds.begin(),
         DeviceKinds.size(), SR);
-    FD->addAttr(NewAttr);
+    if (FD) {
+      FD->addAttr(NewAttr);
+    } else {
+      assert(!DeclareVariantScopeAttr &&
+             "TODO nested begin/end declare varinat");
+      DeclareVariantScopeAttr = NewAttr;
+      return !isOpenMPContextMatch(getASTContext(), DeclareVariantScopeAttr);
+    }
   }
+  return false;
 }
 
 void Sema::markOpenMPDeclareVariantFuncsReferenced(SourceLocation Loc,
@@ -5461,11 +5473,15 @@
          Func->specific_attrs<OMPDeclareVariantAttr>()) {
       // TODO: add checks for active OpenMP context where possible.
       Expr *VariantRef = A->getVariantFuncRef();
-      auto *DRE = cast<DeclRefExpr>(VariantRef->IgnoreParenImpCasts());
-      auto *F = cast<FunctionDecl>(DRE->getDecl());
+      FunctionDecl *F = Func;
+      if (VariantRef) {
+        auto *DRE = cast<DeclRefExpr>(VariantRef->IgnoreParenImpCasts());
+        F = cast<FunctionDecl>(DRE->getDecl());
+      }
       if (!F->isDefined() && F->isTemplateInstantiation())
         InstantiateFunctionDefinition(Loc, F->getFirstDecl());
-      MarkFunctionReferenced(Loc, F, MightBeOdrUse);
+      if (F != Func)
+        MarkFunctionReferenced(Loc, F, MightBeOdrUse);
     }
   }
 }
@@ -17032,3 +17048,20 @@
   return OMPAllocateClause::Create(Context, StartLoc, LParenLoc, Allocator,
                                    ColonLoc, EndLoc, Vars);
 }
+
+template <typename AttrTy>
+static void copyAttrIfPresent(Sema &S, FunctionDecl *FD,
+                              const FunctionDecl &TemplateFD) {
+  if (!FD->hasAttr<AttrTy>())
+    if (AttrTy *Attribute = TemplateFD.getAttr<AttrTy>()) {
+      AttrTy *Clone = Attribute->clone(S.Context);
+      Clone->setInherited(true);
+      FD->addAttr(Clone);
+    }
+}
+
+void Sema::inheritOpenMPVariantAttrs(FunctionDecl *FD,
+                                     const FunctionTemplateDecl &TD) {
+  const FunctionDecl &TemplateFD = *TD.getTemplatedDecl();
+  copyAttrIfPresent<OMPDeclareVariantAttr>(*this, FD, TemplateFD);
+}
Index: clang/lib/Sema/SemaDecl.cpp
===================================================================
--- clang/lib/Sema/SemaDecl.cpp
+++ clang/lib/Sema/SemaDecl.cpp
@@ -2349,8 +2349,7 @@
     if (!isa<TypedefNameDecl>(Old))
       return;
 
-    Diag(New->getLocation(), diag::err_redefinition)
-      << New->getDeclName();
+    Diag(New->getLocation(), diag::err_redefinition) << New->getDeclName();
     notePreviousDefinition(Old, New->getLocation());
     return New->setInvalidDecl();
   }
@@ -8654,6 +8653,14 @@
                                               isVirtualOkay);
   if (!NewFD) return nullptr;
 
+  if (getLangOpts().OpenMP && DeclareVariantScopeAttr) {
+    OMPDeclareVariantAttr *DeclVarAttr =
+        DeclareVariantScopeAttr->clone(getASTContext());
+    DeclVarAttr->setInherited(false);
+    NewFD->addAttr(DeclVarAttr);
+    NewFD->setIsMultiVersion();
+  }
+
   if (OriginalLexicalContext && OriginalLexicalContext->isObjCContainer())
     NewFD->setTopLevelDeclInObjCContainer();
 
@@ -9784,6 +9791,10 @@
       if (MVType != MultiVersionKind::Target)
         return true;
       break;
+    case attr::OMPDeclareVariant:
+      if (MVType != MultiVersionKind::OMPVariant)
+        return true;
+      break;
     default:
       return true;
     }
@@ -9797,7 +9808,8 @@
     const PartialDiagnosticAt &NoteCausedDiagIDAt,
     const PartialDiagnosticAt &NoSupportDiagIDAt,
     const PartialDiagnosticAt &DiffDiagIDAt, bool TemplatesSupported,
-    bool ConstexprSupported, bool CLinkageMayDiffer) {
+    bool ConstexprSupported, bool CLinkageMayDiffer, bool StorageClassMayDiffer,
+    bool ConstexprSpecMayDiffer, bool InlineSpecificationMayDiffer) {
   enum DoesntSupport {
     FuncTemplates = 0,
     VirtFuncs = 1,
@@ -9860,7 +9872,7 @@
 
   QualType NewQType = Context.getCanonicalType(NewFD->getType());
   const auto *NewType = cast<FunctionType>(NewQType);
-  QualType NewReturnType = NewType->getReturnType();
+  QualType NewReturnType = NewType->getReturnType().getUnqualifiedType();
 
   if (NewReturnType->isUndeducedType())
     return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
@@ -9876,18 +9888,21 @@
     if (OldTypeInfo.getCC() != NewTypeInfo.getCC())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << CallingConv;
 
-    QualType OldReturnType = OldType->getReturnType();
+    QualType OldReturnType = OldType->getReturnType().getUnqualifiedType();
 
     if (OldReturnType != NewReturnType)
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << ReturnType;
 
-    if (OldFD->getConstexprKind() != NewFD->getConstexprKind())
+    if (!ConstexprSpecMayDiffer &&
+        OldFD->getConstexprKind() != NewFD->getConstexprKind())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << ConstexprSpec;
 
-    if (OldFD->isInlineSpecified() != NewFD->isInlineSpecified())
+    if (!InlineSpecificationMayDiffer &&
+        OldFD->isInlineSpecified() != NewFD->isInlineSpecified())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << InlineSpec;
 
-    if (OldFD->getStorageClass() != NewFD->getStorageClass())
+    if (!StorageClassMayDiffer &&
+        OldFD->getStorageClass() != NewFD->getStorageClass())
       return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << StorageClass;
 
     if (!CLinkageMayDiffer && OldFD->isExternC() != NewFD->isExternC())
@@ -9905,7 +9920,9 @@
                                              const FunctionDecl *NewFD,
                                              bool CausesMV,
                                              MultiVersionKind MVType) {
-  if (!S.getASTContext().getTargetInfo().supportsMultiVersioning()) {
+  bool IsOpenMPVariant = MVType == MultiVersionKind::OMPVariant;
+  if (!IsOpenMPVariant &&
+      !S.getASTContext().getTargetInfo().supportsMultiVersioning()) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_not_supported);
     if (OldFD)
       S.Diag(OldFD->getLocation(), diag::note_previous_declaration);
@@ -9918,19 +9935,20 @@
 
   // For now, disallow all other attributes.  These should be opt-in, but
   // an analysis of all of them is a future FIXME.
-  if (CausesMV && OldFD && HasNonMultiVersionAttributes(OldFD, MVType)) {
+  if (CausesMV && OldFD && !IsOpenMPVariant &&
+      HasNonMultiVersionAttributes(OldFD, MVType)) {
     S.Diag(OldFD->getLocation(), diag::err_multiversion_no_other_attrs)
         << IsCPUSpecificCPUDispatchMVType;
     S.Diag(NewFD->getLocation(), diag::note_multiversioning_caused_here);
     return true;
   }
 
-  if (HasNonMultiVersionAttributes(NewFD, MVType))
+  if (!IsOpenMPVariant && HasNonMultiVersionAttributes(NewFD, MVType))
     return S.Diag(NewFD->getLocation(), diag::err_multiversion_no_other_attrs)
            << IsCPUSpecificCPUDispatchMVType;
 
   // Only allow transition to MultiVersion if it hasn't been used.
-  if (OldFD && CausesMV && OldFD->isUsed(false))
+  if (OldFD && CausesMV && !IsOpenMPVariant && OldFD->isUsed(false))
     return S.Diag(NewFD->getLocation(), diag::err_multiversion_after_used);
 
   return S.areMultiversionVariantFunctionsCompatible(
@@ -9942,9 +9960,12 @@
                               << IsCPUSpecificCPUDispatchMVType),
       PartialDiagnosticAt(NewFD->getLocation(),
                           S.PDiag(diag::err_multiversion_diff)),
-      /*TemplatesSupported=*/false,
+      /*TemplatesSupported=*/IsOpenMPVariant,
       /*ConstexprSupported=*/!IsCPUSpecificCPUDispatchMVType,
-      /*CLinkageMayDiffer=*/false);
+      /*CLinkageMayDiffer=*/IsOpenMPVariant,
+      /*StorageClassMayDiffer=*/IsOpenMPVariant,
+      /*ConstexprSpecMayDiffer=*/IsOpenMPVariant,
+      /*InlineSpecificationMayDiffer=*/IsOpenMPVariant);
 }
 
 /// Check the validity of a multiversion function declaration that is the
@@ -9955,7 +9976,8 @@
 /// Returns true if there was an error, false otherwise.
 static bool CheckMultiVersionFirstFunction(Sema &S, FunctionDecl *FD,
                                            MultiVersionKind MVType,
-                                           const TargetAttr *TA) {
+                                           const TargetAttr *TA,
+                                           NamedDecl *OldDecl) {
   assert(MVType != MultiVersionKind::None &&
          "Function lacks multiversion attribute");
 
@@ -10075,8 +10097,8 @@
     Sema &S, FunctionDecl *OldFD, FunctionDecl *NewFD,
     MultiVersionKind NewMVType, const TargetAttr *NewTA,
     const CPUDispatchAttr *NewCPUDisp, const CPUSpecificAttr *NewCPUSpec,
-    bool &Redeclaration, NamedDecl *&OldDecl, bool &MergeTypeWithPrevious,
-    LookupResult &Previous) {
+    const OMPDeclareVariantAttr *NewOpenMPVariant, bool &Redeclaration,
+    NamedDecl *&OldDecl, bool &MergeTypeWithPrevious, LookupResult &Previous) {
 
   MultiVersionKind OldMVType = OldFD->getMultiVersionKind();
   // Disallow mixing of multiversioning types.
@@ -10090,6 +10112,18 @@
     return true;
   }
 
+  if (OldMVType == MultiVersionKind::OMPVariant &&
+      NewMVType == MultiVersionKind::None) {
+    assert(!NewOpenMPVariant && "Didn't expect variant attr!");
+    auto *OldOMPVariant = OldFD->getAttr<OMPDeclareVariantAttr>();
+    auto *NewOMPVariant = OldOMPVariant->clone(S.getASTContext());
+    NewOMPVariant->setInherited(true);
+    NewFD->addAttr(NewOMPVariant);
+    NewFD->setIsMultiVersion();
+    NewOpenMPVariant = NewOMPVariant;
+    NewMVType = MultiVersionKind::OMPVariant;
+  }
+
   TargetAttr::ParsedTargetAttr NewParsed;
   if (NewTA) {
     NewParsed = NewTA->parse();
@@ -10125,6 +10159,14 @@
         NewFD->setInvalidDecl();
         return true;
       }
+    } else if (NewMVType == MultiVersionKind::OMPVariant) {
+      auto *CurOMPVariant = CurFD->getAttr<OMPDeclareVariantAttr>();
+      if (!CurOMPVariant) {
+        CurOMPVariant = NewOpenMPVariant->clone(S.getASTContext());
+        CurOMPVariant->setInherited(true);
+        CurFD->addAttr(CurOMPVariant);
+        CurFD->setIsMultiVersion();
+      }
     } else {
       const auto *CurCPUSpec = CurFD->getAttr<CPUSpecificAttr>();
       const auto *CurCPUDisp = CurFD->getAttr<CPUDispatchAttr>();
@@ -10217,7 +10259,6 @@
   return false;
 }
 
-
 /// Check the validity of a mulitversion function declaration.
 /// Also sets the multiversion'ness' of the function itself.
 ///
@@ -10231,10 +10272,12 @@
   const auto *NewTA = NewFD->getAttr<TargetAttr>();
   const auto *NewCPUDisp = NewFD->getAttr<CPUDispatchAttr>();
   const auto *NewCPUSpec = NewFD->getAttr<CPUSpecificAttr>();
+  const auto *NewOpenMPVariant = NewFD->getAttr<OMPDeclareVariantAttr>();
+  unsigned NumMV = bool(NewTA) + bool(NewCPUDisp) + bool(NewCPUSpec) +
+                   bool(NewOpenMPVariant);
 
   // Mixing Multiversioning types is prohibited.
-  if ((NewTA && NewCPUDisp) || (NewTA && NewCPUSpec) ||
-      (NewCPUDisp && NewCPUSpec)) {
+  if (NumMV > 1) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_types_mixed);
     NewFD->setInvalidDecl();
     return true;
@@ -10255,14 +10298,18 @@
     return false;
   }
 
+  if (auto *USD = dyn_cast_or_null<UsingShadowDecl>(OldDecl))
+    OldDecl = USD->getTargetDecl();
+
   if (!OldDecl || !OldDecl->getAsFunction() ||
-      OldDecl->getDeclContext()->getRedeclContext() !=
-          NewFD->getDeclContext()->getRedeclContext()) {
+      (OldDecl->getDeclContext()->getRedeclContext() !=
+           NewFD->getDeclContext()->getRedeclContext() &&
+       !OldDecl->getAsFunction()->isOpenMPMultiVersion())) {
     // If there's no previous declaration, AND this isn't attempting to cause
     // multiversioning, this isn't an error condition.
     if (MVType == MultiVersionKind::None)
       return false;
-    return CheckMultiVersionFirstFunction(S, NewFD, MVType, NewTA);
+    return CheckMultiVersionFirstFunction(S, NewFD, MVType, NewTA, OldDecl);
   }
 
   FunctionDecl *OldFD = OldDecl->getAsFunction();
@@ -10270,7 +10317,8 @@
   if (!OldFD->isMultiVersion() && MVType == MultiVersionKind::None)
     return false;
 
-  if (OldFD->isMultiVersion() && MVType == MultiVersionKind::None) {
+  if (OldFD->isMultiVersion() && MVType == MultiVersionKind::None &&
+      !OldFD->isOpenMPMultiVersion()) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_required_in_redecl)
         << (OldFD->getMultiVersionKind() != MultiVersionKind::Target);
     NewFD->setInvalidDecl();
@@ -10287,8 +10335,8 @@
   // appropriate attribute in the current function decl.  Resolve that these are
   // still compatible with previous declarations.
   return CheckMultiVersionAdditionalDecl(
-      S, OldFD, NewFD, MVType, NewTA, NewCPUDisp, NewCPUSpec, Redeclaration,
-      OldDecl, MergeTypeWithPrevious, Previous);
+      S, OldFD, NewFD, MVType, NewTA, NewCPUDisp, NewCPUSpec, NewOpenMPVariant,
+      Redeclaration, OldDecl, MergeTypeWithPrevious, Previous);
 }
 
 /// Perform semantic checking of a new function declaration.
@@ -10317,8 +10365,8 @@
   // Determine whether the type of this function should be merged with
   // a previous visible declaration. This never happens for functions in C++,
   // and always happens in C if the previous declaration was visible.
-  bool MergeTypeWithPrevious = !getLangOpts().CPlusPlus &&
-                               !Previous.isShadowed();
+  bool MergeTypeWithPrevious =
+      !getLangOpts().CPlusPlus && !Previous.isShadowed();
 
   bool Redeclaration = false;
   NamedDecl *OldDecl = nullptr;
@@ -13582,6 +13630,16 @@
   else
     FD = cast<FunctionDecl>(D);
 
+  if (getLangOpts().OpenMP && DeclareVariantScopeAttr) {
+    OMPDeclareVariantAttr *DeclVarAttr = FD->getAttr<OMPDeclareVariantAttr>();
+    if (!DeclVarAttr) {
+      DeclVarAttr = DeclareVariantScopeAttr->clone(getASTContext());
+      FD->addAttr(DeclVarAttr);
+    }
+    DeclVarAttr->setInherited(false);
+    FD->setIsMultiVersion();
+  }
+
   // Do not push if it is a lambda because one is already pushed when building
   // the lambda in ActOnStartOfLambdaDefinition().
   if (!isLambdaCallOperator(FD))
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -44,6 +44,8 @@
   OMPD_target_teams_distribute_parallel,
   OMPD_mapper,
   OMPD_variant,
+  OMPD_begin,
+  OMPD_begin_declare,
 };
 
 class DeclDirectiveListParserHelper final {
@@ -83,6 +85,7 @@
       .Case("update", OMPD_update)
       .Case("mapper", OMPD_mapper)
       .Case("variant", OMPD_variant)
+      .Case("begin", OMPD_begin)
       .Default(OMPD_unknown);
 }
 
@@ -91,18 +94,21 @@
   // E.g.: OMPD_for OMPD_simd ===> OMPD_for_simd
   // TODO: add other combined directives in topological order.
   static const unsigned F[][3] = {
+      {OMPD_begin, OMPD_declare, OMPD_begin_declare},
+      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_cancellation, OMPD_point, OMPD_cancellation_point},
       {OMPD_declare, OMPD_reduction, OMPD_declare_reduction},
       {OMPD_declare, OMPD_mapper, OMPD_declare_mapper},
       {OMPD_declare, OMPD_simd, OMPD_declare_simd},
       {OMPD_declare, OMPD_target, OMPD_declare_target},
       {OMPD_declare, OMPD_variant, OMPD_declare_variant},
+      {OMPD_begin_declare, OMPD_variant, OMPD_begin_declare_variant},
+      {OMPD_end_declare, OMPD_variant, OMPD_end_declare_variant},
       {OMPD_distribute, OMPD_parallel, OMPD_distribute_parallel},
       {OMPD_distribute_parallel, OMPD_for, OMPD_distribute_parallel_for},
       {OMPD_distribute_parallel_for, OMPD_simd,
        OMPD_distribute_parallel_for_simd},
       {OMPD_distribute, OMPD_simd, OMPD_distribute_simd},
-      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_end_declare, OMPD_target, OMPD_end_declare_target},
       {OMPD_target, OMPD_data, OMPD_target_data},
       {OMPD_target, OMPD_enter, OMPD_target_enter},
@@ -1046,37 +1052,8 @@
   return false;
 }
 
-/// Parse clauses for '#pragma omp declare variant ( variant-func-id ) clause'.
-void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
-                                           CachedTokens &Toks,
-                                           SourceLocation Loc) {
-  PP.EnterToken(Tok, /*IsReinject*/ true);
-  PP.EnterTokenStream(Toks, /*DisableMacroExpansion=*/true,
-                      /*IsReinject*/ true);
-  // Consume the previously pushed token.
-  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
-  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
-
-  FNContextRAII FnContext(*this, Ptr);
-  // Parse function declaration id.
-  SourceLocation RLoc;
-  // Parse with IsAddressOfOperand set to true to parse methods as DeclRefExprs
-  // instead of MemberExprs.
-  ExprResult AssociatedFunction =
-      ParseOpenMPParensExpr(getOpenMPDirectiveName(OMPD_declare_variant), RLoc,
-                            /*IsAddressOfOperand=*/true);
-  if (!AssociatedFunction.isUsable()) {
-    if (!Tok.is(tok::annot_pragma_openmp_end))
-      while (!SkipUntil(tok::annot_pragma_openmp_end, StopBeforeMatch))
-        ;
-    // Skip the last annot_pragma_openmp_end.
-    (void)ConsumeAnnotationToken();
-    return;
-  }
-  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
-      Actions.checkOpenMPDeclareVariantFunction(
-          Ptr, AssociatedFunction.get(), SourceRange(Loc, Tok.getLocation()));
-
+void Parser::ParseOMPDeclareVariantMatchClause(
+    SourceLocation Loc, SmallVectorImpl<Sema::OMPCtxSelectorData> &Data) {
   // Parse 'match'.
   OpenMPClauseKind CKind = Tok.isAnnotation()
                                ? OMPC_unknown
@@ -1103,7 +1080,6 @@
   }
 
   // Parse inner context selectors.
-  SmallVector<Sema::OMPCtxSelectorData, 4> Data;
   if (!parseOpenMPContextSelectors(Loc, Data)) {
     // Parse ')'.
     (void)T.consumeClose();
@@ -1113,6 +1089,41 @@
           << getOpenMPDirectiveName(OMPD_declare_variant);
     }
   }
+}
+
+/// Parse clauses for '#pragma omp declare variant ( variant-func-id ) clause'.
+void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
+                                           CachedTokens &Toks,
+                                           SourceLocation Loc) {
+  PP.EnterToken(Tok, /*IsReinject*/ true);
+  PP.EnterTokenStream(Toks, /*DisableMacroExpansion=*/true,
+                      /*IsReinject*/ true);
+  // Consume the previously pushed token.
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+
+  FNContextRAII FnContext(*this, Ptr);
+  // Parse function declaration id.
+  SourceLocation RLoc;
+  // Parse with IsAddressOfOperand set to true to parse methods as DeclRefExprs
+  // instead of MemberExprs.
+  ExprResult AssociatedFunction =
+      ParseOpenMPParensExpr(getOpenMPDirectiveName(OMPD_declare_variant), RLoc,
+                            /*IsAddressOfOperand=*/true);
+  if (!AssociatedFunction.isUsable()) {
+    if (!Tok.is(tok::annot_pragma_openmp_end))
+      while (!SkipUntil(tok::annot_pragma_openmp_end, StopBeforeMatch))
+        ;
+    // Skip the last annot_pragma_openmp_end.
+    (void)ConsumeAnnotationToken();
+    return;
+  }
+  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
+      Actions.checkOpenMPDeclareVariantFunction(
+          Ptr, AssociatedFunction.get(), SourceRange(Loc, Tok.getLocation()));
+
+  SmallVector<Sema::OMPCtxSelectorData, 4> Data;
+  ParseOMPDeclareVariantMatchClause(Loc, Data);
 
   // Skip last tokens.
   while (Tok.isNot(tok::annot_pragma_openmp_end))
@@ -1446,6 +1457,46 @@
     }
     break;
   }
+  case OMPD_begin_declare_variant: {
+    // The syntax is:
+    // { #pragma omp begin declare variant clause }
+    // <function-declaration-or-definition-sequence>
+    // { #pragma omp end declare variant }
+    //
+    ConsumeToken();
+
+    SmallVector<Sema::OMPCtxSelectorData, 4> Data;
+    ParseOMPDeclareVariantMatchClause(Loc, Data);
+
+    // Skip last tokens.
+    while (Tok.isNot(tok::annot_pragma_openmp_end))
+      ConsumeAnyToken();
+
+    bool Elide = Actions.ActOnOpenMPDeclareVariantDirective(
+        nullptr, nullptr, SourceRange(Loc, Tok.getLocation()), Data);
+    if (!Elide)
+      break;
+
+    // Elide all the code till the matching end declare variant was found.
+    unsigned Nesting = 1;
+    do {
+      ConsumeAnyToken();
+      OpenMPDirectiveKind DK = parseOpenMPDirectiveKind(*this);
+      if (DK == OMPD_end_declare_variant)
+        --Nesting;
+      if (DK == OMPD_begin_declare_variant)
+        ++Nesting;
+    } while (Nesting);
+
+    LLVM_FALLTHROUGH;
+  }
+  case OMPD_end_declare_variant:
+    assert(getActions().DeclareVariantScopeAttr &&
+           "TODO error for unmatched end declare variant");
+    // TODO: verify DeclareVariantScopeAttr is null after parsing
+    // TODO: Make this a call in the SEMA
+    getActions().DeclareVariantScopeAttr = nullptr;
+    break;
   case OMPD_declare_variant:
   case OMPD_declare_simd: {
     // The syntax is:
@@ -1932,6 +1983,8 @@
   case OMPD_end_declare_target:
   case OMPD_requires:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
     Diag(Tok, diag::err_omp_unexpected_directive)
         << 1 << getOpenMPDirectiveName(DKind);
     SkipUntil(tok::annot_pragma_openmp_end);
Index: clang/lib/Headers/openmp_wrappers/math.h
===================================================================
--- clang/lib/Headers/openmp_wrappers/math.h
+++ clang/lib/Headers/openmp_wrappers/math.h
@@ -9,9 +9,4 @@
 
 #include <__clang_openmp_math.h>
 
-#ifndef __CLANG_NO_HOST_MATH__
 #include_next <math.h>
-#else
-#undef __CLANG_NO_HOST_MATH__
-#endif
-
Index: clang/lib/Headers/openmp_wrappers/cmath
===================================================================
--- clang/lib/Headers/openmp_wrappers/cmath
+++ clang/lib/Headers/openmp_wrappers/cmath
@@ -9,8 +9,4 @@
 
 #include <__clang_openmp_math.h>
 
-#ifndef __CLANG_NO_HOST_MATH__
 #include_next <cmath>
-#else
-#undef __CLANG_NO_HOST_MATH__
-#endif
Index: clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
===================================================================
--- clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
+++ clang/lib/Headers/openmp_wrappers/__clang_openmp_math_declares.h
@@ -18,6 +18,8 @@
 
 #define __CUDA__
 
+#pragma omp begin declare variant match(device = {kind(gpu)})
+
 #if defined(__cplusplus)
   #include <__clang_cuda_math_forward_declares.h>
 #endif
@@ -27,6 +29,8 @@
 /// Provide definitions for these functions.
 #include <__clang_cuda_device_functions.h>
 
+#pragma omp end declare variant
+
 #undef __CUDA__
 
 #endif
Index: clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
===================================================================
--- clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
+++ clang/lib/Headers/openmp_wrappers/__clang_openmp_math.h
@@ -8,17 +8,6 @@
  */
 
 #if defined(__NVPTX__) && defined(_OPENMP)
-/// TODO:
-/// We are currently reusing the functionality of the Clang-CUDA code path
-/// as an alternative to the host declarations provided by math.h and cmath.
-/// This is suboptimal.
-///
-/// We should instead declare the device functions in a similar way, e.g.,
-/// through OpenMP 5.0 variants, and afterwards populate the module with the
-/// host declarations by unconditionally including the host math.h or cmath,
-/// respectively. This is actually what the Clang-CUDA code path does, using
-/// __device__ instead of variants to avoid redeclarations and get the desired
-/// overload resolution.
 
 #define __CUDA__
 
@@ -28,8 +17,5 @@
 
 #undef __CUDA__
 
-/// Magic macro for stopping the math.h/cmath host header from being included.
-#define __CLANG_NO_HOST_MATH__
-
 #endif
 
Index: clang/lib/Headers/__clang_cuda_math_forward_declares.h
===================================================================
--- clang/lib/Headers/__clang_cuda_math_forward_declares.h
+++ clang/lib/Headers/__clang_cuda_math_forward_declares.h
@@ -27,30 +27,8 @@
   static __inline__ __attribute__((always_inline)) __attribute__((device))
 #endif
 
-// For C++ 17 we need to include noexcept attribute to be compatible
-// with the header-defined version. This may be removed once
-// variant is supported.
-#if defined(_OPENMP) && defined(__cplusplus) && __cplusplus >= 201703L
-#define __NOEXCEPT noexcept
-#else
-#define __NOEXCEPT
-#endif
-
-#if !(defined(_OPENMP) && defined(__cplusplus))
 __DEVICE__ long abs(long);
 __DEVICE__ long long abs(long long);
-__DEVICE__ double abs(double);
-__DEVICE__ float abs(float);
-#endif
-// While providing the CUDA declarations and definitions for math functions,
-// we may manually define additional functions.
-// TODO: Once variant is supported the additional functions will have
-// to be removed.
-#if defined(_OPENMP) && defined(__cplusplus)
-__DEVICE__ const double abs(const double);
-__DEVICE__ const float abs(const float);
-#endif
-__DEVICE__ int abs(int) __NOEXCEPT;
 __DEVICE__ double acos(double);
 __DEVICE__ float acos(float);
 __DEVICE__ double acosh(double);
@@ -85,8 +63,8 @@
 __DEVICE__ float exp(float);
 __DEVICE__ double expm1(double);
 __DEVICE__ float expm1(float);
-__DEVICE__ double fabs(double) __NOEXCEPT;
-__DEVICE__ float fabs(float) __NOEXCEPT;
+__DEVICE__ double fabs(double);
+__DEVICE__ float fabs(float);
 __DEVICE__ double fdim(double, double);
 __DEVICE__ float fdim(float, float);
 __DEVICE__ double floor(double);
@@ -99,8 +77,6 @@
 __DEVICE__ float fmin(float, float);
 __DEVICE__ double fmod(double, double);
 __DEVICE__ float fmod(float, float);
-__DEVICE__ int fpclassify(double);
-__DEVICE__ int fpclassify(float);
 __DEVICE__ double frexp(double, int *);
 __DEVICE__ float frexp(float, int *);
 __DEVICE__ double hypot(double, double);
@@ -136,12 +112,12 @@
 __DEVICE__ bool isnormal(float);
 __DEVICE__ bool isunordered(double, double);
 __DEVICE__ bool isunordered(float, float);
-__DEVICE__ long labs(long) __NOEXCEPT;
+__DEVICE__ long labs(long);
 __DEVICE__ double ldexp(double, int);
 __DEVICE__ float ldexp(float, int);
 __DEVICE__ double lgamma(double);
 __DEVICE__ float lgamma(float);
-__DEVICE__ long long llabs(long long) __NOEXCEPT;
+__DEVICE__ long long llabs(long long);
 __DEVICE__ long long llrint(double);
 __DEVICE__ long long llrint(float);
 __DEVICE__ double log10(double);
@@ -152,9 +128,7 @@
 __DEVICE__ float log2(float);
 __DEVICE__ double logb(double);
 __DEVICE__ float logb(float);
-#if defined(_OPENMP) && defined(__cplusplus)
 __DEVICE__ long double log(long double);
-#endif
 __DEVICE__ double log(double);
 __DEVICE__ float log(float);
 __DEVICE__ long lrint(double);
@@ -245,7 +219,6 @@
 using ::fmax;
 using ::fmin;
 using ::fmod;
-using ::fpclassify;
 using ::frexp;
 using ::hypot;
 using ::ilogb;
@@ -302,7 +275,6 @@
 } // namespace std
 #endif
 
-#undef __NOEXCEPT
 #pragma pop_macro("__DEVICE__")
 
 #endif
Index: clang/lib/Headers/__clang_cuda_device_functions.h
===================================================================
--- clang/lib/Headers/__clang_cuda_device_functions.h
+++ clang/lib/Headers/__clang_cuda_device_functions.h
@@ -37,15 +37,6 @@
 #define __FAST_OR_SLOW(fast, slow) slow
 #endif
 
-// For C++ 17 we need to include noexcept attribute to be compatible
-// with the header-defined version. This may be removed once
-// variant is supported.
-#if defined(_OPENMP) && defined(__cplusplus) && __cplusplus >= 201703L
-#define __NOEXCEPT noexcept
-#else
-#define __NOEXCEPT
-#endif
-
 __DEVICE__ int __all(int __a) { return __nvvm_vote_all(__a); }
 __DEVICE__ int __any(int __a) { return __nvvm_vote_any(__a); }
 __DEVICE__ unsigned int __ballot(int __a) { return __nvvm_vote_ballot(__a); }
@@ -53,13 +44,8 @@
 __DEVICE__ unsigned long long __brevll(unsigned long long __a) {
   return __nv_brevll(__a);
 }
-#if defined(__cplusplus)
 __DEVICE__ void __brkpt() { asm volatile("brkpt;"); }
 __DEVICE__ void __brkpt(int __a) { __brkpt(); }
-#else
-__DEVICE__ void __attribute__((overloadable)) __brkpt(void) { asm volatile("brkpt;"); }
-__DEVICE__ void __attribute__((overloadable)) __brkpt(int __a) { __brkpt(); }
-#endif
 __DEVICE__ unsigned int __byte_perm(unsigned int __a, unsigned int __b,
                                     unsigned int __c) {
   return __nv_byte_perm(__a, __b, __c);
@@ -1483,8 +1469,8 @@
   return r;
 }
 #endif // CUDA_VERSION >= 9020
-__DEVICE__ int abs(int __a) __NOEXCEPT { return __nv_abs(__a); }
-__DEVICE__ double fabs(double __a) __NOEXCEPT { return __nv_fabs(__a); }
+__DEVICE__ int abs(int __a) { return __nv_abs(__a); }
+__DEVICE__ double fabs(double __a) { return __nv_fabs(__a); }
 __DEVICE__ double acos(double __a) { return __nv_acos(__a); }
 __DEVICE__ float acosf(float __a) { return __nv_acosf(__a); }
 __DEVICE__ double acosh(double __a) { return __nv_acosh(__a); }
@@ -1581,15 +1567,15 @@
 __DEVICE__ double jn(int __n, double __a) { return __nv_jn(__n, __a); }
 __DEVICE__ float jnf(int __n, float __a) { return __nv_jnf(__n, __a); }
 #if defined(__LP64__) || defined(_WIN64)
-__DEVICE__ long labs(long __a) __NOEXCEPT { return __nv_llabs(__a); };
+__DEVICE__ long labs(long __a) { return __nv_llabs(__a); };
 #else
-__DEVICE__ long labs(long __a) __NOEXCEPT { return __nv_abs(__a); };
+__DEVICE__ long labs(long __a) { return __nv_abs(__a); };
 #endif
 __DEVICE__ double ldexp(double __a, int __b) { return __nv_ldexp(__a, __b); }
 __DEVICE__ float ldexpf(float __a, int __b) { return __nv_ldexpf(__a, __b); }
 __DEVICE__ double lgamma(double __a) { return __nv_lgamma(__a); }
 __DEVICE__ float lgammaf(float __a) { return __nv_lgammaf(__a); }
-__DEVICE__ long long llabs(long long __a) __NOEXCEPT { return __nv_llabs(__a); }
+__DEVICE__ long long llabs(long long __a) { return __nv_llabs(__a); }
 __DEVICE__ long long llmax(long long __a, long long __b) {
   return __nv_llmax(__a, __b);
 }
@@ -1719,23 +1705,6 @@
 __DEVICE__ float rsqrtf(float __a) { return __nv_rsqrtf(__a); }
 __DEVICE__ double scalbn(double __a, int __b) { return __nv_scalbn(__a, __b); }
 __DEVICE__ float scalbnf(float __a, int __b) { return __nv_scalbnf(__a, __b); }
-// TODO: remove once variant is supported
-#ifndef _OPENMP
-__DEVICE__ double scalbln(double __a, long __b) {
-  if (__b > INT_MAX)
-    return __a > 0 ? HUGE_VAL : -HUGE_VAL;
-  if (__b < INT_MIN)
-    return __a > 0 ? 0.0 : -0.0;
-  return scalbn(__a, (int)__b);
-}
-__DEVICE__ float scalblnf(float __a, long __b) {
-  if (__b > INT_MAX)
-    return __a > 0 ? HUGE_VALF : -HUGE_VALF;
-  if (__b < INT_MIN)
-    return __a > 0 ? 0.f : -0.f;
-  return scalbnf(__a, (int)__b);
-}
-#endif
 __DEVICE__ double sin(double __a) { return __nv_sin(__a); }
 __DEVICE__ void sincos(double __a, double *__s, double *__c) {
   return __nv_sincos(__a, __s, __c);
@@ -1787,7 +1756,7 @@
 __DEVICE__ double yn(int __a, double __b) { return __nv_yn(__a, __b); }
 __DEVICE__ float ynf(int __a, float __b) { return __nv_ynf(__a, __b); }
 
-#undef __NOEXCEPT
 #pragma pop_macro("__DEVICE__")
 #pragma pop_macro("__FAST_OR_SLOW")
+
 #endif // __CLANG_CUDA_DEVICE_FUNCTIONS_H__
Index: clang/lib/Headers/__clang_cuda_cmath.h
===================================================================
--- clang/lib/Headers/__clang_cuda_cmath.h
+++ clang/lib/Headers/__clang_cuda_cmath.h
@@ -32,30 +32,15 @@
 
 #ifdef _OPENMP
 #define __DEVICE__ static __attribute__((always_inline))
+#pragma omp begin declare variant match(device = {kind(gpu)})
 #else
 #define __DEVICE__ static __device__ __inline__ __attribute__((always_inline))
 #endif
 
-// For C++ 17 we need to include noexcept attribute to be compatible
-// with the header-defined version. This may be removed once
-// variant is supported.
-#if defined(_OPENMP) && defined(__cplusplus) && __cplusplus >= 201703L
-#define __NOEXCEPT noexcept
-#else
-#define __NOEXCEPT
-#endif
-
-#if !(defined(_OPENMP) && defined(__cplusplus))
 __DEVICE__ long long abs(long long __n) { return ::llabs(__n); }
 __DEVICE__ long abs(long __n) { return ::labs(__n); }
 __DEVICE__ float abs(float __x) { return ::fabsf(__x); }
 __DEVICE__ double abs(double __x) { return ::fabs(__x); }
-#endif
-// TODO: remove once variat is supported.
-#if defined(_OPENMP) && defined(__cplusplus)
-__DEVICE__ const float abs(const float __x) { return ::fabsf((float)__x); }
-__DEVICE__ const double abs(const double __x) { return ::fabs((double)__x); }
-#endif
 __DEVICE__ float acos(float __x) { return ::acosf(__x); }
 __DEVICE__ float asin(float __x) { return ::asinf(__x); }
 __DEVICE__ float atan(float __x) { return ::atanf(__x); }
@@ -64,20 +49,9 @@
 __DEVICE__ float cos(float __x) { return ::cosf(__x); }
 __DEVICE__ float cosh(float __x) { return ::coshf(__x); }
 __DEVICE__ float exp(float __x) { return ::expf(__x); }
-__DEVICE__ float fabs(float __x) __NOEXCEPT { return ::fabsf(__x); }
+__DEVICE__ float fabs(float __x) { return ::fabsf(__x); }
 __DEVICE__ float floor(float __x) { return ::floorf(__x); }
 __DEVICE__ float fmod(float __x, float __y) { return ::fmodf(__x, __y); }
-// TODO: remove when variant is supported
-#ifndef _OPENMP
-__DEVICE__ int fpclassify(float __x) {
-  return __builtin_fpclassify(FP_NAN, FP_INFINITE, FP_NORMAL, FP_SUBNORMAL,
-                              FP_ZERO, __x);
-}
-__DEVICE__ int fpclassify(double __x) {
-  return __builtin_fpclassify(FP_NAN, FP_INFINITE, FP_NORMAL, FP_SUBNORMAL,
-                              FP_ZERO, __x);
-}
-#endif
 __DEVICE__ float frexp(float __arg, int *__exp) {
   return ::frexpf(__arg, __exp);
 }
@@ -232,7 +206,6 @@
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, fmax);
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, fmin);
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, fmod);
-__CUDA_CLANG_FN_INTEGER_OVERLOAD_1(int, fpclassify)
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_2(double, hypot);
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_1(int, ilogb)
 __CUDA_CLANG_FN_INTEGER_OVERLOAD_1(bool, isfinite)
@@ -360,7 +333,6 @@
 using ::fmax;
 using ::fmin;
 using ::fmod;
-using ::fpclassify;
 using ::frexp;
 using ::hypot;
 using ::ilogb;
@@ -457,10 +429,6 @@
 using ::remquof;
 using ::rintf;
 using ::roundf;
-// TODO: remove once variant is supported
-#ifndef _OPENMP
-using ::scalblnf;
-#endif
 using ::scalbnf;
 using ::sinf;
 using ::sinhf;
@@ -479,6 +447,10 @@
 } // namespace std
 #endif
 
+#ifdef _OPENMP
+#pragma omp end declare variant
+#endif
+
 #undef __NOEXCEPT
 #undef __DEVICE__
 
Index: clang/lib/CodeGen/CodeGenModule.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenModule.cpp
+++ clang/lib/CodeGen/CodeGenModule.cpp
@@ -982,6 +982,18 @@
   }
 }
 
+static void AppendOpenMPVariantMangling(const CodeGenModule &CGM,
+                                        const FunctionDecl *FD,
+                                        raw_ostream &Out) {
+  for (const OMPDeclareVariantAttr *Attr :
+       FD->specific_attrs<OMPDeclareVariantAttr>()) {
+    if (Attr->isInherited())
+      continue;
+    // TODO: Mangle the name based on the context
+    Out << ".ompvariant";
+  }
+}
+
 static std::string getMangledNameImpl(const CodeGenModule &CGM, GlobalDecl GD,
                                       const NamedDecl *ND,
                                       bool OmitMultiVersionMangling = false) {
@@ -1022,6 +1034,9 @@
       case MultiVersionKind::Target:
         AppendTargetMangling(CGM, FD->getAttr<TargetAttr>(), Out);
         break;
+      case MultiVersionKind::OMPVariant:
+        AppendOpenMPVariantMangling(CGM, FD, Out);
+        break;
       case MultiVersionKind::None:
         llvm_unreachable("None multiversion type isn't valid here");
       }
@@ -2534,7 +2549,8 @@
     return;
   }
 
-    // Check if this must be emitted as declare variant.
+  // Check if this must be emitted as declare variant.
+  // TODO: See the TODO at the other emitDeclareVariant call.
   if (LangOpts.OpenMP && isa<FunctionDecl>(Global) && OpenMPRuntime &&
       OpenMPRuntime->emitDeclareVariant(GD, /*IsForDefinition=*/false))
     return;
@@ -2856,6 +2872,10 @@
   for (GlobalDecl GD : MultiVersionFuncs) {
     SmallVector<CodeGenFunction::MultiVersionResolverOption, 10> Options;
     const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
+    // OpenMP multi versioning is (for now) resolved at compile time, no
+    // resolver function necessary (yet).
+    if (FD->isOpenMPMultiVersion())
+      continue;
     getContext().forEachMultiversionedFunctionVersion(
         FD, [this, &GD, &Options](const FunctionDecl *CurFD) {
           GlobalDecl CurGD{
@@ -3104,10 +3124,17 @@
     }
     // Check if this must be emitted as declare variant and emit reference to
     // the the declare variant function.
+    // TODO: We should introduce function aliases for `omp declare variant`
+    //       directives such that we can treat them through the same overload
+    //       resolution scheme (via multi versioning) as `omp begin declare
+    //       variant` functions. For an `omp declare variant(VARIANT) ...`
+    //       that is attached to a BASE function we would create a global alias
+    //       VARIANT = BASE which will participate in the multi version overload
+    //       resolution. If picked, here is no need to emit them explicitly.
     if (LangOpts.OpenMP && OpenMPRuntime)
       (void)OpenMPRuntime->emitDeclareVariant(GD, /*IsForDefinition=*/true);
 
-    if (FD->isMultiVersion()) {
+    if (FD->isMultiVersion() && !FD->isOpenMPMultiVersion()) {
       const auto *TA = FD->getAttr<TargetAttr>();
       if (TA && TA->isDefaultVersion())
         UpdateMultiVersionNames(GD, FD);
@@ -4385,6 +4412,7 @@
 void CodeGenModule::EmitGlobalFunctionDefinition(GlobalDecl GD,
                                                  llvm::GlobalValue *GV) {
   // Check if this must be emitted as declare variant.
+  // TODO: See the TODO at the other emitDeclareVariant call.
   if (LangOpts.OpenMP && OpenMPRuntime &&
       OpenMPRuntime->emitDeclareVariant(GD, /*IsForDefinition=*/true))
     return;
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -11042,231 +11042,6 @@
   return Address(Addr, Align);
 }
 
-namespace {
-using OMPContextSelectorData =
-    OpenMPCtxSelectorData<ArrayRef<StringRef>, llvm::APSInt>;
-using CompleteOMPContextSelectorData = SmallVector<OMPContextSelectorData, 4>;
-} // anonymous namespace
-
-/// Checks current context and returns true if it matches the context selector.
-template <OpenMPContextSelectorSetKind CtxSet, OpenMPContextSelectorKind Ctx,
-          typename... Arguments>
-static bool checkContext(const OMPContextSelectorData &Data,
-                         Arguments... Params) {
-  assert(Data.CtxSet != OMP_CTX_SET_unknown && Data.Ctx != OMP_CTX_unknown &&
-         "Unknown context selector or context selector set.");
-  return false;
-}
-
-/// Checks for implementation={vendor(<vendor>)} context selector.
-/// \returns true iff <vendor>="llvm", false otherwise.
-template <>
-bool checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(
-    const OMPContextSelectorData &Data) {
-  return llvm::all_of(Data.Names,
-                      [](StringRef S) { return !S.compare_lower("llvm"); });
-}
-
-/// Checks for device={kind(<kind>)} context selector.
-/// \returns true if <kind>="host" and compilation is for host.
-/// true if <kind>="nohost" and compilation is for device.
-/// true if <kind>="cpu" and compilation is for Arm, X86 or PPC CPU.
-/// true if <kind>="gpu" and compilation is for NVPTX or AMDGCN.
-/// false otherwise.
-template <>
-bool checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(
-    const OMPContextSelectorData &Data, CodeGenModule &CGM) {
-  for (StringRef Name : Data.Names) {
-    if (!Name.compare_lower("host")) {
-      if (CGM.getLangOpts().OpenMPIsDevice)
-        return false;
-      continue;
-    }
-    if (!Name.compare_lower("nohost")) {
-      if (!CGM.getLangOpts().OpenMPIsDevice)
-        return false;
-      continue;
-    }
-    switch (CGM.getTriple().getArch()) {
-    case llvm::Triple::arm:
-    case llvm::Triple::armeb:
-    case llvm::Triple::aarch64:
-    case llvm::Triple::aarch64_be:
-    case llvm::Triple::aarch64_32:
-    case llvm::Triple::ppc:
-    case llvm::Triple::ppc64:
-    case llvm::Triple::ppc64le:
-    case llvm::Triple::x86:
-    case llvm::Triple::x86_64:
-      if (Name.compare_lower("cpu"))
-        return false;
-      break;
-    case llvm::Triple::amdgcn:
-    case llvm::Triple::nvptx:
-    case llvm::Triple::nvptx64:
-      if (Name.compare_lower("gpu"))
-        return false;
-      break;
-    case llvm::Triple::UnknownArch:
-    case llvm::Triple::arc:
-    case llvm::Triple::avr:
-    case llvm::Triple::bpfel:
-    case llvm::Triple::bpfeb:
-    case llvm::Triple::hexagon:
-    case llvm::Triple::mips:
-    case llvm::Triple::mipsel:
-    case llvm::Triple::mips64:
-    case llvm::Triple::mips64el:
-    case llvm::Triple::msp430:
-    case llvm::Triple::r600:
-    case llvm::Triple::riscv32:
-    case llvm::Triple::riscv64:
-    case llvm::Triple::sparc:
-    case llvm::Triple::sparcv9:
-    case llvm::Triple::sparcel:
-    case llvm::Triple::systemz:
-    case llvm::Triple::tce:
-    case llvm::Triple::tcele:
-    case llvm::Triple::thumb:
-    case llvm::Triple::thumbeb:
-    case llvm::Triple::xcore:
-    case llvm::Triple::le32:
-    case llvm::Triple::le64:
-    case llvm::Triple::amdil:
-    case llvm::Triple::amdil64:
-    case llvm::Triple::hsail:
-    case llvm::Triple::hsail64:
-    case llvm::Triple::spir:
-    case llvm::Triple::spir64:
-    case llvm::Triple::kalimba:
-    case llvm::Triple::shave:
-    case llvm::Triple::lanai:
-    case llvm::Triple::wasm32:
-    case llvm::Triple::wasm64:
-    case llvm::Triple::renderscript32:
-    case llvm::Triple::renderscript64:
-      return false;
-    }
-  }
-  return true;
-}
-
-bool matchesContext(CodeGenModule &CGM,
-                    const CompleteOMPContextSelectorData &ContextData) {
-  for (const OMPContextSelectorData &Data : ContextData) {
-    switch (Data.Ctx) {
-    case OMP_CTX_vendor:
-      assert(Data.CtxSet == OMP_CTX_SET_implementation &&
-             "Expected implementation context selector set.");
-      if (!checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(Data))
-        return false;
-      break;
-    case OMP_CTX_kind:
-      assert(Data.CtxSet == OMP_CTX_SET_device &&
-             "Expected device context selector set.");
-      if (!checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(Data,
-                                                                           CGM))
-        return false;
-      break;
-    case OMP_CTX_unknown:
-      llvm_unreachable("Unknown context selector kind.");
-    }
-  }
-  return true;
-}
-
-static CompleteOMPContextSelectorData
-translateAttrToContextSelectorData(ASTContext &C,
-                                   const OMPDeclareVariantAttr *A) {
-  CompleteOMPContextSelectorData Data;
-  for (unsigned I = 0, E = A->scores_size(); I < E; ++I) {
-    Data.emplace_back();
-    auto CtxSet = static_cast<OpenMPContextSelectorSetKind>(
-        *std::next(A->ctxSelectorSets_begin(), I));
-    auto Ctx = static_cast<OpenMPContextSelectorKind>(
-        *std::next(A->ctxSelectors_begin(), I));
-    Data.back().CtxSet = CtxSet;
-    Data.back().Ctx = Ctx;
-    const Expr *Score = *std::next(A->scores_begin(), I);
-    Data.back().Score = Score->EvaluateKnownConstInt(C);
-    switch (Ctx) {
-    case OMP_CTX_vendor:
-      assert(CtxSet == OMP_CTX_SET_implementation &&
-             "Expected implementation context selector set.");
-      Data.back().Names =
-          llvm::makeArrayRef(A->implVendors_begin(), A->implVendors_end());
-      break;
-    case OMP_CTX_kind:
-      assert(CtxSet == OMP_CTX_SET_device &&
-             "Expected device context selector set.");
-      Data.back().Names =
-          llvm::makeArrayRef(A->deviceKinds_begin(), A->deviceKinds_end());
-      break;
-    case OMP_CTX_unknown:
-      llvm_unreachable("Unknown context selector kind.");
-    }
-  }
-  return Data;
-}
-
-static bool isStrictSubset(const CompleteOMPContextSelectorData &LHS,
-                           const CompleteOMPContextSelectorData &RHS) {
-  llvm::SmallDenseMap<std::pair<int, int>, llvm::StringSet<>, 4> RHSData;
-  for (const OMPContextSelectorData &D : RHS) {
-    auto &Pair = RHSData.FindAndConstruct(std::make_pair(D.CtxSet, D.Ctx));
-    Pair.getSecond().insert(D.Names.begin(), D.Names.end());
-  }
-  bool AllSetsAreEqual = true;
-  for (const OMPContextSelectorData &D : LHS) {
-    auto It = RHSData.find(std::make_pair(D.CtxSet, D.Ctx));
-    if (It == RHSData.end())
-      return false;
-    if (D.Names.size() > It->getSecond().size())
-      return false;
-    if (llvm::set_union(It->getSecond(), D.Names))
-      return false;
-    AllSetsAreEqual =
-        AllSetsAreEqual && (D.Names.size() == It->getSecond().size());
-  }
-
-  return LHS.size() != RHS.size() || !AllSetsAreEqual;
-}
-
-static bool greaterCtxScore(const CompleteOMPContextSelectorData &LHS,
-                            const CompleteOMPContextSelectorData &RHS) {
-  // Score is calculated as sum of all scores + 1.
-  llvm::APSInt LHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
-  bool RHSIsSubsetOfLHS = isStrictSubset(RHS, LHS);
-  if (RHSIsSubsetOfLHS) {
-    LHSScore = llvm::APSInt::get(0);
-  } else {
-    for (const OMPContextSelectorData &Data : LHS) {
-      if (Data.Score.getBitWidth() > LHSScore.getBitWidth()) {
-        LHSScore = LHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
-      } else if (Data.Score.getBitWidth() < LHSScore.getBitWidth()) {
-        LHSScore += Data.Score.extend(LHSScore.getBitWidth());
-      } else {
-        LHSScore += Data.Score;
-      }
-    }
-  }
-  llvm::APSInt RHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
-  if (!RHSIsSubsetOfLHS && isStrictSubset(LHS, RHS)) {
-    RHSScore = llvm::APSInt::get(0);
-  } else {
-    for (const OMPContextSelectorData &Data : RHS) {
-      if (Data.Score.getBitWidth() > RHSScore.getBitWidth()) {
-        RHSScore = RHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
-      } else if (Data.Score.getBitWidth() < RHSScore.getBitWidth()) {
-        RHSScore += Data.Score.extend(RHSScore.getBitWidth());
-      } else {
-        RHSScore += Data.Score;
-      }
-    }
-  }
-  return llvm::APSInt::compareValues(LHSScore, RHSScore) >= 0;
-}
-
 /// Finds the variant function that matches current context with its context
 /// selector.
 static const FunctionDecl *getDeclareVariantFunction(CodeGenModule &CGM,
@@ -11275,19 +11050,8 @@
     return FD;
   // Iterate through all DeclareVariant attributes and check context selectors.
   const OMPDeclareVariantAttr *TopMostAttr = nullptr;
-  CompleteOMPContextSelectorData TopMostData;
-  for (const auto *A : FD->specific_attrs<OMPDeclareVariantAttr>()) {
-    CompleteOMPContextSelectorData Data =
-        translateAttrToContextSelectorData(CGM.getContext(), A);
-    if (!matchesContext(CGM, Data))
-      continue;
-    // If the attribute matches the context, find the attribute with the highest
-    // score.
-    if (!TopMostAttr || !greaterCtxScore(TopMostData, Data)) {
-      TopMostAttr = A;
-      TopMostData.swap(Data);
-    }
-  }
+  for (const auto *A : FD->specific_attrs<OMPDeclareVariantAttr>())
+    TopMostAttr = getBetterOpenMPContextMatch(CGM.getContext(), TopMostAttr, A);
   if (!TopMostAttr)
     return FD;
   return cast<FunctionDecl>(
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -13,6 +13,7 @@
 #include "clang/AST/StmtOpenMP.h"
 
 #include "clang/AST/ASTContext.h"
+#include "llvm/ADT/SetOperations.h"
 
 using namespace clang;
 
@@ -2239,3 +2240,267 @@
   return new (Mem)
       OMPTargetTeamsDistributeSimdDirective(CollapsedNum, NumClauses);
 }
+
+// TODO: We have various representations for the same data, it might help to
+//       reuse some instead of converting them.
+// TODO: It is unclear where this checking code should live. It is used all over
+//       the place and would probably fit bet in OMPDeclareVariantAttr.
+using OMPContextSelectorData =
+    OpenMPCtxSelectorData<ArrayRef<StringRef>, llvm::APSInt>;
+using CompleteOMPContextSelectorData = SmallVector<OMPContextSelectorData, 4>;
+
+/// Checks current context and returns true if it matches the context selector.
+template <OpenMPContextSelectorSetKind CtxSet, OpenMPContextSelectorKind Ctx,
+          typename... Arguments>
+static bool checkContext(const OMPContextSelectorData &Data,
+                         Arguments... Params) {
+  assert(Data.CtxSet != OMP_CTX_SET_unknown && Data.Ctx != OMP_CTX_unknown &&
+         "Unknown context selector or context selector set.");
+  return false;
+}
+
+/// Checks for implementation={vendor(<vendor>)} context selector.
+/// \returns true iff <vendor>="llvm", false otherwise.
+template <>
+bool checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(
+    const OMPContextSelectorData &Data) {
+  return llvm::all_of(Data.Names,
+                      [](StringRef S) { return !S.compare_lower("llvm"); });
+}
+
+/// Checks for device={kind(<kind>)} context selector.
+/// \returns true if <kind>="host" and compilation is for host.
+/// true if <kind>="nohost" and compilation is for device.
+/// true if <kind>="cpu" and compilation is for Arm, X86 or PPC CPU.
+/// true if <kind>="gpu" and compilation is for NVPTX or AMDGCN.
+/// false otherwise.
+template <>
+bool checkContext<OMP_CTX_SET_device, OMP_CTX_kind, const LangOptions &,
+                  const TargetInfo &>(const OMPContextSelectorData &Data,
+                                      const LangOptions &LO,
+                                      const TargetInfo &TI) {
+  for (StringRef Name : Data.Names) {
+    if (!Name.compare_lower("host")) {
+      if (LO.OpenMPIsDevice)
+        return false;
+      continue;
+    }
+    if (!Name.compare_lower("nohost")) {
+      if (!LO.OpenMPIsDevice)
+        return false;
+      continue;
+    }
+    switch (TI.getTriple().getArch()) {
+    case llvm::Triple::arm:
+    case llvm::Triple::armeb:
+    case llvm::Triple::aarch64:
+    case llvm::Triple::aarch64_be:
+    case llvm::Triple::aarch64_32:
+    case llvm::Triple::ppc:
+    case llvm::Triple::ppc64:
+    case llvm::Triple::ppc64le:
+    case llvm::Triple::x86:
+    case llvm::Triple::x86_64:
+      if (Name.compare_lower("cpu"))
+        return false;
+      break;
+    case llvm::Triple::amdgcn:
+    case llvm::Triple::nvptx:
+    case llvm::Triple::nvptx64:
+      if (Name.compare_lower("gpu"))
+        return false;
+      break;
+    case llvm::Triple::UnknownArch:
+    case llvm::Triple::arc:
+    case llvm::Triple::avr:
+    case llvm::Triple::bpfel:
+    case llvm::Triple::bpfeb:
+    case llvm::Triple::hexagon:
+    case llvm::Triple::mips:
+    case llvm::Triple::mipsel:
+    case llvm::Triple::mips64:
+    case llvm::Triple::mips64el:
+    case llvm::Triple::msp430:
+    case llvm::Triple::r600:
+    case llvm::Triple::riscv32:
+    case llvm::Triple::riscv64:
+    case llvm::Triple::sparc:
+    case llvm::Triple::sparcv9:
+    case llvm::Triple::sparcel:
+    case llvm::Triple::systemz:
+    case llvm::Triple::tce:
+    case llvm::Triple::tcele:
+    case llvm::Triple::thumb:
+    case llvm::Triple::thumbeb:
+    case llvm::Triple::xcore:
+    case llvm::Triple::le32:
+    case llvm::Triple::le64:
+    case llvm::Triple::amdil:
+    case llvm::Triple::amdil64:
+    case llvm::Triple::hsail:
+    case llvm::Triple::hsail64:
+    case llvm::Triple::spir:
+    case llvm::Triple::spir64:
+    case llvm::Triple::kalimba:
+    case llvm::Triple::shave:
+    case llvm::Triple::lanai:
+    case llvm::Triple::wasm32:
+    case llvm::Triple::wasm64:
+    case llvm::Triple::renderscript32:
+    case llvm::Triple::renderscript64:
+      return false;
+    }
+  }
+  return true;
+}
+
+static CompleteOMPContextSelectorData
+translateAttrToContextSelectorData(ASTContext &C,
+                                   const OMPDeclareVariantAttr *A) {
+  CompleteOMPContextSelectorData Data;
+  if (!A)
+    return Data;
+  for (unsigned I = 0, E = A->scores_size(); I < E; ++I) {
+    Data.emplace_back();
+    auto CtxSet = static_cast<OpenMPContextSelectorSetKind>(
+        *std::next(A->ctxSelectorSets_begin(), I));
+    auto Ctx = static_cast<OpenMPContextSelectorKind>(
+        *std::next(A->ctxSelectors_begin(), I));
+    Data.back().CtxSet = CtxSet;
+    Data.back().Ctx = Ctx;
+    const Expr *Score = *std::next(A->scores_begin(), I);
+    Data.back().Score = Score->EvaluateKnownConstInt(C);
+    switch (Ctx) {
+    case OMP_CTX_vendor:
+      assert(CtxSet == OMP_CTX_SET_implementation &&
+             "Expected implementation context selector set.");
+      Data.back().Names =
+          llvm::makeArrayRef(A->implVendors_begin(), A->implVendors_end());
+      break;
+    case OMP_CTX_kind:
+      assert(CtxSet == OMP_CTX_SET_device &&
+             "Expected device context selector set.");
+      Data.back().Names =
+          llvm::makeArrayRef(A->deviceKinds_begin(), A->deviceKinds_end());
+      break;
+    case OMP_CTX_unknown:
+      llvm_unreachable("Unknown context selector kind.");
+    }
+  }
+  return Data;
+}
+
+static bool
+matchesOpenMPContextImpl(const CompleteOMPContextSelectorData &ContextData,
+                         const LangOptions &LO, const TargetInfo &TI) {
+  for (const OMPContextSelectorData &Data : ContextData) {
+    switch (Data.Ctx) {
+    case OMP_CTX_vendor:
+      assert(Data.CtxSet == OMP_CTX_SET_implementation &&
+             "Expected implementation context selector set.");
+      if (!checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(Data))
+        return false;
+      break;
+    case OMP_CTX_kind:
+      assert(Data.CtxSet == OMP_CTX_SET_device &&
+             "Expected device context selector set.");
+      if (!checkContext<OMP_CTX_SET_device, OMP_CTX_kind, const LangOptions &,
+                        const TargetInfo &>(Data, LO, TI))
+        return false;
+      break;
+    case OMP_CTX_unknown:
+      llvm_unreachable("Unknown context selector kind.");
+    }
+  }
+  return true;
+}
+
+static bool isStrictSubset(const CompleteOMPContextSelectorData &LHS,
+                           const CompleteOMPContextSelectorData &RHS) {
+  llvm::SmallDenseMap<std::pair<int, int>, llvm::StringSet<>, 4> RHSData;
+  for (const OMPContextSelectorData &D : RHS) {
+    auto &Pair = RHSData.FindAndConstruct(std::make_pair(D.CtxSet, D.Ctx));
+    Pair.getSecond().insert(D.Names.begin(), D.Names.end());
+  }
+  bool AllSetsAreEqual = true;
+  for (const OMPContextSelectorData &D : LHS) {
+    auto It = RHSData.find(std::make_pair(D.CtxSet, D.Ctx));
+    if (It == RHSData.end())
+      return false;
+    if (D.Names.size() > It->getSecond().size())
+      return false;
+    if (llvm::set_union(It->getSecond(), D.Names))
+      return false;
+    AllSetsAreEqual =
+        AllSetsAreEqual && (D.Names.size() == It->getSecond().size());
+  }
+
+  return LHS.size() != RHS.size() || !AllSetsAreEqual;
+}
+
+const OMPDeclareVariantAttr *
+clang::getBetterOpenMPContextMatch(ASTContext &C,
+                                   const OMPDeclareVariantAttr *LHSAttr,
+                                   const OMPDeclareVariantAttr *RHSAttr) {
+  if (LHSAttr->isInherited() && RHSAttr->isInherited())
+    return nullptr;
+  if (LHSAttr->isInherited())
+    return RHSAttr;
+  if (RHSAttr->isInherited())
+    return LHSAttr;
+  const CompleteOMPContextSelectorData LHS =
+      translateAttrToContextSelectorData(C, LHSAttr);
+  const CompleteOMPContextSelectorData RHS =
+      translateAttrToContextSelectorData(C, RHSAttr);
+  bool LHSMatch =
+      matchesOpenMPContextImpl(LHS, C.getLangOpts(), C.getTargetInfo());
+  bool RHSMatch =
+      matchesOpenMPContextImpl(RHS, C.getLangOpts(), C.getTargetInfo());
+  if (!LHSMatch && !RHSMatch)
+    return nullptr;
+  if (LHSMatch && !RHSMatch)
+    return LHSAttr;
+  if (!LHSMatch && RHSMatch)
+    return RHSAttr;
+  assert(LHSMatch && RHSMatch && "broken invariant");
+
+  // Score is calculated as sum of all scores + 1.
+  llvm::APSInt LHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
+  bool RHSIsSubsetOfLHS = isStrictSubset(RHS, LHS);
+  if (RHSIsSubsetOfLHS) {
+    LHSScore = llvm::APSInt::get(0);
+  } else {
+    for (const OMPContextSelectorData &Data : LHS) {
+      if (Data.Score.getBitWidth() > LHSScore.getBitWidth()) {
+        LHSScore = LHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
+      } else if (Data.Score.getBitWidth() < LHSScore.getBitWidth()) {
+        LHSScore += Data.Score.extend(LHSScore.getBitWidth());
+      } else {
+        LHSScore += Data.Score;
+      }
+    }
+  }
+  llvm::APSInt RHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
+  if (!RHSIsSubsetOfLHS && isStrictSubset(LHS, RHS)) {
+    RHSScore = llvm::APSInt::get(0);
+  } else {
+    for (const OMPContextSelectorData &Data : RHS) {
+      if (Data.Score.getBitWidth() > RHSScore.getBitWidth()) {
+        RHSScore = RHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
+      } else if (Data.Score.getBitWidth() < RHSScore.getBitWidth()) {
+        RHSScore += Data.Score.extend(RHSScore.getBitWidth());
+      } else {
+        RHSScore += Data.Score;
+      }
+    }
+  }
+  return llvm::APSInt::compareValues(LHSScore, RHSScore) >= 0 ? LHSAttr
+                                                              : RHSAttr;
+}
+
+bool clang::isOpenMPContextMatch(ASTContext &C,
+                                 const OMPDeclareVariantAttr *A) {
+  const CompleteOMPContextSelectorData Data =
+      translateAttrToContextSelectorData(C, A);
+  return matchesOpenMPContextImpl(Data, C.getLangOpts(), C.getTargetInfo());
+}
Index: clang/lib/AST/Decl.cpp
===================================================================
--- clang/lib/AST/Decl.cpp
+++ clang/lib/AST/Decl.cpp
@@ -3074,6 +3074,9 @@
     return MultiVersionKind::CPUDispatch;
   if (hasAttr<CPUSpecificAttr>())
     return MultiVersionKind::CPUSpecific;
+  if (hasAttr<OMPDeclareVariantAttr>() &&
+      !getAttr<OMPDeclareVariantAttr>()->getVariantFuncRef())
+    return MultiVersionKind::OMPVariant;
   return MultiVersionKind::None;
 }
 
@@ -3089,6 +3092,11 @@
   return isMultiVersion() && hasAttr<TargetAttr>();
 }
 
+bool FunctionDecl::isOpenMPMultiVersion() const {
+  return isMultiVersion() && hasAttr<OMPDeclareVariantAttr>() &&
+         !getAttr<OMPDeclareVariantAttr>()->getVariantFuncRef();
+}
+
 void
 FunctionDecl::setPreviousDeclaration(FunctionDecl *PrevDecl) {
   redeclarable_base::setPreviousDecl(PrevDecl);
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -9332,6 +9332,9 @@
   // OpenMP directives and clauses.
   //
 private:
+  /// Copies declare variant attributes from the template TD to the function FD.
+  void inheritOpenMPVariantAttrs(FunctionDecl *FD,
+                                 const FunctionTemplateDecl &TD);
   void *VarDataSharingAttributesStack;
   /// Number of nested '#pragma omp declare target' directives.
   unsigned DeclareTargetNestingLevel = 0;
@@ -9401,6 +9404,9 @@
   using OMPCtxSelectorData =
       OpenMPCtxSelectorData<SmallVector<OMPCtxStringType, 4>, ExprResult>;
 
+  /// A declare variant attribute if we are inside a begin/end declare variant
+  OMPDeclareVariantAttr *DeclareVariantScopeAttr = nullptr;
+
   /// Checks if the variant/multiversion functions are compatible.
   bool areMultiversionVariantFunctionsCompatible(
       const FunctionDecl *OldFD, const FunctionDecl *NewFD,
@@ -9408,7 +9414,9 @@
       const PartialDiagnosticAt &NoteCausedDiagIDAt,
       const PartialDiagnosticAt &NoSupportDiagIDAt,
       const PartialDiagnosticAt &DiffDiagIDAt, bool TemplatesSupported,
-      bool ConstexprSupported, bool CLinkageMayDiffer);
+      bool ConstexprSupported, bool CLinkageMayDiffer,
+      bool StorageClassMayDiffer, bool ConstexprSpecMayDiffer,
+      bool InlineSpecificationMayDiffer);
 
   /// Function tries to capture lambda's captured variables in the OpenMP region
   /// before the original lambda is captured.
@@ -9883,7 +9891,7 @@
   /// must be used instead of the original one, specified in \p DG.
   /// \param Data Set of context-specific data for the specified context
   /// selector.
-  void ActOnOpenMPDeclareVariantDirective(FunctionDecl *FD, Expr *VariantRef,
+  bool ActOnOpenMPDeclareVariantDirective(FunctionDecl *FD, Expr *VariantRef,
                                           SourceRange SR,
                                           ArrayRef<OMPCtxSelectorData> Data);
 
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -2860,6 +2860,10 @@
   parseOpenMPContextSelectors(SourceLocation Loc,
                               SmallVectorImpl<Sema::OMPCtxSelectorData> &Data);
 
+  /// Parse match clause of '#pragma omp [begin] declare variant'.
+  void ParseOMPDeclareVariantMatchClause(
+      SourceLocation Loc, SmallVectorImpl<Sema::OMPCtxSelectorData> &Data);
+
   /// Parse clauses for '#pragma omp declare variant'.
   void ParseOMPDeclareVariantClauses(DeclGroupPtrTy Ptr, CachedTokens &Toks,
                                      SourceLocation Loc);
Index: clang/include/clang/Basic/OpenMPKinds.def
===================================================================
--- clang/include/clang/Basic/OpenMPKinds.def
+++ clang/include/clang/Basic/OpenMPKinds.def
@@ -292,6 +292,8 @@
 OPENMP_DIRECTIVE_EXT(parallel_master_taskloop, "parallel master taskloop")
 OPENMP_DIRECTIVE_EXT(master_taskloop_simd, "master taskloop simd")
 OPENMP_DIRECTIVE_EXT(parallel_master_taskloop_simd, "parallel master taskloop simd")
+OPENMP_DIRECTIVE_EXT(begin_declare_variant, "begin declare variant")
+OPENMP_DIRECTIVE_EXT(end_declare_variant, "end declare variant")
 
 // OpenMP clauses.
 OPENMP_CLAUSE(allocator, OMPAllocatorClause)
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -4556,6 +4556,18 @@
   }
 };
 
+class OMPDeclareVariantAttr;
+
+/// Helper to determine the best of two potential context matches. Note that
+/// nullptr are valid inputs but also valid outputs, e.g., if neither attribute
+/// describes a matching context.
+const OMPDeclareVariantAttr *
+getBetterOpenMPContextMatch(ASTContext &C, const OMPDeclareVariantAttr *LHSAttr,
+                            const OMPDeclareVariantAttr *RHSAttr);
+
+/// Return true if the context described by \p A matches.
+bool isOpenMPContextMatch(ASTContext &C, const OMPDeclareVariantAttr *A);
+
 } // end namespace clang
 
 #endif
Index: clang/include/clang/AST/Decl.h
===================================================================
--- clang/include/clang/AST/Decl.h
+++ clang/include/clang/AST/Decl.h
@@ -1775,7 +1775,8 @@
   None,
   Target,
   CPUSpecific,
-  CPUDispatch
+  CPUDispatch,
+  OMPVariant,
 };
 
 /// Represents a function declaration or definition.
@@ -2335,6 +2336,10 @@
   /// the target functionality.
   bool isTargetMultiVersion() const;
 
+  /// True if this function is a multiversioned function as a part of
+  /// the OpenMP begin/end declare variant functionality.
+  bool isOpenMPMultiVersion() const;
+
   void setPreviousDeclaration(FunctionDecl * PrevDecl);
 
   FunctionDecl *getCanonicalDecl() override;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to