https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 77e758855c0e5cf3072704bbe461682ab192a84a Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 28 Nov 2025 13:37:14 +0530 Subject: [PATCH 01/11] [OpenMP][MLIR] Add num_teams clause with dims modifier support --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 72 +++++++++++++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++ mlir/test/Dialect/OpenMP/invalid.mlir | 19 +---- 4 files changed, 80 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index b612d4e136baf..ed24530464ea4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1567,4 +1567,76 @@ class OpenMP_UseDevicePtrClauseSkip< def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>; +//===----------------------------------------------------------------------===// +// V6.2: Multidimensional `num_teams` clause with dims modifier +//===----------------------------------------------------------------------===// + +class OpenMP_NumTeamsMultiDimClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause<traits, arguments, assemblyFormat, description, + extraClassDeclaration> { + let arguments = (ins + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims, + Variadic<AnyInteger>:$num_teams_values + ); + + let optAssemblyFormat = [{ + `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims, + $num_teams_values, + type($num_teams_values)) `)` + }]; + + let description = [{ + The `num_teams_multi_dim` clause with dims modifier support specifies the limit on + the number of teams to be created in a multidimensional team space. + + The dims modifier for the num_teams_multi_dim clause specifies the number of + dimensions for the league space (team space) that the clause arranges. + The dimensions argument in the dims modifier specifies the number of + dimensions and determines the length of the list argument. The list items + are specified in ascending order according to the ordinal number of the + dimensions (dimension 0, 1, 2, ..., N-1). + + - If `dims` is not specified: The space is unidimensional (1D) with a single value + - If `dims(1)` is specified: The space is explicitly unidimensional (1D) + - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D) + + **Examples:** + - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a + 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2. + - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt. + }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasDimsModifier() { + return getNumTeamsDims().has_value(); + } + + /// Returns the number of dimensions specified by dims modifier + /// Returns 1 if dims modifier is not present (unidimensional by default) + unsigned getNumDimensions() { + if (!hasDimsModifier()) + return 1; + return static_cast<unsigned>(*getNumTeamsDims()); + } + + /// Returns all dimension values as an operand range + ::mlir::OperandRange getDimensionValues() { + return getNumTeamsValues(); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getNumDimensions() + ::mlir::Value getDimensionValue(unsigned index) { + assert(index < getDimensionValues().size() && + "Dimension index out of bounds"); + return getDimensionValues()[index]; + } + }]; +} + +def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>; + #endif // OPENMP_CLAUSES diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index d4e8cecda2601..76eeb0bd70ec3 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [ AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface ], clauses = [ OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause, - OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause + OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause, + OpenMP_ThreadLimitClause ], singleRegion = true> { let summary = "teams construct"; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 25bf4e70d9a83..7a9a45b160ba3 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2625,8 +2625,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, +<<<<<<< HEAD clauses.ifExpr, clauses.numTeamsVals, clauses.numTeamsLower, clauses.numTeamsUpper, +======= + clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, + clauses.numTeamsDims, clauses.numTeamsValues, +>>>>>>> [OpenMP][MLIR] Add num_teams clause with dims modifier support /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, clauses.reductionMod, clauses.reductionVars, diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index d451b14e8bfc9..cd06011c2cbc4 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1451,24 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) { // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}} "omp.teams" (%lb) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> () - omp.terminator - } - return -} - -// ----- - -func.func @omp_teams_num_teams_multidim_with_bounds() { - omp.target { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - %lb = arith.constant 3 : i32 - %ub = arith.constant 4 : i32 - // expected-error @below {{num_teams multi-dimensional values cannot be used together with legacy lower/upper bounds}} - "omp.teams" (%v0, %v1, %lb, %ub) ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> () omp.terminator } return >From 2b170fd042b2bd0b2a147fe413780c45a0e4fbdb Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 11:56:58 +0530 Subject: [PATCH 02/11] [OpenMP][MLIR] Add num_threads clause with dims modifier support --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 + mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++-- mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++- mlir/test/Dialect/OpenMP/ops.mlir | 15 ++-- 5 files changed, 163 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index ed24530464ea4..8826c15a15191 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims, + Variadic<AnyInteger>:$num_threads_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ - `num_threads` `(` $num_threads `:` type($num_threads) `)` + `num_threads` `(` custom<NumThreadsClause>( + $num_threads_dims, $num_threads_values, type($num_threads_values), + $num_threads, type($num_threads) + ) `)` }]; let description = [{ - The optional `num_threads` parameter specifies the number of threads which - should be used to execute the parallel region. + num_threads clause specifies the desired number of threads in the team + space formed by the construct on which it appears. + + With dims modifier: + - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list) + - Specifies upper bounds for each dimension (all must have same type) + - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` + - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` + + Without dims modifier: + - Uses `num_threads` + - If lower bound not specified, it defaults to upper bound value + - Format: `num_threads(bounds : type)` + - Example: `num_threads(%ub : i32)` + }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasDimsModifier() { + return getNumThreadsDims().has_value(); + } + + /// Returns the number of dimensions specified by dims modifier + unsigned getNumDimensions() { + if (!hasDimsModifier()) + return 1; + return static_cast<unsigned>(*getNumThreadsDims()); + } + + /// Returns all dimension values as an operand range + ::mlir::OperandRange getDimensionValues() { + return getNumThreadsValues(); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getNumDimensions() + ::mlir::Value getDimensionValue(unsigned index) { + assert(index < getDimensionValues().size() && + "Dimension index out of bounds"); + return getDimensionValues()[index]; + } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 5fcaea7f39c3c..c749106b925f7 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -497,6 +497,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, + /* num_threads_dims = */ nullptr, + /* num_threads_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 7a9a45b160ba3..d75b9e17f1e98 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2504,6 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, + /*num_threads_dims=*/nullptr, + /*num_threads_values=*/ValueRange(), /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2515,13 +2517,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); - ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreads, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.procBindKind, - clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms)); + ParallelOp::build( + builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues, + clauses.numThreads, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, + clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms)); } template <typename OpType> @@ -2568,13 +2571,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + auto numThreadsDims = getNumThreadsDims(); + auto numThreadsValues = getNumThreadsValues(); + auto numThreads = getNumThreads(); + + // num_threads with dims modifier + if (numThreadsDims.has_value() && numThreadsValues.empty()) { + return emitError( + "num_threads dims modifier requires values to be specified"); + } + + if (numThreadsDims.has_value() && + numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) { + return emitError("num_threads dims(") + << *numThreadsDims << ") specified but " << numThreadsValues.size() + << " values provided"; + } + + // num_threads dims and number of threads cannot be used together + if (numThreadsDims.has_value() && numThreads) { + return emitError( + "num_threads dims and number of threads cannot be used together"); + } + + // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); + // verify private variables restrictions if (failed(verifyPrivateVarList(*this))) return failure(); + // verify reduction variables restrictions return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } @@ -4623,6 +4653,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } +//===----------------------------------------------------------------------===// +// Parser and printer for num_threads clause +//===----------------------------------------------------------------------===// +static ParseResult +parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, + SmallVectorImpl<Type> &types, + std::optional<OpAsmParser::UnresolvedOperand> &bounds, + Type &boundsType) { + if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { + return success(); + } + + OpAsmParser::UnresolvedOperand boundsOperand; + if (parser.parseOperand(boundsOperand) || parser.parseColon() || + parser.parseType(boundsType)) { + return failure(); + } + bounds = boundsOperand; + return success(); +} + +static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, + IntegerAttr dimsAttr, OperandRange values, + TypeRange types, Value bounds, + Type boundsType) { + if (!values.empty()) { + printDimsModifierWithValues(p, dimsAttr, values, types); + } + if (bounds) { + p.printOperand(bounds); + p << " : " << boundsType; + } +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index cd06011c2cbc4..e55fe3d0a1aec 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) { // ----- +func.func @num_threads_dims_no_values() { + // expected-error@+1 {{num_threads dims modifier requires values to be specified}} + "omp.parallel"() ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> () + return +} + +// ----- + +func.func @num_threads_dims_mismatch(%n : i64) { + // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}} + omp.parallel num_threads(dims(2): %n : i64) { + omp.terminator + } + + return +} + +// ----- + +func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { + // expected-error@+1 {{num_threads dims and number of threads cannot be used together}} + "omp.parallel"(%n, %n, %m) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> () + return +} + +// ----- + func.func @nowait_not_allowed(%n : memref<i32>) { // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} @@ -2691,7 +2722,7 @@ func.func @undefined_privatizer(%arg0: index) { // ----- func.func @undefined_privatizer(%arg0: !llvm.ptr) { // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} - "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ + "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ ^bb0(%arg2: !llvm.ptr): omp.terminator }) : (!llvm.ptr) -> () diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 49a88e0443e60..f9cfd400387a5 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32) "omp.parallel"(%data_var, %data_var, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () // CHECK: omp.barrier omp.barrier @@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}}) "omp.parallel"(%data_var, %data_var, %if_cond) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> () // CHECK: omp.parallel omp.parallel { @@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre omp.terminator } + // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64) + omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) { + omp.terminator + } + // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) { omp.terminator >From eb17b261fd452be4edcd90ed460675d317ef9f79 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 12:11:49 +0530 Subject: [PATCH 03/11] Mark mlir->llvmir translation for num_threads with dims as NYI --- .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 0b7bf64cefe4c..9c176b56a4d5d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3268,6 +3268,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; + // num_threads dims and values are not yet supported + assert(!opInst.getNumThreadsDims().has_value() && + opInst.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); auto pbKind = llvm::omp::OMP_PROC_BIND_default; @@ -6050,6 +6054,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { + // num_threads dims and values are not yet supported + assert(!parallelOp.getNumThreadsDims().has_value() && + parallelOp.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; else @@ -6167,8 +6175,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, threadLimit = teamsOp.getThreadLimit(); } - if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) + if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { + // num_threads dims and values are not yet supported + assert(!parallelOp.getNumThreadsDims().has_value() && + parallelOp.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); + } } // Handle clauses impacting the number of teams. >From 207eca2b904e1844d82bcffe1baccf690a7f4a1f Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 17:37:52 +0530 Subject: [PATCH 04/11] few more fixes --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++-------- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++---------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++-- mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++--- 5 files changed, 45 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8826c15a15191..8d8db94630f84 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims, - Variadic<AnyInteger>:$num_threads_values, + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, + Variadic<AnyInteger>:$num_threads_dims_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_dims, $num_threads_values, type($num_threads_values), + $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values), $num_threads, type($num_threads) ) `)` }]; @@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip< space formed by the construct on which it appears. With dims modifier: - - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list) + - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list) - Specifies upper bounds for each dimension (all must have same type) - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` @@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip< let extraClassDeclaration = [{ /// Returns true if the dims modifier is explicitly present - bool hasDimsModifier() { - return getNumThreadsDims().has_value(); + bool hasNumThreadsDimsModifier() { + return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value(); } /// Returns the number of dimensions specified by dims modifier - unsigned getNumDimensions() { - if (!hasDimsModifier()) + unsigned getNumThreadsDimsCount() { + if (!hasNumThreadsDimsModifier()) return 1; - return static_cast<unsigned>(*getNumThreadsDims()); - } - - /// Returns all dimension values as an operand range - ::mlir::OperandRange getDimensionValues() { - return getNumThreadsValues(); + return static_cast<unsigned>(*getNumThreadsNumDims()); } /// Returns the value for a specific dimension index - /// Index must be less than getNumDimensions() - ::mlir::Value getDimensionValue(unsigned index) { - assert(index < getDimensionValues().size() && - "Dimension index out of bounds"); - return getDimensionValues()[index]; + /// Index must be less than getNumThreadsDimsCount() + ::mlir::Value getNumThreadsDimsValue(unsigned index) { + assert(index < getNumThreadsDimsCount() && + "Num threads dims index out of bounds"); + return getNumThreadsDimsValues()[index]; } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index c749106b925f7..f9c8cab9b3d7b 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -497,8 +497,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, - /* num_threads_dims = */ nullptr, - /* num_threads_values = */ llvm::SmallVector<Value>{}, + /* num_threads_num_dims = */ nullptr, + /* num_threads_dims_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index d75b9e17f1e98..c2aca0887e38d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2519,7 +2519,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); ParallelOp::build( builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues, + clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues, clauses.numThreads, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, @@ -2570,30 +2570,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) { return success(); } -LogicalResult ParallelOp::verify() { - // verify num_threads clause restrictions - auto numThreadsDims = getNumThreadsDims(); - auto numThreadsValues = getNumThreadsValues(); - auto numThreads = getNumThreads(); - - // num_threads with dims modifier - if (numThreadsDims.has_value() && numThreadsValues.empty()) { - return emitError( - "num_threads dims modifier requires values to be specified"); - } - - if (numThreadsDims.has_value() && - numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) { - return emitError("num_threads dims(") - << *numThreadsDims << ") specified but " << numThreadsValues.size() - << " values provided"; +// Helper: Verify num_threads clause +LogicalResult +verifyNumThreadsClause(Operation *op, + std::optional<IntegerAttr> numThreadsNumDims, + OperandRange numThreadsDimsValues, Value numThreads) { + bool hasDimsModifier = + numThreadsNumDims.has_value() && numThreadsNumDims.value(); + if (hasDimsModifier && numThreads) { + return op->emitError("num_threads with dims modifier cannot be used " + "together with number of threads"); } + if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) + return failure(); + return success(); +} - // num_threads dims and number of threads cannot be used together - if (numThreadsDims.has_value() && numThreads) { - return emitError( - "num_threads dims and number of threads cannot be used together"); - } +LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + if (failed(verifyNumThreadsClause( + getOperation(), this->getNumThreadsNumDimsAttr(), + this->getNumThreadsDimsValues(), this->getNumThreads()))) + return failure(); // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9c176b56a4d5d..2d71910e27a52 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3269,8 +3269,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; // num_threads dims and values are not yet supported - assert(!opInst.getNumThreadsDims().has_value() && - opInst.getNumThreadsValues().empty() && + assert(!opInst.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); @@ -6055,8 +6054,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, }) .Case([&](omp::ParallelOp parallelOp) { // num_threads dims and values are not yet supported - assert(!parallelOp.getNumThreadsDims().has_value() && - parallelOp.getNumThreadsValues().empty() && + assert(!parallelOp.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; @@ -6177,8 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { // num_threads dims and values are not yet supported - assert(!parallelOp.getNumThreadsDims().has_value() && - parallelOp.getNumThreadsValues().empty() && + assert(!parallelOp.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index e55fe3d0a1aec..17985651a1286 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) { // ----- func.func @num_threads_dims_no_values() { - // expected-error@+1 {{num_threads dims modifier requires values to be specified}} + // expected-error@+1 {{dims modifier requires values to be specified}} "omp.parallel"() ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () return } // ----- func.func @num_threads_dims_mismatch(%n : i64) { - // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}} + // expected-error@+1 {{dims(2) specified but 1 values provided}} omp.parallel num_threads(dims(2): %n : i64) { omp.terminator } @@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) { // ----- func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { - // expected-error@+1 {{num_threads dims and number of threads cannot be used together}} + // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}} "omp.parallel"(%n, %n, %m) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> () return } >From cad7b45c8ba7d37198eb2a79d98288d5cb0ed45b Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 19 Dec 2025 12:27:38 +0530 Subject: [PATCH 05/11] Use num_threads_dims_values only --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 15 ++--- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 15 +++-- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 5 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 62 ++++++++----------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++--- mlir/test/Dialect/OpenMP/invalid.mlir | 12 ++-- mlir/test/Dialect/OpenMP/ops.mlir | 10 +-- 8 files changed, 66 insertions(+), 73 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 2f531efaf09aa..8a96872294124 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -516,8 +516,8 @@ bool ClauseProcessor::processNumThreads( mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result.numThreads = - fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.numThreadsDimsValues.push_back( + fir::getBase(converter.genExprValue(clause->v, stmtCtx))); return true; } return false; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0764693f748a5..7b12750eebb4f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -99,8 +99,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) vars.push_back(ops.numTeamsUpper); - if (ops.numThreads) - vars.push_back(ops.numThreads); + for (auto numThreads : ops.numThreadsDimsValues) + vars.push_back(numThreads); if (ops.threadLimit) vars.push_back(ops.threadLimit); @@ -115,7 +115,8 @@ class HostEvalInfo { assert(args.size() == ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + - (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + + ops.numThreadsDimsValues.size() + (ops.threadLimit ? 1 : 0) && "invalid block argument list"); int argIndex = 0; @@ -134,8 +135,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) ops.numTeamsUpper = args[argIndex++]; - if (ops.numThreads) - ops.numThreads = args[argIndex++]; + for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i) + ops.numThreadsDimsValues[i] = args[argIndex++]; if (ops.threadLimit) ops.threadLimit = args[argIndex++]; @@ -169,13 +170,13 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::ParallelOperands &clauseOps) { - if (!ops.numThreads || parallelApplied) { + if (ops.numThreadsDimsValues.empty() || parallelApplied) { parallelApplied = true; return false; } parallelApplied = true; - clauseOps.numThreads = ops.numThreads; + clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues; return true; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8d8db94630f84..10aaab4b6f21c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip< extraClassDeclaration> { let arguments = (ins ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, - Variadic<AnyInteger>:$num_threads_dims_values, - Optional<IntLikeType>:$num_threads + Variadic<IntLikeType>:$num_threads_dims_values ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values), - $num_threads, type($num_threads) + $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values) ) `)` }]; @@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip< - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` Without dims modifier: - - Uses `num_threads` - - If lower bound not specified, it defaults to upper bound value - - Format: `num_threads(bounds : type)` - - Example: `num_threads(%ub : i32)` + - The number of threads is specified by single value in `num_threads_dims_values` + - Format: `num_threads(value : type)` + - Example: `num_threads(%n : i32)` }]; let extraClassDeclaration = [{ @@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip< ::mlir::Value getNumThreadsDimsValue(unsigned index) { assert(index < getNumThreadsDimsCount() && "Num threads dims index out of bounds"); + if(getNumThreadsDimsValues().empty()) + return nullptr; return getNumThreadsDimsValues()[index]; } }]; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index f9c8cab9b3d7b..3a1f311dd63f0 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -487,9 +487,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { rewriter.eraseOp(reduce); Value numThreadsVar; + SmallVector<Value> numThreadsValues; if (numThreads > 0) { numThreadsVar = LLVM::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); + numThreadsValues.push_back(numThreadsVar); } // Create the parallel wrapper. auto ompParallel = omp::ParallelOp::create( @@ -498,8 +500,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, /* num_threads_num_dims = */ nullptr, - /* num_threads_dims_values = */ llvm::SmallVector<Value>{}, - /* num_threads = */ numThreadsVar, + /* num_threads_dims_values = */ numThreadsValues, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, /* private_needs_barrier = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c2aca0887e38d..9366f04e51629 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2252,7 +2252,8 @@ LogicalResult TargetOp::verifyRegions() { if (auto parallelOp = dyn_cast<ParallelOp>(user)) { if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && parallelOp->isAncestor(capturedOp) && - hostEvalArg == parallelOp.getNumThreads()) + llvm::is_contained(parallelOp.getNumThreadsDimsValues(), + hostEvalArg)) continue; return emitOpError() @@ -2506,7 +2507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, /*num_threads_dims=*/nullptr, /*num_threads_values=*/ValueRange(), - /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), + /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(), @@ -2517,14 +2518,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); - ParallelOp::build( - builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues, - clauses.numThreads, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, - clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms)); + ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numThreadsNumDims, + clauses.numThreadsDimsValues, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.procBindKind, + clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms)); } template <typename OpType> @@ -2574,13 +2575,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) { LogicalResult verifyNumThreadsClause(Operation *op, std::optional<IntegerAttr> numThreadsNumDims, - OperandRange numThreadsDimsValues, Value numThreads) { - bool hasDimsModifier = - numThreadsNumDims.has_value() && numThreadsNumDims.value(); - if (hasDimsModifier && numThreads) { - return op->emitError("num_threads with dims modifier cannot be used " - "together with number of threads"); - } + OperandRange numThreadsDimsValues) { if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) return failure(); return success(); @@ -2588,9 +2583,9 @@ verifyNumThreadsClause(Operation *op, LogicalResult ParallelOp::verify() { // verify num_threads clause restrictions - if (failed(verifyNumThreadsClause( - getOperation(), this->getNumThreadsNumDimsAttr(), - this->getNumThreadsDimsValues(), this->getNumThreads()))) + if (failed(verifyNumThreadsClause(getOperation(), + this->getNumThreadsNumDimsAttr(), + this->getNumThreadsDimsValues()))) return failure(); // verify allocate clause restrictions @@ -4657,33 +4652,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, static ParseResult parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, - SmallVectorImpl<Type> &types, - std::optional<OpAsmParser::UnresolvedOperand> &bounds, - Type &boundsType) { + SmallVectorImpl<Type> &types) { if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { return success(); } - OpAsmParser::UnresolvedOperand boundsOperand; - if (parser.parseOperand(boundsOperand) || parser.parseColon() || - parser.parseType(boundsType)) { + // Without dims modifier: value : type + OpAsmParser::UnresolvedOperand singleValue; + Type singleType; + if (parser.parseOperand(singleValue) || parser.parseColon() || + parser.parseType(singleType)) { return failure(); } - bounds = boundsOperand; + values.push_back(singleValue); + types.push_back(singleType); return success(); } static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, IntegerAttr dimsAttr, OperandRange values, - TypeRange types, Value bounds, - Type boundsType) { - if (!values.empty()) { - printDimsModifierWithValues(p, dimsAttr, values, types); - } - if (bounds) { - p.printOperand(bounds); - p << " : " << boundsType; - } + TypeRange types) { + // Multidimensional: dims(N): values : type + printDimsModifierWithValues(p, dimsAttr, values, types); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 2d71910e27a52..d4aaa832636d1 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3270,8 +3270,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, llvm::Value *numThreads = nullptr; // num_threads dims and values are not yet supported assert(!opInst.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is NYI."); - if (auto numThreadsVar = opInst.getNumThreads()) + "Lowering of num_threads with dims modifier is not yet implemented."); + if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0)) numThreads = moduleTranslation.lookupValue(numThreadsVar); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) @@ -6055,8 +6055,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, .Case([&](omp::ParallelOp parallelOp) { // num_threads dims and values are not yet supported assert(!parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is NYI."); - if (parallelOp.getNumThreads() == blockArg) + "Lowering of num_threads with dims modifier is not yet " + "implemented."); + if (parallelOp.getNumThreadsDimsValue(0) == blockArg) numThreads = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6175,9 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { // num_threads dims and values are not yet supported - assert(!parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is NYI."); - numThreads = parallelOp.getNumThreads(); + assert( + !parallelOp.hasNumThreadsDimsModifier() && + "Lowering of num_threads with dims modifier is not yet implemented."); + numThreads = parallelOp.getNumThreadsDimsValue(0); } } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 17985651a1286..b2e20b4c5ee5a 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() { // expected-error@+1 {{dims modifier requires values to be specified}} "omp.parallel"() ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () return } @@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) { // ----- -func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { - // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}} - "omp.parallel"(%n, %n, %m) ({ +func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) { + // expected-error@+1 {{dims values can only be specified with dims modifier}} + "omp.parallel"(%n, %m) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> () return } @@ -2722,7 +2722,7 @@ func.func @undefined_privatizer(%arg0: index) { // ----- func.func @undefined_privatizer(%arg0: !llvm.ptr) { // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} - "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ + "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ ^bb0(%arg2: !llvm.ptr): omp.terminator }) : (!llvm.ptr) -> () diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index f9cfd400387a5..e2a3f8fbe2d5f 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32) "omp.parallel"(%data_var, %data_var, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () // CHECK: omp.barrier omp.barrier @@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}}) "omp.parallel"(%data_var, %data_var, %if_cond) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> () // CHECK: omp.parallel omp.parallel { >From e2b12cf37fea11b6c129b40db66f8524a7cb467c Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Wed, 14 Jan 2026 12:07:56 +0530 Subject: [PATCH 06/11] fix adding numThreadsNumDims to ParallelOperands apply method --- flang/lib/Lower/OpenMP/OpenMP.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 7b12750eebb4f..8d03a04d87a21 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -177,6 +177,7 @@ class HostEvalInfo { parallelApplied = true; clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues; + clauseOps.numThreadsNumDims = ops.numThreadsNumDims; return true; } >From 76d39229e04f303dc6b847c434e0904bece00f1a Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 16 Jan 2026 12:32:56 +0530 Subject: [PATCH 07/11] Remove dims(N) syntax and use list of vals for num_threads --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 14 +++-- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 53 +++++++++---------- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 3 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 51 +++++------------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 26 ++++----- mlir/test/Dialect/OpenMP/invalid.mlir | 31 ----------- mlir/test/Dialect/OpenMP/ops.mlir | 11 +++- mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++++ 9 files changed, 76 insertions(+), 126 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 8a96872294124..e33bdcc5c4dbd 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -516,7 +516,7 @@ bool ClauseProcessor::processNumThreads( mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result.numThreadsDimsValues.push_back( + result.numThreadsVals.push_back( fir::getBase(converter.genExprValue(clause->v, stmtCtx))); return true; } diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 8d03a04d87a21..9947dcc8d5ebc 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -99,7 +99,7 @@ class HostEvalInfo { if (ops.numTeamsUpper) vars.push_back(ops.numTeamsUpper); - for (auto numThreads : ops.numThreadsDimsValues) + for (auto numThreads : ops.numThreadsVals) vars.push_back(numThreads); if (ops.threadLimit) @@ -115,8 +115,7 @@ class HostEvalInfo { assert(args.size() == ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + - (ops.numTeamsUpper ? 1 : 0) + - ops.numThreadsDimsValues.size() + + (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVals.size() + (ops.threadLimit ? 1 : 0) && "invalid block argument list"); int argIndex = 0; @@ -135,8 +134,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) ops.numTeamsUpper = args[argIndex++]; - for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i) - ops.numThreadsDimsValues[i] = args[argIndex++]; + for (size_t i = 0; i < ops.numThreadsVals.size(); ++i) + ops.numThreadsVals[i] = args[argIndex++]; if (ops.threadLimit) ops.threadLimit = args[argIndex++]; @@ -170,14 +169,13 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::ParallelOperands &clauseOps) { - if (ops.numThreadsDimsValues.empty() || parallelApplied) { + if (ops.numThreadsVals.empty() || parallelApplied) { parallelApplied = true; return false; } parallelApplied = true; - clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues; - clauseOps.numThreadsNumDims = ops.numThreadsNumDims; + clauseOps.numThreadsVals = ops.numThreadsVals; return true; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 10aaab4b6f21c..cda6906d46965 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,53 +1069,48 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, - Variadic<IntLikeType>:$num_threads_dims_values + Variadic<IntLikeType>:$num_threads_vals ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values) + $num_threads_vals, type($num_threads_vals) ) `)` }]; let description = [{ - num_threads clause specifies the desired number of threads in the team - space formed by the construct on which it appears. - - With dims modifier: - - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list) - - Specifies upper bounds for each dimension (all must have same type) - - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` - - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` - - Without dims modifier: - - The number of threads is specified by single value in `num_threads_dims_values` - - Format: `num_threads(value : type)` + The `num_threads` clause specifies the number of threads. + + Multi-dimensional format (dims modifier): + - Multiple values can be specified for multi-dimensional thread counts. + - The number of dimensions is derived from the number of values. + - Values can have different integer types. + - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)` + - Example: `num_threads(%n, %m : i32, i64)` + + Single value format: + - A single value specifies the number of threads. + - Format: `num_threads(%value : type)` - Example: `num_threads(%n : i32)` }]; let extraClassDeclaration = [{ - /// Returns true if the dims modifier is explicitly present - bool hasNumThreadsDimsModifier() { - return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value(); + /// Returns true if using multi-dimensional values (more than one value) + bool hasNumThreadsMultiDim() { + return getNumThreadsVals().size() > 1; } - /// Returns the number of dimensions specified by dims modifier + /// Returns the number of dimensions specified for num_threads unsigned getNumThreadsDimsCount() { - if (!hasNumThreadsDimsModifier()) - return 1; - return static_cast<unsigned>(*getNumThreadsNumDims()); + return getNumThreadsVals().size(); } /// Returns the value for a specific dimension index - /// Index must be less than getNumThreadsDimsCount() - ::mlir::Value getNumThreadsDimsValue(unsigned index) { - assert(index < getNumThreadsDimsCount() && - "Num threads dims index out of bounds"); - if(getNumThreadsDimsValues().empty()) - return nullptr; - return getNumThreadsDimsValues()[index]; + /// Index must be less than getNumThreadsVals().size() + ::mlir::Value getNumThreadsVal(unsigned index) { + assert(index < getNumThreadsVals().size() && + "Num threads index out of bounds"); + return getNumThreadsVals()[index]; } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 3a1f311dd63f0..35288687a7eac 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -499,8 +499,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, - /* num_threads_num_dims = */ nullptr, - /* num_threads_dims_values = */ numThreadsValues, + /* num_threads_vals = */ numThreadsValues, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, /* private_needs_barrier = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 9366f04e51629..65a006b48f480 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2252,8 +2252,7 @@ LogicalResult TargetOp::verifyRegions() { if (auto parallelOp = dyn_cast<ParallelOp>(user)) { if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && parallelOp->isAncestor(capturedOp) && - llvm::is_contained(parallelOp.getNumThreadsDimsValues(), - hostEvalArg)) + llvm::is_contained(parallelOp.getNumThreadsVals(), hostEvalArg)) continue; return emitOpError() @@ -2505,8 +2504,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, - /*num_threads_dims=*/nullptr, - /*num_threads_values=*/ValueRange(), + /*num_threads_vals=*/ValueRange(), /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2519,8 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsNumDims, - clauses.numThreadsDimsValues, clauses.privateVars, + clauses.ifExpr, clauses.numThreadsVals, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, @@ -2571,23 +2568,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) { return success(); } -// Helper: Verify num_threads clause -LogicalResult -verifyNumThreadsClause(Operation *op, - std::optional<IntegerAttr> numThreadsNumDims, - OperandRange numThreadsDimsValues) { - if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) - return failure(); - return success(); -} - LogicalResult ParallelOp::verify() { - // verify num_threads clause restrictions - if (failed(verifyNumThreadsClause(getOperation(), - this->getNumThreadsNumDimsAttr(), - this->getNumThreadsDimsValues()))) - return failure(); - // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( @@ -4650,30 +4631,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, // Parser and printer for num_threads clause //===----------------------------------------------------------------------===// static ParseResult -parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, +parseNumThreadsClause(OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, SmallVectorImpl<Type> &types) { - if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { - return success(); - } - - // Without dims modifier: value : type - OpAsmParser::UnresolvedOperand singleValue; - Type singleType; - if (parser.parseOperand(singleValue) || parser.parseColon() || - parser.parseType(singleType)) { + // Parse comma-separated list of values with their types + // Format: %v1, %v2, ... : type1, type2, ... + if (parser.parseOperandList(values) || parser.parseColon() || + parser.parseTypeList(types)) { return failure(); } - values.push_back(singleValue); - types.push_back(singleType); return success(); } static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, - IntegerAttr dimsAttr, OperandRange values, - TypeRange types) { - // Multidimensional: dims(N): values : type - printDimsModifierWithValues(p, dimsAttr, values, types); + OperandRange values, TypeRange types) { + // Print values with their types + llvm::interleaveComma(values, p, [&](Value v) { p << v; }); + p << " : "; + llvm::interleaveComma(types, p, [&](Type t) { p << t; }); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index d4aaa832636d1..a1cb06254f4b0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.hasNumTeamsMultiDim()) result = todo("num_teams with multi-dimensional values"); }; + auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) { + if (op.hasNumThreadsMultiDim()) + result = todo("num_threads with multi-dimensional values"); + }; LogicalResult result = success(); llvm::TypeSwitch<Operation &>(op) @@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::ParallelOp op) { checkAllocate(op, result); checkReduction(op, result); + checkNumThreadsMultiDim(op, result); }) .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, @@ -3268,11 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; - // num_threads dims and values are not yet supported - assert(!opInst.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is not yet implemented."); - if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0)) - numThreads = moduleTranslation.lookupValue(numThreadsVar); + if (!opInst.getNumThreadsVals().empty()) + numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0)); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); @@ -6053,11 +6055,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { - // num_threads dims and values are not yet supported - assert(!parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is not yet " - "implemented."); - if (parallelOp.getNumThreadsDimsValue(0) == blockArg) + if (!parallelOp.getNumThreadsVals().empty() && + parallelOp.getNumThreadsVal(0) == blockArg) numThreads = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6175,11 +6174,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { - // num_threads dims and values are not yet supported - assert( - !parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is not yet implemented."); - numThreads = parallelOp.getNumThreadsDimsValue(0); + if (!parallelOp.getNumThreadsVals().empty()) + numThreads = parallelOp.getNumThreadsVal(0); } } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index b2e20b4c5ee5a..cd06011c2cbc4 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -30,37 +30,6 @@ func.func @num_threads_once(%n : si32) { // ----- -func.func @num_threads_dims_no_values() { - // expected-error@+1 {{dims modifier requires values to be specified}} - "omp.parallel"() ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () - return -} - -// ----- - -func.func @num_threads_dims_mismatch(%n : i64) { - // expected-error@+1 {{dims(2) specified but 1 values provided}} - omp.parallel num_threads(dims(2): %n : i64) { - omp.terminator - } - - return -} - -// ----- - -func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) { - // expected-error@+1 {{dims values can only be specified with dims modifier}} - "omp.parallel"(%n, %m) ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> () - return -} - -// ----- - func.func @nowait_not_allowed(%n : memref<i32>) { // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index e2a3f8fbe2d5f..1700ad696f86f 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -160,8 +160,15 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre omp.terminator } - // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64) - omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) { + // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64) + omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) { + omp.terminator + } + + %n_i16 = arith.constant 8 : i16 + // Test num_threads with mixed types. + // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16) + omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) { omp.terminator } diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 1ea56fdd0bf16..e4c47aae9b485 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -443,6 +443,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) { // ----- +llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) { + // expected-error@below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.parallel}} + omp.parallel num_threads(%lb, %ub : i32, i32) { + omp.terminator + } + llvm.return +} + +// ----- + llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} >From e1fc5f168954e48730538dd83a58a2603ce2ab3b Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Sat, 17 Jan 2026 10:37:09 +0530 Subject: [PATCH 08/11] remove custom parser printer for num_threads --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 82 +------------------ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 29 ------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +-- mlir/test/Dialect/OpenMP/invalid.mlir | 19 ++++- 5 files changed, 28 insertions(+), 115 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index cda6906d46965..228e6e2deb1fb 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1073,9 +1073,7 @@ class OpenMP_NumThreadsClauseSkip< ); let optAssemblyFormat = [{ - `num_threads` `(` custom<NumThreadsClause>( - $num_threads_vals, type($num_threads_vals) - ) `)` + `num_threads` `(` $num_threads_vals `:` type($num_threads_vals) `)` }]; let description = [{ @@ -1107,10 +1105,10 @@ class OpenMP_NumThreadsClauseSkip< /// Returns the value for a specific dimension index /// Index must be less than getNumThreadsVals().size() - ::mlir::Value getNumThreadsVal(unsigned index) { - assert(index < getNumThreadsVals().size() && + ::mlir::Value getNumThreads(unsigned dim = 0) { + assert(dim < getNumThreadsDimsCount() && "Num threads index out of bounds"); - return getNumThreadsVals()[index]; + return getNumThreadsVals()[dim]; } }]; } @@ -1600,76 +1598,4 @@ class OpenMP_UseDevicePtrClauseSkip< def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>; -//===----------------------------------------------------------------------===// -// V6.2: Multidimensional `num_teams` clause with dims modifier -//===----------------------------------------------------------------------===// - -class OpenMP_NumTeamsMultiDimClauseSkip< - bit traits = false, bit arguments = false, bit assemblyFormat = false, - bit description = false, bit extraClassDeclaration = false - > : OpenMP_Clause<traits, arguments, assemblyFormat, description, - extraClassDeclaration> { - let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims, - Variadic<AnyInteger>:$num_teams_values - ); - - let optAssemblyFormat = [{ - `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims, - $num_teams_values, - type($num_teams_values)) `)` - }]; - - let description = [{ - The `num_teams_multi_dim` clause with dims modifier support specifies the limit on - the number of teams to be created in a multidimensional team space. - - The dims modifier for the num_teams_multi_dim clause specifies the number of - dimensions for the league space (team space) that the clause arranges. - The dimensions argument in the dims modifier specifies the number of - dimensions and determines the length of the list argument. The list items - are specified in ascending order according to the ordinal number of the - dimensions (dimension 0, 1, 2, ..., N-1). - - - If `dims` is not specified: The space is unidimensional (1D) with a single value - - If `dims(1)` is specified: The space is explicitly unidimensional (1D) - - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D) - - **Examples:** - - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a - 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2. - - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt. - }]; - - let extraClassDeclaration = [{ - /// Returns true if the dims modifier is explicitly present - bool hasDimsModifier() { - return getNumTeamsDims().has_value(); - } - - /// Returns the number of dimensions specified by dims modifier - /// Returns 1 if dims modifier is not present (unidimensional by default) - unsigned getNumDimensions() { - if (!hasDimsModifier()) - return 1; - return static_cast<unsigned>(*getNumTeamsDims()); - } - - /// Returns all dimension values as an operand range - ::mlir::OperandRange getDimensionValues() { - return getNumTeamsValues(); - } - - /// Returns the value for a specific dimension index - /// Index must be less than getNumDimensions() - ::mlir::Value getDimensionValue(unsigned index) { - assert(index < getDimensionValues().size() && - "Dimension index out of bounds"); - return getDimensionValues()[index]; - } - }]; -} - -def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>; - #endif // OPENMP_CLAUSES diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 76eeb0bd70ec3..d4e8cecda2601 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -241,8 +241,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [ AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface ], clauses = [ OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause, - OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause, - OpenMP_ThreadLimitClause + OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause ], singleRegion = true> { let summary = "teams construct"; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 65a006b48f480..4cdeaa0bc8e87 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2629,13 +2629,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, -<<<<<<< HEAD clauses.ifExpr, clauses.numTeamsVals, clauses.numTeamsLower, clauses.numTeamsUpper, -======= - clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, - clauses.numTeamsDims, clauses.numTeamsValues, ->>>>>>> [OpenMP][MLIR] Add num_teams clause with dims modifier support /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, clauses.reductionMod, clauses.reductionVars, @@ -4627,30 +4622,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } -//===----------------------------------------------------------------------===// -// Parser and printer for num_threads clause -//===----------------------------------------------------------------------===// -static ParseResult -parseNumThreadsClause(OpAsmParser &parser, - SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, - SmallVectorImpl<Type> &types) { - // Parse comma-separated list of values with their types - // Format: %v1, %v2, ... : type1, type2, ... - if (parser.parseOperandList(values) || parser.parseColon() || - parser.parseTypeList(types)) { - return failure(); - } - return success(); -} - -static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, - OperandRange values, TypeRange types) { - // Print values with their types - llvm::interleaveComma(values, p, [&](Value v) { p << v; }); - p << " : "; - llvm::interleaveComma(types, p, [&](Type t) { p << t; }); -} - #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index a1cb06254f4b0..73a91b3707c57 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.hasNumTeamsMultiDim()) result = todo("num_teams with multi-dimensional values"); }; - auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) { + auto checkNumThreads = [&todo](auto op, LogicalResult &result) { if (op.hasNumThreadsMultiDim()) result = todo("num_threads with multi-dimensional values"); }; @@ -435,7 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::ParallelOp op) { checkAllocate(op, result); checkReduction(op, result); - checkNumThreadsMultiDim(op, result); + checkNumThreads(op, result); }) .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, @@ -3274,7 +3274,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; if (!opInst.getNumThreadsVals().empty()) - numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0)); + numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0)); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); @@ -6056,7 +6056,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, }) .Case([&](omp::ParallelOp parallelOp) { if (!parallelOp.getNumThreadsVals().empty() && - parallelOp.getNumThreadsVal(0) == blockArg) + parallelOp.getNumThreads(0) == blockArg) numThreads = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6175,7 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { if (!parallelOp.getNumThreadsVals().empty()) - numThreads = parallelOp.getNumThreadsVal(0); + numThreads = parallelOp.getNumThreads(0); } } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index cd06011c2cbc4..d451b14e8bfc9 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1451,7 +1451,24 @@ func.func @omp_teams_num_teams1(%lb : i32) { // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}} "omp.teams" (%lb) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> () + omp.terminator + } + return +} + +// ----- + +func.func @omp_teams_num_teams_multidim_with_bounds() { + omp.target { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + %lb = arith.constant 3 : i32 + %ub = arith.constant 4 : i32 + // expected-error @below {{num_teams multi-dimensional values cannot be used together with legacy lower/upper bounds}} + "omp.teams" (%v0, %v1, %lb, %ub) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> () omp.terminator } return >From 01db9bf197fac18cdded2e52e96b72c6492f3189 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Tue, 20 Jan 2026 11:52:42 +0530 Subject: [PATCH 09/11] rename num_threads_vals to num_threads_vars --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 12 ++++++------ mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 12 ++++++------ mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 6 +++--- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 6 +++--- .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 6 +++--- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index e33bdcc5c4dbd..159fe371c7ef3 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -516,7 +516,7 @@ bool ClauseProcessor::processNumThreads( mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result.numThreadsVals.push_back( + result.numThreadsVars.push_back( fir::getBase(converter.genExprValue(clause->v, stmtCtx))); return true; } diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 9947dcc8d5ebc..d2548df804929 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -99,7 +99,7 @@ class HostEvalInfo { if (ops.numTeamsUpper) vars.push_back(ops.numTeamsUpper); - for (auto numThreads : ops.numThreadsVals) + for (auto numThreads : ops.numThreadsVars) vars.push_back(numThreads); if (ops.threadLimit) @@ -115,7 +115,7 @@ class HostEvalInfo { assert(args.size() == ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + - (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVals.size() + + (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVars.size() + (ops.threadLimit ? 1 : 0) && "invalid block argument list"); int argIndex = 0; @@ -134,8 +134,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) ops.numTeamsUpper = args[argIndex++]; - for (size_t i = 0; i < ops.numThreadsVals.size(); ++i) - ops.numThreadsVals[i] = args[argIndex++]; + for (size_t i = 0; i < ops.numThreadsVars.size(); ++i) + ops.numThreadsVars[i] = args[argIndex++]; if (ops.threadLimit) ops.threadLimit = args[argIndex++]; @@ -169,13 +169,13 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::ParallelOperands &clauseOps) { - if (ops.numThreadsVals.empty() || parallelApplied) { + if (ops.numThreadsVars.empty() || parallelApplied) { parallelApplied = true; return false; } parallelApplied = true; - clauseOps.numThreadsVals = ops.numThreadsVals; + clauseOps.numThreadsVars = ops.numThreadsVars; return true; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 228e6e2deb1fb..fb1fb2b4461f1 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,11 +1069,11 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - Variadic<IntLikeType>:$num_threads_vals + Variadic<IntLikeType>:$num_threads_vars ); let optAssemblyFormat = [{ - `num_threads` `(` $num_threads_vals `:` type($num_threads_vals) `)` + `num_threads` `(` $num_threads_vars `:` type($num_threads_vars) `)` }]; let description = [{ @@ -1095,20 +1095,20 @@ class OpenMP_NumThreadsClauseSkip< let extraClassDeclaration = [{ /// Returns true if using multi-dimensional values (more than one value) bool hasNumThreadsMultiDim() { - return getNumThreadsVals().size() > 1; + return getNumThreadsVars().size() > 1; } /// Returns the number of dimensions specified for num_threads unsigned getNumThreadsDimsCount() { - return getNumThreadsVals().size(); + return getNumThreadsVars().size(); } /// Returns the value for a specific dimension index - /// Index must be less than getNumThreadsVals().size() + /// Index must be less than getNumThreadsVars().size() ::mlir::Value getNumThreads(unsigned dim = 0) { assert(dim < getNumThreadsDimsCount() && "Num threads index out of bounds"); - return getNumThreadsVals()[dim]; + return getNumThreadsVars()[dim]; } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 35288687a7eac..d410cc80d1fef 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -487,11 +487,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { rewriter.eraseOp(reduce); Value numThreadsVar; - SmallVector<Value> numThreadsValues; + SmallVector<Value> numThreadsVars; if (numThreads > 0) { numThreadsVar = LLVM::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); - numThreadsValues.push_back(numThreadsVar); + numThreadsVars.push_back(numThreadsVar); } // Create the parallel wrapper. auto ompParallel = omp::ParallelOp::create( @@ -499,7 +499,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, - /* num_threads_vals = */ numThreadsValues, + /* num_threads_vars = */ numThreadsVars, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, /* private_needs_barrier = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 4cdeaa0bc8e87..d9eb604a31811 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2252,7 +2252,7 @@ LogicalResult TargetOp::verifyRegions() { if (auto parallelOp = dyn_cast<ParallelOp>(user)) { if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && parallelOp->isAncestor(capturedOp) && - llvm::is_contained(parallelOp.getNumThreadsVals(), hostEvalArg)) + llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg)) continue; return emitOpError() @@ -2504,7 +2504,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, - /*num_threads_vals=*/ValueRange(), + /*num_threads_vars=*/ValueRange(), /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2517,7 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsVals, clauses.privateVars, + clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 73a91b3707c57..98c7c8abcccbc 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3273,7 +3273,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; - if (!opInst.getNumThreadsVals().empty()) + if (!opInst.getNumThreadsVars().empty()) numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0)); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) @@ -6055,7 +6055,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { - if (!parallelOp.getNumThreadsVals().empty() && + if (!parallelOp.getNumThreadsVars().empty() && parallelOp.getNumThreads(0) == blockArg) numThreads = hostEvalVar; else @@ -6174,7 +6174,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { - if (!parallelOp.getNumThreadsVals().empty()) + if (!parallelOp.getNumThreadsVars().empty()) numThreads = parallelOp.getNumThreads(0); } } >From 0e262214fa25cd5c9055e52d1c815d2bd985704b Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Wed, 21 Jan 2026 20:58:53 +0530 Subject: [PATCH 10/11] fix numThreads in clauseProcessor and comments fixes --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 8 +++++--- flang/lib/Lower/OpenMP/Clauses.cpp | 13 ++++++++++--- llvm/include/llvm/Frontend/OpenMP/ClauseT.h | 4 +++- mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 3 +-- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 3 --- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 159fe371c7ef3..cbc38ec865440 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -515,9 +515,11 @@ bool ClauseProcessor::processNumThreads( lower::StatementContext &stmtCtx, mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) { - // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result.numThreadsVars.push_back( - fir::getBase(converter.genExprValue(clause->v, stmtCtx))); + // OMPIRBuilder expects `NUM_THREADS` clause as a list of Values. + for (const ExprTy &expr : clause->v) { + result.numThreadsVars.push_back( + fir::getBase(converter.genExprValue(expr, stmtCtx))); + } return true; } return false; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index c739249bff211..668483f021f4c 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -1290,9 +1290,16 @@ NumTeams make(const parser::OmpClause::NumTeams &inp, NumThreads make(const parser::OmpClause::NumThreads &inp, semantics::SemanticsContext &semaCtx) { // inp.v -> parser::OmpNumThreadsClause - auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t); - assert(!t1.empty()); - return NumThreads{/*Nthreads=*/makeExpr(t1.front(), semaCtx)}; + // With dims modifier (OpenMP 6.1): multiple values + // Without dims modifier: single value + auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t); + assert(!values.empty()); + + List<NumThreads::Nthreads> v; + for (const auto &val : values) { + v.push_back(makeExpr(val, semaCtx)); + } + return NumThreads{/*Nthreads=*/v}; } // OmpxAttribute: empty diff --git a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h index 05ee1ae36a23d..3e4a6b31bf8eb 100644 --- a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h +++ b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h @@ -1011,11 +1011,13 @@ struct NumTeamsT { }; // V5.2: [10.1.2] `num_threads` clause +// V6.1: Extended with dims modifier support template <typename T, typename I, typename E> // struct NumThreadsT { using Nthreads = E; + using List = ListT<Nthreads>; using WrapperTrait = std::true_type; - Nthreads v; + List v; }; template <typename T, typename I, typename E> // diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index d410cc80d1fef..48845734e9547 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -486,10 +486,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { } rewriter.eraseOp(reduce); - Value numThreadsVar; SmallVector<Value> numThreadsVars; if (numThreads > 0) { - numThreadsVar = LLVM::ConstantOp::create( + Value numThreadsVar = LLVM::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); numThreadsVars.push_back(numThreadsVar); } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index d9eb604a31811..13bd64ab6966d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2569,16 +2569,13 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { - // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); - // verify private variables restrictions if (failed(verifyPrivateVarList(*this))) return failure(); - // verify reduction variables restrictions return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } >From aab3083ea3d9d8d3424cf797bc829fece9efa787 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 22 Jan 2026 11:10:19 +0530 Subject: [PATCH 11/11] Add a test for flanf to mlir lowering for num_threads --- flang/test/Lower/OpenMP/num-threads-dims.f90 | 61 ++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 flang/test/Lower/OpenMP/num-threads-dims.f90 diff --git a/flang/test/Lower/OpenMP/num-threads-dims.f90 b/flang/test/Lower/OpenMP/num-threads-dims.f90 new file mode 100644 index 0000000000000..f3a8d706b7283 --- /dev/null +++ b/flang/test/Lower/OpenMP/num-threads-dims.f90 @@ -0,0 +1,61 @@ +! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s + +!=============================================================================== +! `num_threads` clause with dims modifier (OpenMP 6.1) +!=============================================================================== + +! CHECK-LABEL: func @_QPparallel_numthreads_dims4 +subroutine parallel_numthreads_dims4() + ! CHECK: omp.parallel + ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32, i32) + !$omp parallel num_threads(dims(4): 4, 5, 6, 7) + call f1() + ! CHECK: omp.terminator + !$omp end parallel +end subroutine parallel_numthreads_dims4 + +! CHECK-LABEL: func @_QPparallel_numthreads_dims2 +subroutine parallel_numthreads_dims2() + ! CHECK: omp.parallel + ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}} : i32, i32) + !$omp parallel num_threads(dims(2): 8, 4) + call f1() + ! CHECK: omp.terminator + !$omp end parallel +end subroutine parallel_numthreads_dims2 + +! CHECK-LABEL: func @_QPparallel_numthreads_dims_var +subroutine parallel_numthreads_dims_var(a, b, c) + integer, intent(in) :: a, b, c + ! CHECK: omp.parallel + ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) + !$omp parallel num_threads(dims(3): a, b, c) + call f1() + ! CHECK: omp.terminator + !$omp end parallel +end subroutine parallel_numthreads_dims_var + +!=============================================================================== +! `num_threads` clause without dims modifier (legacy) +!=============================================================================== + +! CHECK-LABEL: func @_QPparallel_numthreads_legacy +subroutine parallel_numthreads_legacy(n) + integer, intent(in) :: n + ! CHECK: omp.parallel + ! CHECK-SAME: num_threads(%{{.*}} : i32) + !$omp parallel num_threads(n) + call f1() + ! CHECK: omp.terminator + !$omp end parallel +end subroutine parallel_numthreads_legacy + +! CHECK-LABEL: func @_QPparallel_numthreads_const +subroutine parallel_numthreads_const() + ! CHECK: omp.parallel + ! CHECK-SAME: num_threads(%{{.*}} : i32) + !$omp parallel num_threads(16) + call f1() + ! CHECK: omp.terminator + !$omp end parallel +end subroutine parallel_numthreads_const _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
