skatrak updated this revision to Diff 531437.
skatrak added a comment.

Update patch to integrate with related patch D149337 
<https://reviews.llvm.org/D149337> and address reviewer's comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D147218/new/

https://reviews.llvm.org/D147218

Files:
  flang/include/flang/Lower/OpenMP.h
  flang/lib/Lower/Bridge.cpp
  flang/lib/Lower/OpenMP.cpp
  flang/test/Lower/OpenMP/requires-notarget.f90
  flang/test/Lower/OpenMP/requires.f90

Index: flang/test/Lower/OpenMP/requires.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires.f90
@@ -0,0 +1,13 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+! This test checks the lowering of requires into MLIR
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>
+program requires
+  !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+end program requires
+
+subroutine f
+  !$omp declare target
+end subroutine f
Index: flang/test/Lower/OpenMP/requires-notarget.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires-notarget.f90
@@ -0,0 +1,11 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+! This test checks that requires lowering into MLIR skips creating the
+! omp.requires attribute with target-related clauses if there are no device
+! functions in the compilation unit
+
+!CHECK:      module attributes {
+!CHECK-NOT:  omp.requires
+program requires
+  !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+end program requires
Index: flang/lib/Lower/OpenMP.cpp
===================================================================
--- flang/lib/Lower/OpenMP.cpp
+++ flang/lib/Lower/OpenMP.cpp
@@ -2594,16 +2594,14 @@
   converter.bindSymbol(sym, symThreadprivateExv);
 }
 
-void handleDeclareTarget(Fortran::lower::AbstractConverter &converter,
-                         Fortran::lower::pft::Evaluation &eval,
-                         const Fortran::parser::OpenMPDeclareTargetConstruct
-                             &declareTargetConstruct) {
-  llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause,
-                              Fortran::semantics::Symbol>,
-                    0>
-      symbolAndClause;
-  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
-
+/// Extract the list of function and variable symbols affected by the given
+/// 'declare target' directive and return the intended device type for them.
+static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+    SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
+                              Fortran::semantics::Symbol>> &symbolAndClause) {
+  // Gather the symbols and clauses
   auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList,
                                 mlir::omp::DeclareTargetCaptureClause clause) {
     for (const Fortran::parser::OmpObject &ompObject : objList.v) {
@@ -2628,6 +2626,7 @@
       Fortran::parser::OmpDeviceTypeClause::Type::Any;
   const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
       declareTargetConstruct.t);
+
   if (const auto *objectList{
           Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
     // Case: declare target(func, var1, var2)
@@ -2662,6 +2661,28 @@
     }
   }
 
+  switch (deviceType) {
+  case Fortran::parser::OmpDeviceTypeClause::Type::Any:
+    return mlir::omp::DeclareTargetDeviceType::any;
+  case Fortran::parser::OmpDeviceTypeClause::Type::Host:
+    return mlir::omp::DeclareTargetDeviceType::host;
+  case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
+    return mlir::omp::DeclareTargetDeviceType::nohost;
+  }
+}
+
+void genDeclareTarget(Fortran::lower::AbstractConverter &converter,
+                      Fortran::lower::pft::Evaluation &eval,
+                      const Fortran::parser::OpenMPDeclareTargetConstruct
+                          &declareTargetConstruct) {
+  llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause,
+                              Fortran::semantics::Symbol>,
+                    0>
+      symbolAndClause;
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+  mlir::omp::DeclareTargetDeviceType deviceType =
+      getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause);
+
   for (std::pair<mlir::omp::DeclareTargetCaptureClause,
                  Fortran::semantics::Symbol>
            symClause : symbolAndClause) {
@@ -2688,35 +2709,44 @@
           converter.getCurrentLocation(),
           "Attempt to apply declare target on unsupported operation");
 
-    mlir::omp::DeclareTargetDeviceType newDeviceType;
-    switch (deviceType) {
-    case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
-      newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost;
-      break;
-    case Fortran::parser::OmpDeviceTypeClause::Type::Host:
-      newDeviceType = mlir::omp::DeclareTargetDeviceType::host;
-      break;
-    case Fortran::parser::OmpDeviceTypeClause::Type::Any:
-      newDeviceType = mlir::omp::DeclareTargetDeviceType::any;
-      break;
-    }
-
     // The function or global already has a declare target applied to it,
     // very likely through implicit capture (usage in another declare
     // target function/subroutine). It should be marked as any if it has
     // been assigned both host and nohost, else we skip, as there is no
     // change
     if (declareTargetOp.isDeclareTarget()) {
-      if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType)
+      if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
         declareTargetOp.setDeclareTarget(
             mlir::omp::DeclareTargetDeviceType::any, std::get<0>(symClause));
       continue;
     }
 
-    declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(symClause));
+    declareTargetOp.setDeclareTarget(deviceType, std::get<0>(symClause));
   }
 }
 
+void Fortran::lower::analyzeOpenMPDeclarativeConstruct(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl,
+    bool &ompDeviceCodeFound) {
+  std::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
+            mlir::omp::DeclareTargetDeviceType targetType =
+                Fortran::lower::getOpenMPDeclareTargetFunctionDevice(
+                    converter, eval, ompReq)
+                    .value_or(mlir::omp::DeclareTargetDeviceType::host);
+
+            ompDeviceCodeFound =
+                ompDeviceCodeFound ||
+                targetType != mlir::omp::DeclareTargetDeviceType::host;
+          },
+          [&](const auto &) {},
+      },
+      ompDecl.u);
+}
+
 void Fortran::lower::genOpenMPDeclarativeConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::lower::pft::Evaluation &eval,
@@ -2739,11 +2769,14 @@
           },
           [&](const Fortran::parser::OpenMPDeclareTargetConstruct
                   &declareTargetConstruct) {
-            handleDeclareTarget(converter, eval, declareTargetConstruct);
+            genDeclareTarget(converter, eval, declareTargetConstruct);
           },
           [&](const Fortran::parser::OpenMPRequiresConstruct
                   &requiresConstruct) {
-            TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct");
+            // Requires directives are gathered and processed in semantics in
+            // order to support modules, and then combined in the lowering
+            // bridge before triggering codegen just once. Hence, there is no
+            // need for codegen for each individual occurrence here.
           },
           [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
             // The directive is lowered when instantiating the variable to
@@ -2965,3 +2998,84 @@
     }
   }
 }
+
+std::optional<mlir::omp::DeclareTargetDeviceType>
+Fortran::lower::getOpenMPDeclareTargetFunctionDevice(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct
+        &declareTargetConstruct) {
+  llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause,
+                              Fortran::semantics::Symbol>,
+                    0>
+      symbolAndClause;
+  mlir::omp::DeclareTargetDeviceType deviceType =
+      getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause);
+
+  // Return the device type only if at least one of the targets for the
+  // directive is a function or subroutine
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+  for (std::pair<mlir::omp::DeclareTargetCaptureClause,
+                 Fortran::semantics::Symbol>
+           sym : symbolAndClause) {
+    mlir::Operation *op =
+        mod.lookupSymbol(converter.mangleName(std::get<1>(sym)));
+
+    if (mlir::isa<mlir::func::FuncOp>(op))
+      return deviceType;
+  }
+
+  return std::nullopt;
+}
+
+bool Fortran::lower::isOpenMPTargetConstruct(
+    const Fortran::parser::OpenMPConstruct &omp) {
+  if (const auto *blockDir =
+          std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
+    const auto &beginBlockDir{
+        std::get<Fortran::parser::OmpBeginBlockDirective>(blockDir->t)};
+    const auto &beginDir{
+        std::get<Fortran::parser::OmpBlockDirective>(beginBlockDir.t)};
+
+    switch (beginDir.v) {
+    case llvm::omp::Directive::OMPD_target:
+    case llvm::omp::Directive::OMPD_target_parallel:
+    case llvm::omp::Directive::OMPD_target_parallel_do:
+    case llvm::omp::Directive::OMPD_target_parallel_do_simd:
+    case llvm::omp::Directive::OMPD_target_simd:
+    case llvm::omp::Directive::OMPD_target_teams:
+    case llvm::omp::Directive::OMPD_target_teams_distribute:
+    case llvm::omp::Directive::OMPD_target_teams_distribute_simd:
+      return true;
+    default:
+      break;
+    }
+  }
+
+  return false;
+}
+
+omp::ClauseRequires Fortran::lower::extractOpenMPRequiresClauses(
+    const Fortran::parser::OmpClauseList &clauseList) {
+  using omp::ClauseRequires, Fortran::parser::OmpClause;
+  auto requiresFlags = ClauseRequires::none;
+
+  for (const OmpClause &clause : clauseList.v) {
+    if (std::get_if<OmpClause::DynamicAllocators>(&clause.u))
+      requiresFlags = requiresFlags | ClauseRequires::dynamic_allocators;
+    else if (std::get_if<OmpClause::ReverseOffload>(&clause.u))
+      requiresFlags = requiresFlags | ClauseRequires::reverse_offload;
+    else if (std::get_if<OmpClause::UnifiedAddress>(&clause.u))
+      requiresFlags = requiresFlags | ClauseRequires::unified_address;
+    else if (std::get_if<OmpClause::UnifiedSharedMemory>(&clause.u))
+      requiresFlags = requiresFlags | ClauseRequires::unified_shared_memory;
+  }
+
+  return requiresFlags;
+}
+
+void Fortran::lower::genOpenMPRequires(Operation *mod,
+                                       omp::ClauseRequires flags) {
+  if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod))
+    offloadMod.setRequires(flags);
+}
Index: flang/lib/Lower/Bridge.cpp
===================================================================
--- flang/lib/Lower/Bridge.cpp
+++ flang/lib/Lower/Bridge.cpp
@@ -50,6 +50,7 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Runtime/iostat.h"
 #include "flang/Semantics/runtime-type-info.h"
+#include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -62,6 +63,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/Path.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <optional>
 
 #define DEBUG_TYPE "flang-lower-bridge"
@@ -288,20 +290,34 @@
     //    that they are available before lowering any function that may use
     //    them.
     bool hasMainProgram = false;
+    Fortran::semantics::OmpRequiresFlags ompRequiresFlags =
+        Fortran::semantics::OmpRequiresFlags::None;
+    std::optional<Fortran::parser::OmpAtomicDefaultMemOrderClause::Type>
+        ompAtomicDefaultMemOrder;
     for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
       std::visit(Fortran::common::visitors{
                      [&](Fortran::lower::pft::FunctionLikeUnit &f) {
                        if (f.isMainProgram())
                          hasMainProgram = true;
                        declareFunction(f);
+                       ompProcessTopLevelSymbol(f.getScope().symbol(),
+                                                ompRequiresFlags,
+                                                ompAtomicDefaultMemOrder);
                      },
                      [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                        lowerModuleDeclScope(m);
                        for (Fortran::lower::pft::FunctionLikeUnit &f :
                             m.nestedFunctions)
                          declareFunction(f);
+                       ompProcessTopLevelSymbol(m.getScope().symbol(),
+                                                ompRequiresFlags,
+                                                ompAtomicDefaultMemOrder);
+                     },
+                     [&](Fortran::lower::pft::BlockDataUnit &b) {
+                       ompProcessTopLevelSymbol(b.symTab.symbol(),
+                                                ompRequiresFlags,
+                                                ompAtomicDefaultMemOrder);
                      },
-                     [&](Fortran::lower::pft::BlockDataUnit &b) {},
                      [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
                  },
                  u);
@@ -344,6 +360,24 @@
         fir::runtime::genEnvironmentDefaults(*builder, toLocation(),
                                              bridge.getEnvironmentDefaults());
       });
+
+    // Set the module attribute related to OpenMP requires directives
+    if (ompDeviceCodeFound) {
+      using MlirRequires = mlir::omp::ClauseRequires;
+      using SemaRequires = Fortran::semantics::OmpRequiresFlags;
+      MlirRequires flags = MlirRequires::none;
+
+      if (ompRequiresFlags & SemaRequires::ReverseOffload)
+        flags = flags | MlirRequires::reverse_offload;
+      if (ompRequiresFlags & SemaRequires::UnifiedAddress)
+        flags = flags | MlirRequires::unified_address;
+      if (ompRequiresFlags & SemaRequires::UnifiedSharedMemory)
+        flags = flags | MlirRequires::unified_shared_memory;
+      if (ompRequiresFlags & SemaRequires::DynamicAllocators)
+        flags = flags | MlirRequires::dynamic_allocators;
+
+      Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), flags);
+    }
   }
 
   /// Declare a function.
@@ -1191,6 +1225,47 @@
     activeConstructStack.pop_back();
   }
 
+  void ompProcessTopLevelSymbol(
+      const Fortran::semantics::Symbol *symbol,
+      Fortran::semantics::OmpRequiresFlags &ompRequiresFlags,
+      std::optional<Fortran::parser::OmpAtomicDefaultMemOrderClause::Type>
+          &ompAtomicDefaultMemOrder) {
+    if (!symbol)
+      return;
+
+    Fortran::common::visit(
+        [&](const auto &details) {
+          if constexpr (std::is_base_of_v<
+                            Fortran::semantics::WithOmpDeclarative,
+                            std::decay_t<decltype(details)>>) {
+            // Collect OpenMP 'requires' clauses.
+            if (details.has_ompRequires())
+              ompRequiresFlags |= *details.ompRequires();
+
+            // Make sure any atomic_default_mem_order OpenMP 'requires' clauses
+            // obtained for different top-level symbols match.
+            if (details.has_ompAtomicDefaultMemOrder()) {
+              Fortran::parser::OmpAtomicDefaultMemOrderClause::Type memOrder{
+                  *details.ompAtomicDefaultMemOrder()};
+              if (ompAtomicDefaultMemOrder &&
+                  memOrder != *ompAtomicDefaultMemOrder)
+                fir::emitFatalError(
+                    getCurrentLocation(),
+                    llvm::StringRef{
+                        "incompatible OpenMP requires atomic_default_mem_order "
+                        "clauses found: '"} +
+                        Fortran::parser::OmpAtomicDefaultMemOrderClause::
+                            EnumToString(memOrder) +
+                        llvm::StringRef{"' and '"} +
+                        Fortran::parser::OmpAtomicDefaultMemOrderClause::
+                            EnumToString(*ompAtomicDefaultMemOrder));
+              ompAtomicDefaultMemOrder = memOrder;
+            }
+          }
+        },
+        symbol->details());
+  }
+
   //===--------------------------------------------------------------------===//
   // Termination of symbolically referenced execution units
   //===--------------------------------------------------------------------===//
@@ -2201,10 +2276,16 @@
 
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
+
+    // Register if a target region was found
+    ompDeviceCodeFound =
+        ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
   }
 
   void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+    analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl,
+                                      ompDeviceCodeFound);
     genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
@@ -4530,6 +4611,10 @@
 
   /// A counter for uniquing names in `literalNamesMap`.
   std::uint64_t uniqueLitId = 0;
+
+  /// Whether an OpenMP target region or declare target function/subroutine
+  /// intended for device offloading has been detected
+  bool ompDeviceCodeFound = false;
 };
 
 } // namespace
Index: flang/include/flang/Lower/OpenMP.h
===================================================================
--- flang/include/flang/Lower/OpenMP.h
+++ flang/include/flang/Lower/OpenMP.h
@@ -13,13 +13,9 @@
 #ifndef FORTRAN_LOWER_OPENMP_H
 #define FORTRAN_LOWER_OPENMP_H
 
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include <cinttypes>
 
-namespace mlir {
-class Value;
-class Operation;
-} // namespace mlir
-
 namespace fir {
 class FirOpBuilder;
 class ConvertOp;
@@ -29,6 +25,7 @@
 namespace parser {
 struct OpenMPConstruct;
 struct OpenMPDeclarativeConstruct;
+struct OpenMPDeclareTargetConstruct;
 struct OmpEndLoopDirective;
 struct OmpClauseList;
 } // namespace parser
@@ -44,6 +41,9 @@
 
 void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
                         const parser::OpenMPConstruct &);
+void analyzeOpenMPDeclarativeConstruct(
+    Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+    const parser::OpenMPDeclarativeConstruct &, bool &);
 void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
                                    const parser::OpenMPDeclarativeConstruct &);
 int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
@@ -56,6 +56,17 @@
 void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
                      mlir::Value, fir::ConvertOp * = nullptr);
 void removeStoreOp(mlir::Operation *, mlir::Value);
+
+std::optional<mlir::omp::DeclareTargetDeviceType>
+getOpenMPDeclareTargetFunctionDevice(
+    Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+    const Fortran::parser::OpenMPDeclareTargetConstruct &);
+bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
+
+mlir::omp::ClauseRequires
+extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &);
+void genOpenMPRequires(mlir::Operation *, mlir::omp::ClauseRequires);
+
 } // namespace lower
 } // namespace Fortran
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to