[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams clause with dims modifier support (PR #169883)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 9 23:55:14 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Chaitanya (skc7)
<details>
<summary>Changes</summary>
This is WIP PR for support of openmp 6.1 feature `num_teams` with dims modifier.
---
Full diff: https://github.com/llvm/llvm-project/pull/169883.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+48-8)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+182-8)
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+77-2)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8e43c4284d078..1b44873ea99b1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -974,22 +974,62 @@ class OpenMP_NumTeamsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+ Variadic<AnyInteger>:$num_teams_values,
Optional<AnyInteger>:$num_teams_lower,
Optional<AnyInteger>:$num_teams_upper
);
let optAssemblyFormat = [{
- `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
- $num_teams_upper `:` type($num_teams_upper) `)`
+ `num_teams` `(` custom<NumTeamsClause>(
+ $num_teams_dims, $num_teams_values, type($num_teams_values),
+ $num_teams_lower, type($num_teams_lower),
+ $num_teams_upper, type($num_teams_upper)
+ ) `)`
}];
let description = [{
- The optional `num_teams_upper` and `num_teams_lower` arguments specify the
- limit on the number of teams to be created. If only the upper bound is
- specified, it acts as if the lower bound was set to the same value. It is
- not allowed to set `num_teams_lower` if `num_teams_upper` is not specified.
- They define a closed range, where both the lower and upper bounds are
- included.
+ The `num_teams` clause specifies the bounds on the league space formed by the
+ construct on which it appears.
+
+ With dims modifier: (OpenMP 6.1 requirement)
+ - Uses `num_teams_dims` (dimension count) and `num_teams_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_teams(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_teams(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_teams_upper` and optional `num_teams_lower`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_teams(lower : type to upper : type)` or `num_teams(to upper : type)`
+ - Example: `num_teams(%lb : i32 to %ub : i32)` or `num_teams(to %ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumTeamsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumTeamsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumTeamsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0d6b2870c625a..6f56833b3b76a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2620,7 +2620,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+ clauses.ifExpr, clauses.numTeamsDims, clauses.numTeamsValues,
+ clauses.numTeamsLower, clauses.numTeamsUpper,
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
@@ -2642,14 +2643,57 @@ LogicalResult TeamsOp::verify() {
"in any OpenMP dialect operations");
// Check for num_teams clause restrictions
- if (auto numTeamsLowerBound = getNumTeamsLower()) {
- auto numTeamsUpperBound = getNumTeamsUpper();
- if (!numTeamsUpperBound)
- return emitError("expected num_teams upper bound to be defined if the "
- "lower bound is defined");
- if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
+ auto numTeamsDims = getNumTeamsDims();
+ auto numTeamsValues = getNumTeamsValues();
+ auto numTeamsLower = getNumTeamsLower();
+ auto numTeamsUpper = getNumTeamsUpper();
+
+ // Cannot use both dims modifier and unidimensional style
+ if (numTeamsDims.has_value() && (numTeamsLower || numTeamsUpper)) {
+ return emitError(
+ "num_teams with dims modifier cannot be used together with "
+ "lower/upper bounds (unidimensional style)");
+ }
+
+ // With dims modifier (multidimensional)
+ if (numTeamsDims.has_value()) {
+ if (numTeamsValues.empty()) {
+ return emitError(
+ "num_teams dims modifier requires values to be specified");
+ }
+
+ if (numTeamsValues.size() != static_cast<size_t>(*numTeamsDims)) {
+ return emitError("num_teams dims(")
+ << *numTeamsDims << ") specified but " << numTeamsValues.size()
+ << " values provided";
+ }
+
+ // All values must have the same type
+ if (!numTeamsValues.empty()) {
+ Type firstType = numTeamsValues.front().getType();
+ for (auto value : numTeamsValues) {
+ if (value.getType() != firstType) {
+ return emitError(
+ "num_teams dims modifier requires all values to have "
+ "the same type");
+ }
+ }
+ }
+ } else {
+ // Without dims modifier
+ if (!numTeamsValues.empty()) {
return emitError(
- "expected num_teams upper bound and lower bound to be the same type");
+ "num_teams values can only be specified with dims modifier");
+ }
+
+ if (numTeamsLower) {
+ if (!numTeamsUpper)
+ return emitError("expected num_teams upper bound to be defined if the "
+ "lower bound is defined");
+ if (numTeamsLower.getType() != numTeamsUpper.getType())
+ return emitError("expected num_teams upper bound and lower bound to be "
+ "the same type");
+ }
}
// Check for allocate clause restrictions
@@ -4453,6 +4497,136 @@ LogicalResult WorkdistributeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Helper: Parse dims modifier with values
+//===----------------------------------------------------------------------===//
+// Parses: dims(N): values : type (single type for all values)
+static ParseResult parseDimsModifierWithValues(
+ OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types) {
+ if (failed(parser.parseOptionalKeyword("dims"))) {
+ return failure();
+ }
+
+ // Parse (N): values : type
+ int64_t dimsValue;
+ if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+ parser.parseRParen() || parser.parseColon()) {
+ return failure();
+ }
+
+ if (parser.parseOperandList(values) || parser.parseColon()) {
+ return failure();
+ }
+
+ // Parse single type (all values have same type)
+ Type valueType;
+ if (parser.parseType(valueType)) {
+ return failure();
+ }
+
+ // Fill types vector with same type for all values
+ types.assign(values.size(), valueType);
+
+ dimsAttr = parser.getBuilder().getI64IntegerAttr(dimsValue);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_teams clause with dims modifier
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumTeamsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &lowerBound,
+ Type &lowerBoundType,
+ std::optional<OpAsmParser::UnresolvedOperand> &upperBound,
+ Type &upperBoundType) {
+
+ // Format: num_teams(dims(N): values : type)
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ // Format: num_teams(to upper : type)
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ OpAsmParser::UnresolvedOperand upperOperand;
+ if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+ parser.parseType(upperBoundType)) {
+ return failure();
+ }
+ upperBound = upperOperand;
+ return success();
+ }
+
+ // Format: num_teams(lower : type to upper : type)
+ OpAsmParser::UnresolvedOperand lowerOperand;
+ if (parser.parseOperand(lowerOperand) || parser.parseColon() ||
+ parser.parseType(lowerBoundType)) {
+ return failure();
+ }
+
+ if (failed(parser.parseKeyword("to"))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 'to' keyword in num_teams clause";
+ }
+
+ OpAsmParser::UnresolvedOperand upperOperand;
+ if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+ parser.parseType(upperBoundType)) {
+ return failure();
+ }
+
+ lowerBound = lowerOperand;
+ upperBound = upperOperand;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Helper: Print dims modifier with values
+//===----------------------------------------------------------------------===//
+// Prints: dims(N): values : type (single type for all values)
+static void printDimsModifierWithValues(OpAsmPrinter &p, IntegerAttr dimsAttr,
+ OperandRange values, TypeRange types) {
+ if (dimsAttr) {
+ p << "dims(" << dimsAttr.getInt() << "): ";
+ }
+
+ p.printOperands(values);
+
+ // Print single type
+ p << " : ";
+ if (!types.empty()) {
+ p << types.front();
+ }
+}
+
+static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value lowerBound,
+ Type lowerBoundType, Value upperBound,
+ Type upperBoundType) {
+ if (!values.empty()) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ } else if (upperBound) {
+ if (lowerBound) {
+ // Both bounds: lower : type to upper : type
+ p.printOperand(lowerBound);
+ p << " : " << lowerBoundType << " to ";
+ p.printOperand(upperBound);
+ p << " : " << upperBoundType;
+ } else {
+ // Upper only: to upper : type
+ p << " to ";
+ p.printOperand(upperBound);
+ p << " : " << upperBoundType;
+ }
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..836cac9a53707 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1451,7 +1451,82 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_mismatch() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{num_teams dims(3) specified but 2 values provided}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {num_teams_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_with_bounds() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ %lb = arith.constant 3 : i32
+ %ub = arith.constant 4 : i32
+ // expected-error @below {{num_teams with dims modifier cannot be used together with lower/upper bounds (unidimensional style)}}
+ "omp.teams" (%v0, %v1, %lb, %ub) ({
+ omp.terminator
+ }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_values_without_dims() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{num_teams values can only be specified with dims modifier}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_no_values() {
+ omp.target {
+ // expected-error @below {{num_teams dims modifier requires values to be specified}}
+ "omp.teams" () ({
+ omp.terminator
+ }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_type_mismatch() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i64
+ // expected-error @below {{num_teams dims modifier requires all values to have the same type}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i64) -> ()
omp.terminator
}
return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index ac29e20907b55..3633a4be1eb62 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1108,6 +1108,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
+ // CHECK: omp.teams num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+ omp.teams num_teams(dims(3): %lb, %ub, %ub : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+
// Test if.
// CHECK: omp.teams if(%{{.+}})
omp.teams if(%if_cond) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/169883
More information about the Mlir-commits
mailing list