https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/110267
>From 2c5d74d932797b916b5f0da6fb017b5f4af2b2b4 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Fri, 27 Sep 2024 13:51:27 +0100 Subject: [PATCH] [Flang][OpenMP] Improve entry block argument creation and binding The main purpose of this patch is to centralize the logic for creating MLIR operation entry blocks and for binding them to the corresponding symbols. This minimizes the chances of mixing arguments up for operations having multiple entry block argument-generating clauses and prevents divergence while binding arguments. Some changes implemented to this end are: - Split into two functions the creation of the entry block, and the binding of its arguments and the corresponding Fortran symbol. This enabled a significant simplification of the lowering of composite constructs, where it's no longer necessary to manually ensure the lists of arguments and symbols refer to the same variables in the same order and also match the expected order by the `BlockArgOpenMPOpInterface`. - Removed redundant and error-prone passing of types and locations from `ClauseProcessor` methods. Instead, these are obtained from the values in the appropriate clause operands structure. This also simplifies argument lists of several lowering functions. - Access block arguments of already created MLIR operations through the `BlockArgOpenMPOpInterface` instead of directly indexing the argument list of the operation, which is not scalable as more entry block argument-generating clauses are added to an operation. - Simplified the implementation of `genParallelOp` to no longer need to define different callbacks depending on whether delayed privatization is enabled. --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 79 +- flang/lib/Lower/OpenMP/ClauseProcessor.h | 38 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 1016 +++++++++-------- flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 5 +- flang/lib/Lower/OpenMP/ReductionProcessor.h | 3 +- flang/lib/Lower/OpenMP/Utils.cpp | 9 +- flang/lib/Lower/OpenMP/Utils.h | 4 +- 7 files changed, 554 insertions(+), 600 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index e9ef8579100e93..44f5ca7f342707 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter, static void addUseDeviceClause( lower::AbstractConverter &converter, const omp::ObjectList &objects, llvm::SmallVectorImpl<mlir::Value> &operands, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) { genObjectList(objects, converter, operands); - for (mlir::Value &operand : operands) { + for (mlir::Value &operand : operands) checkMapType(operand.getLoc(), operand.getType()); - useDeviceTypes.push_back(operand.getType()); - useDeviceLocs.push_back(operand.getLoc()); - } + for (const omp::Object &object : objects) useDeviceSyms.push_back(object.sym()); } @@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const { bool ClauseProcessor::processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, - llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const { + llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const { return findRepeatableClause<omp::clause::HasDeviceAddr>( [&](const omp::clause::HasDeviceAddr &devAddrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars, - isDeviceTypes, isDeviceLocs, isDeviceSymbols); + isDeviceSyms); }); } @@ -864,14 +858,12 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, - llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const { + llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const { return findRepeatableClause<omp::clause::IsDevicePtr>( [&](const omp::clause::IsDevicePtr &devPtrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, - isDeviceTypes, isDeviceLocs, isDeviceSymbols); + isDeviceSyms); }); } @@ -892,9 +884,7 @@ void ClauseProcessor::processMapObjects( std::map<const semantics::Symbol *, llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapVars, - llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms, - llvm::SmallVectorImpl<mlir::Location> *mapSymLocs, - llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const { + llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); for (const omp::Object &object : objects) { llvm::SmallVector<mlir::Value> bounds; @@ -927,12 +917,7 @@ void ClauseProcessor::processMapObjects( addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx); } else { mapVars.push_back(mapOp); - if (mapSyms) - mapSyms->push_back(object.sym()); - if (mapSymTypes) - mapSymTypes->push_back(baseOp.getType()); - if (mapSymLocs) - mapSymLocs->push_back(baseOp.getLoc()); + mapSyms.push_back(object.sym()); } } } @@ -940,9 +925,7 @@ void ClauseProcessor::processMapObjects( bool ClauseProcessor::processMap( mlir::Location currentLocation, lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, - llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms, - llvm::SmallVectorImpl<mlir::Location> *mapSymLocs, - llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const { + llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const { // We always require tracking of symbols, even if the caller does not, // so we create an optionally used local set of symbols when the mapSyms // argument is not present. @@ -999,12 +982,11 @@ bool ClauseProcessor::processMap( } processMapObjects(stmtCtx, clauseLocation, std::get<omp::ObjectList>(clause.t), mapTypeBits, - parentMemberIndices, result.mapVars, ptrMapSyms, - mapSymLocs, mapSymTypes); + parentMemberIndices, result.mapVars, *ptrMapSyms); }); insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars, - *ptrMapSyms, mapSymTypes, mapSymLocs); + *ptrMapSyms); return clauseFound; } @@ -1027,7 +1009,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, processMapObjects(stmtCtx, clauseLocation, std::get<ObjectList>(clause.t), mapTypeBits, parentMemberIndices, result.mapVars, - &mapSymbols); + mapSymbols); }; bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn); @@ -1035,8 +1017,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound; insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars, - mapSymbols, - /*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr); + mapSymbols); return clauseFound; } @@ -1054,8 +1035,7 @@ bool ClauseProcessor::processNontemporal( bool ClauseProcessor::processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> *outReductionTypes, - llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const { + llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const { return findRepeatableClause<omp::clause::Reduction>( [&](const omp::clause::Reduction &clause, const parser::CharBlock &) { llvm::SmallVector<mlir::Value> reductionVars; @@ -1063,25 +1043,16 @@ bool ClauseProcessor::processReduction( llvm::SmallVector<mlir::Attribute> reductionDeclSymbols; llvm::SmallVector<const semantics::Symbol *> reductionSyms; ReductionProcessor rp; - rp.addDeclareReduction( - currentLocation, converter, clause, reductionVars, reduceVarByRef, - reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr); + rp.addDeclareReduction(currentLocation, converter, clause, + reductionVars, reduceVarByRef, + reductionDeclSymbols, reductionSyms); // Copy local lists into the output. llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref)); llvm::copy(reductionDeclSymbols, std::back_inserter(result.reductionSyms)); - - if (outReductionTypes) { - outReductionTypes->reserve(outReductionTypes->size() + - reductionVars.size()); - llvm::transform(reductionVars, std::back_inserter(*outReductionTypes), - [](mlir::Value v) { return v.getType(); }); - } - - if (outReductionSyms) - llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms)); + llvm::copy(reductionSyms, std::back_inserter(outReductionSyms)); }); } @@ -1107,8 +1078,6 @@ bool ClauseProcessor::processEnter( bool ClauseProcessor::processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const { std::map<const semantics::Symbol *, llvm::SmallVector<OmpMapMemberIndicesData>> @@ -1122,19 +1091,16 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, - &useDeviceSyms, &useDeviceLocs, &useDeviceTypes); + useDeviceSyms); }); insertChildMapInfoIntoParent(converter, parentMemberIndices, - result.useDeviceAddrVars, useDeviceSyms, - &useDeviceTypes, &useDeviceLocs); + result.useDeviceAddrVars, useDeviceSyms); return clauseFound; } bool ClauseProcessor::processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const { std::map<const semantics::Symbol *, llvm::SmallVector<OmpMapMemberIndicesData>> @@ -1148,12 +1114,11 @@ bool ClauseProcessor::processUseDevicePtr( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, - &useDeviceSyms, &useDeviceLocs, &useDeviceTypes); + useDeviceSyms); }); insertChildMapInfoIntoParent(converter, parentMemberIndices, - result.useDevicePtrVars, useDeviceSyms, - &useDeviceTypes, &useDeviceLocs); + result.useDevicePtrVars, useDeviceSyms); return clauseFound; } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 0c8e7bd47ab5a6..f34121c70d0b44 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -68,9 +68,7 @@ class ClauseProcessor { mlir::omp::FinalClauseOps &result) const; bool processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, - llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const; + llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; bool processHint(mlir::omp::HintClauseOps &result) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; @@ -104,43 +102,33 @@ class ClauseProcessor { mlir::omp::IfClauseOps &result) const; bool processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, - llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const; + llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; bool processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; // This method is used to process a map clause. - // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to - // store the original type, location and Fortran symbol for the map operands. - // They may be used later on to create the block_arguments for some of the - // target directives that require it. - bool processMap( - mlir::Location currentLocation, lower::StatementContext &stmtCtx, - mlir::omp::MapClauseOps &result, - llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr, - llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, - llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const; + // The optional parameter mapSyms is used to store the original Fortran symbol + // for the map operands. It may be used later on to create the block_arguments + // for some of the directives that require it. + bool processMap(mlir::Location currentLocation, + lower::StatementContext &stmtCtx, + mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = + nullptr) const; bool processMotionClauses(lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result); bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const; bool processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr, - llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms = - nullptr) const; + llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const; bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; bool processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; bool processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; // Call this method for these clauses that should be supported but are not @@ -181,9 +169,7 @@ class ClauseProcessor { std::map<const semantics::Symbol *, llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapVars, - llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms, - llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, - llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const; + llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const; lower::AbstractConverter &converter; semantics::SemanticsContext &semaCtx; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 17ebf93edcce1f..456f0a267923df 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -45,6 +45,36 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// +namespace { +struct EntryBlockArgsEntry { + llvm::ArrayRef<const semantics::Symbol *> syms; + llvm::ArrayRef<mlir::Value> vars; + + bool isValid() const { + // This check allows specifying a smaller number of symbols than values + // because in some case cases a single symbol generates multiple block + // arguments. + return syms.size() <= vars.size(); + } +}; + +struct EntryBlockArgs { + EntryBlockArgsEntry inReduction; + EntryBlockArgsEntry map; + EntryBlockArgsEntry priv; + EntryBlockArgsEntry reduction; + EntryBlockArgsEntry taskReduction; + EntryBlockArgsEntry useDeviceAddr; + EntryBlockArgsEntry useDevicePtr; + + bool isValid() const { + return inReduction.isValid() && map.isValid() && priv.isValid() && + reduction.isValid() && taskReduction.isValid() && + useDeviceAddr.isValid() && useDevicePtr.isValid(); + } +}; +} // namespace + static void genOMPDispatch(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, @@ -52,6 +82,163 @@ static void genOMPDispatch(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item); +/// Bind symbols to their corresponding entry block arguments. +/// +/// The binding will be performed inside of the current block, which does not +/// necessarily have to be part of the operation for which the binding is done. +/// However, block arguments must be accessible. This enables controlling the +/// insertion point of any new MLIR operations related to the binding of +/// arguments of a loop wrapper operation. +/// +/// \param [in] converter - PFT to MLIR conversion interface. +/// \param [in] op - owner operation of the block arguments to bind. +/// \param [in] args - entry block arguments information for the given +/// operation. +static void bindEntryBlockArgs(lower::AbstractConverter &converter, + mlir::omp::BlockArgOpenMPOpInterface op, + const EntryBlockArgs &args) { + assert(op != nullptr && "invalid block argument-defining operation"); + assert(args.isValid() && "invalid args"); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + auto bindSingleMapLike = [&converter, + &firOpBuilder](const semantics::Symbol &sym, + const mlir::BlockArgument &arg) { + // Clones the `bounds` placing them inside the entry block and returns + // them. + auto cloneBound = [&](mlir::Value bound) { + if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { + mlir::Operation *clonedOp = firOpBuilder.clone(*bound.getDefiningOp()); + return clonedOp->getResult(0); + } + TODO(converter.getCurrentLocation(), + "target map-like clause operand unsupported bound type"); + }; + + auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) { + llvm::SmallVector<mlir::Value> clonedBounds; + llvm::transform(bounds, std::back_inserter(clonedBounds), + [&](mlir::Value bound) { return cloneBound(bound); }); + return clonedBounds; + }; + + fir::ExtendedValue extVal = converter.getSymbolExtendedValue(sym); + auto refType = mlir::dyn_cast<fir::ReferenceType>(arg.getType()); + if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) { + converter.bindSymbol(sym, arg); + } else { + extVal.match( + [&](const fir::BoxValue &v) { + converter.bindSymbol(sym, + fir::BoxValue(arg, cloneBounds(v.getLBounds()), + v.getExplicitParameters(), + v.getExplicitExtents())); + }, + [&](const fir::MutableBoxValue &v) { + converter.bindSymbol( + sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), + v.getMutableProperties())); + }, + [&](const fir::ArrayBoxValue &v) { + converter.bindSymbol( + sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()), + v.getSourceBox())); + }, + [&](const fir::CharArrayBoxValue &v) { + converter.bindSymbol( + sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), + cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()))); + }, + [&](const fir::CharBoxValue &v) { + converter.bindSymbol( + sym, fir::CharBoxValue(arg, cloneBound(v.getLen()))); + }, + [&](const fir::UnboxedValue &v) { converter.bindSymbol(sym, arg); }, + [&](const auto &) { + TODO(converter.getCurrentLocation(), + "target map clause operand unsupported type"); + }); + } + }; + + auto bindMapLike = + [&bindSingleMapLike](llvm::ArrayRef<const semantics::Symbol *> syms, + llvm::ArrayRef<mlir::BlockArgument> args) { + // Structure component symbols don't have bindings, and can only be + // explicitly mapped individually. If a member is captured implicitly + // we map the entirety of the derived type when we find its symbol. + llvm::SmallVector<const semantics::Symbol *> processedSyms; + llvm::copy_if(syms, std::back_inserter(processedSyms), + [](auto *sym) { return !sym->owner().IsDerivedType(); }); + + for (auto [sym, arg] : llvm::zip_equal(processedSyms, args)) + bindSingleMapLike(*sym, arg); + }; + + auto bindPrivateLike = [&converter, &firOpBuilder]( + llvm::ArrayRef<const semantics::Symbol *> syms, + llvm::ArrayRef<mlir::Value> vars, + llvm::ArrayRef<mlir::BlockArgument> args) { + llvm::SmallVector<const semantics::Symbol *> processedSyms; + for (auto *sym : syms) { + if (const auto *commonDet = + sym->detailsIf<semantics::CommonBlockDetails>()) { + llvm::transform(commonDet->objects(), std::back_inserter(processedSyms), + [&](const auto &mem) { return &*mem; }); + } else { + processedSyms.push_back(sym); + } + } + + for (auto [sym, var, arg] : llvm::zip_equal(processedSyms, vars, args)) + converter.bindSymbol( + *sym, + hlfir::translateToExtendedValue( + var.getLoc(), firOpBuilder, hlfir::Entity{arg}, + /*contiguousHint=*/ + evaluate::IsSimplyContiguous(*sym, converter.getFoldingContext())) + .first); + }; + + // Process in clause name alphabetical order to match block arguments order. + bindPrivateLike(args.inReduction.syms, args.inReduction.vars, + op.getInReductionBlockArgs()); + bindMapLike(args.map.syms, op.getMapBlockArgs()); + bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs()); + bindPrivateLike(args.reduction.syms, args.reduction.vars, + op.getReductionBlockArgs()); + bindPrivateLike(args.taskReduction.syms, args.taskReduction.vars, + op.getTaskReductionBlockArgs()); + bindMapLike(args.useDeviceAddr.syms, op.getUseDeviceAddrBlockArgs()); + bindMapLike(args.useDevicePtr.syms, op.getUseDevicePtrBlockArgs()); +} + +/// Get the list of base values that the specified map-like variables point to. +/// +/// This function must be kept in sync with changes to the `createMapInfoOp` +/// utility function, since it must take into account the potential introduction +/// of levels of indirection (i.e. intermediate ops). +/// +/// \param [in] vars - list of values passed to map-like clauses, returned +/// by an `omp.map.info` operation. +/// \param [out] baseOps - populated with the `var_ptr` values of the +/// corresponding defining operations. +static void extractMapVarsBaseOps(llvm::ArrayRef<mlir::Value> vars, + llvm::SmallVectorImpl<mlir::Value> &baseOps) { + llvm::transform(vars, std::back_inserter(baseOps), [](mlir::Value map) { + auto mapInfo = map.getDefiningOp<mlir::omp::MapInfoOp>(); + assert(mapInfo && "expected all map vars to be defined by omp.map.info"); + + mlir::Value varPtr = mapInfo.getVarPtr(); + if (auto boxAddr = varPtr.getDefiningOp<fir::BoxAddrOp>()) + return boxAddr.getVal(); + + return varPtr; + }); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -226,55 +413,41 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter, return storeOp; } -// This helper function implements the functionality of "promoting" -// non-CPTR arguments of use_device_ptr to use_device_addr -// arguments (automagic conversion of use_device_ptr -> -// use_device_addr in these cases). The way we do so currently is -// through the shuffling of operands from the devicePtrOperands to -// deviceAddrOperands where neccesary and re-organizing the types, -// locations and symbols to maintain the correct ordering of ptr/addr -// input -> BlockArg. +// This helper function implements the functionality of "promoting" non-CPTR +// arguments of use_device_ptr to use_device_addr arguments (automagic +// conversion of use_device_ptr -> use_device_addr in these cases). The way we +// do so currently is through the shuffling of operands from the +// devicePtrOperands to deviceAddrOperands, as well as the types, locations and +// symbols. // -// This effectively implements some deprecated OpenMP functionality -// that some legacy applications unfortunately depend on -// (deprecated in specification version 5.2): +// This effectively implements some deprecated OpenMP functionality that some +// legacy applications unfortunately depend on (deprecated in specification +// version 5.2): // -// "If a list item in a use_device_ptr clause is not of type C_PTR, -// the behavior is as if the list item appeared in a use_device_addr -// clause. Support for such list items in a use_device_ptr clause -// is deprecated." +// "If a list item in a use_device_ptr clause is not of type C_PTR, the behavior +// is as if the list item appeared in a use_device_addr clause. Support for +// such list items in a use_device_ptr clause is deprecated." static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( llvm::SmallVectorImpl<mlir::Value> &useDeviceAddrVars, + llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceAddrSyms, llvm::SmallVectorImpl<mlir::Value> &useDevicePtrVars, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, - llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSymbols) { - auto moveElementToBack = [](size_t idx, auto &vector) { - auto *iter = std::next(vector.begin(), idx); - vector.push_back(*iter); - vector.erase(iter); - }; - + llvm::SmallVectorImpl<const semantics::Symbol *> &useDevicePtrSyms) { // Iterate over our use_device_ptr list and shift all non-cptr arguments into // use_device_addr. - for (auto *it = useDevicePtrVars.begin(); it != useDevicePtrVars.end();) { - if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - useDeviceAddrVars.push_back(*it); - // We have to shuffle the symbols around as well, to maintain - // the correct Input -> BlockArg for use_device_ptr/use_device_addr. - // NOTE: However, as map's do not seem to be included currently - // this isn't as pertinent, but we must try to maintain for - // future alterations. I believe the reason they are not currently - // is that the BlockArg assign/lowering needs to be extended - // to a greater set of types. - auto idx = std::distance(useDevicePtrVars.begin(), it); - moveElementToBack(idx, useDeviceTypes); - moveElementToBack(idx, useDeviceLocs); - moveElementToBack(idx, useDeviceSymbols); - it = useDevicePtrVars.erase(it); + auto *varIt = useDevicePtrVars.begin(); + auto *symIt = useDevicePtrSyms.begin(); + while (varIt != useDevicePtrVars.end()) { + if (fir::isa_builtin_cptr_type(fir::unwrapRefType(varIt->getType()))) { + ++varIt; + ++symIt; continue; } - ++it; + + useDeviceAddrVars.push_back(*varIt); + useDeviceAddrSyms.push_back(*symIt); + + varIt = useDevicePtrVars.erase(varIt); + symIt = useDevicePtrSyms.erase(symIt); } } @@ -380,14 +553,14 @@ getDeclareTargetFunctionDevice( /// \param [in] converter - PFT to MLIR conversion interface. /// \param [in] loc - location. /// \param [in] args - symbols of induction variables. -/// \param [in] wrapperSyms - symbols of variables to be mapped to loop wrapper +/// \param [in] wrapperArgs - list of parent loop wrappers and their associated /// entry block arguments. -/// \param [in] wrapperArgs - entry block arguments of parent loop wrappers. -static void -genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, - mlir::Location &loc, llvm::ArrayRef<const semantics::Symbol *> args, - llvm::ArrayRef<const semantics::Symbol *> wrapperSyms = {}, - llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) { +static void genLoopVars( + mlir::Operation *op, lower::AbstractConverter &converter, + mlir::Location &loc, llvm::ArrayRef<const semantics::Symbol *> args, + llvm::ArrayRef< + std::pair<mlir::omp::BlockArgOpenMPOpInterface, const EntryBlockArgs &>> + wrapperArgs = {}) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto ®ion = op->getRegion(0); @@ -401,8 +574,8 @@ genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, // Bind the entry block arguments of parent wrappers to the corresponding // symbols. - for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs)) - converter.bindSymbol(*arg, prv); + for (auto [argGeneratingOp, args] : wrapperArgs) + bindEntryBlockArgs(converter, argGeneratingOp, args); // The argument is not currently in memory, so make a temporary for the // argument, and store it there, then bind that location to the argument. @@ -415,22 +588,47 @@ genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, firOpBuilder.setInsertionPointAfter(storeOp); } -static void -genReductionVars(mlir::Operation *op, lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef<const semantics::Symbol *> reductionArgs, - llvm::ArrayRef<mlir::Type> reductionTypes) { +/// Create an entry block for the given region, including the clause-defined +/// arguments specified. +/// +/// \param [in] converter - PFT to MLIR conversion interface. +/// \param [in] args - entry block arguments information for the given +/// operation. +/// \param [in] region - Empty region in which to create the entry block. +static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, + const EntryBlockArgs &args, + mlir::Region ®ion) { + assert(args.isValid() && "invalid args"); + assert(region.empty() && "non-empty region"); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - llvm::SmallVector<mlir::Location> blockArgLocs(reductionArgs.size(), loc); - mlir::Block *entryBlock = firOpBuilder.createBlock( - &op->getRegion(0), {}, reductionTypes, blockArgLocs); + llvm::SmallVector<mlir::Type> types; + llvm::SmallVector<mlir::Location> locs; + unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() + + args.priv.vars.size() + args.reduction.vars.size() + + args.taskReduction.vars.size() + + args.useDeviceAddr.vars.size(); + types.reserve(numVars); + locs.reserve(numVars); + + auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) { + llvm::transform(vals, std::back_inserter(types), + [](mlir::Value v) { return v.getType(); }); + llvm::transform(vals, std::back_inserter(locs), + [](mlir::Value v) { return v.getLoc(); }); + }; + + // Populate block arguments in clause name alphabetical order to match + // expected order by the BlockArgOpenMPOpInterface. + extractTypeLoc(args.inReduction.vars); + extractTypeLoc(args.map.vars); + extractTypeLoc(args.priv.vars); + extractTypeLoc(args.reduction.vars); + extractTypeLoc(args.taskReduction.vars); + extractTypeLoc(args.useDeviceAddr.vars); + extractTypeLoc(args.useDevicePtr.vars); - // Bind the reduction arguments to their block arguments. - for (auto [arg, prv] : - llvm::zip_equal(reductionArgs, entryBlock->getArguments())) { - converter.bindSymbol(*arg, prv); - } + return firOpBuilder.createBlock(®ion, {}, types, locs); } static void @@ -458,42 +656,6 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter, declareTargetOp.setDeclareTarget(deviceType, captureClause); } -/// For an operation that takes `omp.private` values as region args, this util -/// merges the private vars info into the region arguments list. -/// -/// \tparam OMPOP - the OpenMP op that takes `omp.private` inputs. -/// \tparam InfoTy - the type of private info we want to merge; e.g. mlir::Type -/// or mlir::Location fields of the private var list. -/// -/// \param [in] op - the op accepting `omp.private` inputs. -/// \param [in] currentList - the current list of region info that we -/// want to merge private info with. For example this could be the list of types -/// or locations of previous arguments to \op's region. -/// \param [in] infoAccessor - for a private variable, this returns the -/// data we want to merge: type or location. -/// \param [out] allRegionArgsInfo - the merged list of region info. -/// \param [in] addBeforePrivate - `true` if the passed information goes before -/// private information. -template <typename OMPOp, typename InfoTy> -static void -mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList, - llvm::function_ref<InfoTy(mlir::Value)> infoAccessor, - llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo, - bool addBeforePrivate) { - mlir::OperandRange privateVars = op.getPrivateVars(); - - if (addBeforePrivate) - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); - - llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo), - infoAccessor); - - if (!addBeforePrivate) - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); -} - //===----------------------------------------------------------------------===// // Op body generation helper structures and functions //===----------------------------------------------------------------------===// @@ -711,94 +873,16 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, marker->erase(); } -void mapBodySymbols(lower::AbstractConverter &converter, mlir::Region ®ion, - llvm::ArrayRef<const semantics::Symbol *> mapSyms) { - assert(region.hasOneBlock() && "target must have single region"); - mlir::Block ®ionBlock = region.front(); - // Clones the `bounds` placing them inside the target region and returns them. - auto cloneBound = [&](mlir::Value bound) { - if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { - mlir::Operation *clonedOp = bound.getDefiningOp()->clone(); - regionBlock.push_back(clonedOp); - return clonedOp->getResult(0); - } - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported bound type"); - }; - - auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) { - llvm::SmallVector<mlir::Value> clonedBounds; - for (mlir::Value bound : bounds) - clonedBounds.emplace_back(cloneBound(bound)); - return clonedBounds; - }; - - // Bind the symbols to their corresponding block arguments. - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - const mlir::BlockArgument &arg = region.getArgument(argIndex); - // Avoid capture of a reference to a structured binding. - const semantics::Symbol *sym = argSymbol; - // Structure component symbols don't have bindings. - if (sym->owner().IsDerivedType()) - continue; - fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); - auto refType = mlir::dyn_cast<fir::ReferenceType>(arg.getType()); - if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) { - converter.bindSymbol(*argSymbol, arg); - } else { - extVal.match( - [&](const fir::BoxValue &v) { - converter.bindSymbol(*sym, - fir::BoxValue(arg, cloneBounds(v.getLBounds()), - v.getExplicitParameters(), - v.getExplicitExtents())); - }, - [&](const fir::MutableBoxValue &v) { - converter.bindSymbol( - *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), - v.getMutableProperties())); - }, - [&](const fir::ArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()), - v.getSourceBox())); - }, - [&](const fir::CharArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), - cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()))); - }, - [&](const fir::CharBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharBoxValue(arg, cloneBound(v.getLen()))); - }, - [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported type"); - }); - } - } -} - static void genBodyOfTargetDataOp( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::omp::TargetDataOp &dataOp, - llvm::ArrayRef<const semantics::Symbol *> useDeviceSymbols, - llvm::ArrayRef<mlir::Location> useDeviceLocs, - llvm::ArrayRef<mlir::Type> useDeviceTypes, + mlir::omp::TargetDataOp &dataOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - assert(useDeviceTypes.size() == useDeviceLocs.size()); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = dataOp.getRegion(); - firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); - mapBodySymbols(converter, region, useDeviceSymbols); + genEntryBlock(converter, args, dataOp.getRegion()); + bindEntryBlockArgs(converter, dataOp, args); // Insert dummy instruction to remember the insertion position. The // marker will be deleted by clean up passes since there are no uses. @@ -839,19 +923,25 @@ static void genBodyOfTargetDataOp( // This is for utilisation with TargetOp. static void genIntermediateCommonBlockAccessors( Fortran::lower::AbstractConverter &converter, - const mlir::Location ¤tLocation, mlir::Region ®ion, + const mlir::Location ¤tLocation, + llvm::ArrayRef<mlir::BlockArgument> mapBlockArgs, llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms) { - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - if (auto *details = - argSymbol->detailsIf<Fortran::semantics::CommonBlockDetails>()) { - for (auto obj : details->objects()) { - auto targetCBMemberBind = Fortran::lower::genCommonBlockMember( - converter, currentLocation, *obj, region.getArgument(argIndex)); - fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj); - fir::ExtendedValue targetCBExv = - getExtendedValue(sexv, targetCBMemberBind); - converter.bindSymbol(*obj, targetCBExv); - } + // Iterate over the symbol list, which will be shorter than the list of + // arguments if new entry block arguments were introduced to implicitly map + // outside values used by the bounds cloned into the target region. In that + // case, the additional block arguments do not need processing here. + for (auto [mapSym, mapArg] : llvm::zip_first(mapSyms, mapBlockArgs)) { + auto *details = mapSym->detailsIf<Fortran::semantics::CommonBlockDetails>(); + if (!details) + continue; + + for (auto obj : details->objects()) { + auto targetCBMemberBind = Fortran::lower::genCommonBlockMember( + converter, currentLocation, *obj, mapArg); + fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj); + fir::ExtendedValue targetCBExv = + getExtendedValue(sexv, targetCBMemberBind); + converter.bindSymbol(*obj, targetCBExv); } } } @@ -861,47 +951,15 @@ static void genIntermediateCommonBlockAccessors( static void genBodyOfTargetOp( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::omp::TargetOp &targetOp, - llvm::ArrayRef<const semantics::Symbol *> mapSyms, - llvm::ArrayRef<mlir::Location> mapSymLocs, - llvm::ArrayRef<mlir::Type> mapSymTypes, + mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item, DataSharingProcessor &dsp) { - assert(mapSymTypes.size() == mapSymLocs.size()); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = targetOp.getRegion(); + auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); - llvm::SmallVector<mlir::Type> allRegionArgTypes; - llvm::SmallVector<mlir::Location> allRegionArgLocs; - mergePrivateVarsInfo(targetOp, mapSymTypes, - llvm::function_ref<mlir::Type(mlir::Value)>{ - [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes, /*addBeforePrivate=*/true); - - mergePrivateVarsInfo(targetOp, mapSymLocs, - llvm::function_ref<mlir::Location(mlir::Value)>{ - [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs, /*addBeforePrivate=*/true); - - mlir::Block *regionBlock = firOpBuilder.createBlock( - ®ion, {}, allRegionArgTypes, allRegionArgLocs); - - mapBodySymbols(converter, region, mapSyms); - - for (auto [argIndex, argSymbol] : - llvm::enumerate(dsp.getAllSymbolsToPrivatize())) { - argIndex = mapSyms.size() + argIndex; - - const mlir::BlockArgument &arg = region.getArgument(argIndex); - converter.bindSymbol(*argSymbol, - hlfir::translateToExtendedValue( - currentLocation, firOpBuilder, hlfir::Entity{arg}, - /*contiguousHint=*/ - evaluate::IsSimplyContiguous( - *argSymbol, converter.getFoldingContext())) - .first); - } + mlir::Region ®ion = targetOp.getRegion(); + mlir::Block *entryBlock = genEntryBlock(converter, args, region); + bindEntryBlockArgs(converter, targetOp, args); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -914,11 +972,11 @@ static void genBodyOfTargetOp( mlir::Operation *valOp = val.getDefiningOp(); if (mlir::isMemoryEffectFree(valOp)) { mlir::Operation *clonedOp = valOp->clone(); - regionBlock->push_front(clonedOp); - val.replaceUsesWithIf( - clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); + entryBlock->push_front(clonedOp); + val.replaceUsesWithIf(clonedOp->getResult(0), + [entryBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == entryBlock; + }); } else { auto savedIP = firOpBuilder.getInsertionPoint(); firOpBuilder.setInsertionPointAfter(valOp); @@ -939,18 +997,23 @@ static void genBodyOfTargetOp( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT), mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType()); + // Get the index of the first non-map argument before modifying mapVars, + // then append an element to mapVars and an associated entry block + // argument at that index. + unsigned insertIndex = + argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs(); targetOp.getMapVarsMutable().append(mapOp); + mlir::Value clonedValArg = region.insertArgument( + insertIndex, copyVal.getType(), copyVal.getLoc()); - mlir::Value clonedValArg = - region.addArgument(copyVal.getType(), copyVal.getLoc()); - firOpBuilder.setInsertionPointToStart(regionBlock); + firOpBuilder.setInsertionPointToStart(entryBlock); auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(), clonedValArg); - val.replaceUsesWithIf( - loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); - firOpBuilder.setInsertionPoint(regionBlock, savedIP); + val.replaceUsesWithIf(loadOp->getResult(0), + [entryBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == entryBlock; + }); + firOpBuilder.setInsertionPoint(entryBlock, savedIP); } } valuesDefinedAbove.clear(); @@ -977,14 +1040,14 @@ static void genBodyOfTargetOp( firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); // If we map a common block using it's symbol e.g. map(tofrom: /common_block/) - // and accessing it's members within the target region, there is a large + // and accessing its members within the target region, there is a large // chance we will end up with uses external to the region accessing the common // resolve these, we do so by generating new common block member accesses // within the region, binding them to the member symbol for the scope of the // region so that subsequent code generation within the region will utilise // our new member accesses we have created. - genIntermediateCommonBlockAccessors(converter, currentLocation, region, - mapSyms); + genIntermediateCommonBlockAccessors( + converter, currentLocation, argIface.getMapBlockArgs(), args.map.syms); if (ConstructQueue::const_iterator next = std::next(item); next != queue.end()) { @@ -1010,7 +1073,7 @@ static OpTy genOpWithBody(const OpWithBodyGenInfo &info, template <typename OpTy, typename ClauseOpsTy> static OpTy genWrapperOp(lower::AbstractConverter &converter, mlir::Location loc, const ClauseOpsTy &clauseOps, - llvm::ArrayRef<mlir::Type> blockArgTypes) { + const EntryBlockArgs &args) { static_assert( OpTy::template hasTrait<mlir::omp::LoopWrapperInterface::Trait>(), "expected a loop wrapper"); @@ -1020,9 +1083,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter, auto op = firOpBuilder.create<OpTy>(loc, clauseOps); // Create entry block with arguments. - llvm::SmallVector<mlir::Location> locs(blockArgTypes.size(), loc); - firOpBuilder.createBlock(&op.getRegion(), /*insertPt=*/{}, blockArgTypes, - locs); + genEntryBlock(converter, args, op.getRegion()); firOpBuilder.setInsertionPoint( lower::genOpenMPTerminator(firOpBuilder, op, loc)); @@ -1102,39 +1163,38 @@ static void genParallelClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List<Clause> &clauses, mlir::Location loc, mlir::omp::ParallelOperands &clauseOps, - llvm::SmallVectorImpl<mlir::Type> &reductionTypes, llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); cp.processNumThreads(stmtCtx, clauseOps); cp.processProcBind(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); } static void genSectionsClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, const List<Clause> &clauses, mlir::Location loc, mlir::omp::SectionsOperands &clauseOps, - llvm::SmallVectorImpl<mlir::Type> &reductionTypes, llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processNowait(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); // TODO Support delayed privatization. } -static void genSimdClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - const List<Clause> &clauses, mlir::Location loc, - mlir::omp::SimdOperands &clauseOps) { +static void genSimdClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const List<Clause> &clauses, mlir::Location loc, + mlir::omp::SimdOperands &clauseOps, + llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAligned(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); cp.processNontemporal(clauseOps); cp.processOrder(clauseOps); - cp.processReduction(loc, clauseOps); + cp.processReduction(loc, clauseOps, reductionSyms); cp.processSafelen(clauseOps); cp.processSimdlen(clauseOps); @@ -1157,24 +1217,16 @@ static void genTargetClauses( lower::StatementContext &stmtCtx, const List<Clause> &clauses, mlir::Location loc, bool processHostOnlyClauses, mlir::omp::TargetOperands &clauseOps, - llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, - llvm::SmallVectorImpl<mlir::Location> &mapLocs, - llvm::SmallVectorImpl<mlir::Type> &mapTypes, - llvm::SmallVectorImpl<const semantics::Symbol *> &deviceAddrSyms, - llvm::SmallVectorImpl<mlir::Location> &deviceAddrLocs, - llvm::SmallVectorImpl<mlir::Type> &deviceAddrTypes, - llvm::SmallVectorImpl<const semantics::Symbol *> &devicePtrSyms, - llvm::SmallVectorImpl<mlir::Location> &devicePtrLocs, - llvm::SmallVectorImpl<mlir::Type> &devicePtrTypes) { + llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms, + llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms, + llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); - cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, - deviceAddrSyms); + cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs, - devicePtrSyms); - cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes); + cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); + cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); if (processHostOnlyClauses) cp.processNowait(clauseOps); @@ -1194,32 +1246,26 @@ static void genTargetDataClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List<Clause> &clauses, mlir::Location loc, mlir::omp::TargetDataOperands &clauseOps, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, - llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) { + llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceAddrSyms, + llvm::SmallVectorImpl<const semantics::Symbol *> &useDevicePtrSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDevice(stmtCtx, clauseOps); cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); cp.processMap(loc, stmtCtx, clauseOps); - cp.processUseDeviceAddr(stmtCtx, clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); - cp.processUseDevicePtr(stmtCtx, clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); + cp.processUseDeviceAddr(stmtCtx, clauseOps, useDeviceAddrSyms); + cp.processUseDevicePtr(stmtCtx, clauseOps, useDevicePtrSyms); // This function implements the deprecated functionality of use_device_ptr // that allows users to provide non-CPTR arguments to it with the caveat // that the compiler will treat them as use_device_addr. A lot of legacy // code may still depend on this functionality, so we should support it // in some manner. We do so currently by simply shifting non-cptr operands - // from the use_device_ptr list into the front of the use_device_addr list - // whilst maintaining the ordering of useDeviceLocs, useDeviceSyms and - // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg - // ordering. + // from the use_device_ptr lists into the use_device_addr lists. // TODO: Perhaps create a user provideable compiler option that will // re-introduce a hard-error rather than a warning in these cases. promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - clauseOps.useDeviceAddrVars, clauseOps.useDevicePtrVars, useDeviceTypes, - useDeviceLocs, useDeviceSyms); + clauseOps.useDeviceAddrVars, useDeviceAddrSyms, + clauseOps.useDevicePtrVars, useDevicePtrSyms); } static void genTargetEnterExitUpdateDataClauses( @@ -1297,13 +1343,12 @@ static void genWsloopClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List<Clause> &clauses, mlir::Location loc, mlir::omp::WsloopOperands &clauseOps, - llvm::SmallVectorImpl<mlir::Type> &reductionTypes, llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processNowait(clauseOps); cp.processOrder(clauseOps); cp.processOrdered(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); cp.processTODO<clause::Allocate, clause::Linear>( @@ -1366,21 +1411,18 @@ genFlushOp(lower::AbstractConverter &converter, lower::SymMap &symTable, converter.getCurrentLocation(), operandRange); } -static mlir::omp::LoopNestOp -genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, ConstructQueue::const_iterator item, - mlir::omp::LoopNestOperands &clauseOps, - llvm::ArrayRef<const semantics::Symbol *> iv, - llvm::ArrayRef<const semantics::Symbol *> wrapperSyms, - llvm::ArrayRef<mlir::BlockArgument> wrapperArgs, - llvm::omp::Directive directive, DataSharingProcessor &dsp) { - assert(wrapperSyms.size() == wrapperArgs.size() && - "Number of symbols and wrapper block arguments must match"); - +static mlir::omp::LoopNestOp genLoopNestOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item, mlir::omp::LoopNestOperands &clauseOps, + llvm::ArrayRef<const semantics::Symbol *> iv, + llvm::ArrayRef< + std::pair<mlir::omp::BlockArgOpenMPOpInterface, const EntryBlockArgs &>> + wrapperArgs, + llvm::omp::Directive directive, DataSharingProcessor &dsp) { auto ivCallback = [&](mlir::Operation *op) { - genLoopVars(op, converter, loc, iv, wrapperSyms, wrapperArgs); + genLoopVars(op, converter, loc, iv, wrapperArgs); return llvm::SmallVector<const semantics::Symbol *>(iv); }; @@ -1452,83 +1494,26 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item, mlir::omp::ParallelOperands &clauseOps, - llvm::ArrayRef<const semantics::Symbol *> reductionSyms, - llvm::ArrayRef<mlir::Type> reductionTypes, - DataSharingProcessor *dsp, bool isComposite = false) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); - return llvm::SmallVector<const semantics::Symbol *>(reductionSyms); + const EntryBlockArgs &args, DataSharingProcessor *dsp, + bool isComposite = false) { + auto genRegionEntryCB = [&](mlir::Operation *op) { + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args); + return llvm::to_vector(llvm::concat<const semantics::Symbol *const>( + args.priv.syms, args.reduction.syms)); }; + assert((!enableDelayedPrivatization || dsp) && + "expected valid DataSharingProcessor"); OpWithBodyGenInfo genInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_parallel) .setClauses(&item->clauses) - .setGenRegionEntryCb(reductionCallback) - .setGenSkeletonOnly(isComposite); - - if (!enableDelayedPrivatization) { - auto parallelOp = - genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps); - parallelOp.setComposite(isComposite); - return parallelOp; - } + .setGenRegionEntryCb(genRegionEntryCB) + .setGenSkeletonOnly(isComposite) + .setDataSharingProcessor(dsp); - assert(dsp && "expected valid DataSharingProcessor"); - auto genRegionEntryCB = [&](mlir::Operation *op) { - auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op); - - llvm::SmallVector<mlir::Location> reductionLocs( - clauseOps.reductionVars.size(), loc); - - llvm::SmallVector<mlir::Type> allRegionArgTypes; - mergePrivateVarsInfo(parallelOp, reductionTypes, - llvm::function_ref<mlir::Type(mlir::Value)>{ - [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes, /*addBeforePrivate=*/false); - - llvm::SmallVector<mlir::Location> allRegionArgLocs; - mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs), - llvm::function_ref<mlir::Location(mlir::Value)>{ - [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs, /*addBeforePrivate=*/false); - - mlir::Region ®ion = parallelOp.getRegion(); - firOpBuilder.createBlock(®ion, /*insertPt=*/{}, allRegionArgTypes, - allRegionArgLocs); - - llvm::SmallVector<const semantics::Symbol *> allSymbols( - dsp->getDelayedPrivSymbols()); - allSymbols.append(reductionSyms.begin(), reductionSyms.end()); - - unsigned argIdx = 0; - for (const semantics::Symbol *arg : allSymbols) { - auto bind = [&](const semantics::Symbol *sym) { - mlir::BlockArgument blockArg = region.getArgument(argIdx); - ++argIdx; - converter.bindSymbol(*sym, - hlfir::translateToExtendedValue( - loc, firOpBuilder, hlfir::Entity{blockArg}, - /*contiguousHint=*/ - evaluate::IsSimplyContiguous( - *sym, converter.getFoldingContext())) - .first); - }; - - if (const auto *commonDet = - arg->detailsIf<semantics::CommonBlockDetails>()) { - for (const auto &mem : commonDet->objects()) - bind(&*mem); - } else - bind(arg); - } - - return allSymbols; - }; - - genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(dsp); auto parallelOp = genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps); parallelOp.setComposite(isComposite); @@ -1544,11 +1529,10 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item, const parser::OmpSectionBlocks §ionBlocks) { - llvm::SmallVector<mlir::Type> reductionTypes; - llvm::SmallVector<const semantics::Symbol *> reductionSyms; mlir::omp::SectionsOperands clauseOps; + llvm::SmallVector<const semantics::Symbol *> reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, - reductionTypes, reductionSyms); + reductionSyms); auto &builder = converter.getFirOpBuilder(); @@ -1579,15 +1563,20 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // SECTIONS construct. auto sectionsOp = builder.create<mlir::omp::SectionsOp>(loc, clauseOps); - // create entry block with reduction variables as arguments - llvm::SmallVector<mlir::Location> blockArgLocs(reductionSyms.size(), loc); - builder.createBlock(§ionsOp->getRegion(0), {}, reductionTypes, - blockArgLocs); + // Create entry block with reduction variables as arguments. + EntryBlockArgs args; + // TODO: Add private syms and vars. + args.reduction.syms = reductionSyms; + args.reduction.vars = clauseOps.reductionVars; + + genEntryBlock(converter, args, sectionsOp.getRegion()); mlir::Operation *terminator = lower::genOpenMPTerminator(builder, sectionsOp, loc); auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args); return reductionSyms; }; @@ -1681,14 +1670,11 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, .getIsTargetDevice(); mlir::omp::TargetOperands clauseOps; - llvm::SmallVector<const semantics::Symbol *> mapSyms, devicePtrSyms, - deviceAddrSyms; - llvm::SmallVector<mlir::Location> mapLocs, devicePtrLocs, deviceAddrLocs; - llvm::SmallVector<mlir::Type> mapTypes, devicePtrTypes, deviceAddrTypes; + llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms, + hasDeviceAddrSyms; genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - processHostOnlyClauses, clauseOps, mapSyms, mapLocs, - mapTypes, deviceAddrSyms, deviceAddrLocs, deviceAddrTypes, - devicePtrSyms, devicePtrLocs, devicePtrTypes); + processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, + isDevicePtrSyms, mapSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ @@ -1795,15 +1781,24 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, clauseOps.mapVars.push_back(mapOp); mapSyms.push_back(&sym); - mapLocs.push_back(baseOp.getLoc()); - mapTypes.push_back(baseOp.getType()); } }; lower::pft::visitAllSymbols(eval, captureImplicitMap); auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps); - genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, mapSyms, - mapLocs, mapTypes, loc, queue, item, dsp); + + llvm::SmallVector<mlir::Value> mapVars; + extractMapVarsBaseOps(clauseOps.mapVars, mapVars); + + EntryBlockArgs args; + // TODO: Add in_reduction syms and vars. + args.map.syms = mapSyms; + args.map.vars = mapVars; + args.priv.syms = dsp.getDelayedPrivSymbols(); + args.priv.vars = clauseOps.privateVars; + + genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, + queue, item, dsp); return targetOp; } @@ -1815,18 +1810,27 @@ genTargetDataOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; mlir::omp::TargetDataOperands clauseOps; - llvm::SmallVector<mlir::Type> useDeviceTypes; - llvm::SmallVector<mlir::Location> useDeviceLocs; - llvm::SmallVector<const semantics::Symbol *> useDeviceSyms; + llvm::SmallVector<const semantics::Symbol *> useDeviceAddrSyms, + useDevicePtrSyms; genTargetDataClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); + clauseOps, useDeviceAddrSyms, useDevicePtrSyms); auto targetDataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(loc, clauseOps); - genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp, - useDeviceSyms, useDeviceLocs, useDeviceTypes, loc, - queue, item); + + llvm::SmallVector<mlir::Value> useDeviceAddrVars, useDevicePtrVars; + extractMapVarsBaseOps(clauseOps.useDeviceAddrVars, useDeviceAddrVars); + extractMapVarsBaseOps(clauseOps.useDevicePtrVars, useDevicePtrVars); + + EntryBlockArgs args; + args.useDeviceAddr.syms = useDeviceAddrSyms; + args.useDeviceAddr.vars = useDeviceAddrVars; + args.useDevicePtr.syms = useDevicePtrSyms; + args.useDevicePtr.vars = useDevicePtrVars; + + genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp, args, + loc, queue, item); return targetDataOp; } @@ -1948,22 +1952,20 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter, /*shouldCollectPreDeterminedSymbols=*/true, enableDelayedPrivatizationStaging, &symTable); dsp.processStep1(&distributeClauseOps); - llvm::SmallVector<mlir::Type> privateVarTypes{}; - - for (mlir::Value privateVar : distributeClauseOps.privateVars) - privateVarTypes.push_back(privateVar.getType()); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector<const semantics::Symbol *> iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); + EntryBlockArgs distributeArgs; + distributeArgs.priv.syms = dsp.getDelayedPrivSymbols(); + distributeArgs.priv.vars = distributeClauseOps.privateVars; auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>( - converter, loc, distributeClauseOps, privateVarTypes); + converter, loc, distributeClauseOps, distributeArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, dsp.getDelayedPrivSymbols(), - distributeOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{distributeOp, distributeArgs}}, llvm::omp::Directive::OMPD_distribute, dsp); } @@ -1976,10 +1978,9 @@ static void genStandaloneDo(lower::AbstractConverter &converter, lower::StatementContext stmtCtx; mlir::omp::WsloopOperands wsloopClauseOps; - llvm::SmallVector<const semantics::Symbol *> reductionSyms; - llvm::SmallVector<mlir::Type> reductionTypes; + llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms; genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - wsloopClauseOps, reductionTypes, reductionSyms); + wsloopClauseOps, wsloopReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, @@ -1992,13 +1993,15 @@ static void genStandaloneDo(lower::AbstractConverter &converter, genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>( - converter, loc, wsloopClauseOps, reductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, reductionSyms, - wsloopOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{wsloopOp, wsloopArgs}}, llvm::omp::Directive::OMPD_do, dsp); } @@ -2011,21 +2014,27 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; - mlir::omp::ParallelOperands clauseOps; - llvm::SmallVector<const semantics::Symbol *> reductionSyms; - llvm::SmallVector<mlir::Type> reductionTypes; - genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps, - reductionTypes, reductionSyms); + mlir::omp::ParallelOperands parallelClauseOps; + llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms; + genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, + parallelClauseOps, parallelReductionSyms); std::optional<DataSharingProcessor> dsp; if (enableDelayedPrivatization) { dsp.emplace(converter, semaCtx, item->clauses, eval, lower::omp::isLastItemInQueue(item, queue), /*useDelayedPrivatization=*/true, &symTable); - dsp->processStep1(&clauseOps); + dsp->processStep1(¶llelClauseOps); } - genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, clauseOps, - reductionSyms, reductionTypes, + + EntryBlockArgs parallelArgs; + if (dsp) + parallelArgs.priv.syms = dsp->getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; + genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, + parallelClauseOps, parallelArgs, enableDelayedPrivatization ? &dsp.value() : nullptr); } @@ -2036,7 +2045,9 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item) { mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps); + llvm::SmallVector<const semantics::Symbol *> simdReductionSyms; + genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, @@ -2049,13 +2060,15 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, - /*wrapperSyms=*/{}, simdOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{simdOp, simdArgs}}, llvm::omp::Directive::OMPD_simd, dsp); } @@ -2088,19 +2101,21 @@ static void genCompositeDistributeParallelDo( // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms; - llvm::SmallVector<mlir::Type> parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - parallelClauseOps, parallelReductionTypes, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, /*useDelayedPrivatization=*/true, &symTable); dsp.processStep1(¶llelClauseOps); + EntryBlockArgs parallelArgs; + parallelArgs.priv.syms = dsp.getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelReductionSyms, - parallelReductionTypes, &dsp, /*isComposite=*/true); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2109,9 +2124,8 @@ static void genCompositeDistributeParallelDo( mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms; - llvm::SmallVector<mlir::Type> wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector<const semantics::Symbol *> iv; @@ -2119,27 +2133,23 @@ static void genCompositeDistributeParallelDo( loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - auto &wrapperSyms = wsloopReductionSyms; - - auto wrapperArgs = llvm::to_vector( - llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(), - wsloopOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, doItem, - loopNestClauseOps, iv, wrapperSyms, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, {wsloopOp, wsloopArgs}}, llvm::omp::Directive::OMPD_distribute_parallel_do, dsp); } @@ -2159,19 +2169,21 @@ static void genCompositeDistributeParallelDoSimd( // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms; - llvm::SmallVector<mlir::Type> parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - parallelClauseOps, parallelReductionTypes, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, /*useDelayedPrivatization=*/true, &symTable); dsp.processStep1(¶llelClauseOps); + EntryBlockArgs parallelArgs; + parallelArgs.priv.syms = dsp.getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelReductionSyms, - parallelReductionTypes, &dsp, /*isComposite=*/true); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2180,12 +2192,13 @@ static void genCompositeDistributeParallelDoSimd( mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms; - llvm::SmallVector<mlir::Type> wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector<const semantics::Symbol *> simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector<const semantics::Symbol *> iv; @@ -2193,32 +2206,33 @@ static void genCompositeDistributeParallelDoSimd( loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - auto &wrapperSyms = wsloopReductionSyms; - - auto wrapperArgs = llvm::to_vector(llvm::concat<mlir::BlockArgument>( - distributeOp.getRegion().getArguments(), - wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, wrapperSyms, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, + {wsloopOp, wsloopArgs}, + {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_distribute_parallel_do_simd, dsp); } @@ -2241,7 +2255,9 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, loc, distributeClauseOps); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector<const semantics::Symbol *> simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, @@ -2257,26 +2273,23 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - // TODO: Add omp.distribute private and omp.simd private and reduction args. - auto wrapperArgs = llvm::to_vector( - llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(), - simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_distribute_simd, dsp); } @@ -2295,12 +2308,13 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, // Clause processing. mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms; - llvm::SmallVector<mlir::Type> wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector<const semantics::Symbol *> simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, @@ -2316,25 +2330,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, loopNestClauseOps, iv); // Operation creation. - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol and block argument order match, so that the symbol-value - // bindings created are correct. - // TODO: Add omp.wsloop private and omp.simd private and reduction args. - auto wrapperArgs = llvm::to_vector(llvm::concat<mlir::BlockArgument>( - wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs, + loopNestClauseOps, iv, + {{wsloopOp, wsloopArgs}, {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_do_simd, dsp); } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index c87182abe3d187..f35e425777141d 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -723,7 +723,7 @@ void ReductionProcessor::addDeclareReduction( llvm::SmallVectorImpl<mlir::Value> &reductionVars, llvm::SmallVectorImpl<bool> &reduceVarByRef, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, - llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols) { + llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>( @@ -754,8 +754,7 @@ void ReductionProcessor::addDeclareReduction( fir::FirOpBuilder &builder = converter.getFirOpBuilder(); for (const Object &object : objectList) { const semantics::Symbol *symbol = object.sym(); - if (reductionSymbols) - reductionSymbols->push_back(symbol); + reductionSymbols.push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); mlir::Type eleType; auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 0ed5782e5da1b7..5f4d742b62cb10 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -126,8 +126,7 @@ class ReductionProcessor { llvm::SmallVectorImpl<mlir::Value> &reductionVars, llvm::SmallVectorImpl<bool> &reduceVarByRef, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, - llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols = - nullptr); + llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols); }; template <typename FloatOp, typename IntegerOp> diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 8073b24a1d5b45..e34e2fbcd51f69 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -263,9 +263,7 @@ void insertChildMapInfoIntoParent( std::map<const semantics::Symbol *, llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapOperands, - llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, - llvm::SmallVectorImpl<mlir::Type> *mapSymTypes, - llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) { + llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { for (auto indices : parentMemberIndices) { bool parentExists = false; size_t parentIdx; @@ -321,11 +319,6 @@ void insertChildMapInfoIntoParent( mapOperands.push_back(mapOp); mapSyms.push_back(indices.first); - - if (mapSymTypes) - mapSymTypes->push_back(mapOp.getType()); - if (mapSymLocs) - mapSymLocs->push_back(mapOp.getLoc()); } } } diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 0b4fe9044bfa7b..298a26239475a3 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -78,9 +78,7 @@ void insertChildMapInfoIntoParent( std::map<const semantics::Symbol *, llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapOperands, - llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, - llvm::SmallVectorImpl<mlir::Type> *mapSymTypes, - llvm::SmallVectorImpl<mlir::Location> *mapSymLocs); + llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms); mlir::Type getLoopVarType(lower::AbstractConverter &converter, std::size_t loopVarTypeSize); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits