[Mlir-commits] [flang] [mlir] [OpenMP][mlir] Add DynGroupPrivateClause in omp dialect (PR #153562)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 14 05:46:32 PDT 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/153562
>From c8fc5c1fcd3d5bcb061c54af9690f4b74ec5f4ee Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 14 Aug 2025 09:46:31 +0530
Subject: [PATCH 1/3] [OpenMP][mlir] Add DynGroupPrivateClause in omp dialect
---
.../Optimizer/OpenMP/FunctionFiltering.cpp | 2 +
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 13 +-
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 43 ++++++
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 34 +++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 5 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 142 +++++++++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 36 ++++-
mlir/test/Dialect/OpenMP/invalid.mlir | 58 ++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 38 ++++-
mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++
10 files changed, 368 insertions(+), 14 deletions(-)
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 475ed35cac9fa..83fb61ee76ce1 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -333,6 +333,8 @@ class FunctionFilteringPass
collectRewrite(privateVar, rewriteValues);
for (Value threadLimit : targetOp.getThreadLimitVars())
collectRewrite(threadLimit, rewriteValues);
+ if (Value dynGpSize = targetOp.getDynGroupprivateSize())
+ collectRewrite(dynGpSize, rewriteValues);
}
// Move omp.map.info ops to the new block and collect dependencies.
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 8a08f67006c0a..ed929880656f1 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -767,7 +767,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr());
+ targetOp.getThreadLimitVars(), targetOp.getAccessGroupAttr(),
+ targetOp.getFallbackAttr(), targetOp.getDynGroupprivateSize(),
+ targetOp.getPrivateMapsAttr());
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
rewriter.replaceOp(targetOp, targetDataOp);
@@ -1488,7 +1490,8 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getAccessGroupAttr(), targetOp.getFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
IRMapping preMapping;
@@ -1579,7 +1582,8 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getAccessGroupAttr(), targetOp.getFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
rewriter.createBlock(&isolatedTargetOp.getRegion(),
isolatedTargetOp.getRegion().begin(), {}, {});
@@ -1660,7 +1664,8 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getAccessGroupAttr(), targetOp.getFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getPrivateMapsAttr());
// Create the block for postTargetOp
auto *postTargetBlock = rewriter.createBlock(
&postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index f24efd0d4fc42..3fd5060f38b4f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1755,4 +1755,47 @@ class OpenMP_UniformClauseSkip<
def OpenMP_UniformClause : OpenMP_UniformClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V6.1 `dyn_groupprivate` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_DynGroupprivateClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+
+ let arguments = (ins
+ OptionalAttr<AccessGroupModifierAttr>:$access_group,
+ OptionalAttr<FallbackModifierAttr>:$fallback,
+ Optional<AnyInteger>:$dyn_groupprivate_size
+ );
+
+ let description = [{
+ The `dyn_groupprivate` clause allows you to dynamically allocate group-private
+ memory in OpenMP parallel regions, specifically for `target` and `teams` directives.
+ This clause enables runtime-sized private memory allocation and applicable to
+ target and teams ops.
+
+ Syntax:
+ ```
+ dyn_groupprivate(cgroup, fallback(abort), %size)
+ ```
+
+ Example:
+ ```
+ omp.target dyn_groupprivate(cgroup, fallback(default_mem), %size : i32)
+ ```
+ }];
+
+ let optAssemblyFormat = [{
+ `dyn_groupprivate` `(`
+ custom<DynGroupprivateClause>($access_group, $fallback,
+ $dyn_groupprivate_size, type($dyn_groupprivate_size))
+ `)`
+ }];
+}
+
+def OpenMP_DynGroupprivateClause : OpenMP_DynGroupprivateClauseSkip<>;
+
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 06c5e0b89ff05..d2a4998bc19ac 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -347,4 +347,38 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
let assemblyFormat = "`(` $value `)`";
}
+//===----------------------------------------------------------------------===//
+// dyn_groupprivate enums.
+//===----------------------------------------------------------------------===//
+
+def AccessGroupCGroup : I32EnumAttrCase<"cgroup", 0>;
+
+def AccessGroupModifier : OpenMP_I32EnumAttr<
+ "AccessGroupModifier",
+ "access group modifier", [
+ AccessGroupCGroup
+ ]>;
+
+def AccessGroupModifierAttr : OpenMP_EnumAttr<AccessGroupModifier,
+ "access_group_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
+def FallbackAbort : I32EnumAttrCase<"abort", 0>;
+def FallbackNull : I32EnumAttrCase<"null", 1>;
+def FallbackDefaultMem : I32EnumAttrCase<"default_mem", 2>;
+
+def FallbackModifier : OpenMP_I32EnumAttr<
+ "FallbackModifier",
+ "fallback modifier", [
+ FallbackAbort,
+ FallbackNull,
+ FallbackDefaultMem
+ ]>;
+
+def FallbackModifierAttr : OpenMP_EnumAttr<FallbackModifier,
+ "fallback_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
#endif // OPENMP_ENUMS
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 76294cb86574f..d83e0a2d81441 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause,
+ OpenMP_DynGroupprivateClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
@@ -1547,7 +1548,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
- OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ThreadLimitClause, OpenMP_DynGroupprivateClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd2846f2cdb42..f3b0a18ab3a3f 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -873,6 +873,130 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
+//===----------------------------------------------------------------------===//
+// Parser, printer and verify for dyn_groupprivate Clause
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyDynGroupprivateClause(
+ Operation *op, AccessGroupModifierAttr modifierFirst,
+ FallbackModifierAttr modifierSecond, Value dynGroupprivateSize) {
+
+ // Verify the size
+ if (dynGroupprivateSize) {
+ Type size_type = dynGroupprivateSize.getType();
+ // Check if the size type is an integer type
+ if (!size_type.isIntOrIndex()) {
+ return op->emitOpError(
+ "dyn_groupprivate size must be an integer type, got ")
+ << size_type;
+ }
+ }
+
+ return success();
+}
+
+static ParseResult parseDynGroupprivateClause(
+ OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
+ FallbackModifierAttr &fallbackAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
+ Type &sizeType) {
+
+ bool parsedAccessGroup = false;
+ bool parsedFallback = false;
+
+ // Parse modifiers separated by commas
+ while (true) {
+ // parse AccessGroupModifier
+ if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
+ if (parsedAccessGroup)
+ return parser.emitError(parser.getCurrentLocation(),
+ "duplicate access group modifier");
+ accessGroupAttr = AccessGroupModifierAttr::get(
+ parser.getContext(), AccessGroupModifier::cgroup);
+ parsedAccessGroup = true;
+ }
+ // parse FallbackModifier
+ else if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+ if (parsedFallback)
+ return parser.emitError(parser.getCurrentLocation(),
+ "duplicate fallback modifier");
+ if (parser.parseLParen())
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected '(' after 'fallback'");
+ llvm::StringRef fbKind;
+ if (parser.parseKeyword(&fbKind))
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "expected fallback modifier (abort/null/default_mem)");
+ std::optional<FallbackModifier> fbEnum;
+ if (fbKind == "abort")
+ fbEnum = FallbackModifier::abort;
+ else if (fbKind == "null")
+ fbEnum = FallbackModifier::null;
+ else if (fbKind == "default_mem")
+ fbEnum = FallbackModifier::default_mem;
+ else
+ return parser.emitError(parser.getCurrentLocation(),
+ "invalid fallback modifier '" + fbKind + "'");
+ fallbackAttr = FallbackModifierAttr::get(parser.getContext(), *fbEnum);
+ if (parser.parseRParen())
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ')' after fallback modifier");
+ parsedFallback = true;
+ } else
+ break;
+
+ // Consume optional comma between modifiers
+ (void)parser.parseOptionalComma();
+ }
+
+ // Consume comma after modifiers, if both modifiers are present
+ (void)parser.parseOptionalComma();
+
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ dynGroupprivateSize = operand;
+ if (failed(parser.parseColon()) || failed(parser.parseType(sizeType))) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' and type after size operand");
+ }
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected dyn_groupprivate_size operand");
+ }
+
+ return success();
+}
+
+static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op,
+ AccessGroupModifierAttr modifierFirst,
+ FallbackModifierAttr modifierSecond,
+ Value dynGroupprivateSize,
+ Type sizeType) {
+
+ bool needsComma = false;
+
+ if (modifierFirst) {
+ printer << modifierFirst.getValue();
+ needsComma = true;
+ }
+
+ if (modifierSecond) {
+ if (needsComma)
+ printer << ", ";
+ printer << "fallback(";
+ printer << modifierSecond.getValue();
+ printer << ")";
+ needsComma = true;
+ }
+
+ if (dynGroupprivateSize) {
+ if (needsComma)
+ printer << ", ";
+ printer << dynGroupprivateSize << " : " << sizeType;
+ }
+}
+
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -2379,7 +2503,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
- clauses.threadLimitVars,
+ clauses.threadLimitVars, clauses.accessGroup, clauses.fallback,
+ clauses.dynGroupprivateSize,
/*private_maps=*/nullptr);
}
@@ -2396,6 +2521,12 @@ LogicalResult TargetOp::verify() {
if (failed(verifyMapClause(*this, getMapVars())))
return failure();
+ // check dyn_groupprivate clause restrictions
+ if (failed(verifyDynGroupprivateClause(*this, getAccessGroupAttr(),
+ getFallbackAttr(),
+ getDynGroupprivateSize())))
+ return failure();
+
return verifyPrivateVarsMapping(*this);
}
@@ -2803,7 +2934,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars,
+ clauses.accessGroup, clauses.fallback, clauses.dynGroupprivateSize);
}
// Verify num_teams clause
@@ -2846,6 +2978,12 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // check dyn_groupprivate clause restrictions
+ if (failed(verifyDynGroupprivateClause(op, getAccessGroupAttr(),
+ getFallbackAttr(),
+ getDynGroupprivateSize())))
+ return failure();
+
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2e15f4de4545d..f99b376e001e4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -387,6 +387,11 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("thread_limit with multi-dimensional values");
};
+ auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
+ if (op.getDynGroupprivateSize())
+ result = todo("dyn_groupprivate");
+ };
+
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
.Case([&](omp::DistributeOp op) {
@@ -407,6 +412,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkPrivate(op, result);
checkNumTeams(op, result);
checkThreadLimit(op, result);
+ checkDynGroupprivate(op, result);
})
.Case([&](omp::TaskOp op) {
checkAllocate(op, result);
@@ -6847,6 +6853,25 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
}
}
+static llvm::omp::OMPDynGroupprivateFallbackType
+getFallbackType(omp::TargetOp targetOp) {
+ if (!targetOp.getFallbackAttr())
+ return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
+
+ // Extract the FallbackModifier enum value.
+ mlir::omp::FallbackModifier fb = targetOp.getFallbackAttr().getValue();
+ switch (fb) {
+ case mlir::omp::FallbackModifier::abort:
+ return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
+ case mlir::omp::FallbackModifier::null:
+ return llvm::omp::OMPDynGroupprivateFallbackType::Null;
+ case mlir::omp::FallbackModifier::default_mem:
+ return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
+ }
+
+ llvm_unreachable("unexpected dyn_groupprivate fallback type");
+}
+
static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -7164,11 +7189,20 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (Value targetIfCond = targetOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(targetIfCond);
+ mlir::Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
+ llvm::Value *dynSizeVal = nullptr;
+ if (dynGroupPrivateSize)
+ dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
+
+ llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
+ getFallbackType(targetOp);
+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+ argAccessorCB, customMapperCB, dds, targetOp.getNowait(), dynSizeVal,
+ fallbackType);
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index c7f2416b4c293..4743a7fa6d9fe 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected exactly one num_teams upper bound when lower bound is specified}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
@@ -1468,7 +1468,7 @@ func.func @omp_teams_num_teams_multidim_with_bounds() {
// expected-error @below {{expected exactly one num_teams upper bound when lower bound is specified}}
"omp.teams" (%lb, %v0, %v1) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,2,0,0,0>} : (i32, i32, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,2,0,0,0,0>} : (i32, i32, i32) -> ()
omp.terminator
}
return
@@ -1489,6 +1489,58 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
// -----
+func.func @test_teams_dyn_groupprivate_errors_1(%dyn_size: i32) {
+ // expected-error @below {{duplicate access group modifier}}
+ omp.teams dyn_groupprivate(cgroup, cgroup, %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_2(%dyn_size: i32) {
+ // expected-error @below {{duplicate fallback modifier}}
+ omp.teams dyn_groupprivate(fallback(null), fallback(abort), %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_3(%dyn_size: i32) {
+ // expected-error @below {{invalid fallback modifier 'no'}}
+ omp.teams dyn_groupprivate(fallback(no), %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_4(%dyn_size: i32) {
+ // expected-error @below {{custom op 'omp.teams' expected dyn_groupprivate_size operand}}
+ // expected-error @below {{expected SSA operand}}
+ omp.teams dyn_groupprivate(fallback(null)) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_5() {
+ // expected-error @below {{expected dyn_groupprivate_size operand}}
+ // expected-error @below {{expected SSA operand}}
+ omp.teams dyn_groupprivate() {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.sections" (%data_var) ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 869f163cb4014..72290c541e3f9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -871,7 +871,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%device, %if_cond, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,1,0,0,0,0,1,0>} : ( si32, i1, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1124,7 +1124,8 @@ func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
// CHECK-LABEL: omp_teams
func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
- %data_var : memref<i32>, %ub64 : i64, %ub16 : i16) -> () {
+ %data_var : memref<i32>, %ub64 : i64, %ub16 : i16,
+ %dyn_size : i32) -> () {
// Test nesting inside of omp.target
omp.target {
// CHECK: omp.teams
@@ -1221,6 +1222,13 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
+ // Test dyn_groupprivate
+ // CHECK: omp.teams dyn_groupprivate(cgroup, fallback(null), %{{.+}} : i32)
+ omp.teams dyn_groupprivate(cgroup, fallback(null), %dyn_size : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+
return
}
@@ -2314,6 +2322,32 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}
+// CHECK-LABEL: @omp_target_dyn_groupprivate
+func.func @omp_target_dyn_groupprivate(%dyn_size: i32, %large_size: i64) {
+ // CHECK: omp.target dyn_groupprivate(%{{.*}} : i32)
+ omp.target dyn_groupprivate(%dyn_size : i32) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(cgroup, %{{.*}} : i64)
+ omp.target dyn_groupprivate(cgroup, %large_size : i64) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(cgroup, fallback(abort), %{{.*}} : i32)
+ omp.target dyn_groupprivate(cgroup, fallback(abort), %dyn_size : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(cgroup, fallback(null), %{{.*}} : i32)
+ omp.target dyn_groupprivate(fallback(null), cgroup, %dyn_size : i32) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(%{{.*}} : i64)
+ omp.target dyn_groupprivate(%large_size : i64) {
+ omp.terminator
+ }
+ return
+}
+
func.func @omp_threadprivate() {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index e0872226531e6..cfe013d98c415 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -479,6 +479,17 @@ llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @teams_dyn_groupprivate(%dyn_size : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause dyn_groupprivate in omp.teams operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+ omp.teams dyn_groupprivate(%dyn_size : i32) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
>From 087eb282b7eba31d464967eea57d65d014a0bcbf Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 8 Apr 2026 16:43:06 +0530
Subject: [PATCH 2/3] changes as per review comments
---
.../Optimizer/OpenMP/FunctionFiltering.cpp | 3 +-
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 33 ++++---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 26 ++----
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 6 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 11 ++-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 91 ++++++++-----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 12 +--
mlir/test/Dialect/OpenMP/invalid.mlir | 9 +-
mlir/test/Dialect/OpenMP/ops.mlir | 4 +-
9 files changed, 86 insertions(+), 109 deletions(-)
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 83fb61ee76ce1..b78b33f31b142 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -313,6 +313,7 @@ class FunctionFilteringPass
targetOp.setDependIteratedKindsAttr(nullptr);
targetOp.getDeviceMutable().clear();
targetOp.getIfExprMutable().clear();
+ targetOp.getDynGroupprivateSizeMutable().clear();
// TODO: Clear some of these operands rather than rewriting them,
// depending on whether they are needed by device codegen once support for
@@ -333,8 +334,6 @@ class FunctionFilteringPass
collectRewrite(privateVar, rewriteValues);
for (Value threadLimit : targetOp.getThreadLimitVars())
collectRewrite(threadLimit, rewriteValues);
- if (Value dynGpSize = targetOp.getDynGroupprivateSize())
- collectRewrite(dynGpSize, rewriteValues);
}
// Move omp.map.info ops to the new block and collect dependencies.
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index ed929880656f1..1d98018ffbd05 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -761,15 +761,15 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
- targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+ targetOp.getDynGroupprivateFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
targetOp.getHostEvalVars(), targetOp.getIfExpr(),
targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimitVars(), targetOp.getAccessGroupAttr(),
- targetOp.getFallbackAttr(), targetOp.getDynGroupprivateSize(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr());
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
rewriter.replaceOp(targetOp, targetDataOp);
@@ -1484,14 +1484,15 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
- targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
- targetOp.getIfExpr(), targetOp.getInReductionVars(),
+ targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+ targetOp.getDynGroupprivateFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
+ preHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
- targetOp.getAccessGroupAttr(), targetOp.getFallbackAttr(),
- targetOp.getDynGroupprivateSize(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
IRMapping preMapping;
@@ -1576,14 +1577,15 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
- targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+ targetOp.getDynGroupprivateFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
- targetOp.getAccessGroupAttr(), targetOp.getFallbackAttr(),
- targetOp.getDynGroupprivateSize(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
rewriter.createBlock(&isolatedTargetOp.getRegion(),
isolatedTargetOp.getRegion().begin(), {}, {});
@@ -1658,14 +1660,15 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
- targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
- targetOp.getIfExpr(), targetOp.getInReductionVars(),
+ targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+ targetOp.getDynGroupprivateFallbackAttr(),
+ targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
+ postHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
- targetOp.getAccessGroupAttr(), targetOp.getFallbackAttr(),
- targetOp.getDynGroupprivateSize(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
// Create the block for postTargetOp
auto *postTargetBlock = rewriter.createBlock(
&postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 3fd5060f38b4f..8c8bf7b6d36ae 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1766,31 +1766,23 @@ class OpenMP_DynGroupprivateClauseSkip<
extraClassDeclaration> {
let arguments = (ins
- OptionalAttr<AccessGroupModifierAttr>:$access_group,
- OptionalAttr<FallbackModifierAttr>:$fallback,
+ OptionalAttr<AccessGroupModifierAttr>:$dyn_groupprivate_access_group,
+ OptionalAttr<FallbackModifierAttr>:$dyn_groupprivate_fallback,
Optional<AnyInteger>:$dyn_groupprivate_size
);
let description = [{
- The `dyn_groupprivate` clause allows you to dynamically allocate group-private
- memory in OpenMP parallel regions, specifically for `target` and `teams` directives.
- This clause enables runtime-sized private memory allocation and applicable to
- target and teams ops.
-
- Syntax:
- ```
- dyn_groupprivate(cgroup, fallback(abort), %size)
- ```
-
- Example:
- ```
- omp.target dyn_groupprivate(cgroup, fallback(default_mem), %size : i32)
- ```
+ The `dyn_groupprivate_access_group` attribute specifies the access group
+ modifier for the dynamically allocated group-private memory. The
+ `dyn_groupprivate_fallback` attribute specifies the fallback behavior when
+ allocation fails. The `dyn_groupprivate_size` operand specifies the size in
+ bytes to allocate.
}];
let optAssemblyFormat = [{
`dyn_groupprivate` `(`
- custom<DynGroupprivateClause>($access_group, $fallback,
+ custom<DynGroupprivateClause>($dyn_groupprivate_access_group,
+ $dyn_groupprivate_fallback,
$dyn_groupprivate_size, type($dyn_groupprivate_size))
`)`
}];
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index d2a4998bc19ac..3de879dc6028e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -348,7 +348,7 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
}
//===----------------------------------------------------------------------===//
-// dyn_groupprivate enums.
+// access_group_modifier enum.
//===----------------------------------------------------------------------===//
def AccessGroupCGroup : I32EnumAttrCase<"cgroup", 0>;
@@ -364,6 +364,10 @@ def AccessGroupModifierAttr : OpenMP_EnumAttr<AccessGroupModifier,
let assemblyFormat = "`(` $value `)`";
}
+//===----------------------------------------------------------------------===//
+// fallback_modifier enum.
+//===----------------------------------------------------------------------===//
+
def FallbackAbort : I32EnumAttrCase<"abort", 0>;
def FallbackNull : I32EnumAttrCase<"null", 1>;
def FallbackDefaultMem : I32EnumAttrCase<"default_mem", 2>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d83e0a2d81441..74a72fa2b68a7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -240,9 +240,9 @@ def TerminatorOp : OpenMP_Op<"terminator", [Terminator, Pure]> {
def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
- OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause,
- OpenMP_DynGroupprivateClause
+ OpenMP_AllocateClause, OpenMP_DynGroupprivateClause, OpenMP_IfClause,
+ OpenMP_NumTeamsClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+ OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
@@ -1545,10 +1545,11 @@ def TargetOp : OpenMP_Op<"target", traits = [
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
- OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
+ OpenMP_DeviceClause, OpenMP_DynGroupprivateClause,
+ OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
- OpenMP_PrivateClause, OpenMP_ThreadLimitClause, OpenMP_DynGroupprivateClause
+ OpenMP_PrivateClause, OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f3b0a18ab3a3f..74e2522a9e082 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -877,20 +877,12 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verify for dyn_groupprivate Clause
//===----------------------------------------------------------------------===//
-static LogicalResult verifyDynGroupprivateClause(
- Operation *op, AccessGroupModifierAttr modifierFirst,
- FallbackModifierAttr modifierSecond, Value dynGroupprivateSize) {
-
- // Verify the size
- if (dynGroupprivateSize) {
- Type size_type = dynGroupprivateSize.getType();
- // Check if the size type is an integer type
- if (!size_type.isIntOrIndex()) {
- return op->emitOpError(
- "dyn_groupprivate size must be an integer type, got ")
- << size_type;
- }
- }
+static LogicalResult
+verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup,
+ FallbackModifierAttr fallback,
+ Value dynGroupprivateSize) {
+ if (!dynGroupprivateSize && (accessGroup || fallback))
+ return op->emitOpError("dyn_groupprivate modifiers require a size operand");
return success();
}
@@ -904,9 +896,8 @@ static ParseResult parseDynGroupprivateClause(
bool parsedAccessGroup = false;
bool parsedFallback = false;
- // Parse modifiers separated by commas
- while (true) {
- // parse AccessGroupModifier
+ return parser.parseCommaSeparatedList([&]() -> ParseResult {
+ // Parse AccessGroupModifier.
if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
if (parsedAccessGroup)
return parser.emitError(parser.getCurrentLocation(),
@@ -914,9 +905,10 @@ static ParseResult parseDynGroupprivateClause(
accessGroupAttr = AccessGroupModifierAttr::get(
parser.getContext(), AccessGroupModifier::cgroup);
parsedAccessGroup = true;
+ return success();
}
- // parse FallbackModifier
- else if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+ // Parse FallbackModifier.
+ if (succeeded(parser.parseOptionalKeyword("fallback"))) {
if (parsedFallback)
return parser.emitError(parser.getCurrentLocation(),
"duplicate fallback modifier");
@@ -943,29 +935,20 @@ static ParseResult parseDynGroupprivateClause(
return parser.emitError(parser.getCurrentLocation(),
"expected ')' after fallback modifier");
parsedFallback = true;
- } else
- break;
-
- // Consume optional comma between modifiers
- (void)parser.parseOptionalComma();
- }
-
- // Consume comma after modifiers, if both modifiers are present
- (void)parser.parseOptionalComma();
-
- OpAsmParser::UnresolvedOperand operand;
- if (succeeded(parser.parseOperand(operand))) {
- dynGroupprivateSize = operand;
- if (failed(parser.parseColon()) || failed(parser.parseType(sizeType))) {
- return parser.emitError(parser.getCurrentLocation(),
- "expected ':' and type after size operand");
+ return success();
+ }
+ // Parse size operand.
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ dynGroupprivateSize = operand;
+ if (failed(parser.parseColon()) || failed(parser.parseType(sizeType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' and type after size operand");
+ return success();
}
- } else {
return parser.emitError(parser.getCurrentLocation(),
"expected dyn_groupprivate_size operand");
- }
-
- return success();
+ });
}
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op,
@@ -2497,14 +2480,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
- clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
- clauses.ifExpr,
+ clauses.device, clauses.dynGroupprivateAccessGroup,
+ clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
+ clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
- clauses.threadLimitVars, clauses.accessGroup, clauses.fallback,
- clauses.dynGroupprivateSize,
+ clauses.threadLimitVars,
/*private_maps=*/nullptr);
}
@@ -2521,10 +2504,9 @@ LogicalResult TargetOp::verify() {
if (failed(verifyMapClause(*this, getMapVars())))
return failure();
- // check dyn_groupprivate clause restrictions
- if (failed(verifyDynGroupprivateClause(*this, getAccessGroupAttr(),
- getFallbackAttr(),
- getDynGroupprivateSize())))
+ if (failed(verifyDynGroupprivateClause(
+ *this, getDynGroupprivateAccessGroupAttr(),
+ getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
return failure();
return verifyPrivateVarsMapping(*this);
@@ -2929,13 +2911,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
- /*private_vars=*/{}, /*private_syms=*/nullptr,
+ clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
+ clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
+ clauses.numTeamsUpperVars, /*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars,
- clauses.accessGroup, clauses.fallback, clauses.dynGroupprivateSize);
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
}
// Verify num_teams clause
@@ -2978,10 +2960,9 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- // check dyn_groupprivate clause restrictions
- if (failed(verifyDynGroupprivateClause(op, getAccessGroupAttr(),
- getFallbackAttr(),
- getDynGroupprivateSize())))
+ if (failed(verifyDynGroupprivateClause(
+ op, getDynGroupprivateAccessGroupAttr(),
+ getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
return failure();
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f99b376e001e4..3b4a490be5519 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6854,12 +6854,10 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
}
static llvm::omp::OMPDynGroupprivateFallbackType
-getFallbackType(omp::TargetOp targetOp) {
- if (!targetOp.getFallbackAttr())
- return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
-
- // Extract the FallbackModifier enum value.
- mlir::omp::FallbackModifier fb = targetOp.getFallbackAttr().getValue();
+getDynGroupprivateFallbackType(omp::TargetOp targetOp) {
+ mlir::omp::FallbackModifier fb =
+ targetOp.getDynGroupprivateFallback().value_or(
+ mlir::omp::FallbackModifier::default_mem);
switch (fb) {
case mlir::omp::FallbackModifier::abort:
return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
@@ -7195,7 +7193,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
- getFallbackType(targetOp);
+ getDynGroupprivateFallbackType(targetOp);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 4743a7fa6d9fe..6feedbc451ffa 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected exactly one num_teams upper bound when lower bound is specified}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
@@ -1468,7 +1468,7 @@ func.func @omp_teams_num_teams_multidim_with_bounds() {
// expected-error @below {{expected exactly one num_teams upper bound when lower bound is specified}}
"omp.teams" (%lb, %v0, %v1) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,2,0,0,0,0>} : (i32, i32, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,2,0,0,0>} : (i32, i32, i32) -> ()
omp.terminator
}
return
@@ -1520,8 +1520,7 @@ func.func @test_teams_dyn_groupprivate_errors_3(%dyn_size: i32) {
// -----
func.func @test_teams_dyn_groupprivate_errors_4(%dyn_size: i32) {
- // expected-error @below {{custom op 'omp.teams' expected dyn_groupprivate_size operand}}
- // expected-error @below {{expected SSA operand}}
+ // expected-error @below {{'omp.teams' op dyn_groupprivate modifiers require a size operand}}
omp.teams dyn_groupprivate(fallback(null)) {
omp.terminator
}
@@ -2586,7 +2585,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 72290c541e3f9..4ba739ce9459b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -871,7 +871,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%device, %if_cond, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,1,0,0,0,0,1,0>} : ( si32, i1, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2318,7 +2318,7 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
// CHECK: omp.terminator
omp.terminator
- } {operandSegmentSizes = array<i32: 0,0,3,0,0,0,0,0,0,0,0,0,0>}
+ } {operandSegmentSizes = array<i32: 0,0,3,0,0,0,0,0,0,0,0,0,0,0>}
return
}
>From 578fd41398507f182acf32df051a1229a3826b98 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 14 Apr 2026 18:13:47 +0530
Subject: [PATCH 3/3] make getDynGroupprivateFallbackType accept fallbackAttr
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 +++++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 17 ++++++++---------
mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++++++++++
3 files changed, 23 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 74e2522a9e082..dbc6fb0ccc223 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -895,6 +895,7 @@ static ParseResult parseDynGroupprivateClause(
bool parsedAccessGroup = false;
bool parsedFallback = false;
+ bool parsedSize = false;
return parser.parseCommaSeparatedList([&]() -> ParseResult {
// Parse AccessGroupModifier.
@@ -940,7 +941,11 @@ static ParseResult parseDynGroupprivateClause(
// Parse size operand.
OpAsmParser::UnresolvedOperand operand;
if (succeeded(parser.parseOperand(operand))) {
+ if (parsedSize)
+ return parser.emitError(parser.getCurrentLocation(),
+ "duplicate size operand");
dynGroupprivateSize = operand;
+ parsedSize = true;
if (failed(parser.parseColon()) || failed(parser.parseType(sizeType)))
return parser.emitError(parser.getCurrentLocation(),
"expected ':' and type after size operand");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3b4a490be5519..7462ba3cbf942 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6854,16 +6854,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
}
static llvm::omp::OMPDynGroupprivateFallbackType
-getDynGroupprivateFallbackType(omp::TargetOp targetOp) {
- mlir::omp::FallbackModifier fb =
- targetOp.getDynGroupprivateFallback().value_or(
- mlir::omp::FallbackModifier::default_mem);
+getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr) {
+ omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
+ : omp::FallbackModifier::default_mem;
switch (fb) {
- case mlir::omp::FallbackModifier::abort:
+ case omp::FallbackModifier::abort:
return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
- case mlir::omp::FallbackModifier::null:
+ case omp::FallbackModifier::null:
return llvm::omp::OMPDynGroupprivateFallbackType::Null;
- case mlir::omp::FallbackModifier::default_mem:
+ case omp::FallbackModifier::default_mem:
return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
}
@@ -7187,13 +7186,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (Value targetIfCond = targetOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(targetIfCond);
- mlir::Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
+ Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
llvm::Value *dynSizeVal = nullptr;
if (dynGroupPrivateSize)
dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
- getDynGroupprivateFallbackType(targetOp);
+ getDynGroupprivateFallbackType(targetOp.getDynGroupprivateFallbackAttr());
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 6feedbc451ffa..0b4e2fd96376e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1540,6 +1540,16 @@ func.func @test_teams_dyn_groupprivate_errors_5() {
// -----
+func.func @test_teams_dyn_groupprivate_errors_6(%s1: i32, %s2: i32) {
+ // expected-error @below {{duplicate size operand}}
+ omp.teams dyn_groupprivate(%s1 : i32, %s2 : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.sections" (%data_var) ({
More information about the Mlir-commits
mailing list