[Mlir-commits] [mlir] [MLIR][OpenMP]Add prescriptiveness-modifier support to grainsize and … (PR #128477)
Kaviya Rajendiran
llvmlistbot at llvm.org
Thu Feb 27 00:05:22 PST 2025
https://github.com/kaviya2510 updated https://github.com/llvm/llvm-project/pull/128477
>From 9bb39944890c3bc816945f23e6e61248c5854f02 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Sun, 23 Feb 2025 22:01:09 +0530
Subject: [PATCH 1/2] [MLIR][OpenMP]Add prescriptiveness-modifier support to
grainsize and num_tasks clause.
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 14 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 122 ++++++++++++++++--
mlir/test/Dialect/OpenMP/invalid.mlir | 24 ++++
mlir/test/Dialect/OpenMP/ops.mlir | 16 +++
4 files changed, 159 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index a8d97a36df79e..32c28f72ec8e5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
- let arguments = (ins
- Optional<IntLikeType>:$grainsize
- );
+ let arguments = (ins OptionalAttr<GrainsizeTypeAttr>:$grainsize_mod,
+ Optional<IntLikeType>:$grainsize);
let optAssemblyFormat = [{
- `grainsize` `(` $grainsize `:` type($grainsize) `)`
+ `grainsize` `(` custom<GrainsizeClause>($grainsize_mod , $grainsize, type($grainsize)) `)`
}];
let description = [{
@@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
- let arguments = (ins
- Optional<IntLikeType>:$num_tasks
- );
+ let arguments = (ins OptionalAttr<NumTasksTypeAttr>:$num_tasks_mod,
+ Optional<IntLikeType>:$num_tasks);
let optAssemblyFormat = [{
- `num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
+ `num_tasks` `(` custom<NumTasksClause>($num_tasks_mod , $num_tasks, type($num_tasks)) `)`
}];
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d725a457aeff6..f8b948ff98864 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -472,6 +472,108 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
p << stringifyClauseOrderKind(order.getValue());
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for grainsize Clause
+//===----------------------------------------------------------------------===//
+
+// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
+static ParseResult
+parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
+ Type &grainsizeType) {
+ SMLoc loc = parser.getCurrentLocation();
+ StringRef enumStr;
+
+ if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
+ if (std::optional<ClauseGrainsizeType> enumValue =
+ symbolizeClauseGrainsizeType(enumStr)) {
+ grainsizeMod =
+ ClauseGrainsizeTypeAttr::get(parser.getContext(), *enumValue);
+ if (parser.parseColon())
+ return failure();
+ } else {
+ return parser.emitError(loc, "invalid grainsize modifier : '")
+ << enumStr << "'";
+ }
+ }
+
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ grainsize = operand;
+ } else {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected grainsize operand";
+ }
+
+ if (grainsize.has_value()) {
+ if (parser.parseColonType(grainsizeType))
+ return failure();
+ }
+
+ return success();
+}
+
+static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
+ ClauseGrainsizeTypeAttr grainsizeMod,
+ Value grainsize, mlir::Type grainsizeType) {
+ if (grainsizeMod)
+ p << stringifyClauseGrainsizeType(grainsizeMod.getValue()) << ": ";
+
+ if (grainsize)
+ p << grainsize << ": " << grainsizeType;
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_tasks Clause
+//===----------------------------------------------------------------------===//
+
+// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
+static ParseResult
+parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
+ std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
+ Type &numTasksType) {
+ SMLoc loc = parser.getCurrentLocation();
+ StringRef enumStr;
+
+ if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
+ if (std::optional<ClauseNumTasksType> enumValue =
+ symbolizeClauseNumTasksType(enumStr)) {
+ numTasksMod =
+ ClauseNumTasksTypeAttr::get(parser.getContext(), *enumValue);
+ if (parser.parseColon())
+ return failure();
+ } else {
+ return parser.emitError(loc, "invalid numTasks modifier : '")
+ << enumStr << "'";
+ }
+ }
+
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ numTasks = operand;
+ } else {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected num_tasks operand";
+ }
+
+ if (numTasks.has_value()) {
+ if (parser.parseColonType(numTasksType))
+ return failure();
+ }
+
+ return success();
+}
+
+static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
+ ClauseNumTasksTypeAttr numTasksMod,
+ Value numTasks, mlir::Type numTasksType) {
+ if (numTasksMod)
+ p << stringifyClauseNumTasksType(numTasksMod.getValue()) << ": ";
+
+ if (numTasks)
+ p << numTasks << ": " << numTasksType;
+}
+
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -2593,15 +2695,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
const TaskloopOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
- TaskloopOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
- makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
- clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
- /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
+ TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.final, clauses.grainsizeMod, clauses.grainsize,
+ clauses.ifExpr, clauses.inReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+ makeArrayAttr(ctx, clauses.inReductionSyms),
+ clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
+ clauses.numTasks, clauses.priority, /*private_vars=*/{},
+ /*private_syms=*/nullptr, clauses.reductionMod,
+ clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
}
SmallVector<Value> TaskloopOp::getAllReductionVars() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d7f468bed3d3d..63ccd7957b492 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// -----
+func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
+ %testi64 = "test.i64"() : () -> (i64)
+ // expected-error @below {{invalid grainsize modifier : 'strict1'}}
+ omp.taskloop grainsize(strict1: %testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }
+ return
+}
+// -----
+
+func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
+ %testi64 = "test.i64"() : () -> (i64)
+ // expected-error @below {{invalid numTasks modifier : 'default'}}
+ omp.taskloop num_tasks(default: %testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }
+ return
+}
+// -----
+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
omp.taskloop {
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e318afbebbf0c..5d44dc1da503d 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
}
}
+ // CHECK: omp.taskloop grainsize(strict: %{{[^:]+}}: i64) {
+ omp.taskloop grainsize(strict: %testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+
+ // CHECK: omp.taskloop num_tasks(strict: %{{[^:]+}}: i64) {
+ omp.taskloop num_tasks(strict: %testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+
// CHECK: omp.taskloop nogroup {
omp.taskloop nogroup {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
>From 4cb1d01964d5019eadf63fe29a6637bea25b5c29 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Wed, 26 Feb 2025 12:45:38 +0530
Subject: [PATCH 2/2] [MLIR][OpenMP]Refactored parser and printer function of
grainsize and numtasks clause.
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 115 +++++++++----------
mlir/test/Dialect/OpenMP/invalid.mlir | 2 +-
2 files changed, 54 insertions(+), 63 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f8b948ff98864..bd82fe1f8ef39 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -472,55 +472,76 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
p << stringifyClauseOrderKind(order.getValue());
}
-//===----------------------------------------------------------------------===//
-// Parser and printer for grainsize Clause
-//===----------------------------------------------------------------------===//
-
-// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
+template <typename ClauseTypeAttr, typename ClauseType>
static ParseResult
-parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
- std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
- Type &grainsizeType) {
- SMLoc loc = parser.getCurrentLocation();
+parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
+ std::optional<OpAsmParser::UnresolvedOperand> &operand,
+ Type &operandType,
+ std::optional<ClauseType> (*symbolizeClause)(StringRef),
+ StringRef clauseName) {
StringRef enumStr;
-
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
- if (std::optional<ClauseGrainsizeType> enumValue =
- symbolizeClauseGrainsizeType(enumStr)) {
- grainsizeMod =
- ClauseGrainsizeTypeAttr::get(parser.getContext(), *enumValue);
+ if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
+ prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
if (parser.parseColon())
return failure();
} else {
- return parser.emitError(loc, "invalid grainsize modifier : '")
- << enumStr << "'";
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid " << clauseName << " modifier : '" << enumStr << "'";
+ ;
}
}
- OpAsmParser::UnresolvedOperand operand;
- if (succeeded(parser.parseOperand(operand))) {
- grainsize = operand;
+ OpAsmParser::UnresolvedOperand var;
+ if (succeeded(parser.parseOperand(var))) {
+ operand = var;
} else {
return parser.emitError(parser.getCurrentLocation())
- << "expected grainsize operand";
+ << "expected " << clauseName << " operand";
}
- if (grainsize.has_value()) {
- if (parser.parseColonType(grainsizeType))
+ if (operand.has_value()) {
+ if (parser.parseColonType(operandType))
return failure();
}
return success();
}
+template <typename ClauseTypeAttr, typename ClauseType>
+static void
+printGranularityClause(OpAsmPrinter &p, Operation *op,
+ ClauseTypeAttr prescriptiveness, Value operand,
+ mlir::Type operandType,
+ StringRef (*stringifyClauseType)(ClauseType)) {
+
+ if (prescriptiveness)
+ p << stringifyClauseType(prescriptiveness.getValue()) << ": ";
+
+ if (operand)
+ p << operand << ": " << operandType;
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for grainsize Clause
+//===----------------------------------------------------------------------===//
+
+// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
+static ParseResult
+parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
+ Type &grainsizeType) {
+ return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
+ parser, grainsizeMod, grainsize, grainsizeType,
+ &symbolizeClauseGrainsizeType, "grainsize");
+}
+
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
ClauseGrainsizeTypeAttr grainsizeMod,
Value grainsize, mlir::Type grainsizeType) {
- if (grainsizeMod)
- p << stringifyClauseGrainsizeType(grainsizeMod.getValue()) << ": ";
-
- if (grainsize)
- p << grainsize << ": " << grainsizeType;
+ printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
+ p, op, grainsizeMod, grainsize, grainsizeType,
+ &stringifyClauseGrainsizeType);
}
//===----------------------------------------------------------------------===//
@@ -532,46 +553,16 @@ static ParseResult
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
Type &numTasksType) {
- SMLoc loc = parser.getCurrentLocation();
- StringRef enumStr;
-
- if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
- if (std::optional<ClauseNumTasksType> enumValue =
- symbolizeClauseNumTasksType(enumStr)) {
- numTasksMod =
- ClauseNumTasksTypeAttr::get(parser.getContext(), *enumValue);
- if (parser.parseColon())
- return failure();
- } else {
- return parser.emitError(loc, "invalid numTasks modifier : '")
- << enumStr << "'";
- }
- }
-
- OpAsmParser::UnresolvedOperand operand;
- if (succeeded(parser.parseOperand(operand))) {
- numTasks = operand;
- } else {
- return parser.emitError(parser.getCurrentLocation())
- << "expected num_tasks operand";
- }
-
- if (numTasks.has_value()) {
- if (parser.parseColonType(numTasksType))
- return failure();
- }
-
- return success();
+ return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
+ parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
+ "num_tasks");
}
static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
ClauseNumTasksTypeAttr numTasksMod,
Value numTasks, mlir::Type numTasksType) {
- if (numTasksMod)
- p << stringifyClauseNumTasksType(numTasksMod.getValue()) << ": ";
-
- if (numTasks)
- p << numTasks << ": " << numTasksType;
+ printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
+ p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 63ccd7957b492..f57ade0262f49 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2078,7 +2078,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testi64 = "test.i64"() : () -> (i64)
- // expected-error @below {{invalid numTasks modifier : 'default'}}
+ // expected-error @below {{invalid num_tasks modifier : 'default'}}
omp.taskloop num_tasks(default: %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
omp.yield
More information about the Mlir-commits
mailing list